conspire/math/tensor/rank_2/list_2d/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{Tensor, TensorRank2, TensorRank2List, tensor::list::TensorList};
5use std::ops::Mul;
6
7#[cfg(test)]
8use crate::math::{TensorRank0, tensor::test::ErrorTensor};
9
10pub type TensorRank2List2D<
11    const D: usize,
12    const I: usize,
13    const J: usize,
14    const W: usize,
15    const X: usize,
16> = TensorList<TensorRank2List<D, I, J, W>, X>;
17
18impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
19    Mul<TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
20{
21    type Output = TensorRank2List2D<D, I, K, W, X>;
22    fn mul(self, tensor_rank_2: TensorRank2<D, J, K>) -> Self::Output {
23        self.iter()
24            .map(|self_entry| {
25                self_entry
26                    .iter()
27                    .map(|self_tensor_rank_2| self_tensor_rank_2 * &tensor_rank_2)
28                    .collect()
29            })
30            .collect()
31    }
32}
33
34impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
35    Mul<&TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
36{
37    type Output = TensorRank2List2D<D, I, K, W, X>;
38    fn mul(self, tensor_rank_2: &TensorRank2<D, J, K>) -> Self::Output {
39        self.iter()
40            .map(|self_entry| {
41                self_entry
42                    .iter()
43                    .map(|self_tensor_rank_2| self_tensor_rank_2 * tensor_rank_2)
44                    .collect()
45            })
46            .collect()
47    }
48}
49
50#[cfg(test)]
51impl<const D: usize, const I: usize, const J: usize, const W: usize, const X: usize> ErrorTensor
52    for TensorRank2List2D<D, I, J, W, X>
53{
54    fn error_fd(&self, comparator: &Self, epsilon: &TensorRank0) -> Option<(bool, usize)> {
55        let error_count = self
56            .iter()
57            .zip(comparator.iter())
58            .map(|(self_a, comparator_a)| {
59                self_a
60                    .iter()
61                    .zip(comparator_a.iter())
62                    .map(|(self_ab, comparator_ab)| {
63                        self_ab
64                            .iter()
65                            .zip(comparator_ab.iter())
66                            .map(|(self_ab_i, comparator_ab_i)| {
67                                self_ab_i
68                                    .iter()
69                                    .zip(comparator_ab_i.iter())
70                                    .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
71                                        &(self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
72                                            && (&self_ab_ij.abs() >= epsilon
73                                                || &comparator_ab_ij.abs() >= epsilon)
74                                    })
75                                    .count()
76                            })
77                            .sum::<usize>()
78                    })
79                    .sum::<usize>()
80            })
81            .sum();
82        if error_count > 0 {
83            let auxiliary = self
84                .iter()
85                .zip(comparator.iter())
86                .map(|(self_a, comparator_a)| {
87                    self_a
88                        .iter()
89                        .zip(comparator_a.iter())
90                        .map(|(self_ab, comparator_ab)| {
91                            self_ab
92                                .iter()
93                                .zip(comparator_ab.iter())
94                                .map(|(self_ab_i, comparator_ab_i)| {
95                                    self_ab_i
96                                        .iter()
97                                        .zip(comparator_ab_i.iter())
98                                        .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
99                                            &(self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
100                                                && &(self_ab_ij - comparator_ab_ij).abs() >= epsilon
101                                                && (&self_ab_ij.abs() >= epsilon
102                                                    || &comparator_ab_ij.abs() >= epsilon)
103                                        })
104                                        .count()
105                                })
106                                .sum::<usize>()
107                        })
108                        .sum::<usize>()
109                })
110                .sum::<usize>()
111                > 0;
112            Some((auxiliary, error_count))
113        } else {
114            None
115        }
116    }
117}