conspire/math/tensor/mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
#[cfg(test)]
pub mod test;
pub mod rank_0;
pub mod rank_1;
pub mod rank_2;
pub mod rank_3;
pub mod rank_4;
use rank_0::TensorRank0;
use std::{
fmt::{Debug, Display},
ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, Sub, SubAssign},
};
/// A value-to-value conversion that does not consume the input value.
///
/// This is as opposed to [`Into`](https://doc.rust-lang.org/std/convert/trait.Into.html), which consumes the input value.
pub trait Convert<T> {
/// Converts this type into the (usually inferred) input type.
fn convert(&self) -> T;
}
/// Possible errors for tensors.
#[derive(Debug)]
pub enum TensorError {
NotPositiveDefinite,
}
impl PartialEq for TensorError {
fn eq(&self, other: &Self) -> bool {
match self {
Self::NotPositiveDefinite => match other {
Self::NotPositiveDefinite => true,
},
}
}
}
/// Common methods for Hessians.
pub trait Hessian {
/// Checks whether the Hessian is positive-definite.
fn is_positive_definite(&self) -> bool;
}
/// Common methods for rank-2 tensors.
pub trait Rank2: Sized {
/// The type that is the transpose of the tensor.
type Transpose;
/// Returns the Cholesky decomposition of the rank-2 tensor.
fn cholesky_decomposition(&self) -> Result<Self, TensorError>;
/// Returns the deviatoric component of the rank-2 tensor.
fn deviatoric(&self) -> Self;
/// Returns the deviatoric component and trace of the rank-2 tensor.
fn deviatoric_and_trace(&self) -> (Self, TensorRank0);
/// Checks whether the tensor is a diagonal tensor.
fn is_diagonal(&self) -> bool;
/// Checks whether the tensor is the identity tensor.
fn is_identity(&self) -> bool;
/// Returns the second invariant of the rank-2 tensor.
fn second_invariant(&self) -> TensorRank0 {
0.5 * (self.trace().powi(2) - self.squared_trace())
}
/// Returns the trace of the rank-2 tensor squared.
fn squared_trace(&self) -> TensorRank0;
/// Returns the trace of the rank-2 tensor.
fn trace(&self) -> TensorRank0;
/// Returns the transpose of the rank-2 tensor.
fn transpose(&self) -> Self::Transpose;
}
/// Common methods for tensors.
pub trait Tensor
where
for<'a> Self: Sized
+ Debug
+ Display
+ Add<Self, Output = Self>
+ Add<&'a Self, Output = Self>
+ AddAssign
+ AddAssign<&'a Self>
+ Clone
+ Div<TensorRank0, Output = Self>
+ DivAssign<TensorRank0>
+ Mul<TensorRank0, Output = Self>
+ Sub<Self, Output = Self>
+ Sub<&'a Self, Output = Self>
+ SubAssign
+ SubAssign<&'a Self>,
Self::Item: Tensor,
{
/// The type of item encountered when iterating over the tensor.
type Item;
/// Returns the full contraction with another tensor.
fn full_contraction(&self, tensor: &Self) -> TensorRank0 {
self.iter()
.zip(tensor.iter())
.map(|(self_entry, tensor_entry)| self_entry.full_contraction(tensor_entry))
.sum()
}
/// Returns a reference to the entry at the specified indices.
fn get_at(&self, _indices: &[usize]) -> &TensorRank0 {
panic!("Need to implement get_at() for {:?}.", self)
}
/// Returns a mutable reference to the entry at the specified indices.
fn get_at_mut(&mut self, _indices: &[usize]) -> &mut TensorRank0 {
panic!("Need to implement get_at_mut() for {:?}.", self)
}
/// Checks whether the tensor is the zero tensor.
fn is_zero(&self) -> bool {
self.iter().filter(|entry| !entry.is_zero()).count() == 0
}
/// Returns an iterator.
///
/// The iterator yields all items from start to end. [Read more](https://doc.rust-lang.org/std/iter/)
fn iter(&self) -> impl Iterator<Item = &Self::Item>;
/// Returns an iterator that allows modifying each value.
///
/// The iterator yields all items from start to end. [Read more](https://doc.rust-lang.org/std/iter/)
fn iter_mut(&mut self) -> impl Iterator<Item = &mut Self::Item>;
/// Returns the tensor norm.
fn norm(&self) -> TensorRank0 {
self.norm_squared().sqrt()
}
/// Returns the tensor norm squared.
fn norm_squared(&self) -> TensorRank0 {
self.full_contraction(self)
}
/// Normalizes the tensor.
fn normalize(&mut self) {
*self /= self.norm()
}
/// Returns the tensor normalized.
fn normalized(self) -> Self {
let norm = self.norm();
self / norm
}
}
/// Common methods for tensors derived from arrays.
pub trait TensorArray {
/// The type of array corresponding to the tensor.
type Array;
/// The type of item encountered when iterating over the tensor.
type Item;
/// Returns the tensor as an array.
fn as_array(&self) -> Self::Array;
/// Returns the identity tensor.
fn identity() -> Self;
/// Returns a tensor given an array.
fn new(array: Self::Array) -> Self;
/// Returns the zero tensor.
fn zero() -> Self;
}
/// Common methods for tensors derived from Vec.
pub trait TensorVec
where
Self: FromIterator<Self::Item> + Index<usize, Output = Self::Item> + IndexMut<usize>,
{
/// The type of item encountered when iterating over the tensor.
type Item;
/// The type of slice corresponding to the tensor.
type Slice<'a>;
/// Moves all the items of other into self, leaving other empty.
fn append(&mut self, other: &mut Self);
/// Returns `true` if the vector contains no items.
fn is_empty(&self) -> bool;
/// Returns the number of items in the vector, also referred to as its ‘length’.
fn len(&self) -> usize;
/// Returns a tensor given a slice.
fn new(slice: Self::Slice<'_>) -> Self;
/// Appends an item to the back of the Vec.
fn push(&mut self, item: Self::Item);
/// Returns the zero tensor.
fn zero(len: usize) -> Self;
}