conspire/math/tensor/rank_2/vec_2d/
mod.rs

1use crate::math::{
2    Hessian, SquareMatrix, Tensor, TensorRank2, TensorRank2Vec, tensor::vec::TensorVector,
3};
4use std::ops::Mul;
5
6#[cfg(test)]
7use crate::math::{TensorRank0, tensor::test::ErrorTensor};
8
9pub type TensorRank2Vec2D<const D: usize, const I: usize, const J: usize> =
10    TensorVector<TensorRank2Vec<D, I, J>>;
11
12impl<const D: usize, const I: usize, const J: usize> TensorRank2Vec2D<D, I, J> {
13    pub fn zero(len: usize) -> Self {
14        (0..len).map(|_| TensorRank2Vec::zero(len)).collect()
15    }
16}
17
18impl<const D: usize, const I: usize, const J: usize> Hessian for TensorRank2Vec2D<D, I, J> {
19    fn fill_into(self, square_matrix: &mut SquareMatrix) {
20        self.into_iter().enumerate().for_each(|(a, entry_a)| {
21            entry_a.into_iter().enumerate().for_each(|(b, entry_ab)| {
22                entry_ab
23                    .into_iter()
24                    .enumerate()
25                    .for_each(|(i, entry_ab_i)| {
26                        entry_ab_i
27                            .into_iter()
28                            .enumerate()
29                            .for_each(|(j, entry_ab_ij)| {
30                                square_matrix[D * a + i][D * b + j] = entry_ab_ij
31                            })
32                    })
33            })
34        });
35    }
36    fn retain_from(self, retained: &[bool]) -> SquareMatrix {
37        SquareMatrix::from(self)
38            .into_iter()
39            .zip(retained.iter())
40            .filter(|(_, retained)| **retained)
41            .map(|(self_i, _)| {
42                self_i
43                    .into_iter()
44                    .zip(retained.iter())
45                    .filter(|(_, retained)| **retained)
46                    .map(|(self_ij, _)| self_ij)
47                    .collect()
48            })
49            .collect()
50    }
51}
52
53impl<const D: usize, const I: usize, const J: usize, const K: usize> Mul<TensorRank2<D, J, K>>
54    for TensorRank2Vec2D<D, I, J>
55{
56    type Output = TensorRank2Vec2D<D, I, K>;
57    fn mul(self, tensor_rank_2: TensorRank2<D, J, K>) -> Self::Output {
58        self.iter()
59            .map(|self_entry| {
60                self_entry
61                    .iter()
62                    .map(|self_tensor_rank_2| self_tensor_rank_2 * &tensor_rank_2)
63                    .collect()
64            })
65            .collect()
66    }
67}
68
69impl<const D: usize, const I: usize, const J: usize, const K: usize> Mul<&TensorRank2<D, J, K>>
70    for TensorRank2Vec2D<D, I, J>
71{
72    type Output = TensorRank2Vec2D<D, I, K>;
73    fn mul(self, tensor_rank_2: &TensorRank2<D, J, K>) -> Self::Output {
74        self.iter()
75            .map(|self_entry| {
76                self_entry
77                    .iter()
78                    .map(|self_tensor_rank_2| self_tensor_rank_2 * tensor_rank_2)
79                    .collect()
80            })
81            .collect()
82    }
83}
84
85#[cfg(test)]
86impl<const D: usize, const I: usize, const J: usize> ErrorTensor for TensorRank2Vec2D<D, I, J> {
87    fn error_fd(&self, comparator: &Self, epsilon: &TensorRank0) -> Option<(bool, usize)> {
88        let error_count = self
89            .iter()
90            .zip(comparator.iter())
91            .map(|(self_a, comparator_a)| {
92                self_a
93                    .iter()
94                    .zip(comparator_a.iter())
95                    .map(|(self_ab, comparator_ab)| {
96                        self_ab
97                            .iter()
98                            .zip(comparator_ab.iter())
99                            .map(|(self_ab_i, comparator_ab_i)| {
100                                self_ab_i
101                                    .iter()
102                                    .zip(comparator_ab_i.iter())
103                                    .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
104                                        &(self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
105                                            && (&self_ab_ij.abs() >= epsilon
106                                                || &comparator_ab_ij.abs() >= epsilon)
107                                    })
108                                    .count()
109                            })
110                            .sum::<usize>()
111                    })
112                    .sum::<usize>()
113            })
114            .sum();
115        if error_count > 0 {
116            let auxiliary = self
117                .iter()
118                .zip(comparator.iter())
119                .map(|(self_a, comparator_a)| {
120                    self_a
121                        .iter()
122                        .zip(comparator_a.iter())
123                        .map(|(self_ab, comparator_ab)| {
124                            self_ab
125                                .iter()
126                                .zip(comparator_ab.iter())
127                                .map(|(self_ab_i, comparator_ab_i)| {
128                                    self_ab_i
129                                        .iter()
130                                        .zip(comparator_ab_i.iter())
131                                        .filter(|&(&self_ab_ij, &comparator_ab_ij)| {
132                                            &(self_ab_ij / comparator_ab_ij - 1.0).abs() >= epsilon
133                                                && &(self_ab_ij - comparator_ab_ij).abs() >= epsilon
134                                                && (&self_ab_ij.abs() >= epsilon
135                                                    || &comparator_ab_ij.abs() >= epsilon)
136                                        })
137                                        .count()
138                                })
139                                .sum::<usize>()
140                        })
141                        .sum::<usize>()
142                })
143                .sum::<usize>()
144                > 0;
145            Some((auxiliary, error_count))
146        } else {
147            None
148        }
149    }
150}