Skip to main content

conspire/math/tensor/
test.rs

1use super::{Scalar, Tensor, TensorError};
2use crate::{ABS_TOL, REL_TOL, defeat_message};
3use std::{
4    cmp::PartialEq,
5    fmt::{self, Debug, Display, Formatter},
6};
7
8#[cfg(test)]
9use crate::EPSILON;
10
11#[cfg(test)]
12use super::rank_1::{TensorRank1, list::TensorRank1List};
13
14#[cfg(test)]
15pub trait ErrorTensor {
16    fn error_fd(&self, comparator: &Self, epsilon: Scalar) -> Option<(bool, usize)>;
17}
18
19pub fn assert_eq<'a, T>(value_1: &'a T, value_2: &'a T) -> Result<(), TestError>
20where
21    T: Display + PartialEq,
22{
23    if value_1 == value_2 {
24        Ok(())
25    } else {
26        Err(TestError {
27            message: format!(
28                "\n\x1b[1;91mAssertion `left == right` failed.\n\x1b[0;91m  left: {value_1}\n right: {value_2}\x1b[0m"
29            ),
30        })
31    }
32}
33
34#[cfg(test)]
35pub fn assert_eq_from_fd<'a, T>(value: &'a T, value_fd: &'a T) -> Result<(), TestError>
36where
37    T: Display + ErrorTensor + Tensor,
38{
39    assert_eq_from_fd_within(value, value_fd, 3.0 * EPSILON)
40}
41
42#[cfg(test)]
43pub fn assert_eq_from_fd_within<'a, T>(
44    value: &'a T,
45    value_fd: &'a T,
46    tol: Scalar,
47) -> Result<(), TestError>
48where
49    T: Display + ErrorTensor + Tensor,
50{
51    if let Some((failed, count)) = value.error_fd(value_fd, tol) {
52        if failed {
53            let abs = value.sub_abs(value_fd);
54            let rel = value.sub_rel(value_fd);
55            Err(TestError {
56                message: format!(
57                    "\n\x1b[1;91mAssertion `left ≈= right` failed in {count} places.\n\x1b[0;91m  left: {value}\n right: {value_fd}\n   abs: {abs}\n   rel: {rel}\x1b[0m"
58                ),
59            })
60        } else {
61            println!(
62                "Warning: \n\x1b[1;93mAssertion `left ≈= right` was weak in {count} places.\x1b[0m"
63            );
64            Ok(())
65        }
66    } else {
67        Ok(())
68    }
69}
70
71pub fn assert_eq_within<'a, T>(
72    value_1: &'a T,
73    value_2: &'a T,
74    tol_abs: Scalar,
75    tol_rel: Scalar,
76) -> Result<(), TestError>
77where
78    T: Display + Tensor,
79{
80    if let Some(count) = value_1.error_count(value_2, tol_abs, tol_rel) {
81        let abs = value_1.sub_abs(value_2);
82        let rel = value_1.sub_rel(value_2);
83        Err(TestError {
84            message: format!(
85                "\n\x1b[1;91mAssertion `left ≈= right` failed in {count} places.\n\x1b[0;91m  left: {value_1}\n right: {value_2}\n   abs: {abs}\n   rel: {rel}\x1b[0m"
86            ),
87        })
88    } else {
89        Ok(())
90    }
91}
92
93pub fn assert_eq_within_tols<'a, T>(value_1: &'a T, value_2: &'a T) -> Result<(), TestError>
94where
95    T: Display + Tensor,
96{
97    assert_eq_within(value_1, value_2, ABS_TOL, REL_TOL)
98}
99
100pub struct TestError {
101    pub message: String,
102}
103
104impl Debug for TestError {
105    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
106        write!(
107            f,
108            "{}\n\x1b[0;2;31m{}\x1b[0m\n",
109            self.message,
110            defeat_message()
111        )
112    }
113}
114
115impl From<String> for TestError {
116    fn from(error: String) -> Self {
117        Self { message: error }
118    }
119}
120
121impl From<&str> for TestError {
122    fn from(error: &str) -> Self {
123        Self {
124            message: error.to_string(),
125        }
126    }
127}
128
129impl From<TensorError> for TestError {
130    fn from(error: TensorError) -> Self {
131        Self {
132            message: error.to_string(),
133        }
134    }
135}
136
137#[test]
138fn test_error_from_string() {
139    assert_eq!(
140        TestError::from("An error occurred".to_string()).message,
141        "An error occurred"
142    );
143}
144
145#[test]
146fn test_error_from_str() {
147    assert_eq!(
148        TestError::from("An error occurred").message,
149        "An error occurred"
150    );
151}
152
153#[test]
154fn test_error_from_tensor_error() {
155    let tensor_error = TensorError::NotPositiveDefinite;
156    let _ = format!("{:?}", tensor_error);
157    let _ = TestError::from(tensor_error);
158}
159
160#[test]
161#[should_panic(expected = "Assertion `left == right` failed.")]
162fn assert_eq_fail() {
163    assert_eq(&0.0, &1.0).unwrap()
164}
165
166#[test]
167#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
168fn assert_eq_from_fd_fail() {
169    assert_eq_from_fd(
170        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
171        &TensorRank1::<_, 1>::from([3.0, 2.0, 1.0]),
172    )
173    .unwrap()
174}
175
176#[test]
177fn assert_eq_from_fd_success() -> Result<(), TestError> {
178    assert_eq_from_fd(
179        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
180        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
181    )
182}
183
184#[test]
185fn assert_eq_from_fd_weak() -> Result<(), TestError> {
186    assert_eq_from_fd(
187        &TensorRank1List::<_, 1, 1>::from([[EPSILON * 1.01]]),
188        &TensorRank1List::<_, 1, 1>::from([[EPSILON * 1.02]]),
189    )
190}
191
192#[test]
193#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
194fn assert_eq_within_tols_fail() {
195    assert_eq_within_tols(
196        &TensorRank1::<_, 1>::from([1.0, 2.0, 3.0]),
197        &TensorRank1::<_, 1>::from([3.0, 2.0, 1.0]),
198    )
199    .unwrap()
200}