1use super::{Scalar, Tensor, TensorError};
2use crate::{ABS_TOL, REL_TOL, defeat_message};
3use std::{
4 cmp::PartialEq,
5 fmt::{self, Debug, Display, Formatter},
6};
7
8#[cfg(test)]
9use crate::EPSILON;
10
11#[cfg(test)]
12use super::{
13 TensorArray,
14 rank_1::{TensorRank1, list::TensorRank1List},
15};
16
17#[cfg(test)]
18pub trait ErrorTensor {
19 fn error_fd(&self, comparator: &Self, epsilon: Scalar) -> Option<(bool, usize)>;
20}
21
22pub fn assert_eq<'a, T>(value_1: &'a T, value_2: &'a T) -> Result<(), TestError>
23where
24 T: Display + PartialEq,
25{
26 if value_1 == value_2 {
27 Ok(())
28 } else {
29 Err(TestError {
30 message: format!(
31 "\n\x1b[1;91mAssertion `left == right` failed.\n\x1b[0;91m left: {value_1}\n right: {value_2}\x1b[0m"
32 ),
33 })
34 }
35}
36
37#[cfg(test)]
38pub fn assert_eq_from_fd<'a, T>(value: &'a T, value_fd: &'a T) -> Result<(), TestError>
39where
40 T: Display + ErrorTensor + Tensor,
41{
42 if let Some((failed, count)) = value.error_fd(value_fd, 3.0 * EPSILON) {
43 if failed {
44 let abs = value.sub_abs(value_fd);
45 let rel = value.sub_rel(value_fd);
46 Err(TestError {
47 message: format!(
48 "\n\x1b[1;91mAssertion `left ≈= right` failed in {count} places.\n\x1b[0;91m left: {value}\n right: {value_fd}\n abs: {abs}\n rel: {rel}\x1b[0m"
49 ),
50 })
51 } else {
52 println!(
53 "Warning: \n\x1b[1;93mAssertion `left ≈= right` was weak in {count} places.\x1b[0m"
54 );
55 Ok(())
56 }
57 } else {
58 Ok(())
59 }
60}
61
62pub fn assert_eq_within<'a, T>(
63 value_1: &'a T,
64 value_2: &'a T,
65 tol_abs: Scalar,
66 tol_rel: Scalar,
67) -> Result<(), TestError>
68where
69 T: Display + Tensor,
70{
71 if let Some(count) = value_1.error_count(value_2, tol_abs, tol_rel) {
72 let abs = value_1.sub_abs(value_2);
73 let rel = value_1.sub_rel(value_2);
74 Err(TestError {
75 message: format!(
76 "\n\x1b[1;91mAssertion `left ≈= right` failed in {count} places.\n\x1b[0;91m left: {value_1}\n right: {value_2}\n abs: {abs}\n rel: {rel}\x1b[0m"
77 ),
78 })
79 } else {
80 Ok(())
81 }
82}
83
84pub fn assert_eq_within_tols<'a, T>(value_1: &'a T, value_2: &'a T) -> Result<(), TestError>
85where
86 T: Display + Tensor,
87{
88 assert_eq_within(value_1, value_2, ABS_TOL, REL_TOL)
89}
90
91pub struct TestError {
92 pub message: String,
93}
94
95impl Debug for TestError {
96 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
97 write!(
98 f,
99 "{}\n\x1b[0;2;31m{}\x1b[0m\n",
100 self.message,
101 defeat_message()
102 )
103 }
104}
105
106impl From<String> for TestError {
107 fn from(error: String) -> Self {
108 Self { message: error }
109 }
110}
111
112impl From<&str> for TestError {
113 fn from(error: &str) -> Self {
114 Self {
115 message: error.to_string(),
116 }
117 }
118}
119
120impl From<TensorError> for TestError {
121 fn from(error: TensorError) -> Self {
122 Self {
123 message: error.to_string(),
124 }
125 }
126}
127
128#[test]
129fn test_error_from_string() {
130 assert_eq!(
131 TestError::from("An error occurred".to_string()).message,
132 "An error occurred"
133 );
134}
135
136#[test]
137fn test_error_from_str() {
138 assert_eq!(
139 TestError::from("An error occurred").message,
140 "An error occurred"
141 );
142}
143
144#[test]
145fn test_error_from_tensor_error() {
146 let tensor_error = TensorError::NotPositiveDefinite;
147 let _ = format!("{:?}", tensor_error);
148 let _ = TestError::from(tensor_error);
149}
150
151#[test]
152#[should_panic(expected = "Assertion `left == right` failed.")]
153fn assert_eq_fail() {
154 assert_eq(&0.0, &1.0).unwrap()
155}
156
157#[test]
158#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
159fn assert_eq_from_fd_fail() {
160 assert_eq_from_fd(
161 &TensorRank1::<_, 1>::new([1.0, 2.0, 3.0]),
162 &TensorRank1::<_, 1>::new([3.0, 2.0, 1.0]),
163 )
164 .unwrap()
165}
166
167#[test]
168fn assert_eq_from_fd_success() -> Result<(), TestError> {
169 assert_eq_from_fd(
170 &TensorRank1::<_, 1>::new([1.0, 2.0, 3.0]),
171 &TensorRank1::<_, 1>::new([1.0, 2.0, 3.0]),
172 )
173}
174
175#[test]
176fn assert_eq_from_fd_weak() -> Result<(), TestError> {
177 assert_eq_from_fd(
178 &TensorRank1List::<_, 1, 1>::new([[EPSILON * 1.01]]),
179 &TensorRank1List::<_, 1, 1>::new([[EPSILON * 1.02]]),
180 )
181}
182
183#[test]
184#[should_panic(expected = "Assertion `left ≈= right` failed in 2 places.")]
185fn assert_eq_within_tols_fail() {
186 assert_eq_within_tols(
187 &TensorRank1::<_, 1>::new([1.0, 2.0, 3.0]),
188 &TensorRank1::<_, 1>::new([3.0, 2.0, 1.0]),
189 )
190 .unwrap()
191}