conspire/math/integrate/dormand_prince/
mod.rs

1#[cfg(test)]
2mod test;
3
4use super::{
5    super::{Scalar, Tensor, TensorVec, Vector, interpolate::InterpolateSolution},
6    Explicit, IntegrationError,
7};
8use crate::{ABS_TOL, REL_TOL};
9use std::ops::{Mul, Sub};
10
11const C_44_45: Scalar = 44.0 / 45.0;
12const C_56_15: Scalar = 56.0 / 15.0;
13const C_32_9: Scalar = 32.0 / 9.0;
14const C_8_9: Scalar = 8.0 / 9.0;
15const C_19372_6561: Scalar = 19372.0 / 6561.0;
16const C_25360_2187: Scalar = 25360.0 / 2187.0;
17const C_64448_6561: Scalar = 64448.0 / 6561.0;
18const C_212_729: Scalar = 212.0 / 729.0;
19const C_9017_3168: Scalar = 9017.0 / 3168.0;
20const C_355_33: Scalar = 355.0 / 33.0;
21const C_46732_5247: Scalar = 46732.0 / 5247.0;
22const C_49_176: Scalar = 49.0 / 176.0;
23const C_5103_18656: Scalar = 5103.0 / 18656.0;
24const C_35_384: Scalar = 35.0 / 384.0;
25const C_500_1113: Scalar = 500.0 / 1113.0;
26const C_125_192: Scalar = 125.0 / 192.0;
27const C_2187_6784: Scalar = 2187.0 / 6784.0;
28const C_11_84: Scalar = 11.0 / 84.0;
29const C_71_57600: Scalar = 71.0 / 57600.0;
30const C_71_16695: Scalar = 71.0 / 16695.0;
31const C_71_1920: Scalar = 71.0 / 1920.0;
32const C_17253_339200: Scalar = 17253.0 / 339200.0;
33const C_22_525: Scalar = 22.0 / 525.0;
34
35/// Explicit, six-stage, fifth-order, variable-step, Runge-Kutta method.[^cite]
36///
37/// [^cite]: J.R. Dormand and P.J. Prince, [J. Comput. Appl. Math. **6**, 19 (1980)](https://doi.org/10.1016/0771-050X(80)90013-3).
38///
39/// ```math
40/// \frac{dy}{dt} = f(t, y)
41/// ```
42/// ```math
43/// t_{n+1} = t_n + h
44/// ```
45/// ```math
46/// k_1 = f(t_n, y_n)
47/// ```
48/// ```math
49/// k_2 = f(t_n + \tfrac{1}{5} h, y_n + \tfrac{1}{5} h k_1)
50/// ```
51/// ```math
52/// k_3 = f(t_n + \tfrac{3}{10} h, y_n + \tfrac{3}{40} h k_1 + \tfrac{9}{40} h k_2)
53/// ```
54/// ```math
55/// k_4 = f(t_n + \tfrac{4}{5} h, y_n + \tfrac{44}{45} h k_1 - \tfrac{56}{15} h k_2 + \tfrac{32}{9} h k_3)
56/// ```
57/// ```math
58/// k_5 = f(t_n + \tfrac{8}{9} h, y_n + \tfrac{19372}{6561} h k_1 - \tfrac{25360}{2187} h k_2 + \tfrac{64448}{6561} h k_3 - \tfrac{212}{729} h k_4)
59/// ```
60/// ```math
61/// k_6 = f(t_n + h, y_n + \tfrac{9017}{3168} h k_1 - \tfrac{355}{33} h k_2 - \tfrac{46732}{5247} h k_3 + \tfrac{49}{176} h k_4 - \tfrac{5103}{18656} h k_5)
62/// ```
63/// ```math
64/// y_{n+1} = y_n + h\left(\frac{35}{384}\,k_1 + \frac{500}{1113}\,k_3 + \frac{125}{192}\,k_4 - \frac{2187}{6784}\,k_5 + \frac{11}{84}\,k_6\right)
65/// ```
66/// ```math
67/// k_7 = f(t_{n+1}, y_{n+1})
68/// ```
69/// ```math
70/// e_{n+1} = \frac{h}{5}\left(\frac{71}{11520}\,k_1 - \frac{71}{3339}\,k_3 + \frac{71}{384}\,k_4 - \frac{17253}{67840}\,k_5 + \frac{22}{105}\,k_6 - \frac{1}{8}\,k_7\right)
71/// ```
72/// ```math
73/// h_{n+1} = \beta h \left(\frac{e_\mathrm{tol}}{e_{n+1}}\right)^{1/p}
74/// ```
75#[derive(Debug)]
76pub struct DormandPrince {
77    /// Absolute error tolerance.
78    pub abs_tol: Scalar,
79    /// Relative error tolerance.
80    pub rel_tol: Scalar,
81    /// Multiplier for adaptive time steps.
82    pub dt_beta: Scalar,
83    /// Exponent for adaptive time steps.
84    pub dt_expn: Scalar,
85}
86
87impl Default for DormandPrince {
88    fn default() -> Self {
89        Self {
90            abs_tol: ABS_TOL,
91            rel_tol: REL_TOL,
92            dt_beta: 0.9,
93            dt_expn: 5.0,
94        }
95    }
96}
97
98impl<Y, U> Explicit<Y, U> for DormandPrince
99where
100    Self: InterpolateSolution<Y, U>,
101    Y: Tensor,
102    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
103    U: TensorVec<Item = Y>,
104{
105    const SLOPES: usize = 7;
106    fn slopes(
107        &self,
108        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
109        y: &Y,
110        t: &Scalar,
111        dt: &Scalar,
112        k: &mut [Y],
113        y_trial: &mut Y,
114    ) -> Result<Scalar, String> {
115        k[1] = function(t + 0.2 * dt, &(&k[0] * (0.2 * dt) + y))?;
116        k[2] = function(
117            t + 0.3 * dt,
118            &(&k[0] * (0.075 * dt) + &k[1] * (0.225 * dt) + y),
119        )?;
120        k[3] = function(
121            t + 0.8 * dt,
122            &(&k[0] * (C_44_45 * dt) - &k[1] * (C_56_15 * dt) + &k[2] * (C_32_9 * dt) + y),
123        )?;
124        k[4] = function(
125            t + C_8_9 * dt,
126            &(&k[0] * (C_19372_6561 * dt) - &k[1] * (C_25360_2187 * dt)
127                + &k[2] * (C_64448_6561 * dt)
128                - &k[3] * (C_212_729 * dt)
129                + y),
130        )?;
131        k[5] = function(
132            t + dt,
133            &(&k[0] * (C_9017_3168 * dt) - &k[1] * (C_355_33 * dt)
134                + &k[2] * (C_46732_5247 * dt)
135                + &k[3] * (C_49_176 * dt)
136                - &k[4] * (C_5103_18656 * dt)
137                + y),
138        )?;
139        *y_trial = (&k[0] * C_35_384 + &k[2] * C_500_1113 + &k[3] * C_125_192
140            - &k[4] * C_2187_6784
141            + &k[5] * C_11_84)
142            * *dt
143            + y;
144        k[6] = function(t + dt, y_trial)?;
145        Ok(
146            ((&k[0] * C_71_57600 - &k[2] * C_71_16695 + &k[3] * C_71_1920
147                - &k[4] * C_17253_339200
148                + &k[5] * C_22_525
149                - &k[6] * 0.025)
150                * *dt)
151                .norm_inf(),
152        )
153    }
154    fn step(
155        &self,
156        _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
157        y: &mut Y,
158        t: &mut Scalar,
159        y_sol: &mut U,
160        t_sol: &mut Vector,
161        dydt_sol: &mut U,
162        dt: &mut Scalar,
163        k: &mut [Y],
164        y_trial: &Y,
165        e: &Scalar,
166    ) -> Result<(), String> {
167        if e < &self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
168            k[0] = k[6].clone();
169            *t += *dt;
170            *y = y_trial.clone();
171            t_sol.push(*t);
172            y_sol.push(y.clone());
173            dydt_sol.push(k[0].clone());
174        }
175        if e > &0.0 {
176            *dt *= self.dt_beta * (self.abs_tol / e).powf(1.0 / self.dt_expn)
177        }
178        Ok(())
179    }
180}
181
182impl<Y, U> InterpolateSolution<Y, U> for DormandPrince
183where
184    Y: Tensor,
185    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
186    U: TensorVec<Item = Y>,
187{
188    fn interpolate(
189        &self,
190        time: &Vector,
191        tp: &Vector,
192        yp: &U,
193        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
194    ) -> Result<(U, U), IntegrationError> {
195        let mut dt;
196        let mut i;
197        let mut k_1;
198        let mut k_2;
199        let mut k_3;
200        let mut k_4;
201        let mut k_5;
202        let mut k_6;
203        let mut t;
204        let mut y;
205        let mut y_int = U::new();
206        let mut dydt_int = U::new();
207        let mut y_trial;
208        for time_k in time.iter() {
209            i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
210            if time_k == &tp[i] {
211                t = tp[i];
212                y_trial = yp[i].clone();
213                dt = 0.0;
214            } else {
215                t = tp[i - 1];
216                y = yp[i - 1].clone();
217                dt = time_k - t;
218                k_1 = function(t, &y)?;
219                k_2 = function(t + 0.2 * dt, &(&k_1 * (0.2 * dt) + &y))?;
220                k_3 = function(
221                    t + 0.3 * dt,
222                    &(&k_1 * (0.075 * dt) + &k_2 * (0.225 * dt) + &y),
223                )?;
224                k_4 = function(
225                    t + 0.8 * dt,
226                    &(&k_1 * (C_44_45 * dt) - &k_2 * (C_56_15 * dt) + &k_3 * (C_32_9 * dt) + &y),
227                )?;
228                k_5 = function(
229                    t + C_8_9 * dt,
230                    &(&k_1 * (C_19372_6561 * dt) - &k_2 * (C_25360_2187 * dt)
231                        + &k_3 * (C_64448_6561 * dt)
232                        - &k_4 * (C_212_729 * dt)
233                        + &y),
234                )?;
235                k_6 = function(
236                    t + dt,
237                    &(&k_1 * (C_9017_3168 * dt) - &k_2 * (C_355_33 * dt)
238                        + &k_3 * (C_46732_5247 * dt)
239                        + &k_4 * (C_49_176 * dt)
240                        - &k_5 * (C_5103_18656 * dt)
241                        + &y),
242                )?;
243                y_trial = (&k_1 * C_35_384 + &k_3 * C_500_1113 + &k_4 * C_125_192
244                    - &k_5 * C_2187_6784
245                    + &k_6 * C_11_84)
246                    * dt
247                    + &y;
248            }
249            dydt_int.push(function(t + dt, &y_trial)?);
250            y_int.push(y_trial);
251        }
252        Ok((y_int, dydt_int))
253    }
254}