1#[cfg(test)]
2mod test;
3
4use crate::math::{
5 Scalar, Tensor, TensorVec, Vector,
6 integrate::{
7 Explicit, ExplicitInternalVariables, IntegrationError, OdeSolver, VariableStep,
8 VariableStepExplicit, VariableStepExplicitInternalVariables,
9 },
10 interpolate::{InterpolateSolution, InterpolateSolutionInternalVariables},
11};
12use crate::{ABS_TOL, REL_TOL};
13use std::ops::{Mul, Sub};
14
15#[doc = include_str!("doc.md")]
16#[derive(Debug)]
17pub struct BogackiShampine {
18 pub abs_tol: Scalar,
20 pub rel_tol: Scalar,
22 pub dt_beta: Scalar,
24 pub dt_expn: Scalar,
26 pub dt_cut: Scalar,
28 pub dt_min: Scalar,
30}
31
32impl Default for BogackiShampine {
33 fn default() -> Self {
34 Self {
35 abs_tol: ABS_TOL,
36 rel_tol: REL_TOL,
37 dt_beta: 0.9,
38 dt_expn: 3.0,
39 dt_cut: 0.5,
40 dt_min: ABS_TOL,
41 }
42 }
43}
44
45impl<Y, U> OdeSolver<Y, U> for BogackiShampine
46where
47 Y: Tensor,
48 U: TensorVec<Item = Y>,
49{
50}
51
52impl VariableStep for BogackiShampine {
53 fn abs_tol(&self) -> Scalar {
54 self.abs_tol
55 }
56 fn rel_tol(&self) -> Scalar {
57 self.rel_tol
58 }
59 fn dt_beta(&self) -> Scalar {
60 self.dt_beta
61 }
62 fn dt_expn(&self) -> Scalar {
63 self.dt_expn
64 }
65 fn dt_cut(&self) -> Scalar {
66 self.dt_cut
67 }
68 fn dt_min(&self) -> Scalar {
69 self.dt_min
70 }
71}
72
73impl<Y, U> Explicit<Y, U> for BogackiShampine
74where
75 Self: OdeSolver<Y, U>,
76 Y: Tensor,
77 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
78 U: TensorVec<Item = Y>,
79{
80 const SLOPES: usize = 4;
81 fn integrate(
82 &self,
83 function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
84 time: &[Scalar],
85 initial_condition: Y,
86 ) -> Result<(Vector, U, U), IntegrationError> {
87 self.integrate_variable_step(function, time, initial_condition)
88 }
89}
90
91pub fn slopes<Y>(
92 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
93 y: &Y,
94 t: Scalar,
95 dt: Scalar,
96 k: &mut [Y],
97 y_trial: &mut Y,
98) -> Result<(), String>
99where
100 Y: Tensor,
101 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
102{
103 *y_trial = &k[0] * (0.5 * dt) + y;
104 k[1] = function(t + 0.5 * dt, y_trial)?;
105 *y_trial = &k[1] * (0.75 * dt) + y;
106 k[2] = function(t + 0.75 * dt, y_trial)?;
107 *y_trial = (&k[0] * 2.0 + &k[1] * 3.0 + &k[2] * 4.0) * (dt / 9.0) + y;
108 Ok(())
109}
110
111impl<Y, U> VariableStepExplicit<Y, U> for BogackiShampine
112where
113 Self: OdeSolver<Y, U>,
114 Y: Tensor,
115 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
116 U: TensorVec<Item = Y>,
117{
118 fn slopes(
119 &self,
120 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
121 y: &Y,
122 t: Scalar,
123 dt: Scalar,
124 k: &mut [Y],
125 y_trial: &mut Y,
126 ) -> Result<Scalar, String> {
127 slopes(&mut function, y, t, dt, k, y_trial)?;
128 k[3] = function(t + dt, y_trial)?;
129 Ok(((&k[0] * -5.0 + &k[1] * 6.0 + &k[2] * 8.0 + &k[3] * -9.0) * (dt / 72.0)).norm_inf())
130 }
131 fn step(
132 &self,
133 _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
134 y: &mut Y,
135 t: &mut Scalar,
136 y_sol: &mut U,
137 t_sol: &mut Vector,
138 dydt_sol: &mut U,
139 dt: &mut Scalar,
140 k: &mut [Y],
141 y_trial: &Y,
142 e: Scalar,
143 ) -> Result<(), String> {
144 if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
145 k[0] = k[3].clone();
146 *t += *dt;
147 *y = y_trial.clone();
148 t_sol.push(*t);
149 y_sol.push(y.clone());
150 dydt_sol.push(k[0].clone());
151 }
152 if e > 0.0 {
154 *dt *= self.dt_beta() * (self.abs_tol() / e).powf(1.0 / self.dt_expn())
155 }
156 Ok(())
157 }
158}
159
160impl<Y, U> InterpolateSolution<Y, U> for BogackiShampine
161where
162 Y: Tensor,
163 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
164 U: TensorVec<Item = Y>,
165{
166 fn interpolate(
167 &self,
168 time: &Vector,
169 tp: &Vector,
170 yp: &U,
171 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
172 ) -> Result<(U, U), IntegrationError> {
173 let mut dt;
174 let mut i;
175 let mut k_1;
176 let mut k_2;
177 let mut k_3;
178 let mut t;
179 let mut y;
180 let mut y_int = U::new();
181 let mut dydt_int = U::new();
182 let mut y_trial;
183 for time_k in time.iter() {
184 i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
185 if time_k == &tp[i] {
186 t = tp[i];
187 y_trial = yp[i].clone();
188 dt = 0.0;
189 } else {
190 t = tp[i - 1];
191 y = yp[i - 1].clone();
192 dt = time_k - t;
193 k_1 = function(t, &y)?;
194 y_trial = &k_1 * (0.5 * dt) + &y;
195 k_2 = function(t + 0.5 * dt, &y_trial)?;
196 y_trial = &k_2 * (0.75 * dt) + &y;
197 k_3 = function(t + 0.75 * dt, &y_trial)?;
198 y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
199 }
200 dydt_int.push(function(t + dt, &y_trial)?);
201 y_int.push(y_trial);
202 }
203 Ok((y_int, dydt_int))
204 }
205}
206
207impl<Y, Z, U, V> ExplicitInternalVariables<Y, Z, U, V> for BogackiShampine
208where
209 Self: OdeSolver<Y, U>,
210 Y: Tensor,
211 Z: Tensor,
212 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
213 U: TensorVec<Item = Y>,
214 V: TensorVec<Item = Z>,
215{
216 fn integrate_and_evaluate(
217 &self,
218 function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
219 evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
220 time: &[Scalar],
221 initial_condition: Y,
222 initial_evaluation: Z,
223 ) -> Result<(Vector, U, U, V), IntegrationError> {
224 self.integrate_and_evaluate_variable_step(
225 function,
226 evaluate,
227 time,
228 initial_condition,
229 initial_evaluation,
230 )
231 }
232}
233
234impl<Y, Z, U, V> VariableStepExplicitInternalVariables<Y, Z, U, V> for BogackiShampine
235where
236 Self: OdeSolver<Y, U>,
237 Y: Tensor,
238 Z: Tensor,
239 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
240 U: TensorVec<Item = Y>,
241 V: TensorVec<Item = Z>,
242{
243 fn slopes(
244 &self,
245 mut function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
246 mut evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
247 y: &Y,
248 z: &Z,
249 t: Scalar,
250 dt: Scalar,
251 k: &mut [Y],
252 y_trial: &mut Y,
253 z_trial: &mut Z,
254 ) -> Result<Scalar, String> {
255 *y_trial = &k[0] * (0.5 * dt) + y;
256 *z_trial = evaluate(t + 0.5 * dt, y_trial, z)?;
257 k[1] = function(t + 0.5 * dt, y_trial, z_trial)?;
258 *y_trial = &k[1] * (0.75 * dt) + y;
259 *z_trial = evaluate(t + 0.75 * dt, y_trial, z_trial)?;
260 k[2] = function(t + 0.75 * dt, y_trial, z_trial)?;
261 *y_trial = (&k[0] * 2.0 + &k[1] * 3.0 + &k[2] * 4.0) * (dt / 9.0) + y;
262 *z_trial = evaluate(t + dt, y_trial, z_trial)?;
263 k[3] = function(t + dt, y_trial, z_trial)?;
264 Ok(((&k[0] * -5.0 + &k[1] * 6.0 + &k[2] * 8.0 + &k[3] * -9.0) * (dt / 72.0)).norm_inf())
265 }
266 fn step(
267 &self,
268 _function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
269 y: &mut Y,
270 z: &mut Z,
271 t: &mut Scalar,
272 y_sol: &mut U,
273 z_sol: &mut V,
274 t_sol: &mut Vector,
275 dydt_sol: &mut U,
276 dt: &mut Scalar,
277 k: &mut [Y],
278 y_trial: &Y,
279 z_trial: &Z,
280 e: Scalar,
281 ) -> Result<(), String> {
282 if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
283 k[0] = k[3].clone();
284 *t += *dt;
285 *y = y_trial.clone();
286 *z = z_trial.clone();
287 t_sol.push(*t);
288 y_sol.push(y.clone());
289 z_sol.push(z.clone());
290 dydt_sol.push(k[0].clone());
291 }
292 self.time_step(e, dt);
293 Ok(())
294 }
295}
296
297impl<Y, Z, U, V> InterpolateSolutionInternalVariables<Y, Z, U, V> for BogackiShampine
298where
299 Y: Tensor,
300 Z: Tensor,
301 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
302 U: TensorVec<Item = Y>,
303 V: TensorVec<Item = Z>,
304{
305 fn interpolate(
306 &self,
307 time: &Vector,
308 tp: &Vector,
309 yp: &U,
310 zp: &V,
311 mut function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
312 mut evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
313 ) -> Result<(U, U, V), IntegrationError> {
314 let mut dt;
315 let mut i;
316 let mut k_1;
317 let mut k_2;
318 let mut k_3;
319 let mut t;
320 let mut y;
321 let mut y_int = U::new();
322 let mut z_int = V::new();
323 let mut dydt_int = U::new();
324 let mut y_trial;
325 let mut z_trial;
326 for time_k in time.iter() {
327 i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
328 if time_k == &tp[i] {
329 t = tp[i];
330 y_trial = yp[i].clone();
331 z_trial = zp[i].clone();
332 dt = 0.0;
333 } else {
334 t = tp[i - 1];
335 y = yp[i - 1].clone();
336 z_trial = zp[i - 1].clone();
337 dt = time_k - t;
338 k_1 = function(t, &y, &z_trial)?;
339 y_trial = &k_1 * (0.5 * dt) + &y;
340 z_trial = evaluate(t + 0.5 * dt, &y_trial, &z_trial)?;
341 k_2 = function(t + 0.5 * dt, &y_trial, &z_trial)?;
342 y_trial = &k_2 * (0.75 * dt) + &y;
343 z_trial = evaluate(t + 0.75 * dt, &y_trial, &z_trial)?;
344 k_3 = function(t + 0.75 * dt, &y_trial, &z_trial)?;
345 y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
346 z_trial = evaluate(t + dt, &y_trial, &z_trial)?;
347 }
348 dydt_int.push(function(t + dt, &y_trial, &z_trial)?);
349 y_int.push(y_trial);
350 z_int.push(z_trial);
351 }
352 Ok((y_int, dydt_int, z_int))
353 }
354}