conspire/math/tensor/rank_2/list_2d/
mod.rs1#[cfg(test)]
2mod test;
3
4use crate::math::{Tensor, TensorRank2, TensorRank2List, tensor::list::TensorList};
5use std::ops::Mul;
6
7#[cfg(test)]
8use crate::math::{TensorRank0, tensor::test::ErrorTensor};
9
10pub type TensorRank2List2D<
11 const D: usize,
12 const I: usize,
13 const J: usize,
14 const W: usize,
15 const X: usize,
16> = TensorList<TensorRank2List<D, I, J, W>, X>;
17
18impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
19 Mul<TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
20{
21 type Output = TensorRank2List2D<D, I, K, W, X>;
22 fn mul(self, tensor_rank_2: TensorRank2<D, J, K>) -> Self::Output {
23 self.iter()
24 .map(|self_entry| {
25 self_entry
26 .iter()
27 .map(|self_tensor_rank_2| self_tensor_rank_2 * &tensor_rank_2)
28 .collect()
29 })
30 .collect()
31 }
32}
33
34impl<const D: usize, const I: usize, const J: usize, const K: usize, const W: usize, const X: usize>
35 Mul<&TensorRank2<D, J, K>> for TensorRank2List2D<D, I, J, W, X>
36{
37 type Output = TensorRank2List2D<D, I, K, W, X>;
38 fn mul(self, tensor_rank_2: &TensorRank2<D, J, K>) -> Self::Output {
39 self.iter()
40 .map(|self_entry| {
41 self_entry
42 .iter()
43 .map(|self_tensor_rank_2| self_tensor_rank_2 * tensor_rank_2)
44 .collect()
45 })
46 .collect()
47 }
48}
49
50#[cfg(test)]
51impl<const D: usize, const I: usize, const J: usize, const W: usize, const X: usize> ErrorTensor
52 for TensorRank2List2D<D, I, J, W, X>
53{
54 fn error_fd(&self, comparator: &Self, epsilon: TensorRank0) -> Option<(bool, usize)> {
55 let error_count = self
56 .iter()
57 .zip(comparator.iter())
58 .map(|(self_a, comparator_a)| {
59 self_a
60 .iter()
61 .zip(comparator_a.iter())
62 .map(|(self_ab, comparator_ab)| {
63 self_ab
64 .iter()
65 .zip(comparator_ab.iter())
66 .map(|(self_ab_i, comparator_ab_i)| {
67 self_ab_i
68 .iter()
69 .zip(comparator_ab_i.iter())
70 .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
71 (self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
72 && (self_ab_ij.abs() >= epsilon
73 || comparator_ab_ij.abs() >= epsilon)
74 })
75 .count()
76 })
77 .sum::<usize>()
78 })
79 .sum::<usize>()
80 })
81 .sum();
82 if error_count > 0 {
83 let auxiliary = self
84 .iter()
85 .zip(comparator.iter())
86 .map(|(self_a, comparator_a)| {
87 self_a
88 .iter()
89 .zip(comparator_a.iter())
90 .map(|(self_ab, comparator_ab)| {
91 self_ab
92 .iter()
93 .zip(comparator_ab.iter())
94 .map(|(self_ab_i, comparator_ab_i)| {
95 self_ab_i
96 .iter()
97 .zip(comparator_ab_i.iter())
98 .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
99 (self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
100 && (self_ab_ij - comparator_ab_ij).abs() >= epsilon
101 && (self_ab_ij.abs() >= epsilon
102 || comparator_ab_ij.abs() >= epsilon)
103 })
104 .count()
105 })
106 .sum::<usize>()
107 })
108 .sum::<usize>()
109 })
110 .sum::<usize>()
111 > 0;
112 Some((auxiliary, error_count))
113 } else {
114 None
115 }
116 }
117}