conspire/math/tensor/
mod.rs

1pub mod test;
2
3pub mod list;
4pub mod rank_0;
5pub mod rank_1;
6pub mod rank_2;
7pub mod rank_3;
8pub mod rank_4;
9pub mod tuple;
10pub mod vec;
11
12use super::{SquareMatrix, Vector};
13use crate::defeat_message;
14use rank_0::{
15    TensorRank0,
16    list::{TensorRank0List, vec::TensorRank0ListVec},
17};
18use std::{
19    fmt::{self, Debug, Display, Formatter},
20    iter::Sum,
21    ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign},
22};
23
24/// A scalar.
25pub type Scalar = TensorRank0;
26
27/// A list of scalars.
28pub type Scalars<const W: usize> = TensorRank0List<W>;
29
30/// A vector of lists of scalars.
31pub type ScalarsVec<const W: usize> = TensorRank0ListVec<W>;
32
33/// Possible errors for tensors.
34#[derive(PartialEq)]
35pub enum TensorError {
36    NotPositiveDefinite,
37    SymmetricMatrixComplexEigenvalues,
38}
39
40impl Debug for TensorError {
41    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
42        let error = match self {
43            Self::NotPositiveDefinite => "\x1b[1;91mResult is not positive definite.".to_string(),
44            Self::SymmetricMatrixComplexEigenvalues => {
45                "\x1b[1;91mSymmetric matrix produced complex eigenvalues".to_string()
46            }
47        };
48        write!(f, "\n{error}\n\x1b[0;2;31m{}\x1b[0m\n", defeat_message())
49    }
50}
51
52impl Display for TensorError {
53    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
54        let error = match self {
55            Self::NotPositiveDefinite => "\x1b[1;91mResult is not positive definite.".to_string(),
56            Self::SymmetricMatrixComplexEigenvalues => {
57                "\x1b[1;91mSymmetric matrix produced complex eigenvalues".to_string()
58            }
59        };
60        write!(f, "{error}\x1b[0m")
61    }
62}
63
64/// Common methods for solutions.
65pub trait Solution
66where
67    Self: From<Vector> + Tensor,
68{
69    /// Decrements the solution from another vector.
70    fn decrement_from(&mut self, other: &Vector);
71    /// Decrements the solution chained with a vector from another vector.
72    fn decrement_from_chained(&mut self, other: &mut Vector, vector: Vector);
73    /// Decrements the solution from another vector on retained entries.
74    fn decrement_from_retained(&mut self, _retained: &[bool], _other: &Vector) {
75        unimplemented!()
76    }
77}
78
79/// Common methods for Jacobians.
80pub trait Jacobian
81where
82    Self:
83        From<Vector> + Tensor + Sub<Vector, Output = Self> + for<'a> Sub<&'a Vector, Output = Self>,
84{
85    /// Fills the Jacobian into a vector.
86    fn fill_into(self, vector: &mut Vector);
87    /// Fills the Jacobian chained with a vector into another vector.
88    fn fill_into_chained(self, other: Vector, vector: &mut Vector);
89    /// Return only the retained indices.
90    fn retain_from(self, _retained: &[bool]) -> Vector {
91        unimplemented!()
92    }
93    /// Zero out the specified indices.
94    fn zero_out(&mut self, _indices: &[usize]) {
95        unimplemented!()
96    }
97}
98
99/// Common methods for Hessians.
100pub trait Hessian
101where
102    Self: Tensor,
103{
104    /// Fills the Hessian into a square matrix.
105    fn fill_into(self, square_matrix: &mut SquareMatrix);
106    /// Return only the retained indices.
107    fn retain_from(self, _retained: &[bool]) -> SquareMatrix {
108        unimplemented!()
109    }
110}
111
112/// Common methods for rank-2 tensors.
113pub trait Rank2
114where
115    Self: Sized,
116{
117    /// The type that is the transpose of the tensor.
118    type Transpose;
119    /// Returns the deviatoric component of the rank-2 tensor.
120    fn deviatoric(&self) -> Self;
121    /// Returns the deviatoric component and trace of the rank-2 tensor.
122    fn deviatoric_and_trace(&self) -> (Self, TensorRank0);
123    /// Checks whether the tensor is a diagonal tensor.
124    fn is_diagonal(&self) -> bool;
125    /// Checks whether the tensor is the identity tensor.
126    fn is_identity(&self) -> bool;
127    /// Checks whether the tensor is a symmetric tensor.
128    fn is_symmetric(&self) -> bool;
129    /// Returns the second invariant of the rank-2 tensor.
130    fn second_invariant(&self) -> TensorRank0 {
131        0.5 * (self.trace().powi(2) - self.squared_trace())
132    }
133    /// Returns the trace of the rank-2 tensor squared.
134    fn squared_trace(&self) -> TensorRank0;
135    /// Returns the trace of the rank-2 tensor.
136    fn trace(&self) -> TensorRank0;
137    /// Returns the transpose of the rank-2 tensor.
138    fn transpose(&self) -> Self::Transpose;
139}
140
141/// Common methods for tensors.
142#[allow(clippy::len_without_is_empty)]
143pub trait Tensor
144where
145    for<'a> Self: Sized
146        + Add<Self, Output = Self>
147        + Add<&'a Self, Output = Self>
148        + AddAssign
149        + AddAssign<&'a Self>
150        + Clone
151        + Debug
152        + Default
153        + Display
154        + Div<TensorRank0, Output = Self>
155        // + Div<&'a TensorRank0, Output = Self>
156        + DivAssign<TensorRank0>
157        + DivAssign<&'a TensorRank0>
158        + Mul<TensorRank0, Output = Self>
159        // + Mul<&'a TensorRank0, Output = Self>
160        + MulAssign<TensorRank0>
161        + MulAssign<&'a TensorRank0>
162        + Sub<Self, Output = Self>
163        + Sub<&'a Self, Output = Self>
164        + SubAssign
165        + SubAssign<&'a Self>
166        + Sum,
167    Self::Item: Tensor,
168{
169    /// The type of item encountered when iterating over the tensor.
170    type Item;
171    /// Returns number of different entries given absolute and relative tolerances.
172    fn error_count(&self, other: &Self, tol_abs: Scalar, tol_rel: Scalar) -> Option<usize> {
173        let error_count = self
174            .iter()
175            .zip(other.iter())
176            .filter_map(|(self_entry, other_entry)| {
177                self_entry.error_count(other_entry, tol_abs, tol_rel)
178            })
179            .sum();
180        if error_count > 0 {
181            Some(error_count)
182        } else {
183            None
184        }
185    }
186    /// Returns the full contraction with another tensor.
187    fn full_contraction(&self, tensor: &Self) -> TensorRank0 {
188        self.iter()
189            .zip(tensor.iter())
190            .map(|(self_entry, tensor_entry)| self_entry.full_contraction(tensor_entry))
191            .sum()
192    }
193    /// Checks whether the tensor is the zero tensor.
194    fn is_zero(&self) -> bool {
195        self.iter().filter(|entry| !entry.is_zero()).count() == 0
196    }
197    /// Returns an iterator.
198    ///
199    /// The iterator yields all items from start to end. [Read more](https://doc.rust-lang.org/std/iter/)
200    fn iter(&self) -> impl Iterator<Item = &Self::Item>;
201    /// Returns an iterator that allows modifying each value.
202    ///
203    /// The iterator yields all items from start to end. [Read more](https://doc.rust-lang.org/std/iter/)
204    fn iter_mut(&mut self) -> impl Iterator<Item = &mut Self::Item>;
205    /// Returns the number of elements, also referred to as the ‘length’.
206    fn len(&self) -> usize;
207    /// Returns the tensor norm.
208    fn norm(&self) -> TensorRank0 {
209        self.norm_squared().sqrt()
210    }
211    /// Returns the infinity norm.
212    fn norm_inf(&self) -> TensorRank0 {
213        self.iter()
214            .fold(0.0, |acc, entry| entry.norm_inf().max(acc))
215    }
216    /// Returns the tensor norm squared.
217    fn norm_squared(&self) -> TensorRank0 {
218        self.full_contraction(self)
219    }
220    /// Normalizes the tensor.
221    fn normalize(&mut self) {
222        *self /= self.norm()
223    }
224    /// Returns the tensor normalized.
225    fn normalized(self) -> Self {
226        let norm = self.norm();
227        self / norm
228    }
229    /// Returns the total number of entries.
230    fn size(&self) -> usize;
231    /// Returns the positive difference of the two tensors.
232    fn sub_abs(&self, other: &Self) -> Self {
233        let mut difference = self.clone();
234        difference
235            .iter_mut()
236            .zip(self.iter().zip(other.iter()))
237            .for_each(|(entry, (self_entry, other_entry))| {
238                *entry = self_entry.sub_abs(other_entry)
239            });
240        difference
241    }
242    /// Returns the relative difference of the two tensors.
243    fn sub_rel(&self, other: &Self) -> Self {
244        let mut difference = self.clone();
245        difference
246            .iter_mut()
247            .zip(self.iter().zip(other.iter()))
248            .for_each(|(entry, (self_entry, other_entry))| {
249                *entry = self_entry.sub_rel(other_entry)
250            });
251        difference
252    }
253}
254
255/// Common methods for tensors derived from arrays.
256pub trait TensorArray {
257    /// The type of array corresponding to the tensor.
258    type Array;
259    /// The type of item encountered when iterating over the tensor.
260    type Item;
261    /// Returns the tensor as an array.
262    fn as_array(&self) -> Self::Array;
263    /// Returns the identity tensor.
264    fn identity() -> Self;
265    /// Returns a tensor given an array.
266    fn new(array: Self::Array) -> Self;
267    /// Returns the zero tensor.
268    fn zero() -> Self;
269}
270
271/// Common methods for tensors derived from Vec.
272pub trait TensorVec
273where
274    Self: FromIterator<Self::Item> + Index<usize, Output = Self::Item> + IndexMut<usize>,
275{
276    /// The type of element encountered when iterating over the tensor.
277    type Item;
278    /// Moves all the elements of other into self, leaving other empty.
279    fn append(&mut self, other: &mut Self);
280    /// Returns the total number of elements the vector can hold without reallocating.
281    fn capacity(&self) -> usize;
282    /// Returns `true` if the vector contains no elements.
283    fn is_empty(&self) -> bool;
284    /// Constructs a new, empty Vec, not allocating until elements are pushed onto it.
285    fn new() -> Self;
286    /// Appends an element to the back of the Vec.
287    fn push(&mut self, item: Self::Item);
288    /// Removes an element from the Vec and returns it, shifting elements to the left.
289    fn remove(&mut self, _index: usize) -> Self::Item;
290    /// Retains only the elements specified by the predicate.
291    fn retain<F>(&mut self, f: F)
292    where
293        F: FnMut(&Self::Item) -> bool;
294    /// Removes an element from the Vec and returns it, replacing it with the last element.
295    fn swap_remove(&mut self, _index: usize) -> Self::Item;
296}