conspire/math/tensor/rank_2/list_2d/
mod.rs1#[cfg(test)]
2mod test;
3
4use crate::math::{Tensor, TensorRank0, TensorRank2, TensorRank2List, tensor::list::TensorList};
5use std::ops::Mul;
6
7#[cfg(test)]
8use crate::math::tensor::test::ErrorTensor;
9
10pub type TensorRank2List2D<
11 const D: usize,
12 const I: usize,
13 const J: usize,
14 const M: usize,
15 const N: usize,
16> = TensorList<TensorRank2List<D, I, J, M>, N>;
17
18impl<const D: usize, const I: usize, const J: usize, const M: usize, const N: usize>
19 From<[[[[TensorRank0; D]; D]; M]; N]> for TensorRank2List2D<D, I, J, M, N>
20{
21 fn from(array: [[[[TensorRank0; D]; D]; M]; N]) -> Self {
22 array.into_iter().map(|entry| entry.into()).collect()
23 }
24}
25
26impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
27 Mul<TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
28{
29 type Output = TensorRank2List2D<D, I, K, W, X>;
30 fn mul(self, tensor_rank_2: TensorRank2<D, J, K>) -> Self::Output {
31 self.iter()
32 .map(|self_entry| {
33 self_entry
34 .iter()
35 .map(|self_tensor_rank_2| self_tensor_rank_2 * &tensor_rank_2)
36 .collect()
37 })
38 .collect()
39 }
40}
41
42impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
43 Mul<&TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
44{
45 type Output = TensorRank2List2D<D, I, K, W, X>;
46 fn mul(self, tensor_rank_2: &TensorRank2<D, J, K>) -> Self::Output {
47 self.iter()
48 .map(|self_entry| {
49 self_entry
50 .iter()
51 .map(|self_tensor_rank_2| self_tensor_rank_2 * tensor_rank_2)
52 .collect()
53 })
54 .collect()
55 }
56}
57
58#[cfg(test)]
59impl<const D: usize, const I: usize, const J: usize, const W: usize, const X: usize> ErrorTensor
60 for TensorRank2List2D<D, I, J, W, X>
61{
62 fn error_fd(&self, comparator: &Self, epsilon: TensorRank0) -> Option<(bool, usize)> {
63 let error_count = self
64 .iter()
65 .zip(comparator.iter())
66 .map(|(self_a, comparator_a)| {
67 self_a
68 .iter()
69 .zip(comparator_a.iter())
70 .map(|(self_ab, comparator_ab)| {
71 self_ab
72 .iter()
73 .zip(comparator_ab.iter())
74 .map(|(self_ab_i, comparator_ab_i)| {
75 self_ab_i
76 .iter()
77 .zip(comparator_ab_i.iter())
78 .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
79 ((self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
80 && (self_ab_ij.abs() >= epsilon
81 || comparator_ab_ij.abs() >= epsilon))
82 || self_ab_ij.is_nan()
83 || comparator_ab_ij.is_nan()
84 })
85 .count()
86 })
87 .sum::<usize>()
88 })
89 .sum::<usize>()
90 })
91 .sum();
92 if error_count > 0 {
93 let auxiliary = self
94 .iter()
95 .zip(comparator.iter())
96 .map(|(self_a, comparator_a)| {
97 self_a
98 .iter()
99 .zip(comparator_a.iter())
100 .map(|(self_ab, comparator_ab)| {
101 self_ab
102 .iter()
103 .zip(comparator_ab.iter())
104 .map(|(self_ab_i, comparator_ab_i)| {
105 self_ab_i
106 .iter()
107 .zip(comparator_ab_i.iter())
108 .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
109 ((self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
110 && (self_ab_ij - comparator_ab_ij).abs() >= epsilon
111 && (self_ab_ij.abs() >= epsilon
112 || comparator_ab_ij.abs() >= epsilon))
113 || self_ab_ij.is_nan()
114 || comparator_ab_ij.is_nan()
115 })
116 .count()
117 })
118 .sum::<usize>()
119 })
120 .sum::<usize>()
121 })
122 .sum::<usize>()
123 > 0;
124 Some((auxiliary, error_count))
125 } else {
126 None
127 }
128 }
129}