conspire/math/tensor/rank_2/vec_2d/
mod.rs

1use crate::math::{
2    Hessian, SquareMatrix, Tensor, TensorRank0, TensorRank2, TensorRank2Vec,
3    tensor::vec::TensorVector,
4};
5use std::ops::Mul;
6
7#[cfg(test)]
8use crate::math::tensor::test::ErrorTensor;
9
10pub type TensorRank2Vec2D<const D: usize, const I: usize, const J: usize> =
11    TensorVector<TensorRank2Vec<D, I, J>>;
12
13impl<const D: usize, const I: usize, const J: usize> TensorRank2Vec2D<D, I, J> {
14    pub fn zero(len: usize) -> Self {
15        (0..len).map(|_| TensorRank2Vec::zero(len)).collect()
16    }
17}
18
19impl<const D: usize, const I: usize, const J: usize> From<TensorRank2Vec2D<D, I, J>>
20    for Vec<TensorRank0>
21{
22    fn from(tensor_rank_2_vec_2d: TensorRank2Vec2D<D, I, J>) -> Self {
23        tensor_rank_2_vec_2d
24            .into_iter()
25            .flat_map(|tensor_rank_2_vec_1d| {
26                tensor_rank_2_vec_1d.into_iter().flat_map(|tensor_rank_2| {
27                    tensor_rank_2
28                        .into_iter()
29                        .flat_map(|tensor_rank_1| tensor_rank_1.into_iter())
30                })
31            })
32            .collect()
33    }
34}
35
36impl<const D: usize, const I: usize, const J: usize> Hessian for TensorRank2Vec2D<D, I, J> {
37    fn fill_into(self, square_matrix: &mut SquareMatrix) {
38        self.into_iter().enumerate().for_each(|(a, entry_a)| {
39            entry_a.into_iter().enumerate().for_each(|(b, entry_ab)| {
40                entry_ab
41                    .into_iter()
42                    .enumerate()
43                    .for_each(|(i, entry_ab_i)| {
44                        entry_ab_i
45                            .into_iter()
46                            .enumerate()
47                            .for_each(|(j, entry_ab_ij)| {
48                                square_matrix[D * a + i][D * b + j] = entry_ab_ij
49                            })
50                    })
51            })
52        });
53    }
54    fn retain_from(self, retained: &[bool]) -> SquareMatrix {
55        SquareMatrix::from(self)
56            .into_iter()
57            .zip(retained.iter())
58            .filter(|(_, retained)| **retained)
59            .map(|(self_i, _)| {
60                self_i
61                    .into_iter()
62                    .zip(retained.iter())
63                    .filter(|(_, retained)| **retained)
64                    .map(|(self_ij, _)| self_ij)
65                    .collect()
66            })
67            .collect()
68    }
69}
70
71impl<const D: usize, const I: usize, const J: usize, const K: usize> Mul<TensorRank2<D, J, K>>
72    for TensorRank2Vec2D<D, I, J>
73{
74    type Output = TensorRank2Vec2D<D, I, K>;
75    fn mul(self, tensor_rank_2: TensorRank2<D, J, K>) -> Self::Output {
76        self.iter()
77            .map(|self_entry| {
78                self_entry
79                    .iter()
80                    .map(|self_tensor_rank_2| self_tensor_rank_2 * &tensor_rank_2)
81                    .collect()
82            })
83            .collect()
84    }
85}
86
87impl<const D: usize, const I: usize, const J: usize, const K: usize> Mul<&TensorRank2<D, J, K>>
88    for TensorRank2Vec2D<D, I, J>
89{
90    type Output = TensorRank2Vec2D<D, I, K>;
91    fn mul(self, tensor_rank_2: &TensorRank2<D, J, K>) -> Self::Output {
92        self.iter()
93            .map(|self_entry| {
94                self_entry
95                    .iter()
96                    .map(|self_tensor_rank_2| self_tensor_rank_2 * tensor_rank_2)
97                    .collect()
98            })
99            .collect()
100    }
101}
102
103#[cfg(test)]
104impl<const D: usize, const I: usize, const J: usize> ErrorTensor for TensorRank2Vec2D<D, I, J> {
105    fn error_fd(&self, comparator: &Self, epsilon: TensorRank0) -> Option<(bool, usize)> {
106        let error_count = self
107            .iter()
108            .zip(comparator.iter())
109            .map(|(self_a, comparator_a)| {
110                self_a
111                    .iter()
112                    .zip(comparator_a.iter())
113                    .map(|(self_ab, comparator_ab)| {
114                        self_ab
115                            .iter()
116                            .zip(comparator_ab.iter())
117                            .map(|(self_ab_i, comparator_ab_i)| {
118                                self_ab_i
119                                    .iter()
120                                    .zip(comparator_ab_i.iter())
121                                    .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
122                                        (self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
123                                            && (self_ab_ij.abs() >= epsilon
124                                                || comparator_ab_ij.abs() >= epsilon)
125                                    })
126                                    .count()
127                            })
128                            .sum::<usize>()
129                    })
130                    .sum::<usize>()
131            })
132            .sum();
133        if error_count > 0 {
134            let auxiliary = self
135                .iter()
136                .zip(comparator.iter())
137                .map(|(self_a, comparator_a)| {
138                    self_a
139                        .iter()
140                        .zip(comparator_a.iter())
141                        .map(|(self_ab, comparator_ab)| {
142                            self_ab
143                                .iter()
144                                .zip(comparator_ab.iter())
145                                .map(|(self_ab_i, comparator_ab_i)| {
146                                    self_ab_i
147                                        .iter()
148                                        .zip(comparator_ab_i.iter())
149                                        .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
150                                            (self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
151                                                && (self_ab_ij - comparator_ab_ij).abs() >= epsilon
152                                                && (self_ab_ij.abs() >= epsilon
153                                                    || comparator_ab_ij.abs() >= epsilon)
154                                        })
155                                        .count()
156                                })
157                                .sum::<usize>()
158                        })
159                        .sum::<usize>()
160                })
161                .sum::<usize>()
162                > 0;
163            Some((auxiliary, error_count))
164        } else {
165            None
166        }
167    }
168}