shrew_optim/
radam.rs

1// RAdam — Rectified Adam
2//
3// RAdam (Liyuan Liu et al., 2019) addresses Adam's variance problem in early
4// training. Standard Adam's adaptive learning rate can have high variance in
5// the first few steps because the 2nd moment estimate (v) is poorly calibrated.
6//
7// RAdam automatically detects when the variance of the adaptive learning rate
8// is too high and falls back to SGD with momentum until the 2nd moment
9// estimate has accumulated enough samples.
10//
11// The key insight: compute ρ (rho), an approximation of the length of the
12// SMA (simple moving average) of the adaptive learning rate. When ρ > 5,
13// the variance is low enough to use the adaptive step. Otherwise, use
14// a momentum-only step.
15//
16// This eliminates the need for a learning rate warmup, which is one of the
17// most fragile hyperparameters in transformer training.
18//
19// Update rule:
20//   m = β1 * m + (1 - β1) * grad
21//   v = β2 * v + (1 - β2) * grad²
22//   m_hat = m / (1 - β1^t)
23//
24//   ρ_inf = 2/(1-β2) - 1
25//   ρ_t = ρ_inf - 2*t*β2^t/(1-β2^t)
26//
27//   if ρ_t > 5:  (variance is tractable → use adaptive step)
28//     v_hat = v / (1 - β2^t)
29//     r = √((ρ_t-4)(ρ_t-2)ρ_inf / ((ρ_inf-4)(ρ_inf-2)ρ_t))
30//     θ = θ - lr * r * m_hat / (√v_hat + ε)
31//   else:  (variance too high → momentum-only step)
32//     θ = θ - lr * m_hat
33//
34// HYPERPARAMETERS (same as Adam):
35//   lr = 1e-3, β1 = 0.9, β2 = 0.999, ε = 1e-8
36
37use shrew_core::backend::Backend;
38use shrew_core::backprop::GradStore;
39use shrew_core::error::Result;
40use shrew_core::tensor::Tensor;
41
42use crate::optimizer::{Optimizer, OptimizerState, Stateful};
43
44/// Rectified Adam (RAdam) optimizer.
45///
46/// Automatically switches between adaptive and momentum-only updates
47/// based on the variance of the adaptive learning rate, eliminating
48/// the need for learning rate warmup.
49pub struct RAdam<B: Backend> {
50    params: Vec<Tensor<B>>,
51    lr: f64,
52    beta1: f64,
53    beta2: f64,
54    epsilon: f64,
55    weight_decay: f64,
56    /// Step counter
57    t: u64,
58    /// First moment (mean of gradients)
59    m: Vec<Vec<f64>>,
60    /// Second moment (mean of squared gradients)
61    v: Vec<Vec<f64>>,
62    /// ρ_inf = 2/(1-β2) - 1 (precomputed)
63    rho_inf: f64,
64}
65
66impl<B: Backend> RAdam<B> {
67    /// Create a new RAdam optimizer.
68    pub fn new(params: Vec<Tensor<B>>, lr: f64) -> Self {
69        let n = params.len();
70        let beta2 = 0.999;
71        RAdam {
72            params,
73            lr,
74            beta1: 0.9,
75            beta2,
76            epsilon: 1e-8,
77            weight_decay: 0.0,
78            t: 0,
79            m: vec![Vec::new(); n],
80            v: vec![Vec::new(); n],
81            rho_inf: 2.0 / (1.0 - beta2) - 1.0,
82        }
83    }
84
85    /// Set β1.
86    pub fn beta1(mut self, beta1: f64) -> Self {
87        self.beta1 = beta1;
88        self
89    }
90
91    /// Set β2 (also recomputes ρ_inf).
92    pub fn beta2(mut self, beta2: f64) -> Self {
93        self.beta2 = beta2;
94        self.rho_inf = 2.0 / (1.0 - beta2) - 1.0;
95        self
96    }
97
98    /// Set ε.
99    pub fn epsilon(mut self, eps: f64) -> Self {
100        self.epsilon = eps;
101        self
102    }
103
104    /// Set weight decay.
105    pub fn weight_decay(mut self, wd: f64) -> Self {
106        self.weight_decay = wd;
107        self
108    }
109
110    /// Access current parameters.
111    pub fn params(&self) -> &[Tensor<B>] {
112        &self.params
113    }
114
115    /// Mutable access to current parameters.
116    pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
117        &mut self.params
118    }
119
120    /// Get the step count.
121    pub fn step_count(&self) -> u64 {
122        self.t
123    }
124}
125
126impl<B: Backend> Optimizer<B> for RAdam<B> {
127    #[allow(clippy::needless_range_loop)]
128    fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
129        self.t += 1;
130        let mut new_params = Vec::with_capacity(self.params.len());
131
132        for (i, param) in self.params.iter().enumerate() {
133            let grad = match grads.get(param) {
134                Some(g) => g,
135                None => {
136                    new_params.push(param.clone());
137                    continue;
138                }
139            };
140
141            let grad_data = grad.to_f64_vec()?;
142            let mut param_data = param.to_f64_vec()?;
143            let n = param_data.len();
144
145            // Initialize moments on first step
146            if self.m[i].is_empty() {
147                self.m[i] = vec![0.0; n];
148                self.v[i] = vec![0.0; n];
149            }
150
151            // Weight decay (decoupled, like AdamW)
152            if self.weight_decay != 0.0 {
153                for j in 0..n {
154                    param_data[j] -= self.lr * self.weight_decay * param_data[j];
155                }
156            }
157
158            // Update moments
159            for (m, &g) in self.m[i].iter_mut().zip(grad_data.iter()) {
160                *m = self.beta1 * *m + (1.0 - self.beta1) * g;
161            }
162            for (v, &g) in self.v[i].iter_mut().zip(grad_data.iter()) {
163                *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
164            }
165
166            // Bias-corrected first moment
167            let bc1 = 1.0 - self.beta1.powi(self.t as i32);
168
169            // Compute ρ_t
170            let beta2_t = self.beta2.powi(self.t as i32);
171            let bc2 = 1.0 - beta2_t;
172            let rho_t = self.rho_inf - 2.0 * self.t as f64 * beta2_t / bc2;
173
174            if rho_t > 5.0 {
175                // Variance is tractable: use adaptive step with rectification
176                let r = ((rho_t - 4.0) * (rho_t - 2.0) * self.rho_inf
177                    / ((self.rho_inf - 4.0) * (self.rho_inf - 2.0) * rho_t))
178                    .sqrt();
179
180                for j in 0..n {
181                    let m_hat = self.m[i][j] / bc1;
182                    let v_hat = self.v[i][j] / bc2;
183                    param_data[j] -= self.lr * r * m_hat / (v_hat.sqrt() + self.epsilon);
184                }
185            } else {
186                // Variance too high: use momentum-only step (no adaptive LR)
187                for j in 0..n {
188                    let m_hat = self.m[i][j] / bc1;
189                    param_data[j] -= self.lr * m_hat;
190                }
191            }
192
193            param.update_data_inplace(&param_data)?;
194            new_params.push(param.clone());
195        }
196
197        Ok(new_params)
198    }
199
200    fn learning_rate(&self) -> f64 {
201        self.lr
202    }
203
204    fn set_learning_rate(&mut self, lr: f64) {
205        self.lr = lr;
206    }
207}
208
209// Stateful — Save/restore optimizer internal state
210
211impl<B: Backend> Stateful for RAdam<B> {
212    fn state_dict(&self) -> OptimizerState {
213        let mut state = OptimizerState::new("RAdam");
214
215        state.set_scalar("t", self.t as f64);
216        state.set_scalar("lr", self.lr);
217        state.set_scalar("beta1", self.beta1);
218        state.set_scalar("beta2", self.beta2);
219        state.set_scalar("epsilon", self.epsilon);
220        state.set_scalar("weight_decay", self.weight_decay);
221        state.set_scalar("rho_inf", self.rho_inf);
222        state.set_scalar("n_params", self.m.len() as f64);
223
224        for (i, m) in self.m.iter().enumerate() {
225            state.set_buffer(format!("m.{i}"), m.clone());
226        }
227        for (i, v) in self.v.iter().enumerate() {
228            state.set_buffer(format!("v.{i}"), v.clone());
229        }
230
231        state
232    }
233
234    fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
235        if state.optimizer_type != "RAdam" {
236            return Err(shrew_core::Error::msg(format!(
237                "Cannot load {} state into RAdam optimizer",
238                state.optimizer_type
239            )));
240        }
241
242        self.t = state.get_scalar("t").unwrap_or(0.0) as u64;
243        if let Some(lr) = state.get_scalar("lr") {
244            self.lr = lr;
245        }
246        if let Some(b1) = state.get_scalar("beta1") {
247            self.beta1 = b1;
248        }
249        if let Some(b2) = state.get_scalar("beta2") {
250            self.beta2 = b2;
251        }
252        if let Some(eps) = state.get_scalar("epsilon") {
253            self.epsilon = eps;
254        }
255        if let Some(wd) = state.get_scalar("weight_decay") {
256            self.weight_decay = wd;
257        }
258        if let Some(ri) = state.get_scalar("rho_inf") {
259            self.rho_inf = ri;
260        }
261
262        let n = self.m.len();
263        for i in 0..n {
264            if let Some(buf) = state.get_buffer(&format!("m.{i}")) {
265                self.m[i] = buf.clone();
266            }
267            if let Some(buf) = state.get_buffer(&format!("v.{i}")) {
268                self.v[i] = buf.clone();
269            }
270        }
271
272        Ok(())
273    }
274}