conspire/math/tensor/
test.rs

1use super::{Scalar, Tensor, TensorError};
2use crate::{ABS_TOL, REL_TOL, defeat_message};
3use std::{
4    cmp::PartialEq,
5    fmt::{self, Debug, Display, Formatter},
6};
7
8#[cfg(test)]
9use crate::EPSILON;
10
11#[cfg(test)]
12use super::rank_1::{TensorRank1, list::TensorRank1List};
13
14#[cfg(test)]
15pub trait ErrorTensor {
16    fn error_fd(&self, comparator: &Self, epsilon: Scalar) -> Option<(bool, usize)>;
17}
18
19pub fn assert_eq<'a, T>(value_1: &'a T, value_2: &'a T) -> Result<(), TestError>
20where
21    T: Display + PartialEq,
22{
23    if value_1 == value_2 {
24        Ok(())
25    } else {
26        Err(TestError {
27            message: format!(
28                "\n\x1b[1;91mAssertion `left == right` failed.\n\x1b[0;91m  left: {value_1}\n right: {value_2}\x1b[0m"
29            ),
30        })
31    }
32}
33
34#[cfg(test)]
35pub fn assert_eq_from_fd<'a, T>(value: &'a T, value_fd: &'a T) -> Result<(), TestError>
36where
37    T: Display + ErrorTensor + Tensor,
38{
39    if let Some((failed, count)) = value.error_fd(value_fd, 3.0 * EPSILON) {
40        if failed {
41            let abs = value.sub_abs(value_fd);
42            let rel = value.sub_rel(value_fd);
43            Err(TestError {
44                message: format!(
45                    "\n\x1b[1;91mAssertion `left ≈= right` failed in {count} places.\n\x1b[0;91m  left: {value}\n right: {value_fd}\n   abs: {abs}\n   rel: {rel}\x1b[0m"
46                ),
47            })
48        } else {
49            println!(
50                "Warning: \n\x1b[1;93mAssertion `left ≈= right` was weak in {count} places.\x1b[0m"
51            );
52            Ok(())
53        }
54    } else {
55        Ok(())
56    }
57}
58
59pub fn assert_eq_within<'a, T>(
60    value_1: &'a T,
61    value_2: &'a T,
62    tol_abs: Scalar,
63    tol_rel: Scalar,
64) -> Result<(), TestError>
65where
66    T: Display + Tensor,
67{
68    if let Some(count) = value_1.error_count(value_2, tol_abs, tol_rel) {
69        let abs = value_1.sub_abs(value_2);
70        let rel = value_1.sub_rel(value_2);
71        Err(TestError {
72            message: format!(
73                "\n\x1b[1;91mAssertion `left ≈= right` failed in {count} places.\n\x1b[0;91m  left: {value_1}\n right: {value_2}\n   abs: {abs}\n   rel: {rel}\x1b[0m"
74            ),
75        })
76    } else {
77        Ok(())
78    }
79}
80
81pub fn assert_eq_within_tols<'a, T>(value_1: &'a T, value_2: &'a T) -> Result<(), TestError>
82where
83    T: Display + Tensor,
84{
85    assert_eq_within(value_1, value_2, ABS_TOL, REL_TOL)
86}
87
88pub struct TestError {
89    pub message: String,
90}
91
92impl Debug for TestError {
93    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
94        write!(
95            f,
96            "{}\n\x1b[0;2;31m{}\x1b[0m\n",
97            self.message,
98            defeat_message()
99        )
100    }
101}
102
103impl From<String> for TestError {
104    fn from(error: String) -> Self {
105        Self { message: error }
106    }
107}
108
109impl From<&str> for TestError {
110    fn from(error: &str) -> Self {
111        Self {
112            message: error.to_string(),
113        }
114    }
115}
116
117impl From<TensorError> for TestError {
118    fn from(error: TensorError) -> Self {
119        Self {
120            message: error.to_string(),
121        }
122    }
123}
124
125#[test]
126fn test_error_from_string() {
127    assert_eq!(
128        TestError::from("An error occurred".to_string()).message,
129        "An error occurred"
130    );
131}
132
133#[test]
134fn test_error_from_str() {
135    assert_eq!(
136        TestError::from("An error occurred").message,
137        "An error occurred"
138    );
139}
140
141#[test]
142fn test_error_from_tensor_error() {
143    let tensor_error = TensorError::NotPositiveDefinite;
144    let _ = format!("{:?}", tensor_error);
145    let _ = TestError::from(tensor_error);
146}
147
148#[test]
149#[should_panic(expected = "Assertion `left == right` failed.")]
150fn assert_eq_fail() {
151    assert_eq(&0.0, &1.0).unwrap()
152}
153
154#[test]
155#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
156fn assert_eq_from_fd_fail() {
157    assert_eq_from_fd(
158        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
159        &TensorRank1::<_, 1>::from([3.0, 2.0, 1.0]),
160    )
161    .unwrap()
162}
163
164#[test]
165fn assert_eq_from_fd_success() -> Result<(), TestError> {
166    assert_eq_from_fd(
167        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
168        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
169    )
170}
171
172#[test]
173fn assert_eq_from_fd_weak() -> Result<(), TestError> {
174    assert_eq_from_fd(
175        &TensorRank1List::<_, 1, 1>::from([[EPSILON * 1.01]]),
176        &TensorRank1List::<_, 1, 1>::from([[EPSILON * 1.02]]),
177    )
178}
179
180#[test]
181#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
182fn assert_eq_within_tols_fail() {
183    assert_eq_within_tols(
184        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
185        &TensorRank1::<_, 1>::from([3.0, 2.0, 1.0]),
186    )
187    .unwrap()
188}