conspire/math/integrate/ode/implicit/midpoint/
mod.rs

1#[cfg(test)]
2mod test;
3
4use crate::math::{
5    Scalar, Tensor, TensorArray, TensorVec,
6    integrate::{FixedStep, ImplicitFirstOrder, ImplicitZerothOrder, IntegrationError, OdeSolver},
7};
8use std::{
9    fmt::Debug,
10    ops::{Add, Mul, Sub},
11};
12
13#[doc = include_str!("doc.md")]
14#[derive(Debug, Default)]
15pub struct Midpoint {
16    /// Fixed value for the time step.
17    dt: Scalar,
18}
19
20impl<Y, U> OdeSolver<Y, U> for Midpoint
21where
22    Y: Tensor,
23    U: TensorVec<Item = Y>,
24{
25}
26
27impl FixedStep for Midpoint {
28    fn dt(&self) -> Scalar {
29        self.dt
30    }
31}
32
33impl<Y, U> ImplicitZerothOrder<Y, U> for Midpoint
34where
35    Y: Tensor,
36    for<'a> &'a Y: Mul<Scalar, Output = Y> + Add<&'a Y, Output = Y> + Sub<&'a Y, Output = Y>,
37    U: TensorVec<Item = Y>,
38{
39    fn residual(
40        &self,
41        mut function: impl FnMut(Scalar, &Y) -> Result<Y, IntegrationError>,
42        t: Scalar,
43        y: &Y,
44        _t_trial: Scalar,
45        y_trial: &Y,
46        dt: Scalar,
47    ) -> Result<Y, String> {
48        Ok(y_trial - y - function(t + 0.5 * dt, &((y + y_trial) * 0.5))? * dt)
49    }
50}
51
52impl<Y, J, U> ImplicitFirstOrder<Y, J, U> for Midpoint
53where
54    Y: Tensor,
55    for<'a> &'a Y: Mul<Scalar, Output = Y> + Add<&'a Y, Output = Y> + Sub<&'a Y, Output = Y>,
56    J: Tensor + TensorArray,
57    U: TensorVec<Item = Y>,
58{
59    fn hessian(
60        &self,
61        mut jacobian: impl FnMut(Scalar, &Y) -> Result<J, IntegrationError>,
62        t: Scalar,
63        y: &Y,
64        _t_trial: Scalar,
65        y_trial: &Y,
66        dt: Scalar,
67    ) -> Result<J, String> {
68        Ok(J::identity() - jacobian(t + 0.5 * dt, &((y + y_trial) * 0.5))? * (dt * 0.5))
69    }
70}