shrew_optim/
rmsprop.rs

1// RMSProp — Root Mean Square Propagation
2//
3// RMSProp was proposed by Geoff Hinton (in his Coursera lectures, unpublished).
4// It maintains a running average of squared gradients to normalize the gradient,
5// effectively giving each parameter its own adaptive learning rate.
6//
7// Update rule:
8//   v = α * v + (1 - α) * grad²          (running average of squared gradients)
9//   θ = θ - lr * grad / (√v + ε)
10//
11// With momentum:
12//   v = α * v + (1 - α) * grad²
13//   buf = momentum * buf + grad / (√v + ε)
14//   θ = θ - lr * buf
15//
16// RMSProp is particularly useful for recurrent neural networks and
17// non-stationary objectives. It can be viewed as a precursor to Adam:
18// Adam combines RMSProp's adaptive scaling with momentum.
19//
20// HYPERPARAMETERS:
21//   lr = 1e-2 (typical), α = 0.99, ε = 1e-8, momentum = 0, weight_decay = 0
22
23use shrew_core::backend::Backend;
24use shrew_core::backprop::GradStore;
25use shrew_core::error::Result;
26use shrew_core::tensor::Tensor;
27
28use crate::optimizer::{Optimizer, OptimizerState, Stateful};
29
30/// RMSProp optimizer.
31///
32/// Adapts the learning rate per-parameter using a running average
33/// of squared gradients.
34///
35/// # Default hyperparameters
36/// - `lr`: 0.01
37/// - `alpha`: 0.99 (smoothing constant)
38/// - `epsilon`: 1e-8
39/// - `momentum`: 0.0
40/// - `weight_decay`: 0.0
41pub struct RMSProp<B: Backend> {
42    params: Vec<Tensor<B>>,
43    lr: f64,
44    alpha: f64,
45    epsilon: f64,
46    momentum: f64,
47    weight_decay: f64,
48    /// Running average of squared gradients (one per parameter)
49    v: Vec<Vec<f64>>,
50    /// Momentum buffer (one per parameter, only if momentum > 0)
51    buf: Vec<Vec<f64>>,
52}
53
54impl<B: Backend> RMSProp<B> {
55    /// Create a new RMSProp optimizer with default hyperparameters.
56    pub fn new(params: Vec<Tensor<B>>, lr: f64) -> Self {
57        let n = params.len();
58        RMSProp {
59            params,
60            lr,
61            alpha: 0.99,
62            epsilon: 1e-8,
63            momentum: 0.0,
64            weight_decay: 0.0,
65            v: vec![Vec::new(); n],
66            buf: vec![Vec::new(); n],
67        }
68    }
69
70    /// Set the smoothing constant α (default: 0.99).
71    pub fn alpha(mut self, alpha: f64) -> Self {
72        self.alpha = alpha;
73        self
74    }
75
76    /// Set ε (numerical stability, default: 1e-8).
77    pub fn epsilon(mut self, eps: f64) -> Self {
78        self.epsilon = eps;
79        self
80    }
81
82    /// Set momentum factor (default: 0).
83    pub fn momentum(mut self, momentum: f64) -> Self {
84        self.momentum = momentum;
85        self
86    }
87
88    /// Set weight decay / L2 penalty (default: 0).
89    pub fn weight_decay(mut self, wd: f64) -> Self {
90        self.weight_decay = wd;
91        self
92    }
93
94    /// Access current parameters.
95    pub fn params(&self) -> &[Tensor<B>] {
96        &self.params
97    }
98
99    /// Mutable access to current parameters.
100    pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
101        &mut self.params
102    }
103}
104
105impl<B: Backend> Optimizer<B> for RMSProp<B> {
106    fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
107        let mut new_params = Vec::with_capacity(self.params.len());
108
109        for (i, param) in self.params.iter().enumerate() {
110            let grad = match grads.get(param) {
111                Some(g) => g,
112                None => {
113                    new_params.push(param.clone());
114                    continue;
115                }
116            };
117
118            let mut grad_data = grad.to_f64_vec()?;
119            let mut param_data = param.to_f64_vec()?;
120            let n = param_data.len();
121
122            // Initialize state on first step
123            if self.v[i].is_empty() {
124                self.v[i] = vec![0.0; n];
125                if self.momentum > 0.0 {
126                    self.buf[i] = vec![0.0; n];
127                }
128            }
129
130            // Weight decay: add L2 penalty to gradient
131            if self.weight_decay != 0.0 {
132                for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
133                    *g += self.weight_decay * p;
134                }
135            }
136
137            // Update running average of squared gradients: v = α * v + (1 - α) * g²
138            for (v, &g) in self.v[i].iter_mut().zip(grad_data.iter()) {
139                *v = self.alpha * *v + (1.0 - self.alpha) * g * g;
140            }
141
142            if self.momentum > 0.0 {
143                // With momentum:
144                //   buf = momentum * buf + grad / (√v + ε)
145                //   param = param - lr * buf
146                for j in 0..n {
147                    self.buf[i][j] = self.momentum * self.buf[i][j]
148                        + grad_data[j] / (self.v[i][j].sqrt() + self.epsilon);
149                    param_data[j] -= self.lr * self.buf[i][j];
150                }
151            } else {
152                // Without momentum:
153                //   param = param - lr * grad / (√v + ε)
154                for j in 0..n {
155                    param_data[j] -= self.lr * grad_data[j] / (self.v[i][j].sqrt() + self.epsilon);
156                }
157            }
158
159            param.update_data_inplace(&param_data)?;
160            new_params.push(param.clone());
161        }
162
163        Ok(new_params)
164    }
165
166    fn learning_rate(&self) -> f64 {
167        self.lr
168    }
169
170    fn set_learning_rate(&mut self, lr: f64) {
171        self.lr = lr;
172    }
173}
174
175// Stateful — Save/restore optimizer internal state
176
177impl<B: Backend> Stateful for RMSProp<B> {
178    fn state_dict(&self) -> OptimizerState {
179        let mut state = OptimizerState::new("RMSProp");
180
181        state.set_scalar("lr", self.lr);
182        state.set_scalar("alpha", self.alpha);
183        state.set_scalar("epsilon", self.epsilon);
184        state.set_scalar("momentum", self.momentum);
185        state.set_scalar("weight_decay", self.weight_decay);
186        state.set_scalar("n_params", self.v.len() as f64);
187
188        for (i, v) in self.v.iter().enumerate() {
189            state.set_buffer(format!("v.{i}"), v.clone());
190        }
191        for (i, b) in self.buf.iter().enumerate() {
192            if !b.is_empty() {
193                state.set_buffer(format!("buf.{i}"), b.clone());
194            }
195        }
196
197        state
198    }
199
200    fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
201        if state.optimizer_type != "RMSProp" {
202            return Err(shrew_core::Error::msg(format!(
203                "Cannot load {} state into RMSProp optimizer",
204                state.optimizer_type
205            )));
206        }
207
208        if let Some(lr) = state.get_scalar("lr") {
209            self.lr = lr;
210        }
211        if let Some(a) = state.get_scalar("alpha") {
212            self.alpha = a;
213        }
214        if let Some(eps) = state.get_scalar("epsilon") {
215            self.epsilon = eps;
216        }
217        if let Some(m) = state.get_scalar("momentum") {
218            self.momentum = m;
219        }
220        if let Some(wd) = state.get_scalar("weight_decay") {
221            self.weight_decay = wd;
222        }
223
224        let n = self.v.len();
225        for i in 0..n {
226            if let Some(buf) = state.get_buffer(&format!("v.{i}")) {
227                self.v[i] = buf.clone();
228            }
229            if let Some(buf) = state.get_buffer(&format!("buf.{i}")) {
230                self.buf[i] = buf.clone();
231            }
232        }
233
234        Ok(())
235    }
236}