conspire/math/tensor/rank_2/list_2d/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{Tensor, TensorRank0, TensorRank2, TensorRank2List, tensor::list::TensorList};
5use std::ops::Mul;
6
7#[cfg(test)]
8use crate::math::tensor::test::ErrorTensor;
9
10pub type TensorRank2List2D<
11    const D: usize,
12    const I: usize,
13    const J: usize,
14    const M: usize,
15    const N: usize,
16> = TensorList<TensorRank2List<D, I, J, M>, N>;
17
18impl<const D: usize, const I: usize, const J: usize, const M: usize, const N: usize>
19    From<[[[[TensorRank0; D]; D]; M]; N]> for TensorRank2List2D<D, I, J, M, N>
20{
21    fn from(array: [[[[TensorRank0; D]; D]; M]; N]) -> Self {
22        array.into_iter().map(|entry| entry.into()).collect()
23    }
24}
25
26impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
27    Mul<TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
28{
29    type Output = TensorRank2List2D<D, I, K, W, X>;
30    fn mul(self, tensor_rank_2: TensorRank2<D, J, K>) -> Self::Output {
31        self.iter()
32            .map(|self_entry| {
33                self_entry
34                    .iter()
35                    .map(|self_tensor_rank_2| self_tensor_rank_2 * &tensor_rank_2)
36                    .collect()
37            })
38            .collect()
39    }
40}
41
42impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
43    Mul<&TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
44{
45    type Output = TensorRank2List2D<D, I, K, W, X>;
46    fn mul(self, tensor_rank_2: &TensorRank2<D, J, K>) -> Self::Output {
47        self.iter()
48            .map(|self_entry| {
49                self_entry
50                    .iter()
51                    .map(|self_tensor_rank_2| self_tensor_rank_2 * tensor_rank_2)
52                    .collect()
53            })
54            .collect()
55    }
56}
57
58#[cfg(test)]
59impl<const D: usize, const I: usize, const J: usize, const W: usize, const X: usize> ErrorTensor
60    for TensorRank2List2D<D, I, J, W, X>
61{
62    fn error_fd(&self, comparator: &Self, epsilon: TensorRank0) -> Option<(bool, usize)> {
63        let error_count = self
64            .iter()
65            .zip(comparator.iter())
66            .map(|(self_a, comparator_a)| {
67                self_a
68                    .iter()
69                    .zip(comparator_a.iter())
70                    .map(|(self_ab, comparator_ab)| {
71                        self_ab
72                            .iter()
73                            .zip(comparator_ab.iter())
74                            .map(|(self_ab_i, comparator_ab_i)| {
75                                self_ab_i
76                                    .iter()
77                                    .zip(comparator_ab_i.iter())
78                                    .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
79                                        ((self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
80                                            && (self_ab_ij.abs() >= epsilon
81                                                || comparator_ab_ij.abs() >= epsilon))
82                                            || self_ab_ij.is_nan()
83                                            || comparator_ab_ij.is_nan()
84                                    })
85                                    .count()
86                            })
87                            .sum::<usize>()
88                    })
89                    .sum::<usize>()
90            })
91            .sum();
92        if error_count > 0 {
93            let auxiliary = self
94                .iter()
95                .zip(comparator.iter())
96                .map(|(self_a, comparator_a)| {
97                    self_a
98                        .iter()
99                        .zip(comparator_a.iter())
100                        .map(|(self_ab, comparator_ab)| {
101                            self_ab
102                                .iter()
103                                .zip(comparator_ab.iter())
104                                .map(|(self_ab_i, comparator_ab_i)| {
105                                    self_ab_i
106                                        .iter()
107                                        .zip(comparator_ab_i.iter())
108                                        .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
109                                            ((self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
110                                                && (self_ab_ij - comparator_ab_ij).abs() >= epsilon
111                                                && (self_ab_ij.abs() >= epsilon
112                                                    || comparator_ab_ij.abs() >= epsilon))
113                                                || self_ab_ij.is_nan()
114                                                || comparator_ab_ij.is_nan()
115                                        })
116                                        .count()
117                                })
118                                .sum::<usize>()
119                        })
120                        .sum::<usize>()
121                })
122                .sum::<usize>()
123                > 0;
124            Some((auxiliary, error_count))
125        } else {
126            None
127        }
128    }
129}