conspire/math/tensor/rank_1/list/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{Tensor, TensorRank0, TensorRank1, TensorRank2, tensor::list::TensorList};
5use std::ops::Mul;
6
7#[cfg(test)]
8use crate::math::tensor::test::ErrorTensor;
9
10pub type TensorRank1List<const D: usize, const I: usize, const N: usize> =
11    TensorList<TensorRank1<D, I>, N>;
12
13impl<const D: usize, const I: usize, const N: usize> From<[[TensorRank0; D]; N]>
14    for TensorRank1List<D, I, N>
15{
16    fn from(array: [[TensorRank0; D]; N]) -> Self {
17        array.into_iter().map(|entry| entry.into()).collect()
18    }
19}
20
21impl<const D: usize, const N: usize> From<TensorRank1List<D, 9, N>> for TensorRank1List<D, 0, N> {
22    fn from(tensor_rank_1_list: TensorRank1List<D, 9, N>) -> Self {
23        tensor_rank_1_list
24            .into_iter()
25            .map(|entry| entry.into())
26            .collect()
27    }
28}
29
30impl<const D: usize, const N: usize> From<TensorRank1List<D, 0, N>> for TensorRank1List<D, 1, N> {
31    fn from(tensor_rank_1_list: TensorRank1List<D, 0, N>) -> Self {
32        tensor_rank_1_list
33            .into_iter()
34            .map(|entry| entry.into())
35            .collect()
36    }
37}
38
39impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<TensorRank1List<D, J, W>>
40    for TensorRank1List<D, I, W>
41{
42    type Output = TensorRank2<D, I, J>;
43    fn mul(self, tensor_rank_1_list: TensorRank1List<D, J, W>) -> Self::Output {
44        self.into_iter()
45            .zip(tensor_rank_1_list)
46            .map(|(self_entry, entry)| (self_entry, entry).into())
47            .sum()
48    }
49}
50
51impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<&TensorRank1List<D, J, W>>
52    for TensorRank1List<D, I, W>
53{
54    type Output = TensorRank2<D, I, J>;
55    fn mul(self, tensor_rank_1_list: &TensorRank1List<D, J, W>) -> Self::Output {
56        self.into_iter()
57            .zip(tensor_rank_1_list.iter())
58            .map(|(self_entry, entry)| (self_entry, entry).into())
59            .sum()
60    }
61}
62
63impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<TensorRank1List<D, J, W>>
64    for &TensorRank1List<D, I, W>
65{
66    type Output = TensorRank2<D, I, J>;
67    fn mul(self, tensor_rank_1_list: TensorRank1List<D, J, W>) -> Self::Output {
68        self.iter()
69            .zip(tensor_rank_1_list)
70            .map(|(self_entry, entry)| (self_entry, entry).into())
71            .sum()
72    }
73}
74
75impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<&TensorRank1List<D, J, W>>
76    for &TensorRank1List<D, I, W>
77{
78    type Output = TensorRank2<D, I, J>;
79    fn mul(self, tensor_rank_1_list: &TensorRank1List<D, J, W>) -> Self::Output {
80        self.iter()
81            .zip(tensor_rank_1_list.iter())
82            .map(|(self_entry, entry)| (self_entry, entry).into())
83            .sum()
84    }
85}
86
87#[cfg(test)]
88impl<const D: usize, const I: usize, const W: usize> ErrorTensor for TensorRank1List<D, I, W> {
89    fn error_fd(&self, comparator: &Self, epsilon: TensorRank0) -> Option<(bool, usize)> {
90        let error_count = self
91            .iter()
92            .zip(comparator.iter())
93            .map(|(entry, comparator_entry)| {
94                entry
95                    .iter()
96                    .zip(comparator_entry.iter())
97                    .filter(|&(&entry_i, &comparator_entry_i)| {
98                        (entry_i / comparator_entry_i - 1.0).abs() >= epsilon
99                            && (entry_i.abs() >= epsilon || comparator_entry_i.abs() >= epsilon)
100                    })
101                    .count()
102            })
103            .sum();
104        if error_count > 0 {
105            let auxiliary = self
106                .iter()
107                .zip(comparator.iter())
108                .map(|(entry, comparator_entry)| {
109                    entry
110                        .iter()
111                        .zip(comparator_entry.iter())
112                        .filter(|&(&entry_i, &comparator_entry_i)| {
113                            (entry_i / comparator_entry_i - 1.0).abs() >= epsilon
114                                && (entry_i - comparator_entry_i).abs() >= epsilon
115                                && (entry_i.abs() >= epsilon || comparator_entry_i.abs() >= epsilon)
116                        })
117                        .count()
118                })
119                .sum::<usize>()
120                > 0;
121            Some((auxiliary, error_count))
122        } else {
123            None
124        }
125    }
126}