conspire/math/integrate/dormand_prince/
mod.rs

1#[cfg(test)]
2mod test;
3
4use super::{
5    super::{Tensor, TensorRank0, TensorVec, Vector, interpolate::InterpolateSolution},
6    Explicit, IntegrationError,
7};
8use crate::{ABS_TOL, REL_TOL};
9use std::ops::{Mul, Sub};
10
11const C_44_45: TensorRank0 = 44.0 / 45.0;
12const C_56_15: TensorRank0 = 56.0 / 15.0;
13const C_32_9: TensorRank0 = 32.0 / 9.0;
14const C_8_9: TensorRank0 = 8.0 / 9.0;
15const C_19372_6561: TensorRank0 = 19372.0 / 6561.0;
16const C_25360_2187: TensorRank0 = 25360.0 / 2187.0;
17const C_64448_6561: TensorRank0 = 64448.0 / 6561.0;
18const C_212_729: TensorRank0 = 212.0 / 729.0;
19const C_9017_3168: TensorRank0 = 9017.0 / 3168.0;
20const C_355_33: TensorRank0 = 355.0 / 33.0;
21const C_46732_5247: TensorRank0 = 46732.0 / 5247.0;
22const C_49_176: TensorRank0 = 49.0 / 176.0;
23const C_5103_18656: TensorRank0 = 5103.0 / 18656.0;
24const C_35_384: TensorRank0 = 35.0 / 384.0;
25const C_500_1113: TensorRank0 = 500.0 / 1113.0;
26const C_125_192: TensorRank0 = 125.0 / 192.0;
27const C_2187_6784: TensorRank0 = 2187.0 / 6784.0;
28const C_11_84: TensorRank0 = 11.0 / 84.0;
29const C_71_57600: TensorRank0 = 71.0 / 57600.0;
30const C_71_16695: TensorRank0 = 71.0 / 16695.0;
31const C_71_1920: TensorRank0 = 71.0 / 1920.0;
32const C_17253_339200: TensorRank0 = 17253.0 / 339200.0;
33const C_22_525: TensorRank0 = 22.0 / 525.0;
34
35/// Explicit, six-stage, fifth-order, variable-step, Runge-Kutta method.[^cite]
36///
37/// [^cite]: J.R. Dormand and P.J. Prince, [J. Comput. Appl. Math. **6**, 19 (1980)](https://doi.org/10.1016/0771-050X(80)90013-3).
38///
39/// ```math
40/// \frac{dy}{dt} = f(t, y)
41/// ```
42/// ```math
43/// t_{n+1} = t_n + h
44/// ```
45/// ```math
46/// k_1 = f(t_n, y_n)
47/// ```
48/// ```math
49/// k_2 = f(t_n + \tfrac{1}{5} h, y_n + \tfrac{1}{5} h k_1)
50/// ```
51/// ```math
52/// k_3 = f(t_n + \tfrac{3}{10} h, y_n + \tfrac{3}{40} h k_1 + \tfrac{9}{40} h k_2)
53/// ```
54/// ```math
55/// k_4 = f(t_n + \tfrac{4}{5} h, y_n + \tfrac{44}{45} h k_1 - \tfrac{56}{15} h k_2 + \tfrac{32}{9} h k_3)
56/// ```
57/// ```math
58/// k_5 = f(t_n + \tfrac{8}{9} h, y_n + \tfrac{19372}{6561} h k_1 - \tfrac{25360}{2187} h k_2 + \tfrac{64448}{6561} h k_3 - \tfrac{212}{729} h k_4)
59/// ```
60/// ```math
61/// k_6 = f(t_n + h, y_n + \tfrac{9017}{3168} h k_1 - \tfrac{355}{33} h k_2 - \tfrac{46732}{5247} h k_3 + \tfrac{49}{176} h k_4 - \tfrac{5103}{18656} h k_5)
62/// ```
63/// ```math
64/// y_{n+1} = y_n + h\left(\frac{35}{384}\,k_1 + \frac{500}{1113}\,k_3 + \frac{125}{192}\,k_4 - \frac{2187}{6784}\,k_5 + \frac{11}{84}\,k_6\right)
65/// ```
66/// ```math
67/// k_7 = f(t_{n+1}, y_{n+1})
68/// ```
69/// ```math
70/// e_{n+1} = \frac{h}{5}\left(\frac{71}{11520}\,k_1 - \frac{71}{3339}\,k_3 + \frac{71}{384}\,k_4 - \frac{17253}{67840}\,k_5 + \frac{22}{105}\,k_6 - \frac{1}{8}\,k_7\right)
71/// ```
72/// ```math
73/// h_{n+1} = \beta h \left(\frac{e_\mathrm{tol}}{e_{n+1}}\right)^{1/p}
74/// ```
75#[derive(Debug)]
76pub struct DormandPrince {
77    /// Absolute error tolerance.
78    pub abs_tol: TensorRank0,
79    /// Relative error tolerance.
80    pub rel_tol: TensorRank0,
81    /// Multiplier for adaptive time steps.
82    pub dt_beta: TensorRank0,
83    /// Exponent for adaptive time steps.
84    pub dt_expn: TensorRank0,
85}
86
87impl Default for DormandPrince {
88    fn default() -> Self {
89        Self {
90            abs_tol: ABS_TOL,
91            rel_tol: REL_TOL,
92            dt_beta: 0.9,
93            dt_expn: 5.0,
94        }
95    }
96}
97
98impl<Y, U> Explicit<Y, U> for DormandPrince
99where
100    Self: InterpolateSolution<Y, U>,
101    Y: Tensor,
102    for<'a> &'a Y: Mul<TensorRank0, Output = Y> + Sub<&'a Y, Output = Y>,
103    U: TensorVec<Item = Y>,
104{
105    fn integrate(
106        &self,
107        mut function: impl FnMut(TensorRank0, &Y) -> Result<Y, IntegrationError>,
108        time: &[TensorRank0],
109        initial_condition: Y,
110    ) -> Result<(Vector, U, U), IntegrationError> {
111        let t_0 = time[0];
112        let t_f = time[time.len() - 1];
113        if time.len() < 2 {
114            return Err(IntegrationError::LengthTimeLessThanTwo);
115        } else if t_0 >= t_f {
116            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
117        }
118        let mut t = t_0;
119        let mut dt = t_f;
120        let mut e;
121        let mut k_1 = function(t, &initial_condition)?;
122        let mut k_2;
123        let mut k_3;
124        let mut k_4;
125        let mut k_5;
126        let mut k_6;
127        let mut k_7;
128        let mut t_sol = Vector::zero(0);
129        t_sol.push(t_0);
130        let mut y = initial_condition.clone();
131        let mut y_sol = U::zero(0);
132        y_sol.push(initial_condition.clone());
133        let mut dydt_sol = U::zero(0);
134        dydt_sol.push(k_1.clone());
135        let mut y_trial;
136        while t < t_f {
137            k_2 = function(t + 0.2 * dt, &(&k_1 * (0.2 * dt) + &y))?;
138            k_3 = function(
139                t + 0.3 * dt,
140                &(&k_1 * (0.075 * dt) + &k_2 * (0.225 * dt) + &y),
141            )?;
142            k_4 = function(
143                t + 0.8 * dt,
144                &(&k_1 * (C_44_45 * dt) - &k_2 * (C_56_15 * dt) + &k_3 * (C_32_9 * dt) + &y),
145            )?;
146            k_5 = function(
147                t + C_8_9 * dt,
148                &(&k_1 * (C_19372_6561 * dt) - &k_2 * (C_25360_2187 * dt)
149                    + &k_3 * (C_64448_6561 * dt)
150                    - &k_4 * (C_212_729 * dt)
151                    + &y),
152            )?;
153            k_6 = function(
154                t + dt,
155                &(&k_1 * (C_9017_3168 * dt) - &k_2 * (C_355_33 * dt)
156                    + &k_3 * (C_46732_5247 * dt)
157                    + &k_4 * (C_49_176 * dt)
158                    - &k_5 * (C_5103_18656 * dt)
159                    + &y),
160            )?;
161            y_trial = (&k_1 * C_35_384 + &k_3 * C_500_1113 + &k_4 * C_125_192 - &k_5 * C_2187_6784
162                + &k_6 * C_11_84)
163                * dt
164                + &y;
165            k_7 = function(t + dt, &y_trial)?;
166            e = ((&k_1 * C_71_57600 - k_3 * C_71_16695 + k_4 * C_71_1920 - k_5 * C_17253_339200
167                + k_6 * C_22_525
168                - &k_7 * 0.025)
169                * dt)
170                .norm_inf();
171            if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
172                k_1 = k_7;
173                t += dt;
174                y = y_trial;
175                t_sol.push(t);
176                y_sol.push(y.clone());
177                dydt_sol.push(k_1.clone());
178            }
179            if e > 0.0 {
180                dt *= self.dt_beta * (self.abs_tol / e).powf(1.0 / self.dt_expn)
181            }
182            dt = dt.min(t_f - t)
183        }
184        if time.len() > 2 {
185            let t_int = Vector::new(time);
186            let (y_int, dydt_int) = self.interpolate(&t_int, &t_sol, &y_sol, function)?;
187            Ok((t_int, y_int, dydt_int))
188        } else {
189            Ok((t_sol, y_sol, dydt_sol))
190        }
191    }
192}
193
194impl<Y, U> InterpolateSolution<Y, U> for DormandPrince
195where
196    Y: Tensor,
197    for<'a> &'a Y: Mul<TensorRank0, Output = Y> + Sub<&'a Y, Output = Y>,
198    U: TensorVec<Item = Y>,
199{
200    fn interpolate(
201        &self,
202        time: &Vector,
203        tp: &Vector,
204        yp: &U,
205        mut function: impl FnMut(TensorRank0, &Y) -> Result<Y, IntegrationError>,
206    ) -> Result<(U, U), IntegrationError> {
207        let mut dt;
208        let mut i;
209        let mut k_1;
210        let mut k_2;
211        let mut k_3;
212        let mut k_4;
213        let mut k_5;
214        let mut k_6;
215        let mut t;
216        let mut y;
217        let mut y_int = U::zero(0);
218        let mut dydt_int = U::zero(0);
219        let mut y_trial;
220        for time_k in time.iter() {
221            i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
222            if time_k == &tp[i] {
223                t = tp[i];
224                y_trial = yp[i].clone();
225                dt = 0.0;
226            } else {
227                t = tp[i - 1];
228                y = yp[i - 1].clone();
229                dt = time_k - t;
230                k_1 = function(t, &y)?;
231                k_2 = function(t + 0.2 * dt, &(&k_1 * (0.2 * dt) + &y))?;
232                k_3 = function(
233                    t + 0.3 * dt,
234                    &(&k_1 * (0.075 * dt) + &k_2 * (0.225 * dt) + &y),
235                )?;
236                k_4 = function(
237                    t + 0.8 * dt,
238                    &(&k_1 * (C_44_45 * dt) - &k_2 * (C_56_15 * dt) + &k_3 * (C_32_9 * dt) + &y),
239                )?;
240                k_5 = function(
241                    t + C_8_9 * dt,
242                    &(&k_1 * (C_19372_6561 * dt) - &k_2 * (C_25360_2187 * dt)
243                        + &k_3 * (C_64448_6561 * dt)
244                        - &k_4 * (C_212_729 * dt)
245                        + &y),
246                )?;
247                k_6 = function(
248                    t + dt,
249                    &(&k_1 * (C_9017_3168 * dt) - &k_2 * (C_355_33 * dt)
250                        + &k_3 * (C_46732_5247 * dt)
251                        + &k_4 * (C_49_176 * dt)
252                        - &k_5 * (C_5103_18656 * dt)
253                        + &y),
254                )?;
255                y_trial = (&k_1 * C_35_384 + &k_3 * C_500_1113 + &k_4 * C_125_192
256                    - &k_5 * C_2187_6784
257                    + &k_6 * C_11_84)
258                    * dt
259                    + &y;
260            }
261            dydt_int.push(function(t + dt, &y_trial)?);
262            y_int.push(y_trial);
263        }
264        Ok((y_int, dydt_int))
265    }
266}