conspire/math/tensor/
test.rs1use super::{TensorError, TensorRank0};
2use crate::{ABS_TOL, REL_TOL, defeat_message};
3use std::{
4 cmp::PartialEq,
5 fmt::{self, Debug, Display, Formatter},
6};
7
8#[cfg(test)]
9use crate::EPSILON;
10
11#[cfg(test)]
12use super::{
13 TensorArray,
14 rank_1::{TensorRank1, list::TensorRank1List},
15};
16
17pub trait ErrorTensor {
18 fn error(
19 &self,
20 comparator: &Self,
21 tol_abs: &TensorRank0,
22 tol_rel: &TensorRank0,
23 ) -> Option<usize>;
24 fn error_fd(&self, comparator: &Self, epsilon: &TensorRank0) -> Option<(bool, usize)>;
25}
26
27pub fn assert_eq<'a, T: Display + PartialEq + ErrorTensor>(
28 value_1: &'a T,
29 value_2: &'a T,
30) -> Result<(), TestError> {
31 if value_1 == value_2 {
32 Ok(())
33 } else {
34 Err(TestError {
35 message: format!(
36 "\n\x1b[1;91mAssertion `left == right` failed.\n\x1b[0;91m left: {}\n right: {}\x1b[0m",
37 value_1, value_2
38 ),
39 })
40 }
41}
42
43#[cfg(test)]
44pub fn assert_eq_from_fd<'a, T: Display + ErrorTensor>(
45 value: &'a T,
46 value_fd: &'a T,
47) -> Result<(), TestError> {
48 if let Some((failed, error_count)) = value.error_fd(value_fd, &EPSILON) {
49 if failed {
50 Err(TestError {
51 message: format!(
52 "\n\x1b[1;91mAssertion `left ≈= right` failed in {} places.\n\x1b[0;91m left: {}\n right: {}\x1b[0m",
53 error_count, value, value_fd
54 ),
55 })
56 } else {
57 println!(
58 "Warning: \n\x1b[1;93mAssertion `left ≈= right` was weak in {} places.\x1b[0m",
59 error_count
60 );
61 Ok(())
62 }
63 } else {
64 Ok(())
65 }
66}
67
68pub fn assert_eq_within_tols<'a, T: Display + ErrorTensor>(
69 value_1: &'a T,
70 value_2: &'a T,
71) -> Result<(), TestError> {
72 if let Some(error_count) = value_1.error(value_2, &ABS_TOL, &REL_TOL) {
73 Err(TestError {
74 message: format!(
75 "\n\x1b[1;91mAssertion `left ≈= right` failed in {} places.\n\x1b[0;91m left: {}\n right: {}\x1b[0m",
76 error_count, value_1, value_2
77 ),
78 })
79 } else {
80 Ok(())
81 }
82}
83
84pub struct TestError {
85 pub message: String,
86}
87
88impl Debug for TestError {
89 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
90 write!(
91 f,
92 "{}\n\x1b[0;2;31m{}\x1b[0m\n",
93 self.message,
94 defeat_message()
95 )
96 }
97}
98
99impl From<TensorError> for TestError {
100 fn from(error: TensorError) -> TestError {
101 Self {
102 message: error.to_string(),
103 }
104 }
105}
106
107#[test]
108#[should_panic(expected = "Assertion `left == right` failed.")]
109fn assert_eq_fail() {
110 assert_eq(&0.0, &1.0).unwrap()
111}
112
113#[test]
114#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
115fn assert_eq_from_fd_fail() {
116 assert_eq_from_fd(
117 &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
118 &TensorRank1::<3, 1>::new([3.0, 2.0, 1.0]),
119 )
120 .unwrap()
121}
122
123#[test]
124fn assert_eq_from_fd_success() -> Result<(), TestError> {
125 assert_eq_from_fd(
126 &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
127 &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
128 )
129}
130
131#[test]
132fn assert_eq_from_fd_weak() -> Result<(), TestError> {
133 assert_eq_from_fd(
134 &TensorRank1List::<1, 1, 1>::new([[EPSILON * 1.01]]),
135 &TensorRank1List::<1, 1, 1>::new([[EPSILON * 1.02]]),
136 )
137}
138
139#[test]
140#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
141fn assert_eq_within_tols_fail() {
142 assert_eq_within_tols(
143 &TensorRank1::<3, 1>::new([1.0, 2.0, 3.0]),
144 &TensorRank1::<3, 1>::new([3.0, 2.0, 1.0]),
145 )
146 .unwrap()
147}