conspire/math/tensor/rank_1/list/
mod.rs1#[cfg(test)]
2mod test;
3
4use crate::math::{Tensor, TensorRank1, TensorRank2, tensor::list::TensorList};
5use std::{mem::transmute, ops::Mul};
6
7#[cfg(test)]
8use crate::math::{TensorRank0, tensor::test::ErrorTensor};
9
10pub type TensorRank1List<const D: usize, const I: usize, const W: usize> =
11    TensorList<TensorRank1<D, I>, W>;
12
13impl From<TensorRank1List<3, 0, 3>> for TensorRank1List<3, 1, 3> {
14    fn from(tensor_rank_1_list: TensorRank1List<3, 0, 3>) -> Self {
15        unsafe {
16            transmute::<TensorRank1List<3, 0, 3>, TensorRank1List<3, 1, 3>>(tensor_rank_1_list)
17        }
18    }
19}
20
21impl From<TensorRank1List<3, 0, 4>> for TensorRank1List<3, 1, 4> {
22    fn from(tensor_rank_1_list: TensorRank1List<3, 0, 4>) -> Self {
23        unsafe { transmute::<TensorRank1List<3, 0, 4>, Self>(tensor_rank_1_list) }
24    }
25}
26
27impl From<TensorRank1List<3, 0, 8>> for TensorRank1List<3, 1, 8> {
28    fn from(tensor_rank_1_list: TensorRank1List<3, 0, 8>) -> Self {
29        unsafe { transmute::<TensorRank1List<3, 0, 8>, Self>(tensor_rank_1_list) }
30    }
31}
32
33impl From<TensorRank1List<3, 0, 10>> for TensorRank1List<3, 1, 10> {
34    fn from(tensor_rank_1_list: TensorRank1List<3, 0, 10>) -> Self {
35        unsafe { transmute::<TensorRank1List<3, 0, 10>, Self>(tensor_rank_1_list) }
36    }
37}
38
39impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<TensorRank1List<D, J, W>>
40    for TensorRank1List<D, I, W>
41{
42    type Output = TensorRank2<D, I, J>;
43    fn mul(self, tensor_rank_1_list: TensorRank1List<D, J, W>) -> Self::Output {
44        self.iter()
45            .zip(tensor_rank_1_list.iter())
46            .map(|(self_entry, entry)| TensorRank2::dyad(self_entry, entry))
47            .sum()
48    }
49}
50
51impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<&TensorRank1List<D, J, W>>
52    for TensorRank1List<D, I, W>
53{
54    type Output = TensorRank2<D, I, J>;
55    fn mul(self, tensor_rank_1_list: &TensorRank1List<D, J, W>) -> Self::Output {
56        self.iter()
57            .zip(tensor_rank_1_list.iter())
58            .map(|(self_entry, entry)| TensorRank2::dyad(self_entry, entry))
59            .sum()
60    }
61}
62
63impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<TensorRank1List<D, J, W>>
64    for &TensorRank1List<D, I, W>
65{
66    type Output = TensorRank2<D, I, J>;
67    fn mul(self, tensor_rank_1_list: TensorRank1List<D, J, W>) -> Self::Output {
68        self.iter()
69            .zip(tensor_rank_1_list.iter())
70            .map(|(self_entry, entry)| TensorRank2::dyad(self_entry, entry))
71            .sum()
72    }
73}
74
75impl<const D: usize, const I: usize, const J: usize, const W: usize> Mul<&TensorRank1List<D, J, W>>
76    for &TensorRank1List<D, I, W>
77{
78    type Output = TensorRank2<D, I, J>;
79    fn mul(self, tensor_rank_1_list: &TensorRank1List<D, J, W>) -> Self::Output {
80        self.iter()
81            .zip(tensor_rank_1_list.iter())
82            .map(|(self_entry, entry)| TensorRank2::dyad(self_entry, entry))
83            .sum()
84    }
85}
86
87#[cfg(test)]
88impl<const D: usize, const I: usize, const W: usize> ErrorTensor for TensorRank1List<D, I, W> {
89    fn error_fd(&self, comparator: &Self, epsilon: &TensorRank0) -> Option<(bool, usize)> {
90        let error_count = self
91            .iter()
92            .zip(comparator.iter())
93            .map(|(entry, comparator_entry)| {
94                entry
95                    .iter()
96                    .zip(comparator_entry.iter())
97                    .filter(|&(&entry_i, &comparator_entry_i)| {
98                        &(entry_i / comparator_entry_i - 1.0).abs() >= epsilon
99                            && (&entry_i.abs() >= epsilon || &comparator_entry_i.abs() >= epsilon)
100                    })
101                    .count()
102            })
103            .sum();
104        if error_count > 0 {
105            let auxiliary = self
106                .iter()
107                .zip(comparator.iter())
108                .map(|(entry, comparator_entry)| {
109                    entry
110                        .iter()
111                        .zip(comparator_entry.iter())
112                        .filter(|&(&entry_i, &comparator_entry_i)| {
113                            &(entry_i / comparator_entry_i - 1.0).abs() >= epsilon
114                                && &(entry_i - comparator_entry_i).abs() >= epsilon
115                                && (&entry_i.abs() >= epsilon
116                                    || &comparator_entry_i.abs() >= epsilon)
117                        })
118                        .count()
119                })
120                .sum::<usize>()
121                > 0;
122            Some((auxiliary, error_count))
123        } else {
124            None
125        }
126    }
127}