conspire/math/optimize/line_search/
mod.rs1use crate::{
2    defeat_message,
3    math::{Jacobian, Scalar, Solution},
4};
5use std::{
6    fmt::{self, Debug, Display, Formatter},
7    ops::Mul,
8};
9
10#[derive(Debug)]
12pub enum LineSearch {
13    Armijo {
15        control: Scalar,
16        cut_back: Scalar,
17        max_steps: usize,
18    },
19    Error { cut_back: Scalar, max_steps: usize },
21    Goldstein {
23        control: Scalar,
24        cut_back: Scalar,
25        max_steps: usize,
26    },
27    Wolfe {
29        control_1: Scalar,
30        control_2: Scalar,
31        cut_back: Scalar,
32        max_steps: usize,
33        strong: bool,
34    },
35    None,
37}
38
39impl Default for LineSearch {
40    fn default() -> Self {
41        Self::Armijo {
42            control: 1e-3,
43            cut_back: 9e-1,
44            max_steps: 100,
45        }
46    }
47}
48
49impl Display for LineSearch {
50    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
51        match self {
52            Self::Armijo { .. } => write!(f, "Armijo"),
53            Self::Error { .. } => write!(f, "Error"),
54            Self::Goldstein { .. } => write!(f, "Goldstein"),
55            Self::Wolfe { .. } => write!(f, "Wolfe"),
56            Self::None { .. } => write!(f, "None"),
57        }
58    }
59}
60
61impl LineSearch {
62    pub fn backtrack<X, J>(
63        &self,
64        function: impl Fn(&X) -> Result<Scalar, String>,
65        jacobian: impl Fn(&X) -> Result<J, String>,
66        argument: &X,
67        jacobian0: &J,
68        decrement: &X,
69        step_size: Scalar,
70    ) -> Result<Scalar, LineSearchError>
71    where
72        J: Jacobian,
73        for<'a> &'a J: From<&'a X>,
74        X: Solution,
75        for<'a> &'a X: Mul<Scalar, Output = X>,
76    {
77        if step_size <= 0.0 {
78            return Err(LineSearchError::NegativeStepSize(
79                format!("{self:?}"),
80                step_size,
81            ));
82        }
83        let mut n = step_size;
84        let f = if let Ok(value) = function(argument) {
85            value
86        } else {
87            return Err(LineSearchError::InvalidStartingPoint(format!("{self:?}")));
88        };
89        let m = jacobian0.full_contraction(decrement.into());
90        if m <= 0.0 {
91            return Err(LineSearchError::NotDescentDirection(format!("{self:?}")));
92        }
93        match self {
94            Self::Armijo {
95                control,
96                cut_back,
97                max_steps,
98            } => {
99                let mut f_n;
100                let t = control * m;
101                for _ in 0..*max_steps {
102                    f_n = function(&(decrement * -n + argument));
103                    if let Ok(value) = f_n
104                        && f - value >= n * t
105                    {
106                        return Ok(n);
107                    } else {
108                        n *= cut_back
109                    }
110                }
111                Err(LineSearchError::MaximumStepsReached(
112                    format!("{self:?}"),
113                    *max_steps,
114                ))
115            }
116            Self::Error {
117                cut_back,
118                max_steps,
119            } => {
120                for _ in 0..*max_steps {
121                    if function(&(decrement * -n + argument)).is_ok() {
122                        return Ok(n);
123                    } else {
124                        n *= cut_back
125                    }
126                }
127                Err(LineSearchError::MaximumStepsReached(
128                    format!("{self:?}"),
129                    *max_steps,
130                ))
131            }
132            Self::Goldstein {
133                control,
134                cut_back,
135                max_steps,
136            } => {
137                let mut f_n;
138                let t = control * m;
139                let u = (1.0 - control) * m;
140                let mut v;
141                for _ in 0..*max_steps {
142                    f_n = function(&(decrement * -n + argument));
143                    if let Ok(value) = f_n {
144                        v = f - value;
145                        if n * u < v || v < n * t {
146                            n *= cut_back
147                        } else {
148                            return Ok(n);
149                        }
150                    } else {
151                        n *= cut_back
152                    }
153                }
154                Err(LineSearchError::MaximumStepsReached(
155                    format!("{self:?}"),
156                    *max_steps,
157                ))
158            }
159            Self::Wolfe {
160                control_1,
161                control_2,
162                cut_back,
163                max_steps,
164                strong,
165            } => {
166                let mut f_n;
167                let mut j_n;
168                let t_1 = control_1 * m;
169                let t_2 = control_2 * m;
170                let mut trial_argument = decrement * -n + argument;
171                for _ in 0..*max_steps {
172                    f_n = function(&trial_argument);
173                    j_n = jacobian(&trial_argument);
174                    if let Ok(f_val) = f_n
175                        && let Ok(j_val) = j_n
176                        && f - f_val >= n * t_1
177                        && if *strong {
178                            j_val.full_contraction(decrement.into()) < t_2
179                        } else {
180                            j_val.full_contraction(decrement.into()).abs() < t_2.abs() }
182                    {
183                        return Ok(n);
184                    } else {
185                        n *= cut_back;
186                        trial_argument = decrement * -n + argument
187                    }
188                }
189                Err(LineSearchError::MaximumStepsReached(
190                    format!("{self:?}"),
191                    *max_steps,
192                ))
193            }
194            Self::None => {
195                panic!("Cannot call backtracking line search when there is no algorithm.")
196            }
197        }
198    }
199}
200
201pub enum LineSearchError {
203    InvalidStartingPoint(String),
204    MaximumStepsReached(String, usize),
205    NegativeStepSize(String, Scalar),
206    NotDescentDirection(String),
207}
208
209impl Debug for LineSearchError {
210    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
211        let error = match self {
212            Self::InvalidStartingPoint(line_search) => {
213                format!(
214                    "\x1b[1;91mStaring point is invalid.\x1b[0;91m\n\
215                     In line search: {line_search}."
216                )
217            }
218            Self::MaximumStepsReached(line_search, steps) => {
219                format!(
220                    "\x1b[1;91mMaximum number of steps ({steps}) reached.\x1b[0;91m\n\
221                     In line search: {line_search}."
222                )
223            }
224            Self::NegativeStepSize(line_search, step_size) => {
225                format!(
226                    "\x1b[1;91mNegative step size ({step_size}) encountered.\x1b[0;91m\n\
227                     In line search: {line_search}."
228                )
229            }
230            Self::NotDescentDirection(line_search) => {
231                format!(
232                    "\x1b[1;91mDirection is not a descent direction.\x1b[0;91m\n\
233                     In line search: {line_search}."
234                )
235            }
236        };
237        write!(f, "\n{error}\n\x1b[0;2;31m{}\x1b[0m\n", defeat_message())
238    }
239}
240
241impl Display for LineSearchError {
242    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
243        let error = match self {
244            Self::InvalidStartingPoint(line_search) => {
245                format!(
246                    "\x1b[1;91mStaring point is invalid.\x1b[0;91m\n\
247                     In line search: {line_search}."
248                )
249            }
250            Self::MaximumStepsReached(line_search, steps) => {
251                format!(
252                    "\x1b[1;91mMaximum number of steps ({steps}) reached.\x1b[0;91m\n\
253                     In line search: {line_search}."
254                )
255            }
256            Self::NegativeStepSize(line_search, step_size) => {
257                format!(
258                    "\x1b[1;91mNegative step size ({step_size}) encountered.\x1b[0;91m\n\
259                     In line search: {line_search}."
260                )
261            }
262            Self::NotDescentDirection(line_search) => {
263                format!(
264                    "\x1b[1;91mDirection is not a descent direction.\x1b[0;91m\n\
265                     In line search: {line_search}."
266                )
267            }
268        };
269        write!(f, "{error}\x1b[0m")
270    }
271}