conspire/math/integrate/
mod.rs

1#[cfg(feature = "doc")]
2pub mod doc;
3
4#[cfg(test)]
5mod test;
6
7mod backward_euler;
8mod bogacki_shampine;
9mod dormand_prince;
10mod verner_8;
11mod verner_9;
12
13pub use backward_euler::BackwardEuler;
14pub use bogacki_shampine::BogackiShampine;
15pub use dormand_prince::DormandPrince;
16pub use verner_8::Verner8;
17pub use verner_9::Verner9;
18
19/// Alias for [`BackwardEuler`].
20pub type Ode1be = BackwardEuler;
21
22/// Alias for [`BogackiShampine`].
23pub type Ode23 = BogackiShampine;
24
25/// Alias for [`DormandPrince`].
26pub type Ode45 = DormandPrince;
27
28/// Alias for [`Verner8`].
29pub type Ode78 = Verner8;
30
31/// Alias for [`Verner9`].
32pub type Ode89 = Verner9;
33
34use super::{
35    Scalar, Solution, Tensor, TensorArray, TensorVec, TestError, Vector, assert_eq_within_tols,
36    interpolate::{InterpolateSolution, InterpolateSolutionIV},
37    optimize::{FirstOrderRootFinding, ZerothOrderRootFinding},
38};
39use crate::defeat_message;
40use std::{
41    fmt::{self, Debug, Display, Formatter},
42    ops::{Div, Mul, Sub},
43};
44
45/// Base trait for ordinary differential equation solvers.
46pub trait OdeSolver<Y, U>
47where
48    Self: Debug,
49    Y: Tensor,
50    U: TensorVec<Item = Y>,
51{
52    /// Returns the absolute error tolerance.
53    fn abs_tol(&self) -> Scalar;
54    /// Returns the cut back factor for function errors.
55    fn dt_cut(&self) -> Scalar;
56    /// Returns the minimum value for the time step.
57    fn dt_min(&self) -> Scalar;
58}
59
60/// Base trait for explicit ordinary differential equation solvers.
61pub trait Explicit<Y, U>
62where
63    Self: InterpolateSolution<Y, U> + OdeSolver<Y, U>,
64    Y: Tensor,
65    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
66    U: TensorVec<Item = Y>,
67{
68    const SLOPES: usize;
69    /// Returns the multiplier for adaptive time steps.
70    fn dt_beta(&self) -> Scalar;
71    /// Returns the exponent for adaptive time steps.
72    fn dt_expn(&self) -> Scalar;
73    #[doc = include_str!("explicit.md")]
74    fn integrate(
75        &self,
76        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
77        time: &[Scalar],
78        initial_condition: Y,
79    ) -> Result<(Vector, U, U), IntegrationError> {
80        let t_0 = time[0];
81        let t_f = time[time.len() - 1];
82        if time.len() < 2 {
83            return Err(IntegrationError::LengthTimeLessThanTwo);
84        } else if t_0 >= t_f {
85            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
86        }
87        let mut t = t_0;
88        let mut dt = t_f - t_0;
89        let mut k = vec![Y::default(); Self::SLOPES];
90        k[0] = function(t, &initial_condition)?;
91        let mut t_sol = Vector::new();
92        t_sol.push(t_0);
93        let mut y = initial_condition.clone();
94        let mut y_sol = U::new();
95        y_sol.push(initial_condition.clone());
96        let mut dydt_sol = U::new();
97        dydt_sol.push(k[0].clone());
98        let mut y_trial = Y::default();
99        while t < t_f {
100            match self.slopes(&mut function, &y, t, dt, &mut k, &mut y_trial) {
101                Ok(e) => {
102                    if let Some(error) = self
103                        .step(
104                            &mut function,
105                            &mut y,
106                            &mut t,
107                            &mut y_sol,
108                            &mut t_sol,
109                            &mut dydt_sol,
110                            &mut dt,
111                            &mut k,
112                            &y_trial,
113                            e,
114                        )
115                        .err()
116                    {
117                        dt *= self.dt_cut();
118                        if dt < self.dt_min() {
119                            return Err(IntegrationError::MinimumStepSizeUpstream(
120                                self.dt_min(),
121                                error,
122                                format!("{self:?}"),
123                            ));
124                        }
125                    } else {
126                        dt = dt.min(t_f - t);
127                        if dt < self.dt_min() && t < t_f {
128                            return Err(IntegrationError::MinimumStepSizeReached(
129                                self.dt_min(),
130                                format!("{self:?}"),
131                            ));
132                        }
133                    }
134                }
135                Err(error) => {
136                    dt *= self.dt_cut();
137                    if dt < self.dt_min() {
138                        return Err(IntegrationError::MinimumStepSizeUpstream(
139                            self.dt_min(),
140                            error,
141                            format!("{self:?}"),
142                        ));
143                    }
144                }
145            }
146        }
147        if time.len() > 2 {
148            let t_int = Vector::from(time);
149            let (y_int, dydt_int) = self.interpolate(&t_int, &t_sol, &y_sol, function)?;
150            Ok((t_int, y_int, dydt_int))
151        } else {
152            Ok((t_sol, y_sol, dydt_sol))
153        }
154    }
155    #[doc(hidden)]
156    fn slopes(
157        &self,
158        function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
159        y: &Y,
160        t: Scalar,
161        dt: Scalar,
162        k: &mut [Y],
163        y_trial: &mut Y,
164    ) -> Result<Scalar, String>;
165    #[allow(clippy::too_many_arguments)]
166    #[doc(hidden)]
167    fn step(
168        &self,
169        function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
170        y: &mut Y,
171        t: &mut Scalar,
172        y_sol: &mut U,
173        t_sol: &mut Vector,
174        dydt_sol: &mut U,
175        dt: &mut Scalar,
176        k: &mut [Y],
177        y_trial: &Y,
178        e: Scalar,
179    ) -> Result<(), String>;
180    /// Provides the adaptive time step as a function of the error.
181    ///
182    /// ```math
183    /// h_{n+1} = \beta h \left(\frac{e_\mathrm{tol}}{e_{n+1}}\right)^{1/p}
184    /// ```
185    fn time_step(&self, error: Scalar, dt: &mut Scalar) {
186        if error > 0.0 {
187            *dt *= (self.dt_beta() * (self.abs_tol() / error).powf(1.0 / self.dt_expn()))
188                .max(self.dt_cut())
189        }
190    }
191}
192
193/// Base trait for explicit ordinary differential equation solvers with internal variables.
194pub trait ExplicitIV<Y, Z, U, V>
195where
196    Self: InterpolateSolutionIV<Y, Z, U, V> + OdeSolver<Y, U>,
197    Y: Tensor,
198    Z: Tensor,
199    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
200    U: TensorVec<Item = Y>,
201    V: TensorVec<Item = Z>,
202{
203    const SLOPES: usize;
204    #[doc = include_str!("explicit_iv.md")]
205    fn integrate(
206        &self,
207        mut function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
208        mut evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
209        time: &[Scalar],
210        initial_condition: Y,
211        initial_evaluation: Z,
212    ) -> Result<(Vector, U, U, V), IntegrationError> {
213        let t_0 = time[0];
214        let t_f = time[time.len() - 1];
215        if time.len() < 2 {
216            return Err(IntegrationError::LengthTimeLessThanTwo);
217        } else if t_0 >= t_f {
218            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
219        }
220        let mut t = t_0;
221        let mut dt = t_f - t_0;
222        let mut t_sol = Vector::new();
223        t_sol.push(t_0);
224        let mut y = initial_condition;
225        let mut z = initial_evaluation;
226        if assert_eq_within_tols(&evaluate(t, &y, &z)?, &z).is_err() {
227            return Err(IntegrationError::InconsistentInitialConditions);
228        }
229        let mut k = vec![Y::default(); Self::SLOPES];
230        k[0] = function(t, &y, &z)?;
231        let mut y_sol = U::new();
232        y_sol.push(y.clone());
233        let mut z_sol = V::new();
234        z_sol.push(z.clone());
235        let mut dydt_sol = U::new();
236        dydt_sol.push(k[0].clone());
237        let mut y_trial = Y::default();
238        let mut z_trial = Z::default();
239        while t < t_f {
240            match self.slopes(
241                &mut function,
242                &mut evaluate,
243                &y,
244                &z,
245                t,
246                dt,
247                &mut k,
248                &mut y_trial,
249                &mut z_trial,
250            ) {
251                Ok(e) => {
252                    if let Some(error) = self
253                        .step(
254                            &mut function,
255                            &mut y,
256                            &mut z,
257                            &mut t,
258                            &mut y_sol,
259                            &mut z_sol,
260                            &mut t_sol,
261                            &mut dydt_sol,
262                            &mut dt,
263                            &mut k,
264                            &y_trial,
265                            &z_trial,
266                            e,
267                        )
268                        .err()
269                    {
270                        dt *= self.dt_cut();
271                        if dt < self.dt_min() {
272                            return Err(IntegrationError::MinimumStepSizeUpstream(
273                                self.dt_min(),
274                                error,
275                                format!("{self:?}"),
276                            ));
277                        }
278                    } else {
279                        dt = dt.min(t_f - t);
280                        if dt < self.dt_min() && t < t_f {
281                            return Err(IntegrationError::MinimumStepSizeReached(
282                                self.dt_min(),
283                                format!("{self:?}"),
284                            ));
285                        }
286                    }
287                }
288                Err(error) => {
289                    dt *= self.dt_cut();
290                    if dt < self.dt_min() {
291                        return Err(IntegrationError::MinimumStepSizeUpstream(
292                            self.dt_min(),
293                            error,
294                            format!("{self:?}"),
295                        ));
296                    }
297                }
298            }
299        }
300        if time.len() > 2 {
301            let t_int = Vector::from(time);
302            let (y_int, dydt_int, z_int) =
303                self.interpolate(&t_int, &t_sol, &y_sol, &z_sol, function, evaluate)?;
304            Ok((t_int, y_int, dydt_int, z_int))
305        } else {
306            Ok((t_sol, y_sol, dydt_sol, z_sol))
307        }
308    }
309    #[allow(clippy::too_many_arguments)]
310    #[doc(hidden)]
311    fn slopes(
312        &self,
313        function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
314        evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
315        y: &Y,
316        z: &Z,
317        t: Scalar,
318        dt: Scalar,
319        k: &mut [Y],
320        y_trial: &mut Y,
321        z_trial: &mut Z,
322    ) -> Result<Scalar, String>;
323    #[allow(clippy::too_many_arguments)]
324    #[doc(hidden)]
325    fn step(
326        &self,
327        function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
328        y: &mut Y,
329        z: &mut Z,
330        t: &mut Scalar,
331        y_sol: &mut U,
332        z_sol: &mut V,
333        t_sol: &mut Vector,
334        dydt_sol: &mut U,
335        dt: &mut Scalar,
336        k: &mut [Y],
337        y_trial: &Y,
338        z_trial: &Z,
339        e: Scalar,
340    ) -> Result<(), String>;
341}
342
343/// Base trait for zeroth-order implicit ordinary differential equation solvers.
344pub trait ImplicitZerothOrder<Y, U>
345where
346    Self: InterpolateSolution<Y, U> + OdeSolver<Y, U>,
347    Y: Solution,
348    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
349    U: TensorVec<Item = Y>,
350{
351    #[doc = include_str!("implicit.md")]
352    fn integrate(
353        &self,
354        function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
355        time: &[Scalar],
356        initial_condition: Y,
357        solver: impl ZerothOrderRootFinding<Y>,
358    ) -> Result<(Vector, U, U), IntegrationError>;
359}
360
361/// Base trait for first-order implicit ordinary differential equation solvers.
362pub trait ImplicitFirstOrder<Y, J, U>
363where
364    Self: InterpolateSolution<Y, U> + OdeSolver<Y, U>,
365    Y: Solution + Div<J, Output = Y>,
366    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
367    J: Tensor + TensorArray,
368    U: TensorVec<Item = Y>,
369{
370    #[doc = include_str!("implicit.md")]
371    fn integrate(
372        &self,
373        function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
374        jacobian: impl Fn(Scalar, &Y) -> Result<J, IntegrationError>,
375        time: &[Scalar],
376        initial_condition: Y,
377        solver: impl FirstOrderRootFinding<Y, J, Y>,
378    ) -> Result<(Vector, U, U), IntegrationError>;
379}
380
381/// Possible errors encountered when integrating.
382pub enum IntegrationError {
383    InconsistentInitialConditions,
384    InitialTimeNotLessThanFinalTime,
385    Intermediate(String),
386    LengthTimeLessThanTwo,
387    MinimumStepSizeReached(Scalar, String),
388    MinimumStepSizeUpstream(Scalar, String, String),
389    Upstream(String, String),
390}
391
392impl From<String> for IntegrationError {
393    fn from(error: String) -> Self {
394        Self::Intermediate(error)
395    }
396}
397
398impl Debug for IntegrationError {
399    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
400        let error = match self {
401            Self::InconsistentInitialConditions => {
402                "\x1b[1;91mThe initial condition z_0 is not consistent with g(t_0, y_0)."
403                    .to_string()
404            }
405            Self::InitialTimeNotLessThanFinalTime => {
406                "\x1b[1;91mThe initial time must precede the final time.".to_string()
407            }
408            Self::Intermediate(message) => message.to_string(),
409            Self::LengthTimeLessThanTwo => {
410                "\x1b[1;91mThe time must contain at least two entries.".to_string()
411            }
412            Self::MinimumStepSizeReached(dt_min, integrator) => {
413                format!(
414                    "\x1b[1;91mMinimum time step ({dt_min:?}) reached.\x1b[0;91m\n\
415                    In integrator: {integrator}."
416                )
417            }
418            Self::MinimumStepSizeUpstream(dt_min, error, integrator) => {
419                format!(
420                    "{error}\x1b[0;91m\n\
421                    Causing error: \x1b[1;91mMinimum time step ({dt_min:?}) reached.\x1b[0;91m\n\
422                    In integrator: {integrator}."
423                )
424            }
425            Self::Upstream(error, integrator) => {
426                format!(
427                    "{error}\x1b[0;91m\n\
428                    In integrator: {integrator}."
429                )
430            }
431        };
432        write!(f, "\n{}\n\x1b[0;2;31m{}\x1b[0m\n", error, defeat_message())
433    }
434}
435
436impl Display for IntegrationError {
437    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
438        let error = match self {
439            Self::InconsistentInitialConditions => {
440                "\x1b[1;91mThe initial condition z_0 is not consistent with g(t_0, y_0)."
441                    .to_string()
442            }
443            Self::InitialTimeNotLessThanFinalTime => {
444                "\x1b[1;91mThe initial time must precede the final time.".to_string()
445            }
446            Self::Intermediate(message) => message.to_string(),
447            Self::LengthTimeLessThanTwo => {
448                "\x1b[1;91mThe time must contain at least two entries.".to_string()
449            }
450            Self::MinimumStepSizeReached(dt_min, integrator) => {
451                format!(
452                    "\x1b[1;91mMinimum time step ({dt_min:?}) reached.\x1b[0;91m\n\
453                    In integrator: {integrator}."
454                )
455            }
456            Self::MinimumStepSizeUpstream(dt_min, error, integrator) => {
457                format!(
458                    "{error}\x1b[0;91m\n\
459                    Causing error: \x1b[1;91mMinimum time step ({dt_min:?}) reached.\x1b[0;91m\n\
460                    In integrator: {integrator}."
461                )
462            }
463            Self::Upstream(error, integrator) => {
464                format!(
465                    "{error}\x1b[0;91m\n\
466                    In integrator: {integrator}."
467                )
468            }
469        };
470        write!(f, "{error}\x1b[0m")
471    }
472}
473
474impl From<IntegrationError> for String {
475    fn from(error: IntegrationError) -> Self {
476        format!("{}", error)
477    }
478}
479
480impl From<IntegrationError> for TestError {
481    fn from(error: IntegrationError) -> Self {
482        TestError {
483            message: error.to_string(),
484        }
485    }
486}