conspire/math/optimize/line_search/
mod.rs1use crate::{
2 defeat_message,
3 math::{Jacobian, Scalar, Solution},
4};
5use std::{
6 fmt::{self, Debug, Display, Formatter},
7 ops::Mul,
8};
9
10#[derive(Debug)]
12pub enum LineSearch {
13 Armijo {
15 control: Scalar,
16 cut_back: Scalar,
17 max_steps: usize,
18 },
19 Error { cut_back: Scalar, max_steps: usize },
21 Goldstein {
23 control: Scalar,
24 cut_back: Scalar,
25 max_steps: usize,
26 },
27 Wolfe {
29 control_1: Scalar,
30 control_2: Scalar,
31 cut_back: Scalar,
32 max_steps: usize,
33 strong: bool,
34 },
35 None,
37}
38
39impl Default for LineSearch {
40 fn default() -> Self {
41 Self::Armijo {
42 control: 1e-3,
43 cut_back: 9e-1,
44 max_steps: 100,
45 }
46 }
47}
48
49impl Display for LineSearch {
50 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
51 match self {
52 Self::Armijo { .. } => write!(f, "Armijo"),
53 Self::Error { .. } => write!(f, "Error"),
54 Self::Goldstein { .. } => write!(f, "Goldstein"),
55 Self::Wolfe { .. } => write!(f, "Wolfe"),
56 Self::None { .. } => write!(f, "None"),
57 }
58 }
59}
60
61impl LineSearch {
62 pub fn backtrack<X, J>(
63 &self,
64 mut function: impl FnMut(&X) -> Result<Scalar, String>,
65 mut jacobian: impl FnMut(&X) -> Result<J, String>,
66 argument: &X,
67 jacobian0: &J,
68 decrement: &X,
69 step_size: Scalar,
70 ) -> Result<Scalar, LineSearchError>
71 where
72 J: Jacobian,
73 for<'a> &'a J: From<&'a X>,
74 X: Solution,
75 for<'a> &'a X: Mul<Scalar, Output = X>,
76 {
77 if step_size <= 0.0 {
78 return Err(LineSearchError::NegativeStepSize(
79 format!("{self:?}"),
80 step_size,
81 ));
82 }
83 let mut n = step_size;
84 let f = if let Ok(value) = function(argument) {
85 value
86 } else {
87 return Err(LineSearchError::InvalidStartingPoint(format!("{self:?}")));
88 };
89 let m = jacobian0.full_contraction(decrement.into());
90 if m <= 0.0 {
91 return Err(LineSearchError::NotDescentDirection(format!("{self:?}")));
92 }
93 match self {
94 Self::Armijo {
95 control,
96 cut_back,
97 max_steps,
98 } => {
99 let mut f_n;
100 let t = control * m;
101 for _ in 0..*max_steps {
102 f_n = function(&(decrement * -n + argument));
103 if let Ok(value) = f_n
104 && f - value >= n * t
105 {
106 return Ok(n);
107 } else {
108 n *= cut_back
109 }
110 }
111 Err(LineSearchError::MaximumStepsReached(
112 format!("{self:?}"),
113 *max_steps,
114 ))
115 }
116 Self::Error {
117 cut_back,
118 max_steps,
119 } => {
120 for _ in 0..*max_steps {
121 if function(&(decrement * -n + argument)).is_ok() {
122 return Ok(n);
123 } else {
124 n *= cut_back
125 }
126 }
127 Err(LineSearchError::MaximumStepsReached(
128 format!("{self:?}"),
129 *max_steps,
130 ))
131 }
132 Self::Goldstein {
133 control,
134 cut_back,
135 max_steps,
136 } => {
137 let mut f_n;
138 let t = control * m;
139 let u = (1.0 - control) * m;
140 let mut v;
141 for _ in 0..*max_steps {
142 f_n = function(&(decrement * -n + argument));
143 if let Ok(value) = f_n {
144 v = f - value;
145 if n * u < v || v < n * t {
146 n *= cut_back
147 } else {
148 return Ok(n);
149 }
150 } else {
151 n *= cut_back
152 }
153 }
154 Err(LineSearchError::MaximumStepsReached(
155 format!("{self:?}"),
156 *max_steps,
157 ))
158 }
159 Self::Wolfe {
160 control_1,
161 control_2,
162 cut_back,
163 max_steps,
164 strong,
165 } => {
166 let mut f_n;
167 let mut j_n;
168 let t_1 = control_1 * m;
169 let t_2 = control_2 * m;
170 let mut trial_argument = decrement * -n + argument;
171 for _ in 0..*max_steps {
172 f_n = function(&trial_argument);
173 j_n = jacobian(&trial_argument);
174 if let Ok(f_val) = f_n
175 && let Ok(j_val) = j_n
176 && f - f_val >= n * t_1
177 && if *strong {
178 j_val.full_contraction(decrement.into()) < t_2
179 } else {
180 j_val.full_contraction(decrement.into()).abs() < t_2.abs() }
182 {
183 return Ok(n);
184 } else {
185 n *= cut_back;
186 trial_argument = decrement * -n + argument
187 }
188 }
189 Err(LineSearchError::MaximumStepsReached(
190 format!("{self:?}"),
191 *max_steps,
192 ))
193 }
194 Self::None => {
195 panic!("Cannot call backtracking line search when there is no algorithm.")
196 }
197 }
198 }
199}
200
201pub enum LineSearchError {
203 InvalidStartingPoint(String),
204 MaximumStepsReached(String, usize),
205 NegativeStepSize(String, Scalar),
206 NotDescentDirection(String),
207}
208
209impl Debug for LineSearchError {
210 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
211 let error = match self {
212 Self::InvalidStartingPoint(line_search) => {
213 format!(
214 "\x1b[1;91mStaring point is invalid.\x1b[0;91m\n\
215 In line search: {line_search}."
216 )
217 }
218 Self::MaximumStepsReached(line_search, steps) => {
219 format!(
220 "\x1b[1;91mMaximum number of steps ({steps}) reached.\x1b[0;91m\n\
221 In line search: {line_search}."
222 )
223 }
224 Self::NegativeStepSize(line_search, step_size) => {
225 format!(
226 "\x1b[1;91mNegative step size ({step_size}) encountered.\x1b[0;91m\n\
227 In line search: {line_search}."
228 )
229 }
230 Self::NotDescentDirection(line_search) => {
231 format!(
232 "\x1b[1;91mDirection is not a descent direction.\x1b[0;91m\n\
233 In line search: {line_search}."
234 )
235 }
236 };
237 write!(f, "\n{error}\n\x1b[0;2;31m{}\x1b[0m\n", defeat_message())
238 }
239}
240
241impl Display for LineSearchError {
242 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
243 let error = match self {
244 Self::InvalidStartingPoint(line_search) => {
245 format!(
246 "\x1b[1;91mStaring point is invalid.\x1b[0;91m\n\
247 In line search: {line_search}."
248 )
249 }
250 Self::MaximumStepsReached(line_search, steps) => {
251 format!(
252 "\x1b[1;91mMaximum number of steps ({steps}) reached.\x1b[0;91m\n\
253 In line search: {line_search}."
254 )
255 }
256 Self::NegativeStepSize(line_search, step_size) => {
257 format!(
258 "\x1b[1;91mNegative step size ({step_size}) encountered.\x1b[0;91m\n\
259 In line search: {line_search}."
260 )
261 }
262 Self::NotDescentDirection(line_search) => {
263 format!(
264 "\x1b[1;91mDirection is not a descent direction.\x1b[0;91m\n\
265 In line search: {line_search}."
266 )
267 }
268 };
269 write!(f, "{error}\x1b[0m")
270 }
271}