conspire/math/integrate/ode/explicit/variable_step/dormand_prince/
mod.rs1#[cfg(test)]
2mod test;
3
4use crate::math::{
5 Scalar, Tensor, TensorVec, Vector,
6 integrate::{Explicit, IntegrationError, OdeSolver, VariableStep, VariableStepExplicit},
7 interpolate::InterpolateSolution,
8};
9use crate::{ABS_TOL, REL_TOL};
10use std::ops::{Mul, Sub};
11
12const C_44_45: Scalar = 44.0 / 45.0;
13const C_56_15: Scalar = 56.0 / 15.0;
14const C_32_9: Scalar = 32.0 / 9.0;
15const C_8_9: Scalar = 8.0 / 9.0;
16const C_19372_6561: Scalar = 19372.0 / 6561.0;
17const C_25360_2187: Scalar = 25360.0 / 2187.0;
18const C_64448_6561: Scalar = 64448.0 / 6561.0;
19const C_212_729: Scalar = 212.0 / 729.0;
20const C_9017_3168: Scalar = 9017.0 / 3168.0;
21const C_355_33: Scalar = 355.0 / 33.0;
22const C_46732_5247: Scalar = 46732.0 / 5247.0;
23const C_49_176: Scalar = 49.0 / 176.0;
24const C_5103_18656: Scalar = 5103.0 / 18656.0;
25const C_35_384: Scalar = 35.0 / 384.0;
26const C_500_1113: Scalar = 500.0 / 1113.0;
27const C_125_192: Scalar = 125.0 / 192.0;
28const C_2187_6784: Scalar = 2187.0 / 6784.0;
29const C_11_84: Scalar = 11.0 / 84.0;
30const C_71_57600: Scalar = 71.0 / 57600.0;
31const C_71_16695: Scalar = 71.0 / 16695.0;
32const C_71_1920: Scalar = 71.0 / 1920.0;
33const C_17253_339200: Scalar = 17253.0 / 339200.0;
34const C_22_525: Scalar = 22.0 / 525.0;
35
36#[doc = include_str!("doc.md")]
37#[derive(Debug)]
38pub struct DormandPrince {
39 pub abs_tol: Scalar,
41 pub rel_tol: Scalar,
43 pub dt_beta: Scalar,
45 pub dt_expn: Scalar,
47 pub dt_cut: Scalar,
49 pub dt_min: Scalar,
51}
52
53impl Default for DormandPrince {
54 fn default() -> Self {
55 Self {
56 abs_tol: ABS_TOL,
57 rel_tol: REL_TOL,
58 dt_beta: 0.9,
59 dt_expn: 5.0,
60 dt_cut: 0.5,
61 dt_min: ABS_TOL,
62 }
63 }
64}
65
66impl<Y, U> OdeSolver<Y, U> for DormandPrince
67where
68 Y: Tensor,
69 U: TensorVec<Item = Y>,
70{
71}
72
73impl VariableStep for DormandPrince {
74 fn abs_tol(&self) -> Scalar {
75 self.abs_tol
76 }
77 fn rel_tol(&self) -> Scalar {
78 self.rel_tol
79 }
80 fn dt_beta(&self) -> Scalar {
81 self.dt_beta
82 }
83 fn dt_expn(&self) -> Scalar {
84 self.dt_expn
85 }
86 fn dt_cut(&self) -> Scalar {
87 self.dt_cut
88 }
89 fn dt_min(&self) -> Scalar {
90 self.dt_min
91 }
92}
93
94impl<Y, U> Explicit<Y, U> for DormandPrince
95where
96 Self: OdeSolver<Y, U>,
97 Y: Tensor,
98 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
99 U: TensorVec<Item = Y>,
100{
101 const SLOPES: usize = 7;
102 fn integrate(
103 &self,
104 function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
105 time: &[Scalar],
106 initial_condition: Y,
107 ) -> Result<(Vector, U, U), IntegrationError> {
108 self.integrate_variable_step(function, time, initial_condition)
109 }
110}
111
112pub fn slopes<Y>(
113 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
114 y: &Y,
115 t: Scalar,
116 dt: Scalar,
117 k: &mut [Y],
118 y_trial: &mut Y,
119) -> Result<(), String>
120where
121 Y: Tensor,
122 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
123{
124 *y_trial = &k[0] * (0.2 * dt) + y;
125 k[1] = function(t + 0.2 * dt, y_trial)?;
126 *y_trial = &k[0] * (0.075 * dt) + &k[1] * (0.225 * dt) + y;
127 k[2] = function(t + 0.3 * dt, y_trial)?;
128 *y_trial = &k[0] * (C_44_45 * dt) - &k[1] * (C_56_15 * dt) + &k[2] * (C_32_9 * dt) + y;
129 k[3] = function(t + 0.8 * dt, y_trial)?;
130 *y_trial = &k[0] * (C_19372_6561 * dt) - &k[1] * (C_25360_2187 * dt)
131 + &k[2] * (C_64448_6561 * dt)
132 - &k[3] * (C_212_729 * dt)
133 + y;
134 k[4] = function(t + C_8_9 * dt, y_trial)?;
135 *y_trial = &k[0] * (C_9017_3168 * dt) - &k[1] * (C_355_33 * dt)
136 + &k[2] * (C_46732_5247 * dt)
137 + &k[3] * (C_49_176 * dt)
138 - &k[4] * (C_5103_18656 * dt)
139 + y;
140 k[5] = function(t + dt, y_trial)?;
141 *y_trial = (&k[0] * C_35_384 + &k[2] * C_500_1113 + &k[3] * C_125_192 - &k[4] * C_2187_6784
142 + &k[5] * C_11_84)
143 * dt
144 + y;
145 Ok(())
146}
147
148impl<Y, U> VariableStepExplicit<Y, U> for DormandPrince
149where
150 Self: OdeSolver<Y, U>,
151 Y: Tensor,
152 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
153 U: TensorVec<Item = Y>,
154{
155 fn slopes(
156 &self,
157 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
158 y: &Y,
159 t: Scalar,
160 dt: Scalar,
161 k: &mut [Y],
162 y_trial: &mut Y,
163 ) -> Result<Scalar, String> {
164 slopes(&mut function, y, t, dt, k, y_trial)?;
165 k[6] = function(t + dt, y_trial)?;
166 Ok(
167 ((&k[0] * C_71_57600 - &k[2] * C_71_16695 + &k[3] * C_71_1920
168 - &k[4] * C_17253_339200
169 + &k[5] * C_22_525
170 - &k[6] * 0.025)
171 * dt)
172 .norm_inf(),
173 )
174 }
175 fn step(
176 &self,
177 _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
178 y: &mut Y,
179 t: &mut Scalar,
180 y_sol: &mut U,
181 t_sol: &mut Vector,
182 dydt_sol: &mut U,
183 dt: &mut Scalar,
184 k: &mut [Y],
185 y_trial: &Y,
186 e: Scalar,
187 ) -> Result<(), String> {
188 if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
189 k[0] = k[6].clone();
190 *t += *dt;
191 *y = y_trial.clone();
192 t_sol.push(*t);
193 y_sol.push(y.clone());
194 dydt_sol.push(k[0].clone());
195 }
196 self.time_step(e, dt);
197 Ok(())
198 }
199}
200
201impl<Y, U> InterpolateSolution<Y, U> for DormandPrince
202where
203 Y: Tensor,
204 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
205 U: TensorVec<Item = Y>,
206{
207 fn interpolate(
208 &self,
209 time: &Vector,
210 tp: &Vector,
211 yp: &U,
212 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
213 ) -> Result<(U, U), IntegrationError> {
214 let mut dt;
215 let mut i;
216 let mut k_1;
217 let mut k_2;
218 let mut k_3;
219 let mut k_4;
220 let mut k_5;
221 let mut k_6;
222 let mut t;
223 let mut y;
224 let mut y_int = U::new();
225 let mut dydt_int = U::new();
226 let mut y_trial;
227 for time_k in time.iter() {
228 i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
229 if time_k == &tp[i] {
230 t = tp[i];
231 y_trial = yp[i].clone();
232 dt = 0.0;
233 } else {
234 t = tp[i - 1];
235 y = yp[i - 1].clone();
236 dt = time_k - t;
237 k_1 = function(t, &y)?;
238 y_trial = &k_1 * (0.2 * dt) + &y;
239 k_2 = function(t + 0.2 * dt, &y_trial)?;
240 y_trial = &k_1 * (0.075 * dt) + &k_2 * (0.225 * dt) + &y;
241 k_3 = function(t + 0.3 * dt, &y_trial)?;
242 y_trial = &k_1 * (C_44_45 * dt) - &k_2 * (C_56_15 * dt) + &k_3 * (C_32_9 * dt) + &y;
243 k_4 = function(t + 0.8 * dt, &y_trial)?;
244 y_trial = &k_1 * (C_19372_6561 * dt) - &k_2 * (C_25360_2187 * dt)
245 + &k_3 * (C_64448_6561 * dt)
246 - &k_4 * (C_212_729 * dt)
247 + &y;
248 k_5 = function(t + C_8_9 * dt, &y_trial)?;
249 y_trial = &k_1 * (C_9017_3168 * dt) - &k_2 * (C_355_33 * dt)
250 + &k_3 * (C_46732_5247 * dt)
251 + &k_4 * (C_49_176 * dt)
252 - &k_5 * (C_5103_18656 * dt)
253 + &y;
254 k_6 = function(t + dt, &y_trial)?;
255 y_trial = (&k_1 * C_35_384 + &k_3 * C_500_1113 + &k_4 * C_125_192
256 - &k_5 * C_2187_6784
257 + &k_6 * C_11_84)
258 * dt
259 + &y;
260 }
261 dydt_int.push(function(t + dt, &y_trial)?);
262 y_int.push(y_trial);
263 }
264 Ok((y_int, dydt_int))
265 }
266}