1#[cfg(test)]
2mod test;
3
4mod backward_euler;
5mod bogacki_shampine;
6mod dormand_prince;
7mod verner_8;
8mod verner_9;
9
10pub use backward_euler::BackwardEuler;
11pub use bogacki_shampine::BogackiShampine;
12pub use dormand_prince::DormandPrince;
13pub use verner_8::Verner8;
14pub use verner_9::Verner9;
15
16pub type Ode1be = BackwardEuler;
17pub type Ode23 = BogackiShampine;
18pub type Ode45 = DormandPrince;
19pub type Ode78 = Verner8;
20pub type Ode89 = Verner9;
21
22use super::{
25 Scalar, Solution, Tensor, TensorArray, TensorVec, TestError, Vector,
26 interpolate::InterpolateSolution,
27 optimize::{FirstOrderRootFinding, ZerothOrderRootFinding},
28};
29use crate::defeat_message;
30use std::{
31 fmt::{self, Debug, Display, Formatter},
32 ops::{Div, Mul, Sub},
33};
34
35pub trait OdeSolver<Y, U>
37where
38 Self: Debug,
39 Y: Tensor,
40 U: TensorVec<Item = Y>,
41{
42}
43
44impl<A, Y, U> OdeSolver<Y, U> for A
45where
46 A: Debug,
47 Y: Tensor,
48 U: TensorVec<Item = Y>,
49{
50}
51
52pub trait Explicit<Y, U>: OdeSolver<Y, U>
54where
55 Self: InterpolateSolution<Y, U>,
56 Y: Tensor,
57 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
58 U: TensorVec<Item = Y>,
59{
60 const SLOPES: usize;
61 fn integrate(
67 &self,
68 mut function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
69 time: &[Scalar],
70 initial_condition: Y,
71 ) -> Result<(Vector, U, U), IntegrationError> {
72 let t_0 = time[0];
73 let t_f = time[time.len() - 1];
74 if time.len() < 2 {
75 return Err(IntegrationError::LengthTimeLessThanTwo);
76 } else if t_0 >= t_f {
77 return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
78 }
79 let mut t = t_0;
80 let mut dt = t_f;
81 let mut e;
82 let mut k = vec![Y::default(); Self::SLOPES];
83 k[0] = function(t, &initial_condition)?;
84 let mut t_sol = Vector::zero(0);
85 t_sol.push(t_0);
86 let mut y = initial_condition.clone();
87 let mut y_sol = U::zero(0);
88 y_sol.push(initial_condition.clone());
89 let mut dydt_sol = U::zero(0);
90 dydt_sol.push(k[0].clone());
91 let mut y_trial = Y::default();
92 while t < t_f {
93 e = match self.slopes(&mut function, &y, &t, &dt, &mut k, &mut y_trial) {
94 Ok(e) => e,
95 Err(error) => {
96 return Err(IntegrationError::Upstream(error, format!("{self:?}")));
97 }
98 };
99 match self.step(
100 &mut function,
101 &mut y,
102 &mut t,
103 &mut y_sol,
104 &mut t_sol,
105 &mut dydt_sol,
106 &mut dt,
107 &mut k,
108 &y_trial,
109 &e,
110 ) {
111 Ok(e) => e,
112 Err(error) => {
113 return Err(IntegrationError::Upstream(error, format!("{self:?}")));
114 }
115 };
116 dt = dt.min(t_f - t);
117 }
118 if time.len() > 2 {
119 let t_int = Vector::new(time);
120 let (y_int, dydt_int) = self.interpolate(&t_int, &t_sol, &y_sol, function)?;
121 Ok((t_int, y_int, dydt_int))
122 } else {
123 Ok((t_sol, y_sol, dydt_sol))
124 }
125 }
126 #[doc(hidden)]
127 fn slopes(
128 &self,
129 function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
130 y: &Y,
131 t: &Scalar,
132 dt: &Scalar,
133 k: &mut [Y],
134 y_trial: &mut Y,
135 ) -> Result<Scalar, String>;
136 #[allow(clippy::too_many_arguments)]
137 #[doc(hidden)]
138 fn step(
139 &self,
140 function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
141 y: &mut Y,
142 t: &mut Scalar,
143 y_sol: &mut U,
144 t_sol: &mut Vector,
145 dydt_sol: &mut U,
146 dt: &mut Scalar,
147 k: &mut [Y],
148 y_trial: &Y,
149 e: &Scalar,
150 ) -> Result<(), String>;
151}
152
153pub trait ImplicitZerothOrder<Y, U>: OdeSolver<Y, U>
155where
156 Self: InterpolateSolution<Y, U>,
157 Y: Solution,
158 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
159 U: TensorVec<Item = Y>,
160{
161 fn integrate(
167 &self,
168 function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
169 time: &[Scalar],
170 initial_condition: Y,
171 solver: impl ZerothOrderRootFinding<Y>,
172 ) -> Result<(Vector, U, U), IntegrationError>;
173}
174
175pub trait ImplicitFirstOrder<Y, J, U>: OdeSolver<Y, U>
177where
178 Self: InterpolateSolution<Y, U>,
179 Y: Solution + Div<J, Output = Y>,
180 for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
181 J: Tensor + TensorArray,
182 U: TensorVec<Item = Y>,
183{
184 fn integrate(
190 &self,
191 function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
192 jacobian: impl Fn(Scalar, &Y) -> Result<J, IntegrationError>,
193 time: &[Scalar],
194 initial_condition: Y,
195 solver: impl FirstOrderRootFinding<Y, J, Y>,
196 ) -> Result<(Vector, U, U), IntegrationError>;
197}
198
199pub enum IntegrationError {
201 InitialTimeNotLessThanFinalTime,
202 Intermediate(String),
203 LengthTimeLessThanTwo,
204 Upstream(String, String),
205}
206
207impl From<String> for IntegrationError {
208 fn from(error: String) -> Self {
209 Self::Intermediate(error)
210 }
211}
212
213impl Debug for IntegrationError {
214 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
215 let error = match self {
216 Self::InitialTimeNotLessThanFinalTime => {
217 "\x1b[1;91mThe initial time must precede the final time.".to_string()
218 }
219 Self::Intermediate(message) => message.to_string(),
220 Self::LengthTimeLessThanTwo => {
221 "\x1b[1;91mThe time must contain at least two entries.".to_string()
222 }
223 Self::Upstream(error, integrator) => {
224 format!(
225 "{error}\x1b[0;91m\n\
226 In integrator: {integrator}."
227 )
228 }
229 };
230 write!(f, "\n{}\n\x1b[0;2;31m{}\x1b[0m\n", error, defeat_message())
231 }
232}
233
234impl Display for IntegrationError {
235 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
236 let error = match self {
237 Self::InitialTimeNotLessThanFinalTime => {
238 "\x1b[1;91mThe initial time must precede the final time.".to_string()
239 }
240 Self::Intermediate(message) => message.to_string(),
241 Self::LengthTimeLessThanTwo => {
242 "\x1b[1;91mThe time must contain at least two entries.".to_string()
243 }
244 Self::Upstream(error, integrator) => {
245 format!(
246 "{error}\x1b[0;91m\n\
247 In integrator: {integrator}."
248 )
249 }
250 };
251 write!(f, "{error}\x1b[0m")
252 }
253}
254
255impl From<IntegrationError> for String {
256 fn from(error: IntegrationError) -> Self {
257 format!("{}", error)
258 }
259}
260
261impl From<IntegrationError> for TestError {
262 fn from(error: IntegrationError) -> Self {
263 TestError {
264 message: error.to_string(),
265 }
266 }
267}