conspire/math/integrate/backward_euler/
mod.rs

1#[cfg(test)]
2mod test;
3
4use super::{
5    super::{
6        Hessian, Jacobian, Scalar, Solution, Tensor, TensorArray, TensorVec, Vector,
7        interpolate::InterpolateSolution,
8        optimize::{EqualityConstraint, FirstOrderRootFinding, ZerothOrderRootFinding},
9    },
10    ImplicitFirstOrder, ImplicitZerothOrder, IntegrationError, OdeSolver,
11};
12use crate::ABS_TOL;
13use std::{
14    fmt::Debug,
15    ops::{Div, Mul, Sub},
16};
17
18#[doc = include_str!("doc.md")]
19#[derive(Debug)]
20pub struct BackwardEuler {
21    /// Cut back factor for the time step.
22    pub dt_cut: Scalar,
23    /// Minimum value for the time step.
24    pub dt_min: Scalar,
25}
26
27impl Default for BackwardEuler {
28    fn default() -> Self {
29        Self {
30            dt_cut: 0.5,
31            dt_min: ABS_TOL,
32        }
33    }
34}
35
36impl<Y, U> OdeSolver<Y, U> for BackwardEuler
37where
38    Y: Tensor,
39    U: TensorVec<Item = Y>,
40{
41    fn abs_tol(&self) -> Scalar {
42        unimplemented!()
43    }
44    fn dt_cut(&self) -> Scalar {
45        self.dt_cut
46    }
47    fn dt_min(&self) -> Scalar {
48        self.dt_min
49    }
50}
51
52impl<Y, U> ImplicitZerothOrder<Y, U> for BackwardEuler
53where
54    Self: InterpolateSolution<Y, U>,
55    Y: Solution,
56    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
57    U: TensorVec<Item = Y>,
58    Vector: From<Y>,
59{
60    fn integrate(
61        &self,
62        function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
63        time: &[Scalar],
64        initial_condition: Y,
65        solver: impl ZerothOrderRootFinding<Y>,
66    ) -> Result<(Vector, U, U), IntegrationError> {
67        let t_0 = time[0];
68        let t_f = time[time.len() - 1];
69        if time.len() < 2 {
70            return Err(IntegrationError::LengthTimeLessThanTwo);
71        } else if t_0 >= t_f {
72            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
73        }
74        let mut index = 0;
75        let mut t = t_0;
76        let mut dt;
77        let mut t_sol = Vector::new();
78        t_sol.push(t_0);
79        let mut t_trial;
80        let mut y = initial_condition.clone();
81        let mut y_sol = U::new();
82        y_sol.push(initial_condition.clone());
83        let mut dydt_sol = U::new();
84        dydt_sol.push(function(t, &y.clone())?);
85        let mut y_trial;
86        while t < t_f {
87            t_trial = time[index + 1];
88            dt = t_trial - t;
89            y_trial = match solver.root(
90                |y_trial: &Y| Ok(y_trial - &y - &(&function(t_trial, y_trial)? * dt)),
91                y.clone(),
92                EqualityConstraint::None,
93            ) {
94                Ok(solution) => solution,
95                Err(error) => {
96                    return Err(IntegrationError::Upstream(
97                        format!("{error:?}"),
98                        format!("{self:?}"),
99                    ));
100                }
101            };
102            t = t_trial;
103            y = y_trial;
104            t_sol.push(t);
105            y_sol.push(y.clone());
106            dydt_sol.push(function(t, &y)?);
107            index += 1;
108        }
109        Ok((t_sol, y_sol, dydt_sol))
110    }
111}
112
113impl<Y, J, U> ImplicitFirstOrder<Y, J, U> for BackwardEuler
114where
115    Self: InterpolateSolution<Y, U>,
116    Y: Jacobian + Solution + Div<J, Output = Y>,
117    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
118    J: Hessian + TensorArray,
119    U: TensorVec<Item = Y>,
120    Vector: From<Y>,
121{
122    fn integrate(
123        &self,
124        function: impl Fn(Scalar, &Y) -> Result<Y, IntegrationError>,
125        jacobian: impl Fn(Scalar, &Y) -> Result<J, IntegrationError>,
126        time: &[Scalar],
127        initial_condition: Y,
128        solver: impl FirstOrderRootFinding<Y, J, Y>,
129    ) -> Result<(Vector, U, U), IntegrationError> {
130        let t_0 = time[0];
131        let t_f = time[time.len() - 1];
132        if time.len() < 2 {
133            return Err(IntegrationError::LengthTimeLessThanTwo);
134        } else if t_0 >= t_f {
135            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
136        }
137        let mut index = 0;
138        let mut t = t_0;
139        let mut dt;
140        let identity = J::identity();
141        let mut t_sol = Vector::new();
142        t_sol.push(t_0);
143        let mut t_trial;
144        let mut y = initial_condition.clone();
145        let mut y_sol = U::new();
146        y_sol.push(initial_condition.clone());
147        let mut dydt_sol = U::new();
148        dydt_sol.push(function(t, &y.clone())?);
149        let mut y_trial;
150        while t < t_f {
151            t_trial = time[index + 1];
152            dt = t_trial - t;
153            y_trial = match solver.root(
154                |y_trial: &Y| Ok(y_trial - &y - &(&function(t_trial, y_trial)? * dt)),
155                |y_trial: &Y| Ok(jacobian(t_trial, y_trial)? * -dt + &identity),
156                y.clone(),
157                EqualityConstraint::None,
158            ) {
159                Ok(solution) => solution,
160                Err(error) => {
161                    return Err(IntegrationError::Upstream(
162                        format!("{error:?}"),
163                        format!("{self:?}"),
164                    ));
165                }
166            };
167            t = t_trial;
168            y = y_trial;
169            t_sol.push(t);
170            y_sol.push(y.clone());
171            dydt_sol.push(function(t, &y)?);
172            index += 1;
173        }
174        Ok((t_sol, y_sol, dydt_sol))
175    }
176}
177
178impl<Y, U> InterpolateSolution<Y, U> for BackwardEuler
179where
180    Y: Jacobian + Solution + TensorArray,
181    for<'a> &'a Y: Mul<Scalar, Output = Y> + Sub<&'a Y, Output = Y>,
182    U: TensorVec<Item = Y>,
183    Vector: From<Y>,
184{
185    fn interpolate(
186        &self,
187        _time: &Vector,
188        _tp: &Vector,
189        _yp: &U,
190        mut _function: impl FnMut(Scalar, &Y) -> Result<Y, String>,
191    ) -> Result<(U, U), IntegrationError> {
192        unimplemented!()
193    }
194}