1#[cfg(test)]
2mod test;
3
4use super::{
5 super::{
6 Scalar, Tensor, TensorVec, Vector,
7 interpolate::{InterpolateSolution, InterpolateSolutionIV},
8 },
9 Explicit, ExplicitIV, IntegrationError, OdeSolver,
10};
11use crate::{ABS_TOL, REL_TOL};
12use std::ops::{Mul, Sub};
13
14#[doc = include_str!("doc.md")]
15#[derive(Debug)]
16pub struct BogackiShampine {
17 pub abs_tol: Scalar,
19 pub rel_tol: Scalar,
21 pub dt_beta: Scalar,
23 pub dt_expn: Scalar,
25 pub dt_cut: Scalar,
27 pub dt_min: Scalar,
29}
30
31impl Default for BogackiShampine {
32 fn default() -> Self {
33 Self {
34 abs_tol: ABS_TOL,
35 rel_tol: REL_TOL,
36 dt_beta: 0.9,
37 dt_expn: 3.0,
38 dt_cut: 0.5,
39 dt_min: ABS_TOL,
40 }
41 }
42}
43
44impl<Y, U> OdeSolver<Y, U> for BogackiShampine
45where
46 Y: Tensor,
47 U: TensorVec<Item = Y>,
48{
49 fn abs_tol(&self) -> Scalar {
50 self.abs_tol
51 }
52 fn dt_cut(&self) -> Scalar {
53 self.dt_cut
54 }
55 fn dt_min(&self) -> Scalar {
56 self.dt_min
57 }
58}
59
60impl<Y, U> Explicit<Y, U> for BogackiShampine
61where
62 Self: OdeSolver<Y, U>,
63 Y: Tensor,
64 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
65 U: TensorVec<Item = Y>,
66{
67 const SLOPES: usize = 4;
68 fn dt_beta(&self) -> Scalar {
69 self.dt_beta
70 }
71 fn dt_expn(&self) -> Scalar {
72 self.dt_expn
73 }
74 fn slopes(
75 &self,
76 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
77 y: &Y,
78 t: Scalar,
79 dt: Scalar,
80 k: &mut [Y],
81 y_trial: &mut Y,
82 ) -> Result<Scalar, String> {
83 *y_trial = &k[0] * (0.5 * dt) + y;
84 k[1] = function(t + 0.5 * dt, y_trial)?;
85 *y_trial = &k[1] * (0.75 * dt) + y;
86 k[2] = function(t + 0.75 * dt, y_trial)?;
87 *y_trial = (&k[0] * 2.0 + &k[1] * 3.0 + &k[2] * 4.0) * (dt / 9.0) + y;
88 k[3] = function(t + dt, y_trial)?;
89 Ok(((&k[0] * -5.0 + &k[1] * 6.0 + &k[2] * 8.0 + &k[3] * -9.0) * (dt / 72.0)).norm_inf())
90 }
91 fn step(
92 &self,
93 _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
94 y: &mut Y,
95 t: &mut Scalar,
96 y_sol: &mut U,
97 t_sol: &mut Vector,
98 dydt_sol: &mut U,
99 dt: &mut Scalar,
100 k: &mut [Y],
101 y_trial: &Y,
102 e: Scalar,
103 ) -> Result<(), String> {
104 if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
105 k[0] = k[3].clone();
106 *t += *dt;
107 *y = y_trial.clone();
108 t_sol.push(*t);
109 y_sol.push(y.clone());
110 dydt_sol.push(k[0].clone());
111 }
112 if e > 0.0 {
114 *dt *= self.dt_beta() * (self.abs_tol() / e).powf(1.0 / self.dt_expn())
115 }
116 Ok(())
117 }
118}
119
120impl<Y, U> InterpolateSolution<Y, U> for BogackiShampine
121where
122 Y: Tensor,
123 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
124 U: TensorVec<Item = Y>,
125{
126 fn interpolate(
127 &self,
128 time: &Vector,
129 tp: &Vector,
130 yp: &U,
131 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
132 ) -> Result<(U, U), IntegrationError> {
133 let mut dt;
134 let mut i;
135 let mut k_1;
136 let mut k_2;
137 let mut k_3;
138 let mut t;
139 let mut y;
140 let mut y_int = U::new();
141 let mut dydt_int = U::new();
142 let mut y_trial;
143 for time_k in time.iter() {
144 i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
145 if time_k == &tp[i] {
146 t = tp[i];
147 y_trial = yp[i].clone();
148 dt = 0.0;
149 } else {
150 t = tp[i - 1];
151 y = yp[i - 1].clone();
152 dt = time_k - t;
153 k_1 = function(t, &y)?;
154 y_trial = &k_1 * (0.5 * dt) + &y;
155 k_2 = function(t + 0.5 * dt, &y_trial)?;
156 y_trial = &k_2 * (0.75 * dt) + &y;
157 k_3 = function(t + 0.75 * dt, &y_trial)?;
158 y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
159 }
160 dydt_int.push(function(t + dt, &y_trial)?);
161 y_int.push(y_trial);
162 }
163 Ok((y_int, dydt_int))
164 }
165}
166
167impl<Y, Z, U, V> ExplicitIV<Y, Z, U, V> for BogackiShampine
168where
169 Self: OdeSolver<Y, U>,
170 Y: Tensor,
171 Z: Tensor,
172 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
173 U: TensorVec<Item = Y>,
174 V: TensorVec<Item = Z>,
175{
176 const SLOPES: usize = 4;
177 fn slopes(
178 &self,
179 mut function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
180 mut evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
181 y: &Y,
182 z: &Z,
183 t: Scalar,
184 dt: Scalar,
185 k: &mut [Y],
186 y_trial: &mut Y,
187 z_trial: &mut Z,
188 ) -> Result<Scalar, String> {
189 *y_trial = &k[0] * (0.5 * dt) + y;
190 *z_trial = evaluate(t + 0.5 * dt, y_trial, z)?;
191 k[1] = function(t + 0.5 * dt, y_trial, z_trial)?;
192 *y_trial = &k[1] * (0.75 * dt) + y;
193 *z_trial = evaluate(t + 0.75 * dt, y_trial, z_trial)?;
194 k[2] = function(t + 0.75 * dt, y_trial, z_trial)?;
195 *y_trial = (&k[0] * 2.0 + &k[1] * 3.0 + &k[2] * 4.0) * (dt / 9.0) + y;
196 *z_trial = evaluate(t + dt, y_trial, z_trial)?;
197 k[3] = function(t + dt, y_trial, z_trial)?;
198 Ok(((&k[0] * -5.0 + &k[1] * 6.0 + &k[2] * 8.0 + &k[3] * -9.0) * (dt / 72.0)).norm_inf())
199 }
200 fn step(
201 &self,
202 _function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
203 y: &mut Y,
204 z: &mut Z,
205 t: &mut Scalar,
206 y_sol: &mut U,
207 z_sol: &mut V,
208 t_sol: &mut Vector,
209 dydt_sol: &mut U,
210 dt: &mut Scalar,
211 k: &mut [Y],
212 y_trial: &Y,
213 z_trial: &Z,
214 e: Scalar,
215 ) -> Result<(), String> {
216 if e < self.abs_tol || e / y_trial.norm_inf() < self.rel_tol {
217 k[0] = k[3].clone();
218 *t += *dt;
219 *y = y_trial.clone();
220 *z = z_trial.clone();
221 t_sol.push(*t);
222 y_sol.push(y.clone());
223 z_sol.push(z.clone());
224 dydt_sol.push(k[0].clone());
225 }
226 self.time_step(e, dt);
227 Ok(())
228 }
229}
230
231impl<Y, Z, U, V> InterpolateSolutionIV<Y, Z, U, V> for BogackiShampine
232where
233 Y: Tensor,
234 Z: Tensor,
235 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
236 U: TensorVec<Item = Y>,
237 V: TensorVec<Item = Z>,
238{
239 fn interpolate(
240 &self,
241 time: &Vector,
242 tp: &Vector,
243 yp: &U,
244 zp: &V,
245 mut function: impl FnMut(Scalar, &Y, &Z) -> Result<Y, String>,
246 mut evaluate: impl FnMut(Scalar, &Y, &Z) -> Result<Z, String>,
247 ) -> Result<(U, U, V), IntegrationError> {
248 let mut dt;
249 let mut i;
250 let mut k_1;
251 let mut k_2;
252 let mut k_3;
253 let mut t;
254 let mut y;
255 let mut y_int = U::new();
256 let mut z_int = V::new();
257 let mut dydt_int = U::new();
258 let mut y_trial;
259 let mut z_trial;
260 for time_k in time.iter() {
261 i = tp.iter().position(|tp_i| tp_i >= time_k).unwrap();
262 if time_k == &tp[i] {
263 t = tp[i];
264 y_trial = yp[i].clone();
265 z_trial = zp[i].clone();
266 dt = 0.0;
267 } else {
268 t = tp[i - 1];
269 y = yp[i - 1].clone();
270 z_trial = zp[i - 1].clone();
271 dt = time_k - t;
272 k_1 = function(t, &y, &z_trial)?;
273 y_trial = &k_1 * (0.5 * dt) + &y;
274 z_trial = evaluate(t + 0.5 * dt, &y_trial, &z_trial)?;
275 k_2 = function(t + 0.5 * dt, &y_trial, &z_trial)?;
276 y_trial = &k_2 * (0.75 * dt) + &y;
277 z_trial = evaluate(t + 0.75 * dt, &y_trial, &z_trial)?;
278 k_3 = function(t + 0.75 * dt, &y_trial, &z_trial)?;
279 y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
280 z_trial = evaluate(t + dt, &y_trial, &z_trial)?;
281 }
282 dydt_int.push(function(t + dt, &y_trial, &z_trial)?);
283 y_int.push(y_trial);
284 z_int.push(z_trial);
285 }
286 Ok((y_int, dydt_int, z_int))
287 }
288}