shrew_optim/
adam.rs

1// Adam / AdamW — Adaptive Moment Estimation
2//
3// Adam is the most widely-used optimizer in deep learning. It maintains
4// TWO moving averages per parameter:
5//
6//   m (1st moment): exponential average of gradients (direction)
7//   v (2nd moment): exponential average of squared gradients (magnitude)
8//
9// Update rule (bias-corrected):
10//   m_hat = m / (1 - β1^t)
11//   v_hat = v / (1 - β2^t)
12//   θ = θ - lr * m_hat / (√v_hat + ε)
13//
14// WHY ADAM WORKS SO WELL:
15//
16// - The 1st moment (m) acts like momentum, smoothing gradient direction
17// - The 2nd moment (v) acts like per-parameter learning rate scaling:
18//   parameters with historically large gradients get smaller updates,
19//   and parameters with small gradients get larger updates.
20//   This is "adaptive" — hence the name.
21//
22// AdamW DIFFERENCE:
23//
24// Standard Adam applies weight decay INSIDE the gradient (coupled).
25// AdamW applies weight decay DIRECTLY to the parameters (decoupled).
26// This is more principled and is the default for training Transformers.
27//
28// HYPERPARAMETERS (default values are from the original paper):
29//   lr = 1e-3, β1 = 0.9, β2 = 0.999, ε = 1e-8, weight_decay = 0.01
30
31use shrew_core::backend::Backend;
32use shrew_core::backprop::GradStore;
33use shrew_core::error::Result;
34use shrew_core::tensor::Tensor;
35
36use crate::optimizer::{Optimizer, OptimizerState, Stateful};
37
38/// Adam optimizer (Adaptive Moment Estimation).
39///
40/// Standard defaults: lr=1e-3, β1=0.9, β2=0.999, ε=1e-8
41pub struct Adam<B: Backend> {
42    params: Vec<Tensor<B>>,
43    lr: f64,
44    beta1: f64,
45    beta2: f64,
46    epsilon: f64,
47    weight_decay: f64,
48    /// Whether to use decoupled weight decay (AdamW style)
49    decoupled_decay: bool,
50    /// Step counter (for bias correction)
51    t: u64,
52    /// First moment vectors (one per parameter)
53    m: Vec<Vec<f64>>,
54    /// Second moment vectors (one per parameter)
55    v: Vec<Vec<f64>>,
56}
57
58impl<B: Backend> Adam<B> {
59    /// Create a standard Adam optimizer.
60    pub fn new(params: Vec<Tensor<B>>, lr: f64) -> Self {
61        let n = params.len();
62        Adam {
63            params,
64            lr,
65            beta1: 0.9,
66            beta2: 0.999,
67            epsilon: 1e-8,
68            weight_decay: 0.0,
69            decoupled_decay: false,
70            t: 0,
71            m: vec![Vec::new(); n],
72            v: vec![Vec::new(); n],
73        }
74    }
75
76    /// Set β1 (1st moment decay rate).
77    pub fn beta1(mut self, beta1: f64) -> Self {
78        self.beta1 = beta1;
79        self
80    }
81
82    /// Set β2 (2nd moment decay rate).
83    pub fn beta2(mut self, beta2: f64) -> Self {
84        self.beta2 = beta2;
85        self
86    }
87
88    /// Set ε (numerical stability term).
89    pub fn epsilon(mut self, eps: f64) -> Self {
90        self.epsilon = eps;
91        self
92    }
93
94    /// Set weight decay (L2 penalty).
95    pub fn weight_decay(mut self, wd: f64) -> Self {
96        self.weight_decay = wd;
97        self
98    }
99
100    /// Access current parameters.
101    pub fn params(&self) -> &[Tensor<B>] {
102        &self.params
103    }
104
105    /// Mutable access to current parameters (for checkpoint loading).
106    pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
107        &mut self.params
108    }
109
110    /// Get the step count.
111    pub fn step_count(&self) -> u64 {
112        self.t
113    }
114
115    /// Set the learning rate (used by LR schedulers).
116    pub fn set_lr(&mut self, lr: f64) {
117        self.lr = lr;
118    }
119}
120
121/// AdamW optimizer (Adam with decoupled weight decay).
122///
123/// This is the standard optimizer for training Transformers.
124/// The key difference from Adam: weight decay is applied directly to
125/// the parameters, not mixed into the gradient.
126pub struct AdamW<B: Backend>(pub Adam<B>);
127
128impl<B: Backend> AdamW<B> {
129    /// Create an AdamW optimizer with standard defaults.
130    ///
131    /// Default: lr=1e-3, β1=0.9, β2=0.999, ε=1e-8, weight_decay=0.01
132    pub fn new(params: Vec<Tensor<B>>, lr: f64, weight_decay: f64) -> Self {
133        let mut adam = Adam::new(params, lr);
134        adam.decoupled_decay = true;
135        adam.weight_decay = weight_decay;
136        AdamW(adam)
137    }
138
139    /// Set weight decay.
140    pub fn weight_decay(mut self, wd: f64) -> Self {
141        self.0.weight_decay = wd;
142        self
143    }
144
145    /// Set β1.
146    pub fn beta1(mut self, beta1: f64) -> Self {
147        self.0.beta1 = beta1;
148        self
149    }
150
151    /// Set β2.
152    pub fn beta2(mut self, beta2: f64) -> Self {
153        self.0.beta2 = beta2;
154        self
155    }
156
157    /// Access current parameters.
158    pub fn params(&self) -> &[Tensor<B>] {
159        self.0.params()
160    }
161
162    /// Mutable access to current parameters (for checkpoint loading).
163    pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
164        self.0.params_mut()
165    }
166
167    /// Set the learning rate (used by LR schedulers).
168    pub fn set_lr(&mut self, lr: f64) {
169        self.0.set_lr(lr);
170    }
171}
172
173impl<B: Backend> Optimizer<B> for Adam<B> {
174    #[allow(clippy::needless_range_loop)]
175    fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
176        self.t += 1;
177        let mut new_params = Vec::with_capacity(self.params.len());
178
179        for (i, param) in self.params.iter().enumerate() {
180            let grad = match grads.get(param) {
181                Some(g) => g,
182                None => {
183                    new_params.push(param.clone());
184                    continue;
185                }
186            };
187
188            let mut grad_data = grad.to_f64_vec()?;
189            let mut param_data = param.to_f64_vec()?;
190            let n = param_data.len();
191
192            // Initialize moment vectors on first step
193            if self.m[i].is_empty() {
194                self.m[i] = vec![0.0; n];
195                self.v[i] = vec![0.0; n];
196            }
197
198            // Coupled weight decay (standard Adam): add to gradient
199            if self.weight_decay != 0.0 && !self.decoupled_decay {
200                for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
201                    *g += self.weight_decay * p;
202                }
203            }
204
205            // Update 1st moment: m = β1 * m + (1 - β1) * grad
206            for (m, &g) in self.m[i].iter_mut().zip(grad_data.iter()) {
207                *m = self.beta1 * *m + (1.0 - self.beta1) * g;
208            }
209
210            // Update 2nd moment: v = β2 * v + (1 - β2) * grad²
211            for (v, &g) in self.v[i].iter_mut().zip(grad_data.iter()) {
212                *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
213            }
214
215            // Bias correction factors
216            let bc1 = 1.0 - self.beta1.powi(self.t as i32);
217            let bc2 = 1.0 - self.beta2.powi(self.t as i32);
218
219            // Update parameters: θ = θ - lr * m_hat / (√v_hat + ε)
220            for j in 0..n {
221                let m_hat = self.m[i][j] / bc1;
222                let v_hat = self.v[i][j] / bc2;
223
224                // Decoupled weight decay (AdamW): subtract directly from param
225                if self.weight_decay != 0.0 && self.decoupled_decay {
226                    param_data[j] -= self.lr * self.weight_decay * param_data[j];
227                }
228
229                param_data[j] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
230            }
231
232            // Update storage in-place so model layers sharing this tensor see new values
233            param.update_data_inplace(&param_data)?;
234
235            new_params.push(param.clone());
236        }
237
238        // Keep self.params pointing to the same tensors (they've been updated in-place)
239        Ok(new_params)
240    }
241
242    fn learning_rate(&self) -> f64 {
243        self.lr
244    }
245
246    fn set_learning_rate(&mut self, lr: f64) {
247        self.lr = lr;
248    }
249}
250
251impl<B: Backend> Optimizer<B> for AdamW<B> {
252    fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
253        self.0.step(grads)
254    }
255
256    fn learning_rate(&self) -> f64 {
257        self.0.learning_rate()
258    }
259
260    fn set_learning_rate(&mut self, lr: f64) {
261        self.0.set_learning_rate(lr);
262    }
263}
264
265// Stateful — Save/restore optimizer internal state
266
267impl<B: Backend> Stateful for Adam<B> {
268    fn state_dict(&self) -> OptimizerState {
269        let name = if self.decoupled_decay {
270            "AdamW"
271        } else {
272            "Adam"
273        };
274        let mut state = OptimizerState::new(name);
275
276        state.set_scalar("t", self.t as f64);
277        state.set_scalar("lr", self.lr);
278        state.set_scalar("beta1", self.beta1);
279        state.set_scalar("beta2", self.beta2);
280        state.set_scalar("epsilon", self.epsilon);
281        state.set_scalar("weight_decay", self.weight_decay);
282        state.set_scalar(
283            "decoupled_decay",
284            if self.decoupled_decay { 1.0 } else { 0.0 },
285        );
286        state.set_scalar("n_params", self.m.len() as f64);
287
288        for (i, m) in self.m.iter().enumerate() {
289            state.set_buffer(format!("m.{i}"), m.clone());
290        }
291        for (i, v) in self.v.iter().enumerate() {
292            state.set_buffer(format!("v.{i}"), v.clone());
293        }
294
295        state
296    }
297
298    fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
299        if state.optimizer_type != "Adam" && state.optimizer_type != "AdamW" {
300            return Err(shrew_core::Error::msg(format!(
301                "Cannot load {} state into Adam/AdamW optimizer",
302                state.optimizer_type
303            )));
304        }
305
306        self.t = state.get_scalar("t").unwrap_or(0.0) as u64;
307        if let Some(lr) = state.get_scalar("lr") {
308            self.lr = lr;
309        }
310        if let Some(b1) = state.get_scalar("beta1") {
311            self.beta1 = b1;
312        }
313        if let Some(b2) = state.get_scalar("beta2") {
314            self.beta2 = b2;
315        }
316        if let Some(eps) = state.get_scalar("epsilon") {
317            self.epsilon = eps;
318        }
319        if let Some(wd) = state.get_scalar("weight_decay") {
320            self.weight_decay = wd;
321        }
322        if let Some(dd) = state.get_scalar("decoupled_decay") {
323            self.decoupled_decay = dd != 0.0;
324        }
325
326        let n = self.m.len();
327        for i in 0..n {
328            if let Some(buf) = state.get_buffer(&format!("m.{i}")) {
329                self.m[i] = buf.clone();
330            }
331            if let Some(buf) = state.get_buffer(&format!("v.{i}")) {
332                self.v[i] = buf.clone();
333            }
334        }
335
336        Ok(())
337    }
338}
339
340impl<B: Backend> Stateful for AdamW<B> {
341    fn state_dict(&self) -> OptimizerState {
342        self.0.state_dict()
343    }
344
345    fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
346        self.0.load_state_dict(state)
347    }
348}