conspire/math/integrate/ode/implicit/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{
5    Scalar, Tensor, TensorVec, Vector,
6    integrate::{FixedStep, IntegrationError, OdeSolver},
7    optimize::{EqualityConstraint, FirstOrderRootFinding, ZerothOrderRootFinding},
8};
9
10pub mod backward_euler;
11pub mod midpoint;
12pub mod trapezoidal;
13
14/// Zeroth-order implicit ordinary differential equation solvers.
15pub trait ImplicitZerothOrder<Y, U>
16where
17    Self: FixedStep + OdeSolver<Y, U>,
18    Y: Tensor,
19    U: TensorVec<Item = Y>,
20{
21    #[doc = include_str!("doc.md")]
22    fn integrate(
23        &self,
24        mut function: impl FnMut(Scalar, &Y) -> Result<Y, IntegrationError>,
25        time: &[Scalar],
26        initial_condition: Y,
27        solver: impl ZerothOrderRootFinding<Y>,
28    ) -> Result<(Vector, U, U), IntegrationError> {
29        let t_0 = time[0];
30        let t_f = time[time.len() - 1];
31        let mut t_sol: Vector;
32        if time.len() < 2 {
33            return Err(IntegrationError::LengthTimeLessThanTwo);
34        } else if t_0 >= t_f {
35            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
36        } else if time.len() == 2 {
37            if self.dt() <= 0.0 || self.dt().is_nan() {
38                return Err(IntegrationError::TimeStepNotSet(
39                    time[0],
40                    time[1],
41                    format!("{self:?}"),
42                ));
43            } else {
44                let num_steps = ((t_f - t_0) / self.dt()).ceil() as usize;
45                t_sol = (0..num_steps)
46                    .map(|step| t_0 + (step as Scalar) * self.dt())
47                    .collect();
48                t_sol.push(t_f);
49            }
50        } else {
51            t_sol = time.iter().copied().collect();
52        }
53        let mut index = 0;
54        let mut t = t_0;
55        let mut dt;
56        let mut t_trial;
57        let mut y = initial_condition.clone();
58        let mut y_sol = U::new();
59        y_sol.push(initial_condition.clone());
60        let mut dydt_sol = U::new();
61        dydt_sol.push(function(t, &y.clone())?);
62        let mut y_trial;
63        while t < t_f {
64            t_trial = t_sol[index + 1];
65            dt = t_trial - t;
66            y_trial = match solver.root(
67                |y_trial: &Y| self.residual(&mut function, t, &y, t_trial, y_trial, dt),
68                y.clone(),
69                EqualityConstraint::None,
70            ) {
71                Ok(solution) => solution,
72                Err(error) => {
73                    return Err(IntegrationError::Upstream(
74                        format!("{error}"),
75                        format!("{self:?}"),
76                    ));
77                }
78            };
79            t = t_trial;
80            y = y_trial;
81            y_sol.push(y.clone());
82            dydt_sol.push(function(t, &y)?);
83            index += 1;
84        }
85        Ok((t_sol, y_sol, dydt_sol))
86    }
87    fn residual(
88        &self,
89        function: impl FnMut(Scalar, &Y) -> Result<Y, IntegrationError>,
90        t: Scalar,
91        y: &Y,
92        t_trial: Scalar,
93        y_trial: &Y,
94        dt: Scalar,
95    ) -> Result<Y, String>;
96}
97
98/// First-order implicit ordinary differential equation solvers.
99pub trait ImplicitFirstOrder<Y, J, U>
100where
101    Self: ImplicitZerothOrder<Y, U>,
102    Y: Tensor,
103    U: TensorVec<Item = Y>,
104{
105    #[doc = include_str!("doc.md")]
106    fn integrate(
107        &self,
108        mut function: impl FnMut(Scalar, &Y) -> Result<Y, IntegrationError>,
109        mut jacobian: impl FnMut(Scalar, &Y) -> Result<J, IntegrationError>,
110        time: &[Scalar],
111        initial_condition: Y,
112        solver: impl FirstOrderRootFinding<Y, J, Y>,
113    ) -> Result<(Vector, U, U), IntegrationError> {
114        let t_0 = time[0];
115        let t_f = time[time.len() - 1];
116        let mut t_sol: Vector;
117        if time.len() < 2 {
118            return Err(IntegrationError::LengthTimeLessThanTwo);
119        } else if t_0 >= t_f {
120            return Err(IntegrationError::InitialTimeNotLessThanFinalTime);
121        } else if time.len() == 2 {
122            if self.dt() <= 0.0 || self.dt().is_nan() {
123                return Err(IntegrationError::TimeStepNotSet(
124                    time[0],
125                    time[1],
126                    format!("{self:?}"),
127                ));
128            } else {
129                let num_steps = ((t_f - t_0) / self.dt()).ceil() as usize;
130                t_sol = (0..num_steps)
131                    .map(|step| t_0 + (step as Scalar) * self.dt())
132                    .collect();
133                t_sol.push(t_f);
134            }
135        } else {
136            t_sol = time.iter().copied().collect();
137        }
138        let mut index = 0;
139        let mut t = t_0;
140        let mut dt;
141        let mut t_trial;
142        let mut y = initial_condition.clone();
143        let mut y_sol = U::new();
144        y_sol.push(initial_condition.clone());
145        let mut dydt_sol = U::new();
146        dydt_sol.push(function(t, &y.clone())?);
147        let mut y_trial;
148        while t < t_f {
149            t_trial = t_sol[index + 1];
150            dt = t_trial - t;
151            y_trial = match solver.root(
152                |y_trial: &Y| self.residual(&mut function, t, &y, t_trial, y_trial, dt),
153                |y_trial: &Y| self.hessian(&mut jacobian, t, &y, t_trial, y_trial, dt),
154                y.clone(),
155                EqualityConstraint::None,
156            ) {
157                Ok(solution) => solution,
158                Err(error) => {
159                    return Err(IntegrationError::Upstream(
160                        format!("{error}"),
161                        format!("{self:?}"),
162                    ));
163                }
164            };
165            t = t_trial;
166            y = y_trial;
167            y_sol.push(y.clone());
168            dydt_sol.push(function(t, &y)?);
169            index += 1;
170        }
171        Ok((t_sol, y_sol, dydt_sol))
172    }
173    fn hessian(
174        &self,
175        jacobian: impl FnMut(Scalar, &Y) -> Result<J, IntegrationError>,
176        t: Scalar,
177        y: &Y,
178        t_trial: Scalar,
179        y_trial: &Y,
180        dt: Scalar,
181    ) -> Result<J, String>;
182}