conspire/math/integrate/
mod.rs

1#[cfg(test)]
2mod test;
3
4mod backward_euler;
5mod bogacki_shampine;
6mod dormand_prince;
7mod verner_8;
8mod verner_9;
9
10pub use backward_euler::BackwardEuler;
11pub use bogacki_shampine::BogackiShampine;
12pub use dormand_prince::DormandPrince;
13pub use verner_8::Verner8;
14pub use verner_9::Verner9;
15
16pub type Ode1be = BackwardEuler;
17pub type Ode23 = BogackiShampine;
18pub type Ode45 = DormandPrince;
19pub type Ode78 = Verner8;
20pub type Ode89 = Verner9;
21
22// consider symplectic integrators for dynamics eventually
23
24use super::{
25    Scalar, Solution, Tensor, TensorArray, TensorVec, TestError, Vector,
26    interpolate::InterpolateSolution,
27    optimize::{FirstOrderRootFinding, ZerothOrderRootFinding},
28};
29use crate::defeat_message;
30use std::{
31    fmt::{self, Debug, Display, Formatter},
32    ops::{Div, Mul, Sub},
33};
34
35/// Base trait for ordinary differential equation solvers.
36pub trait OdeSolver<Y, U>
37where
38    Self: Debug,
39    Y: Tensor,
40    U: TensorVec<Item = Y>,
41{
42}
43
44impl<A, Y, U> OdeSolver<Y, U> for A
45where
46    A: Debug,
47    Y: Tensor,
48    U: TensorVec<Item = Y>,
49{
50}
51
52/// Base trait for explicit ordinary differential equation solvers.
53pub trait Explicit<Y, U>: OdeSolver<Y, U>
54where
55    Self: InterpolateSolution<Y, U>,
56    Y: Tensor,
57    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
58    U: TensorVec<Item = Y>,
59{
60    const SLOPES: usize;
61    /// Solves an initial value problem by explicitly integrating a system of ordinary differential equations.
62    ///
63    /// ```math
64    /// \frac{dy}{dt} = f(t, y),\quad y(t_0) = y_0
65    /// ```
66    fn integrate(
67        &self,
68        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
69        time: &[Scalar],
70        initial_condition: Y,
71    ) -> Result<(Vector, U, U), IntegrationError> {
72        let t_0 = time[0];
73        let t_f = time[time.len() - 1];
74        if time.len() < 2 {
75            return Err(IntegrationError::LengthTimeLessThanTwo);
76        } else if t_0 >= t_f {
77            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
78        }
79        let mut t = t_0;
80        let mut dt = t_f;
81        let mut e;
82        let mut k = vec![Y::default(); Self::SLOPES];
83        k[0] = function(t, &initial_condition)?;
84        let mut t_sol = Vector::zero(0);
85        t_sol.push(t_0);
86        let mut y = initial_condition.clone();
87        let mut y_sol = U::zero(0);
88        y_sol.push(initial_condition.clone());
89        let mut dydt_sol = U::zero(0);
90        dydt_sol.push(k[0].clone());
91        let mut y_trial = Y::default();
92        while t < t_f {
93            e = match self.slopes(&mut function, &y, &t, &dt, &mut k, &mut y_trial) {
94                Ok(e) => e,
95                Err(error) => {
96                    return Err(IntegrationError::Upstream(error, format!("{self:?}")));
97                }
98            };
99            match self.step(
100                &mut function,
101                &mut y,
102                &mut t,
103                &mut y_sol,
104                &mut t_sol,
105                &mut dydt_sol,
106                &mut dt,
107                &mut k,
108                &y_trial,
109                &e,
110            ) {
111                Ok(e) => e,
112                Err(error) => {
113                    return Err(IntegrationError::Upstream(error, format!("{self:?}")));
114                }
115            };
116            dt = dt.min(t_f - t);
117        }
118        if time.len() > 2 {
119            let t_int = Vector::new(time);
120            let (y_int, dydt_int) = self.interpolate(&t_int, &t_sol, &y_sol, function)?;
121            Ok((t_int, y_int, dydt_int))
122        } else {
123            Ok((t_sol, y_sol, dydt_sol))
124        }
125    }
126    #[doc(hidden)]
127    fn slopes(
128        &self,
129        function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
130        y: &Y,
131        t: &Scalar,
132        dt: &Scalar,
133        k: &mut [Y],
134        y_trial: &mut Y,
135    ) -> Result<Scalar, String>;
136    #[allow(clippy::too_many_arguments)]
137    #[doc(hidden)]
138    fn step(
139        &self,
140        function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
141        y: &mut Y,
142        t: &mut Scalar,
143        y_sol: &mut U,
144        t_sol: &mut Vector,
145        dydt_sol: &mut U,
146        dt: &mut Scalar,
147        k: &mut [Y],
148        y_trial: &Y,
149        e: &Scalar,
150    ) -> Result<(), String>;
151}
152
153/// Base trait for zeroth-order implicit ordinary differential equation solvers.
154pub trait ImplicitZerothOrder<Y, U>: OdeSolver<Y, U>
155where
156    Self: InterpolateSolution<Y, U>,
157    Y: Solution,
158    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
159    U: TensorVec<Item = Y>,
160{
161    /// Solves an initial value problem by implicitly integrating a system of ordinary differential equations.
162    ///
163    /// ```math
164    /// \frac{dy}{dt} = f(t, y),\quad y(t_0) = y_0,\quad \frac{\partial f}{\partial y} = J(t, y)
165    /// ```
166    fn integrate(
167        &self,
168        function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
169        time: &[Scalar],
170        initial_condition: Y,
171        solver: impl ZerothOrderRootFinding<Y>,
172    ) -> Result<(Vector, U, U), IntegrationError>;
173}
174
175/// Base trait for first-order implicit ordinary differential equation solvers.
176pub trait ImplicitFirstOrder<Y, J, U>: OdeSolver<Y, U>
177where
178    Self: InterpolateSolution<Y, U>,
179    Y: Solution + Div<J, Output = Y>,
180    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
181    J: Tensor + TensorArray,
182    U: TensorVec<Item = Y>,
183{
184    /// Solves an initial value problem by implicitly integrating a system of ordinary differential equations.
185    ///
186    /// ```math
187    /// \frac{dy}{dt} = f(t, y),\quad y(t_0) = y_0,\quad \frac{\partial f}{\partial y} = J(t, y)
188    /// ```
189    fn integrate(
190        &self,
191        function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
192        jacobian: impl Fn(Scalar, &Y) -> Result<J, IntegrationError>,
193        time: &[Scalar],
194        initial_condition: Y,
195        solver: impl FirstOrderRootFinding<Y, J, Y>,
196    ) -> Result<(Vector, U, U), IntegrationError>;
197}
198
199/// Possible errors encountered when integrating.
200pub enum IntegrationError {
201    InitialTimeNotLessThanFinalTime,
202    Intermediate(String),
203    LengthTimeLessThanTwo,
204    Upstream(String, String),
205}
206
207impl From<String> for IntegrationError {
208    fn from(error: String) -> Self {
209        Self::Intermediate(error)
210    }
211}
212
213impl Debug for IntegrationError {
214    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
215        let error = match self {
216            Self::InitialTimeNotLessThanFinalTime => {
217                "\x1b[1;91mThe initial time must precede the final time.".to_string()
218            }
219            Self::Intermediate(message) => message.to_string(),
220            Self::LengthTimeLessThanTwo => {
221                "\x1b[1;91mThe time must contain at least two entries.".to_string()
222            }
223            Self::Upstream(error, integrator) => {
224                format!(
225                    "{error}\x1b[0;91m\n\
226                    In integrator: {integrator}."
227                )
228            }
229        };
230        write!(f, "\n{}\n\x1b[0;2;31m{}\x1b[0m\n", error, defeat_message())
231    }
232}
233
234impl Display for IntegrationError {
235    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
236        let error = match self {
237            Self::InitialTimeNotLessThanFinalTime => {
238                "\x1b[1;91mThe initial time must precede the final time.".to_string()
239            }
240            Self::Intermediate(message) => message.to_string(),
241            Self::LengthTimeLessThanTwo => {
242                "\x1b[1;91mThe time must contain at least two entries.".to_string()
243            }
244            Self::Upstream(error, integrator) => {
245                format!(
246                    "{error}\x1b[0;91m\n\
247                    In integrator: {integrator}."
248                )
249            }
250        };
251        write!(f, "{error}\x1b[0m")
252    }
253}
254
255impl From<IntegrationError> for String {
256    fn from(error: IntegrationError) -> Self {
257        format!("{}", error)
258    }
259}
260
261impl From<IntegrationError> for TestError {
262    fn from(error: IntegrationError) -> Self {
263        TestError {
264            message: error.to_string(),
265        }
266    }
267}