conspire/math/integrate/bogacki_shampine/
mod.rs

1#[cfg(test)]
2mod test;
3
4use super::{
5    super::{Scalar, Tensor, TensorVec, Vector, interpolate::InterpolateSolution},
6    Explicit, IntegrationError,
7};
8use crate::{ABS_TOL, REL_TOL};
9use std::ops::{Mul, Sub};
10
11/// Explicit, three-stage, third-order, variable-step, Runge-Kutta method.[^cite]
12///
13/// [^cite]: P. Bogacki and L.F. Shampine, [Appl. Math. Lett. **2**, 321 (1989)](https://doi.org/10.1016/0893-9659(89)90079-7).
14///
15/// ```math
16/// \frac{dy}{dt} = f(t, y)
17/// ```
18/// ```math
19/// t_{n+1} = t_n + h
20/// ```
21/// ```math
22/// k_1 = f(t_n, y_n)
23/// ```
24/// ```math
25/// k_2 = f(t_n + \tfrac{1}{2} h, y_n + \tfrac{1}{2} h k_1)
26/// ```
27/// ```math
28/// k_3 = f(t_n + \tfrac{3}{4} h, y_n + \tfrac{3}{4} h k_2)
29/// ```
30/// ```math
31/// y_{n+1} = y_n + \frac{h}{9}\left(2k_1 + 3k_2 + 4k_3\right)
32/// ```
33/// ```math
34/// k_4 = f(t_{n+1}, y_{n+1})
35/// ```
36/// ```math
37/// e_{n+1} = \frac{h}{72}\left(-5k_1 + 6k_2 + 8k_3 - 9k_4\right)
38/// ```
39/// ```math
40/// h_{n+1} = \beta h \left(\frac{e_\mathrm{tol}}{e_{n+1}}\right)^{1/p}
41/// ```
42#[derive(Debug)]
43pub struct BogackiShampine {
44    /// Absolute error tolerance.
45    pub abs_tol: Scalar,
46    /// Relative error tolerance.
47    pub rel_tol: Scalar,
48    /// Multiplier for adaptive time steps.
49    pub dt_beta: Scalar,
50    /// Exponent for adaptive time steps.
51    pub dt_expn: Scalar,
52}
53
54impl Default for BogackiShampine {
55    fn default() -> Self {
56        Self {
57            abs_tol: ABS_TOL,
58            rel_tol: REL_TOL,
59            dt_beta: 0.9,
60            dt_expn: 3.0,
61        }
62    }
63}
64
65impl<Y, U> Explicit<Y, U> for BogackiShampine
66where
67    Self: InterpolateSolution<Y, U>,
68    Y: Tensor,
69    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
70    U: TensorVec<Item = Y>,
71{
72    const SLOPES: usize = 4;
73    fn slopes(
74        &self,
75        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
76        y: &Y,
77        t: &Scalar,
78        dt: &Scalar,
79        k: &mut [Y],
80        y_trial: &mut Y,
81    ) -> Result<Scalar, String> {
82        k[1] = function(t + 0.5 * dt, &(&k[0] * (0.5 * dt) + y))?;
83        k[2] = function(t + 0.75 * dt, &(&k[1] * (0.75 * dt) + y))?;
84        *y_trial = (&k[0] * 2.0 + &k[1] * 3.0 + &k[2] * 4.0) * (dt / 9.0) + y;
85        k[3] = function(t + dt, y_trial)?;
86        Ok(((&k[0] * -5.0 + &k[1] * 6.0 + &k[2] * 8.0 + &k[3] * -9.0) * (dt / 72.0)).norm_inf())
87    }
88    fn step(
89        &self,
90        _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
91        y: &mut Y,
92        t: &mut Scalar,
93        y_sol: &mut U,
94        t_sol: &mut Vector,
95        dydt_sol: &mut U,
96        dt: &mut Scalar,
97        k: &mut [Y],
98        y_trial: &Y,
99        e: &Scalar,
100    ) -> Result<(), String> {
101        if e < &self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
102            k[0] = k[3].clone();
103            *t += *dt;
104            *y = y_trial.clone();
105            t_sol.push(*t);
106            y_sol.push(y.clone());
107            dydt_sol.push(k[0].clone());
108        }
109        if e > &0.0 {
110            *dt *= self.dt_beta * (self.abs_tol / e).powf(1.0 / self.dt_expn)
111        }
112        Ok(())
113    }
114}
115
116impl<Y, U> InterpolateSolution<Y, U> for BogackiShampine
117where
118    Y: Tensor,
119    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
120    U: TensorVec<Item = Y>,
121{
122    fn interpolate(
123        &self,
124        time: &Vector,
125        tp: &Vector,
126        yp: &U,
127        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
128    ) -> Result<(U, U), IntegrationError> {
129        let mut dt;
130        let mut i;
131        let mut k_1;
132        let mut k_2;
133        let mut k_3;
134        let mut t;
135        let mut y;
136        let mut y_int = U::zero(0);
137        let mut dydt_int = U::zero(0);
138        let mut y_trial;
139        for time_k in time.iter() {
140            i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
141            if time_k == &tp[i] {
142                t = tp[i];
143                y_trial = yp[i].clone();
144                dt = 0.0;
145            } else {
146                t = tp[i - 1];
147                y = yp[i - 1].clone();
148                dt = time_k - t;
149                k_1 = function(t, &y)?;
150                k_2 = function(t + 0.5 * dt, &(&k_1 * (0.5 * dt) + &y))?;
151                k_3 = function(t + 0.75 * dt, &(&k_2 * (0.75 * dt) + &y))?;
152                y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
153            }
154            dydt_int.push(function(t + dt, &y_trial)?);
155            y_int.push(y_trial);
156        }
157        Ok((y_int, dydt_int))
158    }
159}