conspire/math/integrate/ode/explicit/variable_step/dormand_prince/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{
5    Scalar, Tensor, TensorVec, Vector,
6    integrate::{Explicit, IntegrationError, OdeSolver, VariableStep, VariableStepExplicit},
7    interpolate::InterpolateSolution,
8};
9use crate::{ABS_TOL, REL_TOL};
10use std::ops::{Mul, Sub};
11
12const C_44_45: Scalar = 44.0 / 45.0;
13const C_56_15: Scalar = 56.0 / 15.0;
14const C_32_9: Scalar = 32.0 / 9.0;
15const C_8_9: Scalar = 8.0 / 9.0;
16const C_19372_6561: Scalar = 19372.0 / 6561.0;
17const C_25360_2187: Scalar = 25360.0 / 2187.0;
18const C_64448_6561: Scalar = 64448.0 / 6561.0;
19const C_212_729: Scalar = 212.0 / 729.0;
20const C_9017_3168: Scalar = 9017.0 / 3168.0;
21const C_355_33: Scalar = 355.0 / 33.0;
22const C_46732_5247: Scalar = 46732.0 / 5247.0;
23const C_49_176: Scalar = 49.0 / 176.0;
24const C_5103_18656: Scalar = 5103.0 / 18656.0;
25const C_35_384: Scalar = 35.0 / 384.0;
26const C_500_1113: Scalar = 500.0 / 1113.0;
27const C_125_192: Scalar = 125.0 / 192.0;
28const C_2187_6784: Scalar = 2187.0 / 6784.0;
29const C_11_84: Scalar = 11.0 / 84.0;
30const C_71_57600: Scalar = 71.0 / 57600.0;
31const C_71_16695: Scalar = 71.0 / 16695.0;
32const C_71_1920: Scalar = 71.0 / 1920.0;
33const C_17253_339200: Scalar = 17253.0 / 339200.0;
34const C_22_525: Scalar = 22.0 / 525.0;
35
36#[doc = include_str!("doc.md")]
37#[derive(Debug)]
38pub struct DormandPrince {
39    /// Absolute error tolerance.
40    pub abs_tol: Scalar,
41    /// Relative error tolerance.
42    pub rel_tol: Scalar,
43    /// Multiplier for adaptive time steps.
44    pub dt_beta: Scalar,
45    /// Exponent for adaptive time steps.
46    pub dt_expn: Scalar,
47    /// Cut back factor for the time step.
48    pub dt_cut: Scalar,
49    /// Minimum value for the time step.
50    pub dt_min: Scalar,
51}
52
53impl Default for DormandPrince {
54    fn default() -> Self {
55        Self {
56            abs_tol: ABS_TOL,
57            rel_tol: REL_TOL,
58            dt_beta: 0.9,
59            dt_expn: 5.0,
60            dt_cut: 0.5,
61            dt_min: ABS_TOL,
62        }
63    }
64}
65
66impl<Y, U> OdeSolver<Y, U> for DormandPrince
67where
68    Y: Tensor,
69    U: TensorVec<Item = Y>,
70{
71}
72
73impl VariableStep for DormandPrince {
74    fn abs_tol(&self) -> Scalar {
75        self.abs_tol
76    }
77    fn rel_tol(&self) -> Scalar {
78        self.rel_tol
79    }
80    fn dt_beta(&self) -> Scalar {
81        self.dt_beta
82    }
83    fn dt_expn(&self) -> Scalar {
84        self.dt_expn
85    }
86    fn dt_cut(&self) -> Scalar {
87        self.dt_cut
88    }
89    fn dt_min(&self) -> Scalar {
90        self.dt_min
91    }
92}
93
94impl<Y, U> Explicit<Y, U> for DormandPrince
95where
96    Self: OdeSolver<Y, U>,
97    Y: Tensor,
98    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
99    U: TensorVec<Item = Y>,
100{
101    const SLOPES: usize = 7;
102    fn integrate(
103        &self,
104        function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
105        time: &[Scalar],
106        initial_condition: Y,
107    ) -> Result<(Vector, U, U), IntegrationError> {
108        self.integrate_variable_step(function, time, initial_condition)
109    }
110}
111
112pub fn slopes<Y>(
113    mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
114    y: &Y,
115    t: Scalar,
116    dt: Scalar,
117    k: &mut [Y],
118    y_trial: &mut Y,
119) -> Result<(), String>
120where
121    Y: Tensor,
122    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
123{
124    *y_trial = &k[0] * (0.2 * dt) + y;
125    k[1] = function(t + 0.2 * dt, y_trial)?;
126    *y_trial = &k[0] * (0.075 * dt) + &k[1] * (0.225 * dt) + y;
127    k[2] = function(t + 0.3 * dt, y_trial)?;
128    *y_trial = &k[0] * (C_44_45 * dt) - &k[1] * (C_56_15 * dt) + &k[2] * (C_32_9 * dt) + y;
129    k[3] = function(t + 0.8 * dt, y_trial)?;
130    *y_trial = &k[0] * (C_19372_6561 * dt) - &k[1] * (C_25360_2187 * dt)
131        + &k[2] * (C_64448_6561 * dt)
132        - &k[3] * (C_212_729 * dt)
133        + y;
134    k[4] = function(t + C_8_9 * dt, y_trial)?;
135    *y_trial = &k[0] * (C_9017_3168 * dt) - &k[1] * (C_355_33 * dt)
136        + &k[2] * (C_46732_5247 * dt)
137        + &k[3] * (C_49_176 * dt)
138        - &k[4] * (C_5103_18656 * dt)
139        + y;
140    k[5] = function(t + dt, y_trial)?;
141    *y_trial = (&k[0] * C_35_384 + &k[2] * C_500_1113 + &k[3] * C_125_192 - &k[4] * C_2187_6784
142        + &k[5] * C_11_84)
143        * dt
144        + y;
145    Ok(())
146}
147
148impl<Y, U> VariableStepExplicit<Y, U> for DormandPrince
149where
150    Self: OdeSolver<Y, U>,
151    Y: Tensor,
152    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
153    U: TensorVec<Item = Y>,
154{
155    fn slopes(
156        &self,
157        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
158        y: &Y,
159        t: Scalar,
160        dt: Scalar,
161        k: &mut [Y],
162        y_trial: &mut Y,
163    ) -> Result<Scalar, String> {
164        slopes(&mut function, y, t, dt, k, y_trial)?;
165        k[6] = function(t + dt, y_trial)?;
166        Ok(
167            ((&k[0] * C_71_57600 - &k[2] * C_71_16695 + &k[3] * C_71_1920
168                - &k[4] * C_17253_339200
169                + &k[5] * C_22_525
170                - &k[6] * 0.025)
171                * dt)
172                .norm_inf(),
173        )
174    }
175    fn step(
176        &self,
177        _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
178        y: &mut Y,
179        t: &mut Scalar,
180        y_sol: &mut U,
181        t_sol: &mut Vector,
182        dydt_sol: &mut U,
183        dt: &mut Scalar,
184        k: &mut [Y],
185        y_trial: &Y,
186        e: Scalar,
187    ) -> Result<(), String> {
188        if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
189            k[0] = k[6].clone();
190            *t += *dt;
191            *y = y_trial.clone();
192            t_sol.push(*t);
193            y_sol.push(y.clone());
194            dydt_sol.push(k[0].clone());
195        }
196        self.time_step(e, dt);
197        Ok(())
198    }
199}
200
201impl<Y, U> InterpolateSolution<Y, U> for DormandPrince
202where
203    Y: Tensor,
204    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
205    U: TensorVec<Item = Y>,
206{
207    fn interpolate(
208        &self,
209        time: &Vector,
210        tp: &Vector,
211        yp: &U,
212        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
213    ) -> Result<(U, U), IntegrationError> {
214        let mut dt;
215        let mut i;
216        let mut k_1;
217        let mut k_2;
218        let mut k_3;
219        let mut k_4;
220        let mut k_5;
221        let mut k_6;
222        let mut t;
223        let mut y;
224        let mut y_int = U::new();
225        let mut dydt_int = U::new();
226        let mut y_trial;
227        for time_k in time.iter() {
228            i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
229            if time_k == &tp[i] {
230                t = tp[i];
231                y_trial = yp[i].clone();
232                dt = 0.0;
233            } else {
234                t = tp[i - 1];
235                y = yp[i - 1].clone();
236                dt = time_k - t;
237                k_1 = function(t, &y)?;
238                y_trial = &k_1 * (0.2 * dt) + &y;
239                k_2 = function(t + 0.2 * dt, &y_trial)?;
240                y_trial = &k_1 * (0.075 * dt) + &k_2 * (0.225 * dt) + &y;
241                k_3 = function(t + 0.3 * dt, &y_trial)?;
242                y_trial = &k_1 * (C_44_45 * dt) - &k_2 * (C_56_15 * dt) + &k_3 * (C_32_9 * dt) + &y;
243                k_4 = function(t + 0.8 * dt, &y_trial)?;
244                y_trial = &k_1 * (C_19372_6561 * dt) - &k_2 * (C_25360_2187 * dt)
245                    + &k_3 * (C_64448_6561 * dt)
246                    - &k_4 * (C_212_729 * dt)
247                    + &y;
248                k_5 = function(t + C_8_9 * dt, &y_trial)?;
249                y_trial = &k_1 * (C_9017_3168 * dt) - &k_2 * (C_355_33 * dt)
250                    + &k_3 * (C_46732_5247 * dt)
251                    + &k_4 * (C_49_176 * dt)
252                    - &k_5 * (C_5103_18656 * dt)
253                    + &y;
254                k_6 = function(t + dt, &y_trial)?;
255                y_trial = (&k_1 * C_35_384 + &k_3 * C_500_1113 + &k_4 * C_125_192
256                    - &k_5 * C_2187_6784
257                    + &k_6 * C_11_84)
258                    * dt
259                    + &y;
260            }
261            dydt_int.push(function(t + dt, &y_trial)?);
262            y_int.push(y_trial);
263        }
264        Ok((y_int, dydt_int))
265    }
266}