shrew_optim/
sgd.rs

1// SGD — Stochastic Gradient Descent
2//
3// The simplest optimizer: θ_new = θ - lr * gradient
4//
5// With momentum (the default in practice):
6//   v = momentum * v_prev + gradient
7//   θ_new = θ - lr * v
8//
9// Momentum accelerates convergence by accumulating a velocity vector.
10// Think of a ball rolling down a hill — momentum helps it push through
11// flat regions and small bumps.
12//
13// SGD with momentum is surprisingly competitive. Many state-of-the-art
14// models (especially in computer vision) are trained with SGD + momentum.
15
16use shrew_core::backend::Backend;
17use shrew_core::backprop::GradStore;
18use shrew_core::error::Result;
19use shrew_core::tensor::Tensor;
20
21use crate::optimizer::{Optimizer, OptimizerState, Stateful};
22
23/// Stochastic Gradient Descent optimizer with optional momentum.
24///
25/// # Parameters
26/// - `lr`: learning rate (typical: 0.01 - 0.1)
27/// - `momentum`: momentum factor (typical: 0.9, 0 = no momentum)
28/// - `weight_decay`: L2 regularization coefficient (typical: 1e-4)
29pub struct SGD<B: Backend> {
30    params: Vec<Tensor<B>>,
31    lr: f64,
32    momentum: f64,
33    weight_decay: f64,
34    /// Velocity buffers for momentum (one per parameter)
35    velocities: Vec<Option<Vec<f64>>>,
36}
37
38impl<B: Backend> SGD<B> {
39    /// Create a new SGD optimizer.
40    ///
41    /// # Arguments
42    /// - `params`: the model parameters to optimize
43    /// - `lr`: learning rate
44    /// - `momentum`: momentum factor (0 for vanilla SGD)
45    /// - `weight_decay`: L2 regularization strength
46    pub fn new(params: Vec<Tensor<B>>, lr: f64, momentum: f64, weight_decay: f64) -> Self {
47        let n = params.len();
48        SGD {
49            params,
50            lr,
51            momentum,
52            weight_decay,
53            velocities: vec![None; n],
54        }
55    }
56
57    /// Update the parameter references (needed after step() returns new tensors).
58    pub fn update_params(&mut self, new_params: Vec<Tensor<B>>) {
59        self.params = new_params;
60    }
61
62    /// Access current parameters.
63    pub fn params(&self) -> &[Tensor<B>] {
64        &self.params
65    }
66
67    /// Mutable access to current parameters (for checkpoint loading).
68    pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
69        &mut self.params
70    }
71}
72
73impl<B: Backend> Optimizer<B> for SGD<B> {
74    fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
75        let mut new_params = Vec::with_capacity(self.params.len());
76
77        for (i, param) in self.params.iter().enumerate() {
78            let grad = match grads.get(param) {
79                Some(g) => g,
80                None => {
81                    // No gradient for this parameter — keep unchanged
82                    new_params.push(param.clone());
83                    continue;
84                }
85            };
86
87            let mut grad_data = grad.to_f64_vec()?;
88            let param_data = param.to_f64_vec()?;
89
90            // Apply weight decay: grad = grad + weight_decay * param
91            if self.weight_decay != 0.0 {
92                for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
93                    *g += self.weight_decay * p;
94                }
95            }
96
97            // Apply momentum
98            if self.momentum != 0.0 {
99                let velocity = self.velocities[i].get_or_insert_with(|| vec![0.0; grad_data.len()]);
100                for (v, &g) in velocity.iter_mut().zip(grad_data.iter()) {
101                    *v = self.momentum * *v + g;
102                }
103                grad_data = velocity.clone();
104            }
105
106            // Update: param = param - lr * grad_with_momentum
107            let updated: Vec<f64> = param_data
108                .iter()
109                .zip(grad_data.iter())
110                .map(|(&p, &g)| p - self.lr * g)
111                .collect();
112
113            // Update storage in-place so model layers sharing this tensor see new values
114            param.update_data_inplace(&updated)?;
115
116            new_params.push(param.clone());
117        }
118
119        Ok(new_params)
120    }
121
122    fn learning_rate(&self) -> f64 {
123        self.lr
124    }
125
126    fn set_learning_rate(&mut self, lr: f64) {
127        self.lr = lr;
128    }
129}
130
131// Stateful — Save/restore optimizer internal state
132
133impl<B: Backend> Stateful for SGD<B> {
134    fn state_dict(&self) -> OptimizerState {
135        let mut state = OptimizerState::new("SGD");
136
137        state.set_scalar("lr", self.lr);
138        state.set_scalar("momentum", self.momentum);
139        state.set_scalar("weight_decay", self.weight_decay);
140        state.set_scalar("n_params", self.velocities.len() as f64);
141
142        for (i, vel) in self.velocities.iter().enumerate() {
143            if let Some(v) = vel {
144                state.set_buffer(format!("velocity.{i}"), v.clone());
145            }
146        }
147
148        state
149    }
150
151    fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
152        if state.optimizer_type != "SGD" {
153            return Err(shrew_core::Error::msg(format!(
154                "Cannot load {} state into SGD optimizer",
155                state.optimizer_type
156            )));
157        }
158
159        if let Some(lr) = state.get_scalar("lr") {
160            self.lr = lr;
161        }
162        if let Some(m) = state.get_scalar("momentum") {
163            self.momentum = m;
164        }
165        if let Some(wd) = state.get_scalar("weight_decay") {
166            self.weight_decay = wd;
167        }
168
169        let n = self.velocities.len();
170        for i in 0..n {
171            if let Some(buf) = state.get_buffer(&format!("velocity.{i}")) {
172                self.velocities[i] = Some(buf.clone());
173            }
174        }
175
176        Ok(())
177    }
178}