conspire/math/tensor/
test.rs

1use super::{TensorError, TensorRank0};
2use crate::{ABS_TOL, REL_TOL, defeat_message};
3use std::{
4    cmp::PartialEq,
5    fmt::{self, Debug, Display, Formatter},
6};
7
8#[cfg(test)]
9use crate::EPSILON;
10
11#[cfg(test)]
12use super::{
13    TensorArray,
14    rank_1::{TensorRank1, list::TensorRank1List},
15};
16
17pub trait ErrorTensor {
18    fn error(
19        &self,
20        comparator: &Self,
21        tol_abs: &TensorRank0,
22        tol_rel: &TensorRank0,
23    ) -> Option<usize>;
24    fn error_fd(&self, comparator: &Self, epsilon: &TensorRank0) -> Option<(bool, usize)>;
25}
26
27pub fn assert_eq<'a, T: Display + PartialEq + ErrorTensor>(
28    value_1: &'a T,
29    value_2: &'a T,
30) -> Result<(), TestError> {
31    if value_1 == value_2 {
32        Ok(())
33    } else {
34        Err(TestError {
35            message: format!(
36                "\n\x1b[1;91mAssertion `left == right` failed.\n\x1b[0;91m  left: {}\n right: {}\x1b[0m",
37                value_1, value_2
38            ),
39        })
40    }
41}
42
43#[cfg(test)]
44pub fn assert_eq_from_fd<'a, T: Display + ErrorTensor>(
45    value: &'a T,
46    value_fd: &'a T,
47) -> Result<(), TestError> {
48    if let Some((failed, error_count)) = value.error_fd(value_fd, &EPSILON) {
49        if failed {
50            Err(TestError {
51                message: format!(
52                    "\n\x1b[1;91mAssertion `left ≈= right` failed in {} places.\n\x1b[0;91m  left: {}\n right: {}\x1b[0m",
53                    error_count, value, value_fd
54                ),
55            })
56        } else {
57            println!(
58                "Warning: \n\x1b[1;93mAssertion `left ≈= right` was weak in {} places.\x1b[0m",
59                error_count
60            );
61            Ok(())
62        }
63    } else {
64        Ok(())
65    }
66}
67
68pub fn assert_eq_within_tols<'a, T: Display + ErrorTensor>(
69    value_1: &'a T,
70    value_2: &'a T,
71) -> Result<(), TestError> {
72    if let Some(error_count) = value_1.error(value_2, &ABS_TOL, &REL_TOL) {
73        Err(TestError {
74            message: format!(
75                "\n\x1b[1;91mAssertion `left ≈= right` failed in {} places.\n\x1b[0;91m  left: {}\n right: {}\x1b[0m",
76                error_count, value_1, value_2
77            ),
78        })
79    } else {
80        Ok(())
81    }
82}
83
84pub struct TestError {
85    pub message: String,
86}
87
88impl Debug for TestError {
89    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
90        write!(
91            f,
92            "{}\n\x1b[0;2;31m{}\x1b[0m\n",
93            self.message,
94            defeat_message()
95        )
96    }
97}
98
99impl From<TensorError> for TestError {
100    fn from(error: TensorError) -> TestError {
101        Self {
102            message: error.to_string(),
103        }
104    }
105}
106
107#[test]
108#[should_panic(expected = "Assertion `left == right` failed.")]
109fn assert_eq_fail() {
110    assert_eq(&0.0, &1.0).unwrap()
111}
112
113#[test]
114#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
115fn assert_eq_from_fd_fail() {
116    assert_eq_from_fd(
117        &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
118        &TensorRank1::<3, 1>::new([3.0, 2.0, 1.0]),
119    )
120    .unwrap()
121}
122
123#[test]
124fn assert_eq_from_fd_success() -> Result<(), TestError> {
125    assert_eq_from_fd(
126        &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
127        &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
128    )
129}
130
131#[test]
132fn assert_eq_from_fd_weak() -> Result<(), TestError> {
133    assert_eq_from_fd(
134        &TensorRank1List::<1, 1, 1>::new([[EPSILON * 1.01]]),
135        &TensorRank1List::<1, 1, 1>::new([[EPSILON * 1.02]]),
136    )
137}
138
139#[test]
140#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
141fn assert_eq_within_tols_fail() {
142    assert_eq_within_tols(
143        &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
144        &TensorRank1::<3, 1>::new([3.0, 2.0, 1.0]),
145    )
146    .unwrap()
147}