conspire/math/tensor/rank_0/list/
mod.rs

1pub mod vec;
2
3#[cfg(test)]
4mod test;
5
6use crate::math::{Tensor, TensorRank0, tensor::list::TensorList};
7use std::ops::Mul;
8
9#[cfg(test)]
10use crate::math::tensor::test::ErrorTensor;
11
12pub type TensorRank0List<const N: usize> = TensorList<TensorRank0, N>;
13
14impl<const N: usize> Mul for TensorRank0List<N> {
15    type Output = TensorRank0;
16    fn mul(self, tensor_rank_0_list: Self) -> Self::Output {
17        self.iter()
18            .zip(tensor_rank_0_list.iter())
19            .map(|(self_entry, entry)| self_entry * entry)
20            .sum()
21    }
22}
23
24impl<const N: usize> Mul<&Self> for TensorRank0List<N> {
25    type Output = TensorRank0;
26    fn mul(self, tensor_rank_0_list: &Self) -> Self::Output {
27        self.iter()
28            .zip(tensor_rank_0_list.iter())
29            .map(|(self_entry, entry)| self_entry * entry)
30            .sum()
31    }
32}
33
34impl<const N: usize> Mul<TensorRank0List<N>> for &TensorRank0List<N> {
35    type Output = TensorRank0;
36    fn mul(self, tensor_rank_0_list: TensorRank0List<N>) -> Self::Output {
37        self.iter()
38            .zip(tensor_rank_0_list.iter())
39            .map(|(self_entry, entry)| self_entry * entry)
40            .sum()
41    }
42}
43
44impl<const N: usize> Mul for &TensorRank0List<N> {
45    type Output = TensorRank0;
46    fn mul(self, tensor_rank_0_list: Self) -> Self::Output {
47        self.iter()
48            .zip(tensor_rank_0_list.iter())
49            .map(|(self_entry, entry)| self_entry * entry)
50            .sum()
51    }
52}
53
54#[cfg(test)]
55impl<const N: usize> ErrorTensor for TensorRank0List<N> {
56    fn error_fd(&self, comparator: &Self, epsilon: TensorRank0) -> Option<(bool, usize)> {
57        let error_count = self
58            .iter()
59            .zip(comparator.iter())
60            .filter(|&(&self_i, &comparator_i)| {
61                (self_i / comparator_i - 1.0).abs() >= epsilon
62                    && (self_i.abs() >= epsilon || comparator_i.abs() >= epsilon)
63            })
64            .count();
65        if error_count > 0 {
66            Some((true, error_count))
67        } else {
68            None
69        }
70    }
71}