shrew_optim/
scheduler.rs

1// Learning Rate Schedulers — Adjust the learning rate during training
2//
3// LR schedulers implement a strategy for changing the learning rate across
4// training steps. They are designed to work with any Optimizer via
5// `set_learning_rate()`.
6//
7// IMPLEMENTED:
8//   - StepLR: Decay by gamma every `step_size` epochs
9//   - CosineAnnealingLR: Cosine decay from initial LR to min LR
10//   - CosineWarmupLR: Linear warmup → cosine decay (standard for Transformers)
11//   - LinearLR: Linear interpolation from start_factor to end_factor
12//   - ExponentialLR: Multiply LR by gamma every epoch
13//
14// USAGE:
15//   let mut scheduler = CosineWarmupLR::new(initial_lr, warmup_steps, total_steps, min_lr);
16//   for epoch in 0..epochs {
17//       for batch in batches {
18//           let lr = scheduler.step();
19//           optimizer.set_learning_rate(lr);
20//           // ... training step ...
21//       }
22//   }
23
24use std::f64::consts::PI;
25
26// Scheduler Trait
27
28/// Trait for learning rate schedulers.
29///
30/// Each call to `step()` advances the internal counter and returns the new LR.
31pub trait LrScheduler {
32    /// Advance by one step and return the new learning rate.
33    fn step(&mut self) -> f64;
34
35    /// Get the current learning rate without advancing.
36    fn current_lr(&self) -> f64;
37
38    /// Get the current step count.
39    fn current_step(&self) -> u64;
40
41    /// Reset the scheduler to step 0.
42    fn reset(&mut self);
43
44    /// Set the internal step counter to a specific value (for checkpoint restore).
45    fn set_step(&mut self, step: u64);
46}
47
48// StepLR — Decay by gamma every N epochs
49
50/// Multiply the learning rate by `gamma` every `step_size` steps.
51///
52/// ```text
53/// lr = initial_lr * gamma^(current_step / step_size)
54/// ```
55///
56/// # Example
57/// ```ignore
58/// let mut sched = StepLR::new(0.1, 30, 0.1); // decay by 10x every 30 steps
59/// ```
60pub struct StepLR {
61    initial_lr: f64,
62    step_size: u64,
63    gamma: f64,
64    current: u64,
65}
66
67impl StepLR {
68    pub fn new(initial_lr: f64, step_size: u64, gamma: f64) -> Self {
69        StepLR {
70            initial_lr,
71            step_size,
72            gamma,
73            current: 0,
74        }
75    }
76}
77
78impl LrScheduler for StepLR {
79    fn step(&mut self) -> f64 {
80        self.current += 1;
81        self.current_lr()
82    }
83
84    fn current_lr(&self) -> f64 {
85        let n = self.current / self.step_size;
86        self.initial_lr * self.gamma.powi(n as i32)
87    }
88
89    fn current_step(&self) -> u64 {
90        self.current
91    }
92    fn reset(&mut self) {
93        self.current = 0;
94    }
95    fn set_step(&mut self, step: u64) {
96        self.current = step;
97    }
98}
99
100// ExponentialLR — Multiply LR by gamma every step
101
102/// Multiply the learning rate by `gamma` every step.
103///
104/// ```text
105/// lr = initial_lr * gamma^step
106/// ```
107pub struct ExponentialLR {
108    initial_lr: f64,
109    gamma: f64,
110    current: u64,
111}
112
113impl ExponentialLR {
114    pub fn new(initial_lr: f64, gamma: f64) -> Self {
115        ExponentialLR {
116            initial_lr,
117            gamma,
118            current: 0,
119        }
120    }
121}
122
123impl LrScheduler for ExponentialLR {
124    fn step(&mut self) -> f64 {
125        self.current += 1;
126        self.current_lr()
127    }
128
129    fn current_lr(&self) -> f64 {
130        self.initial_lr * self.gamma.powi(self.current as i32)
131    }
132
133    fn current_step(&self) -> u64 {
134        self.current
135    }
136    fn reset(&mut self) {
137        self.current = 0;
138    }
139    fn set_step(&mut self, step: u64) {
140        self.current = step;
141    }
142}
143
144// LinearLR — Linear interpolation between two factors
145
146/// Linearly interpolate the learning rate from `start_factor * initial_lr`
147/// to `end_factor * initial_lr` over `total_steps` steps.
148///
149/// After `total_steps`, the LR stays at `end_factor * initial_lr`.
150pub struct LinearLR {
151    initial_lr: f64,
152    start_factor: f64,
153    end_factor: f64,
154    total_steps: u64,
155    current: u64,
156}
157
158impl LinearLR {
159    pub fn new(initial_lr: f64, start_factor: f64, end_factor: f64, total_steps: u64) -> Self {
160        LinearLR {
161            initial_lr,
162            start_factor,
163            end_factor,
164            total_steps,
165            current: 0,
166        }
167    }
168}
169
170impl LrScheduler for LinearLR {
171    fn step(&mut self) -> f64 {
172        self.current += 1;
173        self.current_lr()
174    }
175
176    fn current_lr(&self) -> f64 {
177        if self.total_steps == 0 {
178            return self.initial_lr * self.end_factor;
179        }
180        let t = (self.current as f64 / self.total_steps as f64).min(1.0);
181        let factor = self.start_factor + (self.end_factor - self.start_factor) * t;
182        self.initial_lr * factor
183    }
184
185    fn current_step(&self) -> u64 {
186        self.current
187    }
188    fn reset(&mut self) {
189        self.current = 0;
190    }
191    fn set_step(&mut self, step: u64) {
192        self.current = step;
193    }
194}
195
196// CosineAnnealingLR — Cosine decay from initial to minimum LR
197
198/// Cosine annealing from `initial_lr` to `min_lr` over `total_steps`.
199///
200/// ```text
201/// lr = min_lr + 0.5 * (initial_lr - min_lr) * (1 + cos(π * step / total_steps))
202/// ```
203///
204/// After `total_steps`, the LR stays at `min_lr`.
205pub struct CosineAnnealingLR {
206    initial_lr: f64,
207    min_lr: f64,
208    total_steps: u64,
209    current: u64,
210}
211
212impl CosineAnnealingLR {
213    pub fn new(initial_lr: f64, total_steps: u64, min_lr: f64) -> Self {
214        CosineAnnealingLR {
215            initial_lr,
216            min_lr,
217            total_steps,
218            current: 0,
219        }
220    }
221}
222
223impl LrScheduler for CosineAnnealingLR {
224    fn step(&mut self) -> f64 {
225        self.current += 1;
226        self.current_lr()
227    }
228
229    fn current_lr(&self) -> f64 {
230        if self.current >= self.total_steps {
231            return self.min_lr;
232        }
233        let progress = self.current as f64 / self.total_steps as f64;
234        self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + (PI * progress).cos())
235    }
236
237    fn current_step(&self) -> u64 {
238        self.current
239    }
240    fn reset(&mut self) {
241        self.current = 0;
242    }
243    fn set_step(&mut self, step: u64) {
244        self.current = step;
245    }
246}
247
248// CosineWarmupLR — Linear warmup → cosine decay (THE transformer scheduler)
249
250/// Linear warmup from 0 to `initial_lr` over `warmup_steps`, then cosine
251/// decay from `initial_lr` to `min_lr` over the remaining steps.
252///
253/// This is the standard scheduler used for training transformers (GPT, BERT, etc.).
254///
255/// ```text
256/// warmup phase (step < warmup_steps):
257///   lr = initial_lr * step / warmup_steps
258///
259/// decay phase (step >= warmup_steps):
260///   progress = (step - warmup_steps) / (total_steps - warmup_steps)
261///   lr = min_lr + 0.5 * (initial_lr - min_lr) * (1 + cos(π * progress))
262/// ```
263pub struct CosineWarmupLR {
264    initial_lr: f64,
265    min_lr: f64,
266    warmup_steps: u64,
267    total_steps: u64,
268    current: u64,
269}
270
271impl CosineWarmupLR {
272    /// Create a cosine warmup scheduler.
273    ///
274    /// # Arguments
275    /// - `initial_lr`: Peak learning rate (reached at end of warmup)
276    /// - `warmup_steps`: Number of linear warmup steps
277    /// - `total_steps`: Total training steps (warmup + decay)
278    /// - `min_lr`: Minimum learning rate at end of training
279    pub fn new(initial_lr: f64, warmup_steps: u64, total_steps: u64, min_lr: f64) -> Self {
280        assert!(
281            warmup_steps <= total_steps,
282            "warmup_steps ({warmup_steps}) must be <= total_steps ({total_steps})"
283        );
284        CosineWarmupLR {
285            initial_lr,
286            min_lr,
287            warmup_steps,
288            total_steps,
289            current: 0,
290        }
291    }
292}
293
294impl LrScheduler for CosineWarmupLR {
295    fn step(&mut self) -> f64 {
296        self.current += 1;
297        self.current_lr()
298    }
299
300    fn current_lr(&self) -> f64 {
301        if self.current <= self.warmup_steps {
302            // Linear warmup: 0 → initial_lr
303            if self.warmup_steps == 0 {
304                return self.initial_lr;
305            }
306            self.initial_lr * (self.current as f64 / self.warmup_steps as f64)
307        } else if self.current >= self.total_steps {
308            // Past end of schedule
309            self.min_lr
310        } else {
311            // Cosine decay phase
312            let decay_steps = self.total_steps - self.warmup_steps;
313            let decay_current = self.current - self.warmup_steps;
314            let progress = decay_current as f64 / decay_steps as f64;
315            self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + (PI * progress).cos())
316        }
317    }
318
319    fn current_step(&self) -> u64 {
320        self.current
321    }
322    fn reset(&mut self) {
323        self.current = 0;
324    }
325    fn set_step(&mut self, step: u64) {
326        self.current = step;
327    }
328}
329
330// ReduceLROnPlateau — Reduce LR when a metric stops improving
331
332/// Reduce the learning rate when a monitored metric plateaus.
333///
334/// Unlike the other schedulers which step automatically, this scheduler
335/// requires you to report the metric value (e.g., validation loss) and
336/// it decides whether to reduce the LR.
337///
338/// # Arguments (builder pattern)
339/// - `factor`: Factor to multiply LR by when reducing (default: 0.1)
340/// - `patience`: Number of steps with no improvement before reducing (default: 10)
341/// - `min_lr`: Lower bound on the learning rate (default: 1e-6)
342/// - `threshold`: Minimum improvement to qualify as improvement (default: 1e-4)
343///
344/// # Example
345/// ```ignore
346/// let mut sched = ReduceLROnPlateau::new(0.01);
347/// // After each epoch:
348/// let new_lr = sched.step_metric(val_loss);
349/// optimizer.set_learning_rate(new_lr);
350/// ```
351pub struct ReduceLROnPlateau {
352    lr: f64,
353    factor: f64,
354    patience: u64,
355    min_lr: f64,
356    threshold: f64,
357    /// Whether lower metric is better (true = min mode, false = max mode)
358    mode_min: bool,
359    best: f64,
360    num_bad_steps: u64,
361    current_step_count: u64,
362}
363
364impl ReduceLROnPlateau {
365    /// Create a new ReduceLROnPlateau with sensible defaults.
366    ///
367    /// Default: factor=0.1, patience=10, min_lr=1e-6, threshold=1e-4, mode=min
368    pub fn new(initial_lr: f64) -> Self {
369        ReduceLROnPlateau {
370            lr: initial_lr,
371            factor: 0.1,
372            patience: 10,
373            min_lr: 1e-6,
374            threshold: 1e-4,
375            mode_min: true,
376            best: f64::INFINITY,
377            num_bad_steps: 0,
378            current_step_count: 0,
379        }
380    }
381
382    /// Set the factor by which to reduce LR (default: 0.1).
383    pub fn factor(mut self, factor: f64) -> Self {
384        self.factor = factor;
385        self
386    }
387
388    /// Set patience (steps without improvement before reducing, default: 10).
389    pub fn patience(mut self, patience: u64) -> Self {
390        self.patience = patience;
391        self
392    }
393
394    /// Set the minimum learning rate (default: 1e-6).
395    pub fn min_lr(mut self, min_lr: f64) -> Self {
396        self.min_lr = min_lr;
397        self
398    }
399
400    /// Set the improvement threshold (default: 1e-4).
401    pub fn threshold(mut self, threshold: f64) -> Self {
402        self.threshold = threshold;
403        self
404    }
405
406    /// Set mode to maximize (higher metric = better).
407    /// Default is minimize (lower metric = better).
408    pub fn mode_max(mut self) -> Self {
409        self.mode_min = false;
410        self.best = f64::NEG_INFINITY;
411        self
412    }
413
414    /// Report a metric value and return the (possibly updated) learning rate.
415    ///
416    /// Call this once per epoch/evaluation with the metric value (e.g., val loss).
417    pub fn step_metric(&mut self, metric: f64) -> f64 {
418        self.current_step_count += 1;
419
420        let improved = if self.mode_min {
421            metric < self.best - self.threshold
422        } else {
423            metric > self.best + self.threshold
424        };
425
426        if improved {
427            self.best = metric;
428            self.num_bad_steps = 0;
429        } else {
430            self.num_bad_steps += 1;
431            if self.num_bad_steps >= self.patience {
432                let new_lr = (self.lr * self.factor).max(self.min_lr);
433                self.lr = new_lr;
434                self.num_bad_steps = 0;
435            }
436        }
437
438        self.lr
439    }
440
441    /// Get the current learning rate.
442    pub fn lr(&self) -> f64 {
443        self.lr
444    }
445
446    /// Get the best metric value seen so far.
447    pub fn best_metric(&self) -> f64 {
448        self.best
449    }
450
451    /// Get number of steps without improvement.
452    pub fn bad_steps(&self) -> u64 {
453        self.num_bad_steps
454    }
455
456    /// Reset state.
457    pub fn reset(&mut self) {
458        self.num_bad_steps = 0;
459        self.current_step_count = 0;
460        if self.mode_min {
461            self.best = f64::INFINITY;
462        } else {
463            self.best = f64::NEG_INFINITY;
464        }
465    }
466}