shrew_optim/scheduler.rs
1// Learning Rate Schedulers — Adjust the learning rate during training
2//
3// LR schedulers implement a strategy for changing the learning rate across
4// training steps. They are designed to work with any Optimizer via
5// `set_learning_rate()`.
6//
7// IMPLEMENTED:
8// - StepLR: Decay by gamma every `step_size` epochs
9// - CosineAnnealingLR: Cosine decay from initial LR to min LR
10// - CosineWarmupLR: Linear warmup → cosine decay (standard for Transformers)
11// - LinearLR: Linear interpolation from start_factor to end_factor
12// - ExponentialLR: Multiply LR by gamma every epoch
13//
14// USAGE:
15// let mut scheduler = CosineWarmupLR::new(initial_lr, warmup_steps, total_steps, min_lr);
16// for epoch in 0..epochs {
17// for batch in batches {
18// let lr = scheduler.step();
19// optimizer.set_learning_rate(lr);
20// // ... training step ...
21// }
22// }
23
24use std::f64::consts::PI;
25
26// Scheduler Trait
27
28/// Trait for learning rate schedulers.
29///
30/// Each call to `step()` advances the internal counter and returns the new LR.
31pub trait LrScheduler {
32 /// Advance by one step and return the new learning rate.
33 fn step(&mut self) -> f64;
34
35 /// Get the current learning rate without advancing.
36 fn current_lr(&self) -> f64;
37
38 /// Get the current step count.
39 fn current_step(&self) -> u64;
40
41 /// Reset the scheduler to step 0.
42 fn reset(&mut self);
43
44 /// Set the internal step counter to a specific value (for checkpoint restore).
45 fn set_step(&mut self, step: u64);
46}
47
48// StepLR — Decay by gamma every N epochs
49
50/// Multiply the learning rate by `gamma` every `step_size` steps.
51///
52/// ```text
53/// lr = initial_lr * gamma^(current_step / step_size)
54/// ```
55///
56/// # Example
57/// ```ignore
58/// let mut sched = StepLR::new(0.1, 30, 0.1); // decay by 10x every 30 steps
59/// ```
60pub struct StepLR {
61 initial_lr: f64,
62 step_size: u64,
63 gamma: f64,
64 current: u64,
65}
66
67impl StepLR {
68 pub fn new(initial_lr: f64, step_size: u64, gamma: f64) -> Self {
69 StepLR {
70 initial_lr,
71 step_size,
72 gamma,
73 current: 0,
74 }
75 }
76}
77
78impl LrScheduler for StepLR {
79 fn step(&mut self) -> f64 {
80 self.current += 1;
81 self.current_lr()
82 }
83
84 fn current_lr(&self) -> f64 {
85 let n = self.current / self.step_size;
86 self.initial_lr * self.gamma.powi(n as i32)
87 }
88
89 fn current_step(&self) -> u64 {
90 self.current
91 }
92 fn reset(&mut self) {
93 self.current = 0;
94 }
95 fn set_step(&mut self, step: u64) {
96 self.current = step;
97 }
98}
99
100// ExponentialLR — Multiply LR by gamma every step
101
102/// Multiply the learning rate by `gamma` every step.
103///
104/// ```text
105/// lr = initial_lr * gamma^step
106/// ```
107pub struct ExponentialLR {
108 initial_lr: f64,
109 gamma: f64,
110 current: u64,
111}
112
113impl ExponentialLR {
114 pub fn new(initial_lr: f64, gamma: f64) -> Self {
115 ExponentialLR {
116 initial_lr,
117 gamma,
118 current: 0,
119 }
120 }
121}
122
123impl LrScheduler for ExponentialLR {
124 fn step(&mut self) -> f64 {
125 self.current += 1;
126 self.current_lr()
127 }
128
129 fn current_lr(&self) -> f64 {
130 self.initial_lr * self.gamma.powi(self.current as i32)
131 }
132
133 fn current_step(&self) -> u64 {
134 self.current
135 }
136 fn reset(&mut self) {
137 self.current = 0;
138 }
139 fn set_step(&mut self, step: u64) {
140 self.current = step;
141 }
142}
143
144// LinearLR — Linear interpolation between two factors
145
146/// Linearly interpolate the learning rate from `start_factor * initial_lr`
147/// to `end_factor * initial_lr` over `total_steps` steps.
148///
149/// After `total_steps`, the LR stays at `end_factor * initial_lr`.
150pub struct LinearLR {
151 initial_lr: f64,
152 start_factor: f64,
153 end_factor: f64,
154 total_steps: u64,
155 current: u64,
156}
157
158impl LinearLR {
159 pub fn new(initial_lr: f64, start_factor: f64, end_factor: f64, total_steps: u64) -> Self {
160 LinearLR {
161 initial_lr,
162 start_factor,
163 end_factor,
164 total_steps,
165 current: 0,
166 }
167 }
168}
169
170impl LrScheduler for LinearLR {
171 fn step(&mut self) -> f64 {
172 self.current += 1;
173 self.current_lr()
174 }
175
176 fn current_lr(&self) -> f64 {
177 if self.total_steps == 0 {
178 return self.initial_lr * self.end_factor;
179 }
180 let t = (self.current as f64 / self.total_steps as f64).min(1.0);
181 let factor = self.start_factor + (self.end_factor - self.start_factor) * t;
182 self.initial_lr * factor
183 }
184
185 fn current_step(&self) -> u64 {
186 self.current
187 }
188 fn reset(&mut self) {
189 self.current = 0;
190 }
191 fn set_step(&mut self, step: u64) {
192 self.current = step;
193 }
194}
195
196// CosineAnnealingLR — Cosine decay from initial to minimum LR
197
198/// Cosine annealing from `initial_lr` to `min_lr` over `total_steps`.
199///
200/// ```text
201/// lr = min_lr + 0.5 * (initial_lr - min_lr) * (1 + cos(π * step / total_steps))
202/// ```
203///
204/// After `total_steps`, the LR stays at `min_lr`.
205pub struct CosineAnnealingLR {
206 initial_lr: f64,
207 min_lr: f64,
208 total_steps: u64,
209 current: u64,
210}
211
212impl CosineAnnealingLR {
213 pub fn new(initial_lr: f64, total_steps: u64, min_lr: f64) -> Self {
214 CosineAnnealingLR {
215 initial_lr,
216 min_lr,
217 total_steps,
218 current: 0,
219 }
220 }
221}
222
223impl LrScheduler for CosineAnnealingLR {
224 fn step(&mut self) -> f64 {
225 self.current += 1;
226 self.current_lr()
227 }
228
229 fn current_lr(&self) -> f64 {
230 if self.current >= self.total_steps {
231 return self.min_lr;
232 }
233 let progress = self.current as f64 / self.total_steps as f64;
234 self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + (PI * progress).cos())
235 }
236
237 fn current_step(&self) -> u64 {
238 self.current
239 }
240 fn reset(&mut self) {
241 self.current = 0;
242 }
243 fn set_step(&mut self, step: u64) {
244 self.current = step;
245 }
246}
247
248// CosineWarmupLR — Linear warmup → cosine decay (THE transformer scheduler)
249
250/// Linear warmup from 0 to `initial_lr` over `warmup_steps`, then cosine
251/// decay from `initial_lr` to `min_lr` over the remaining steps.
252///
253/// This is the standard scheduler used for training transformers (GPT, BERT, etc.).
254///
255/// ```text
256/// warmup phase (step < warmup_steps):
257/// lr = initial_lr * step / warmup_steps
258///
259/// decay phase (step >= warmup_steps):
260/// progress = (step - warmup_steps) / (total_steps - warmup_steps)
261/// lr = min_lr + 0.5 * (initial_lr - min_lr) * (1 + cos(π * progress))
262/// ```
263pub struct CosineWarmupLR {
264 initial_lr: f64,
265 min_lr: f64,
266 warmup_steps: u64,
267 total_steps: u64,
268 current: u64,
269}
270
271impl CosineWarmupLR {
272 /// Create a cosine warmup scheduler.
273 ///
274 /// # Arguments
275 /// - `initial_lr`: Peak learning rate (reached at end of warmup)
276 /// - `warmup_steps`: Number of linear warmup steps
277 /// - `total_steps`: Total training steps (warmup + decay)
278 /// - `min_lr`: Minimum learning rate at end of training
279 pub fn new(initial_lr: f64, warmup_steps: u64, total_steps: u64, min_lr: f64) -> Self {
280 assert!(
281 warmup_steps <= total_steps,
282 "warmup_steps ({warmup_steps}) must be <= total_steps ({total_steps})"
283 );
284 CosineWarmupLR {
285 initial_lr,
286 min_lr,
287 warmup_steps,
288 total_steps,
289 current: 0,
290 }
291 }
292}
293
294impl LrScheduler for CosineWarmupLR {
295 fn step(&mut self) -> f64 {
296 self.current += 1;
297 self.current_lr()
298 }
299
300 fn current_lr(&self) -> f64 {
301 if self.current <= self.warmup_steps {
302 // Linear warmup: 0 → initial_lr
303 if self.warmup_steps == 0 {
304 return self.initial_lr;
305 }
306 self.initial_lr * (self.current as f64 / self.warmup_steps as f64)
307 } else if self.current >= self.total_steps {
308 // Past end of schedule
309 self.min_lr
310 } else {
311 // Cosine decay phase
312 let decay_steps = self.total_steps - self.warmup_steps;
313 let decay_current = self.current - self.warmup_steps;
314 let progress = decay_current as f64 / decay_steps as f64;
315 self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + (PI * progress).cos())
316 }
317 }
318
319 fn current_step(&self) -> u64 {
320 self.current
321 }
322 fn reset(&mut self) {
323 self.current = 0;
324 }
325 fn set_step(&mut self, step: u64) {
326 self.current = step;
327 }
328}
329
330// ReduceLROnPlateau — Reduce LR when a metric stops improving
331
332/// Reduce the learning rate when a monitored metric plateaus.
333///
334/// Unlike the other schedulers which step automatically, this scheduler
335/// requires you to report the metric value (e.g., validation loss) and
336/// it decides whether to reduce the LR.
337///
338/// # Arguments (builder pattern)
339/// - `factor`: Factor to multiply LR by when reducing (default: 0.1)
340/// - `patience`: Number of steps with no improvement before reducing (default: 10)
341/// - `min_lr`: Lower bound on the learning rate (default: 1e-6)
342/// - `threshold`: Minimum improvement to qualify as improvement (default: 1e-4)
343///
344/// # Example
345/// ```ignore
346/// let mut sched = ReduceLROnPlateau::new(0.01);
347/// // After each epoch:
348/// let new_lr = sched.step_metric(val_loss);
349/// optimizer.set_learning_rate(new_lr);
350/// ```
351pub struct ReduceLROnPlateau {
352 lr: f64,
353 factor: f64,
354 patience: u64,
355 min_lr: f64,
356 threshold: f64,
357 /// Whether lower metric is better (true = min mode, false = max mode)
358 mode_min: bool,
359 best: f64,
360 num_bad_steps: u64,
361 current_step_count: u64,
362}
363
364impl ReduceLROnPlateau {
365 /// Create a new ReduceLROnPlateau with sensible defaults.
366 ///
367 /// Default: factor=0.1, patience=10, min_lr=1e-6, threshold=1e-4, mode=min
368 pub fn new(initial_lr: f64) -> Self {
369 ReduceLROnPlateau {
370 lr: initial_lr,
371 factor: 0.1,
372 patience: 10,
373 min_lr: 1e-6,
374 threshold: 1e-4,
375 mode_min: true,
376 best: f64::INFINITY,
377 num_bad_steps: 0,
378 current_step_count: 0,
379 }
380 }
381
382 /// Set the factor by which to reduce LR (default: 0.1).
383 pub fn factor(mut self, factor: f64) -> Self {
384 self.factor = factor;
385 self
386 }
387
388 /// Set patience (steps without improvement before reducing, default: 10).
389 pub fn patience(mut self, patience: u64) -> Self {
390 self.patience = patience;
391 self
392 }
393
394 /// Set the minimum learning rate (default: 1e-6).
395 pub fn min_lr(mut self, min_lr: f64) -> Self {
396 self.min_lr = min_lr;
397 self
398 }
399
400 /// Set the improvement threshold (default: 1e-4).
401 pub fn threshold(mut self, threshold: f64) -> Self {
402 self.threshold = threshold;
403 self
404 }
405
406 /// Set mode to maximize (higher metric = better).
407 /// Default is minimize (lower metric = better).
408 pub fn mode_max(mut self) -> Self {
409 self.mode_min = false;
410 self.best = f64::NEG_INFINITY;
411 self
412 }
413
414 /// Report a metric value and return the (possibly updated) learning rate.
415 ///
416 /// Call this once per epoch/evaluation with the metric value (e.g., val loss).
417 pub fn step_metric(&mut self, metric: f64) -> f64 {
418 self.current_step_count += 1;
419
420 let improved = if self.mode_min {
421 metric < self.best - self.threshold
422 } else {
423 metric > self.best + self.threshold
424 };
425
426 if improved {
427 self.best = metric;
428 self.num_bad_steps = 0;
429 } else {
430 self.num_bad_steps += 1;
431 if self.num_bad_steps >= self.patience {
432 let new_lr = (self.lr * self.factor).max(self.min_lr);
433 self.lr = new_lr;
434 self.num_bad_steps = 0;
435 }
436 }
437
438 self.lr
439 }
440
441 /// Get the current learning rate.
442 pub fn lr(&self) -> f64 {
443 self.lr
444 }
445
446 /// Get the best metric value seen so far.
447 pub fn best_metric(&self) -> f64 {
448 self.best
449 }
450
451 /// Get number of steps without improvement.
452 pub fn bad_steps(&self) -> u64 {
453 self.num_bad_steps
454 }
455
456 /// Reset state.
457 pub fn reset(&mut self) {
458 self.num_bad_steps = 0;
459 self.current_step_count = 0;
460 if self.mode_min {
461 self.best = f64::INFINITY;
462 } else {
463 self.best = f64::NEG_INFINITY;
464 }
465 }
466}