conspire/math/integrate/ode/explicit/variable_step/bogacki_shampine/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{
5    Scalar, Tensor, TensorVec, Vector,
6    integrate::{
7        Explicit, ExplicitInternalVariables, IntegrationError, OdeSolver, VariableStep,
8        VariableStepExplicit, VariableStepExplicitInternalVariables,
9    },
10    interpolate::{InterpolateSolution, InterpolateSolutionInternalVariables},
11};
12use crate::{ABS_TOL, REL_TOL};
13use std::ops::{Mul, Sub};
14
15#[doc = include_str!("doc.md")]
16#[derive(Debug)]
17pub struct BogackiShampine {
18    /// Absolute error tolerance.
19    pub abs_tol: Scalar,
20    /// Relative error tolerance.
21    pub rel_tol: Scalar,
22    /// Multiplier for adaptive time steps.
23    pub dt_beta: Scalar,
24    /// Exponent for adaptive time steps.
25    pub dt_expn: Scalar,
26    /// Cut back factor for the time step.
27    pub dt_cut: Scalar,
28    /// Minimum value for the time step.
29    pub dt_min: Scalar,
30}
31
32impl Default for BogackiShampine {
33    fn default() -> Self {
34        Self {
35            abs_tol: ABS_TOL,
36            rel_tol: REL_TOL,
37            dt_beta: 0.9,
38            dt_expn: 3.0,
39            dt_cut: 0.5,
40            dt_min: ABS_TOL,
41        }
42    }
43}
44
45impl<Y, U> OdeSolver<Y, U> for BogackiShampine
46where
47    Y: Tensor,
48    U: TensorVec<Item = Y>,
49{
50}
51
52impl VariableStep for BogackiShampine {
53    fn abs_tol(&self) -> Scalar {
54        self.abs_tol
55    }
56    fn rel_tol(&self) -> Scalar {
57        self.rel_tol
58    }
59    fn dt_beta(&self) -> Scalar {
60        self.dt_beta
61    }
62    fn dt_expn(&self) -> Scalar {
63        self.dt_expn
64    }
65    fn dt_cut(&self) -> Scalar {
66        self.dt_cut
67    }
68    fn dt_min(&self) -> Scalar {
69        self.dt_min
70    }
71}
72
73impl<Y, U> Explicit<Y, U> for BogackiShampine
74where
75    Self: OdeSolver<Y, U>,
76    Y: Tensor,
77    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
78    U: TensorVec<Item = Y>,
79{
80    const SLOPES: usize = 4;
81    fn integrate(
82        &self,
83        function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
84        time: &[Scalar],
85        initial_condition: Y,
86    ) -> Result<(Vector, U, U), IntegrationError> {
87        self.integrate_variable_step(function, time, initial_condition)
88    }
89}
90
91pub fn slopes<Y>(
92    mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
93    y: &Y,
94    t: Scalar,
95    dt: Scalar,
96    k: &mut [Y],
97    y_trial: &mut Y,
98) -> Result<(), String>
99where
100    Y: Tensor,
101    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
102{
103    *y_trial = &k[0] * (0.5 * dt) + y;
104    k[1] = function(t + 0.5 * dt, y_trial)?;
105    *y_trial = &k[1] * (0.75 * dt) + y;
106    k[2] = function(t + 0.75 * dt, y_trial)?;
107    *y_trial = (&k[0] * 2.0 + &k[1] * 3.0 + &k[2] * 4.0) * (dt / 9.0) + y;
108    Ok(())
109}
110
111impl<Y, U> VariableStepExplicit<Y, U> for BogackiShampine
112where
113    Self: OdeSolver<Y, U>,
114    Y: Tensor,
115    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
116    U: TensorVec<Item = Y>,
117{
118    fn slopes(
119        &self,
120        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
121        y: &Y,
122        t: Scalar,
123        dt: Scalar,
124        k: &mut [Y],
125        y_trial: &mut Y,
126    ) -> Result<Scalar, String> {
127        slopes(&mut function, y, t, dt, k, y_trial)?;
128        k[3] = function(t + dt, y_trial)?;
129        Ok(((&k[0] * -5.0 + &k[1] * 6.0 + &k[2] * 8.0 + &k[3] * -9.0) * (dt / 72.0)).norm_inf())
130    }
131    fn step(
132        &self,
133        _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
134        y: &mut Y,
135        t: &mut Scalar,
136        y_sol: &mut U,
137        t_sol: &mut Vector,
138        dydt_sol: &mut U,
139        dt: &mut Scalar,
140        k: &mut [Y],
141        y_trial: &Y,
142        e: Scalar,
143    ) -> Result<(), String> {
144        if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
145            k[0] = k[3].clone();
146            *t += *dt;
147            *y = y_trial.clone();
148            t_sol.push(*t);
149            y_sol.push(y.clone());
150            dydt_sol.push(k[0].clone());
151        }
152        // self.time_step(e, dt); using below temporarily to pass test barely failing
153        if e > 0.0 {
154            *dt *= self.dt_beta() * (self.abs_tol() / e).powf(1.0 / self.dt_expn())
155        }
156        Ok(())
157    }
158}
159
160impl<Y, U> InterpolateSolution<Y, U> for BogackiShampine
161where
162    Y: Tensor,
163    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
164    U: TensorVec<Item = Y>,
165{
166    fn interpolate(
167        &self,
168        time: &Vector,
169        tp: &Vector,
170        yp: &U,
171        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
172    ) -> Result<(U, U), IntegrationError> {
173        let mut dt;
174        let mut i;
175        let mut k_1;
176        let mut k_2;
177        let mut k_3;
178        let mut t;
179        let mut y;
180        let mut y_int = U::new();
181        let mut dydt_int = U::new();
182        let mut y_trial;
183        for time_k in time.iter() {
184            i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
185            if time_k == &tp[i] {
186                t = tp[i];
187                y_trial = yp[i].clone();
188                dt = 0.0;
189            } else {
190                t = tp[i - 1];
191                y = yp[i - 1].clone();
192                dt = time_k - t;
193                k_1 = function(t, &y)?;
194                y_trial = &k_1 * (0.5 * dt) + &y;
195                k_2 = function(t + 0.5 * dt, &y_trial)?;
196                y_trial = &k_2 * (0.75 * dt) + &y;
197                k_3 = function(t + 0.75 * dt, &y_trial)?;
198                y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
199            }
200            dydt_int.push(function(t + dt, &y_trial)?);
201            y_int.push(y_trial);
202        }
203        Ok((y_int, dydt_int))
204    }
205}
206
207impl<Y, Z, U, V> ExplicitInternalVariables<Y, Z, U, V> for BogackiShampine
208where
209    Self: OdeSolver<Y, U>,
210    Y: Tensor,
211    Z: Tensor,
212    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
213    U: TensorVec<Item = Y>,
214    V: TensorVec<Item = Z>,
215{
216    fn integrate_and_evaluate(
217        &self,
218        function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
219        evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
220        time: &[Scalar],
221        initial_condition: Y,
222        initial_evaluation: Z,
223    ) -> Result<(Vector, U, U, V), IntegrationError> {
224        self.integrate_and_evaluate_variable_step(
225            function,
226            evaluate,
227            time,
228            initial_condition,
229            initial_evaluation,
230        )
231    }
232}
233
234impl<Y, Z, U, V> VariableStepExplicitInternalVariables<Y, Z, U, V> for BogackiShampine
235where
236    Self: OdeSolver<Y, U>,
237    Y: Tensor,
238    Z: Tensor,
239    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
240    U: TensorVec<Item = Y>,
241    V: TensorVec<Item = Z>,
242{
243    fn slopes(
244        &self,
245        mut function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
246        mut evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
247        y: &Y,
248        z: &Z,
249        t: Scalar,
250        dt: Scalar,
251        k: &mut [Y],
252        y_trial: &mut Y,
253        z_trial: &mut Z,
254    ) -> Result<Scalar, String> {
255        *y_trial = &k[0] * (0.5 * dt) + y;
256        *z_trial = evaluate(t + 0.5 * dt, y_trial, z)?;
257        k[1] = function(t + 0.5 * dt, y_trial, z_trial)?;
258        *y_trial = &k[1] * (0.75 * dt) + y;
259        *z_trial = evaluate(t + 0.75 * dt, y_trial, z_trial)?;
260        k[2] = function(t + 0.75 * dt, y_trial, z_trial)?;
261        *y_trial = (&k[0] * 2.0 + &k[1] * 3.0 + &k[2] * 4.0) * (dt / 9.0) + y;
262        *z_trial = evaluate(t + dt, y_trial, z_trial)?;
263        k[3] = function(t + dt, y_trial, z_trial)?;
264        Ok(((&k[0] * -5.0 + &k[1] * 6.0 + &k[2] * 8.0 + &k[3] * -9.0) * (dt / 72.0)).norm_inf())
265    }
266    fn step(
267        &self,
268        _function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
269        y: &mut Y,
270        z: &mut Z,
271        t: &mut Scalar,
272        y_sol: &mut U,
273        z_sol: &mut V,
274        t_sol: &mut Vector,
275        dydt_sol: &mut U,
276        dt: &mut Scalar,
277        k: &mut [Y],
278        y_trial: &Y,
279        z_trial: &Z,
280        e: Scalar,
281    ) -> Result<(), String> {
282        if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
283            k[0] = k[3].clone();
284            *t += *dt;
285            *y = y_trial.clone();
286            *z = z_trial.clone();
287            t_sol.push(*t);
288            y_sol.push(y.clone());
289            z_sol.push(z.clone());
290            dydt_sol.push(k[0].clone());
291        }
292        self.time_step(e, dt);
293        Ok(())
294    }
295}
296
297impl<Y, Z, U, V> InterpolateSolutionInternalVariables<Y, Z, U, V> for BogackiShampine
298where
299    Y: Tensor,
300    Z: Tensor,
301    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
302    U: TensorVec<Item = Y>,
303    V: TensorVec<Item = Z>,
304{
305    fn interpolate(
306        &self,
307        time: &Vector,
308        tp: &Vector,
309        yp: &U,
310        zp: &V,
311        mut function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
312        mut evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
313    ) -> Result<(U, U, V), IntegrationError> {
314        let mut dt;
315        let mut i;
316        let mut k_1;
317        let mut k_2;
318        let mut k_3;
319        let mut t;
320        let mut y;
321        let mut y_int = U::new();
322        let mut z_int = V::new();
323        let mut dydt_int = U::new();
324        let mut y_trial;
325        let mut z_trial;
326        for time_k in time.iter() {
327            i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
328            if time_k == &tp[i] {
329                t = tp[i];
330                y_trial = yp[i].clone();
331                z_trial = zp[i].clone();
332                dt = 0.0;
333            } else {
334                t = tp[i - 1];
335                y = yp[i - 1].clone();
336                z_trial = zp[i - 1].clone();
337                dt = time_k - t;
338                k_1 = function(t, &y, &z_trial)?;
339                y_trial = &k_1 * (0.5 * dt) + &y;
340                z_trial = evaluate(t + 0.5 * dt, &y_trial, &z_trial)?;
341                k_2 = function(t + 0.5 * dt, &y_trial, &z_trial)?;
342                y_trial = &k_2 * (0.75 * dt) + &y;
343                z_trial = evaluate(t + 0.75 * dt, &y_trial, &z_trial)?;
344                k_3 = function(t + 0.75 * dt, &y_trial, &z_trial)?;
345                y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
346                z_trial = evaluate(t + dt, &y_trial, &z_trial)?;
347            }
348            dydt_int.push(function(t + dt, &y_trial, &z_trial)?);
349            y_int.push(y_trial);
350            z_int.push(z_trial);
351        }
352        Ok((y_int, dydt_int, z_int))
353    }
354}