conspire/math/integrate/ode/explicit/variable_step/verner_9/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{
5    Scalar, Tensor, TensorVec, Vector,
6    integrate::{Explicit, IntegrationError, OdeSolver, VariableStep, VariableStepExplicit},
7    interpolate::InterpolateSolution,
8};
9use crate::{ABS_TOL, REL_TOL};
10use std::ops::{Mul, Sub};
11
12const C_2: Scalar = 0.03462;
13const C_3: Scalar = 0.097_024_350_638_780_44;
14const C_4: Scalar = 0.145_536_525_958_170_67;
15const C_5: Scalar = 0.561;
16const C_6: Scalar = 0.229_007_911_590_485;
17const C_7: Scalar = 0.544_992_088_409_515;
18const C_8: Scalar = 0.645;
19const C_9: Scalar = 0.48375;
20const C_10: Scalar = 0.06757;
21const C_11: Scalar = 0.2500;
22const C_12: Scalar = 0.659_065_061_873_099_9;
23const C_13: Scalar = 0.8206;
24const C_14: Scalar = 0.9012;
25
26const A_2_1: Scalar = 0.03462;
27const A_3_1: Scalar = -0.03893354388572875;
28const A_3_2: Scalar = 0.13595789452450918;
29const A_4_1: Scalar = 0.03638413148954267;
30const A_4_3: Scalar = 0.10915239446862801;
31const A_5_1: Scalar = 2.0257639143939694;
32const A_5_3: Scalar = -7.638023836496291;
33const A_5_4: Scalar = 6.173259922102322;
34const A_6_1: Scalar = 0.05112275589406061;
35const A_6_4: Scalar = 0.17708237945550218;
36const A_6_5: Scalar = 0.0008027762409222536;
37const A_7_1: Scalar = 0.13160063579752163;
38const A_7_4: Scalar = -0.2957276252669636;
39const A_7_5: Scalar = 0.08781378035642955;
40const A_7_6: Scalar = 0.6213052975225274;
41const A_8_1: Scalar = 0.07166666666666667;
42const A_8_6: Scalar = 0.33055335789153195;
43const A_8_7: Scalar = 0.2427799754418014;
44const A_9_1: Scalar = 0.071806640625;
45const A_9_6: Scalar = 0.3294380283228177;
46const A_9_7: Scalar = 0.1165190029271823;
47const A_9_8: Scalar = -0.034013671875;
48const A_10_1: Scalar = 0.04836757646340646;
49const A_10_6: Scalar = 0.03928989925676164;
50const A_10_7: Scalar = 0.10547409458903446;
51const A_10_8: Scalar = -0.021438652846483126;
52const A_10_9: Scalar = -0.10412291746271944;
53const A_11_1: Scalar = -0.026645614872014785;
54const A_11_6: Scalar = 0.03333333333333333;
55const A_11_7: Scalar = -0.1631072244872467;
56const A_11_8: Scalar = 0.03396081684127761;
57const A_11_9: Scalar = 0.1572319413814626;
58const A_11_10: Scalar = 0.21522674780318796;
59const A_12_1: Scalar = 0.03689009248708622;
60const A_12_6: Scalar = -0.1465181576725543;
61const A_12_7: Scalar = 0.2242577768172024;
62const A_12_8: Scalar = 0.02294405717066073;
63const A_12_9: Scalar = -0.0035850052905728597;
64const A_12_10: Scalar = 0.08669223316444385;
65const A_12_11: Scalar = 0.43838406519683376;
66const A_13_1: Scalar = -0.4866012215113341;
67const A_13_6: Scalar = -6.304602650282853;
68const A_13_7: Scalar = -0.2812456182894729;
69const A_13_8: Scalar = -2.679019236219849;
70const A_13_9: Scalar = 0.5188156639241577;
71const A_13_10: Scalar = 1.3653531876033418;
72const A_13_11: Scalar = 5.8850910885039465;
73const A_13_12: Scalar = 2.8028087862720628;
74const A_14_1: Scalar = 0.4185367457753472;
75const A_14_6: Scalar = 6.724547581906459;
76const A_14_7: Scalar = -0.42544428016461133;
77const A_14_8: Scalar = 3.3432791530012653;
78const A_14_9: Scalar = 0.6170816631175374;
79const A_14_10: Scalar = -0.9299661239399329;
80const A_14_11: Scalar = -6.099948804751011;
81const A_14_12: Scalar = -3.002206187889399;
82const A_14_13: Scalar = 0.2553202529443446;
83const A_15_1: Scalar = -0.7793740861228848;
84const A_15_6: Scalar = -13.937342538107776;
85const A_15_7: Scalar = 1.2520488533793563;
86const A_15_8: Scalar = -14.691500408016868;
87const A_15_9: Scalar = -0.494705058533141;
88const A_15_10: Scalar = 2.2429749091462368;
89const A_15_11: Scalar = 13.367893803828643;
90const A_15_12: Scalar = 14.396650486650687;
91const A_15_13: Scalar = -0.79758133317768;
92const A_15_14: Scalar = 0.4409353709534278;
93const A_16_1: Scalar = 2.0580513374668867;
94const A_16_6: Scalar = 22.357937727968032;
95const A_16_7: Scalar = 0.9094981099755646;
96const A_16_8: Scalar = 35.89110098240264;
97const A_16_9: Scalar = -3.442515027624454;
98const A_16_10: Scalar = -4.865481358036369;
99const A_16_11: Scalar = -18.909803813543427;
100const A_16_12: Scalar = -34.26354448030452;
101const A_16_13: Scalar = 1.2647565216956427;
102
103const B_1: Scalar = 0.014611976858423152;
104const B_8: Scalar = -0.3915211862331339;
105const B_9: Scalar = 0.23109325002895065;
106const B_10: Scalar = 0.12747667699928525;
107const B_11: Scalar = 0.2246434176204158;
108const B_12: Scalar = 0.5684352689748513;
109const B_13: Scalar = 0.058258715572158275;
110const B_14: Scalar = 0.13643174034822156;
111const B_15: Scalar = 0.030570139830827976;
112
113const D_1: Scalar = -0.005357988290444578;
114const D_8: Scalar = -2.583020491182464;
115const D_9: Scalar = 0.14252253154686625;
116const D_10: Scalar = 0.013420653512688676;
117const D_11: Scalar = -0.02867296291409493;
118const D_12: Scalar = 2.624999655215792;
119const D_13: Scalar = -0.2825509643291537;
120const D_14: Scalar = 0.13643174034822156;
121const D_15: Scalar = 0.030570139830827976;
122const D_16: Scalar = -0.04834231373823958;
123
124#[doc = include_str!("doc.md")]
125#[derive(Debug)]
126pub struct Verner9 {
127    /// Absolute error tolerance.
128    pub abs_tol: Scalar,
129    /// Relative error tolerance.
130    pub rel_tol: Scalar,
131    /// Multiplier for adaptive time steps.
132    pub dt_beta: Scalar,
133    /// Exponent for adaptive time steps.
134    pub dt_expn: Scalar,
135    /// Cut back factor for the time step.
136    pub dt_cut: Scalar,
137    /// Minimum value for the time step.
138    pub dt_min: Scalar,
139}
140
141impl Default for Verner9 {
142    fn default() -> Self {
143        Self {
144            abs_tol: ABS_TOL,
145            rel_tol: REL_TOL,
146            dt_beta: 0.9,
147            dt_expn: 9.0,
148            dt_cut: 0.5,
149            dt_min: ABS_TOL,
150        }
151    }
152}
153
154impl<Y, U> OdeSolver<Y, U> for Verner9
155where
156    Y: Tensor,
157    U: TensorVec<Item = Y>,
158{
159}
160
161impl VariableStep for Verner9 {
162    fn abs_tol(&self) -> Scalar {
163        self.abs_tol
164    }
165    fn rel_tol(&self) -> Scalar {
166        self.rel_tol
167    }
168    fn dt_beta(&self) -> Scalar {
169        self.dt_beta
170    }
171    fn dt_expn(&self) -> Scalar {
172        self.dt_expn
173    }
174    fn dt_cut(&self) -> Scalar {
175        self.dt_cut
176    }
177    fn dt_min(&self) -> Scalar {
178        self.dt_min
179    }
180}
181
182impl<Y, U> Explicit<Y, U> for Verner9
183where
184    Self: OdeSolver<Y, U>,
185    Y: Tensor,
186    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
187    U: TensorVec<Item = Y>,
188{
189    const SLOPES: usize = 16;
190    fn integrate(
191        &self,
192        function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
193        time: &[Scalar],
194        initial_condition: Y,
195    ) -> Result<(Vector, U, U), IntegrationError> {
196        self.integrate_variable_step(function, time, initial_condition)
197    }
198}
199
200impl<Y, U> VariableStepExplicit<Y, U> for Verner9
201where
202    Self: OdeSolver<Y, U>,
203    Y: Tensor,
204    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
205    U: TensorVec<Item = Y>,
206{
207    fn slopes(
208        &self,
209        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
210        y: &Y,
211        t: Scalar,
212        dt: Scalar,
213        k: &mut [Y],
214        y_trial: &mut Y,
215    ) -> Result<Scalar, String> {
216        k[0] = function(t, y)?;
217        *y_trial = &k[0] * (A_2_1 * dt) + y;
218        k[1] = function(t + C_2 * dt, y_trial)?;
219        *y_trial = &k[0] * (A_3_1 * dt) + &k[1] * (A_3_2 * dt) + y;
220        k[2] = function(t + C_3 * dt, y_trial)?;
221        *y_trial = &k[0] * (A_4_1 * dt) + &k[2] * (A_4_3 * dt) + y;
222        k[3] = function(t + C_4 * dt, y_trial)?;
223        *y_trial = &k[0] * (A_5_1 * dt) + &k[2] * (A_5_3 * dt) + &k[3] * (A_5_4 * dt) + y;
224        k[4] = function(t + C_5 * dt, y_trial)?;
225        *y_trial = &k[0] * (A_6_1 * dt) + &k[3] * (A_6_4 * dt) + &k[4] * (A_6_5 * dt) + y;
226        k[5] = function(t + C_6 * dt, y_trial)?;
227        *y_trial = &k[0] * (A_7_1 * dt)
228            + &k[3] * (A_7_4 * dt)
229            + &k[4] * (A_7_5 * dt)
230            + &k[5] * (A_7_6 * dt)
231            + y;
232        k[6] = function(t + C_7 * dt, y_trial)?;
233        *y_trial = &k[0] * (A_8_1 * dt) + &k[5] * (A_8_6 * dt) + &k[6] * (A_8_7 * dt) + y;
234        k[7] = function(t + C_8 * dt, y_trial)?;
235        *y_trial = &k[0] * (A_9_1 * dt)
236            + &k[5] * (A_9_6 * dt)
237            + &k[6] * (A_9_7 * dt)
238            + &k[7] * (A_9_8 * dt)
239            + y;
240        k[8] = function(t + C_9 * dt, y_trial)?;
241        *y_trial = &k[0] * (A_10_1 * dt)
242            + &k[5] * (A_10_6 * dt)
243            + &k[6] * (A_10_7 * dt)
244            + &k[7] * (A_10_8 * dt)
245            + &k[8] * (A_10_9 * dt)
246            + y;
247        k[9] = function(t + C_10 * dt, y_trial)?;
248        *y_trial = &k[0] * (A_11_1 * dt)
249            + &k[5] * (A_11_6 * dt)
250            + &k[6] * (A_11_7 * dt)
251            + &k[7] * (A_11_8 * dt)
252            + &k[8] * (A_11_9 * dt)
253            + &k[9] * (A_11_10 * dt)
254            + y;
255        k[10] = function(t + C_11 * dt, y_trial)?;
256        *y_trial = &k[0] * (A_12_1 * dt)
257            + &k[5] * (A_12_6 * dt)
258            + &k[6] * (A_12_7 * dt)
259            + &k[7] * (A_12_8 * dt)
260            + &k[8] * (A_12_9 * dt)
261            + &k[9] * (A_12_10 * dt)
262            + &k[10] * (A_12_11 * dt)
263            + y;
264        k[11] = function(t + C_12 * dt, y_trial)?;
265        *y_trial = &k[0] * (A_13_1 * dt)
266            + &k[5] * (A_13_6 * dt)
267            + &k[6] * (A_13_7 * dt)
268            + &k[7] * (A_13_8 * dt)
269            + &k[8] * (A_13_9 * dt)
270            + &k[9] * (A_13_10 * dt)
271            + &k[10] * (A_13_11 * dt)
272            + &k[11] * (A_13_12 * dt)
273            + y;
274        k[12] = function(t + C_13 * dt, y_trial)?;
275        *y_trial = &k[0] * (A_14_1 * dt)
276            + &k[5] * (A_14_6 * dt)
277            + &k[6] * (A_14_7 * dt)
278            + &k[7] * (A_14_8 * dt)
279            + &k[8] * (A_14_9 * dt)
280            + &k[9] * (A_14_10 * dt)
281            + &k[10] * (A_14_11 * dt)
282            + &k[11] * (A_14_12 * dt)
283            + &k[12] * (A_14_13 * dt)
284            + y;
285        k[13] = function(t + C_14 * dt, y_trial)?;
286        *y_trial = &k[0] * (A_15_1 * dt)
287            + &k[5] * (A_15_6 * dt)
288            + &k[6] * (A_15_7 * dt)
289            + &k[7] * (A_15_8 * dt)
290            + &k[8] * (A_15_9 * dt)
291            + &k[9] * (A_15_10 * dt)
292            + &k[10] * (A_15_11 * dt)
293            + &k[11] * (A_15_12 * dt)
294            + &k[12] * (A_15_13 * dt)
295            + &k[13] * (A_15_14 * dt)
296            + y;
297        k[14] = function(t + dt, y_trial)?;
298        *y_trial = &k[0] * (A_16_1 * dt)
299            + &k[5] * (A_16_6 * dt)
300            + &k[6] * (A_16_7 * dt)
301            + &k[7] * (A_16_8 * dt)
302            + &k[8] * (A_16_9 * dt)
303            + &k[9] * (A_16_10 * dt)
304            + &k[10] * (A_16_11 * dt)
305            + &k[11] * (A_16_12 * dt)
306            + &k[12] * (A_16_13 * dt)
307            + y;
308        k[15] = function(t + dt, y_trial)?;
309        *y_trial = (&k[0] * B_1
310            + &k[7] * B_8
311            + &k[8] * B_9
312            + &k[9] * B_10
313            + &k[10] * B_11
314            + &k[11] * B_12
315            + &k[12] * B_13
316            + &k[13] * B_14
317            + &k[14] * B_15)
318            * dt
319            + y;
320        Ok(((&k[0] * D_1
321            + &k[7] * D_8
322            + &k[8] * D_9
323            + &k[9] * D_10
324            + &k[10] * D_11
325            + &k[11] * D_12
326            + &k[12] * D_13
327            + &k[13] * D_14
328            + &k[14] * D_15
329            + &k[15] * D_16)
330            * dt)
331            .norm_inf())
332    }
333    fn step(
334        &self,
335        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
336        y: &mut Y,
337        t: &mut Scalar,
338        y_sol: &mut U,
339        t_sol: &mut Vector,
340        dydt_sol: &mut U,
341        dt: &mut Scalar,
342        _k: &mut [Y],
343        y_trial: &Y,
344        e: Scalar,
345    ) -> Result<(), String> {
346        if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
347            *t += *dt;
348            *y = y_trial.clone();
349            t_sol.push(*t);
350            y_sol.push(y.clone());
351            dydt_sol.push(function(*t, y)?);
352        }
353        self.time_step(e, dt);
354        Ok(())
355    }
356}
357
358impl<Y, U> InterpolateSolution<Y, U> for Verner9
359where
360    Y: Tensor,
361    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
362    U: TensorVec<Item = Y>,
363{
364    fn interpolate(
365        &self,
366        time: &Vector,
367        tp: &Vector,
368        yp: &U,
369        mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
370    ) -> Result<(U, U), IntegrationError> {
371        let mut dt;
372        let mut i;
373        let mut k_1;
374        let mut k_2;
375        let mut k_3;
376        let mut k_4;
377        let mut k_5;
378        let mut k_6;
379        let mut k_7;
380        let mut k_8;
381        let mut k_9;
382        let mut k_10;
383        let mut k_11;
384        let mut k_12;
385        let mut k_13;
386        let mut k_14;
387        let mut k_15;
388        let mut t;
389        let mut y;
390        let mut y_int = U::new();
391        let mut dydt_int = U::new();
392        let mut y_trial;
393        for time_k in time.iter() {
394            i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
395            if time_k == &tp[i] {
396                t = tp[i];
397                y_trial = yp[i].clone();
398                dt = 0.0;
399            } else {
400                t = tp[i - 1];
401                y = yp[i - 1].clone();
402                dt = time_k - t;
403                k_1 = function(t, &y)?;
404                y_trial = &k_1 * (A_2_1 * dt) + &y;
405                k_2 = function(t + C_2 * dt, &y_trial)?;
406                y_trial = &k_1 * (A_3_1 * dt) + &k_2 * (A_3_2 * dt) + &y;
407                k_3 = function(t + C_3 * dt, &y_trial)?;
408                y_trial = &k_1 * (A_4_1 * dt) + &k_3 * (A_4_3 * dt) + &y;
409                k_4 = function(t + C_4 * dt, &y_trial)?;
410                y_trial = &k_1 * (A_5_1 * dt) + &k_3 * (A_5_3 * dt) + &k_4 * (A_5_4 * dt) + &y;
411                k_5 = function(t + C_5 * dt, &y_trial)?;
412                y_trial = &k_1 * (A_6_1 * dt) + &k_4 * (A_6_4 * dt) + &k_5 * (A_6_5 * dt) + &y;
413                k_6 = function(t + C_6 * dt, &y_trial)?;
414                y_trial = &k_1 * (A_7_1 * dt)
415                    + &k_4 * (A_7_4 * dt)
416                    + &k_5 * (A_7_5 * dt)
417                    + &k_6 * (A_7_6 * dt)
418                    + &y;
419                k_7 = function(t + C_7 * dt, &y_trial)?;
420                y_trial = &k_1 * (A_8_1 * dt) + &k_6 * (A_8_6 * dt) + &k_7 * (A_8_7 * dt) + &y;
421                k_8 = function(t + C_8 * dt, &y_trial)?;
422                y_trial = &k_1 * (A_9_1 * dt)
423                    + &k_6 * (A_9_6 * dt)
424                    + &k_7 * (A_9_7 * dt)
425                    + &k_8 * (A_9_8 * dt)
426                    + &y;
427                k_9 = function(t + C_9 * dt, &y_trial)?;
428                y_trial = &k_1 * (A_10_1 * dt)
429                    + &k_6 * (A_10_6 * dt)
430                    + &k_7 * (A_10_7 * dt)
431                    + &k_8 * (A_10_8 * dt)
432                    + &k_9 * (A_10_9 * dt)
433                    + &y;
434                k_10 = function(t + C_10 * dt, &y_trial)?;
435                y_trial = &k_1 * (A_11_1 * dt)
436                    + &k_6 * (A_11_6 * dt)
437                    + &k_7 * (A_11_7 * dt)
438                    + &k_8 * (A_11_8 * dt)
439                    + &k_9 * (A_11_9 * dt)
440                    + &k_10 * (A_11_10 * dt)
441                    + &y;
442                k_11 = function(t + C_11 * dt, &y_trial)?;
443                y_trial = &k_1 * (A_12_1 * dt)
444                    + &k_6 * (A_12_6 * dt)
445                    + &k_7 * (A_12_7 * dt)
446                    + &k_8 * (A_12_8 * dt)
447                    + &k_9 * (A_12_9 * dt)
448                    + &k_10 * (A_12_10 * dt)
449                    + &k_11 * (A_12_11 * dt)
450                    + &y;
451                k_12 = function(t + C_12 * dt, &y_trial)?;
452                y_trial = &k_1 * (A_13_1 * dt)
453                    + &k_6 * (A_13_6 * dt)
454                    + &k_7 * (A_13_7 * dt)
455                    + &k_8 * (A_13_8 * dt)
456                    + &k_9 * (A_13_9 * dt)
457                    + &k_10 * (A_13_10 * dt)
458                    + &k_11 * (A_13_11 * dt)
459                    + &k_12 * (A_13_12 * dt)
460                    + &y;
461                k_13 = function(t + C_13 * dt, &y_trial)?;
462                y_trial = &k_1 * (A_14_1 * dt)
463                    + &k_6 * (A_14_6 * dt)
464                    + &k_7 * (A_14_7 * dt)
465                    + &k_8 * (A_14_8 * dt)
466                    + &k_9 * (A_14_9 * dt)
467                    + &k_10 * (A_14_10 * dt)
468                    + &k_11 * (A_14_11 * dt)
469                    + &k_12 * (A_14_12 * dt)
470                    + &k_13 * (A_14_13 * dt)
471                    + &y;
472                k_14 = function(t + C_14 * dt, &y_trial)?;
473                y_trial = &k_1 * (A_15_1 * dt)
474                    + &k_6 * (A_15_6 * dt)
475                    + &k_7 * (A_15_7 * dt)
476                    + &k_8 * (A_15_8 * dt)
477                    + &k_9 * (A_15_9 * dt)
478                    + &k_10 * (A_15_10 * dt)
479                    + &k_11 * (A_15_11 * dt)
480                    + &k_12 * (A_15_12 * dt)
481                    + &k_13 * (A_15_13 * dt)
482                    + &k_14 * (A_15_14 * dt)
483                    + &y;
484                k_15 = function(t + dt, &y_trial)?;
485                y_trial = (&k_1 * B_1
486                    + &k_8 * B_8
487                    + &k_9 * B_9
488                    + &k_10 * B_10
489                    + &k_11 * B_11
490                    + &k_12 * B_12
491                    + &k_13 * B_13
492                    + &k_14 * B_14
493                    + &k_15 * B_15)
494                    * dt
495                    + &y;
496            }
497            dydt_int.push(function(t + dt, &y_trial)?);
498            y_int.push(y_trial);
499        }
500        Ok((y_int, dydt_int))
501    }
502}