1use shrew_core::backend::Backend;
24use shrew_core::backprop::GradStore;
25use shrew_core::error::Result;
26use shrew_core::tensor::Tensor;
27
28use crate::optimizer::{Optimizer, OptimizerState, Stateful};
29
30pub struct RMSProp<B: Backend> {
42 params: Vec<Tensor<B>>,
43 lr: f64,
44 alpha: f64,
45 epsilon: f64,
46 momentum: f64,
47 weight_decay: f64,
48 v: Vec<Vec<f64>>,
50 buf: Vec<Vec<f64>>,
52}
53
54impl<B: Backend> RMSProp<B> {
55 pub fn new(params: Vec<Tensor<B>>, lr: f64) -> Self {
57 let n = params.len();
58 RMSProp {
59 params,
60 lr,
61 alpha: 0.99,
62 epsilon: 1e-8,
63 momentum: 0.0,
64 weight_decay: 0.0,
65 v: vec![Vec::new(); n],
66 buf: vec![Vec::new(); n],
67 }
68 }
69
70 pub fn alpha(mut self, alpha: f64) -> Self {
72 self.alpha = alpha;
73 self
74 }
75
76 pub fn epsilon(mut self, eps: f64) -> Self {
78 self.epsilon = eps;
79 self
80 }
81
82 pub fn momentum(mut self, momentum: f64) -> Self {
84 self.momentum = momentum;
85 self
86 }
87
88 pub fn weight_decay(mut self, wd: f64) -> Self {
90 self.weight_decay = wd;
91 self
92 }
93
94 pub fn params(&self) -> &[Tensor<B>] {
96 &self.params
97 }
98
99 pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
101 &mut self.params
102 }
103}
104
105impl<B: Backend> Optimizer<B> for RMSProp<B> {
106 fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
107 let mut new_params = Vec::with_capacity(self.params.len());
108
109 for (i, param) in self.params.iter().enumerate() {
110 let grad = match grads.get(param) {
111 Some(g) => g,
112 None => {
113 new_params.push(param.clone());
114 continue;
115 }
116 };
117
118 let mut grad_data = grad.to_f64_vec()?;
119 let mut param_data = param.to_f64_vec()?;
120 let n = param_data.len();
121
122 if self.v[i].is_empty() {
124 self.v[i] = vec![0.0; n];
125 if self.momentum > 0.0 {
126 self.buf[i] = vec![0.0; n];
127 }
128 }
129
130 if self.weight_decay != 0.0 {
132 for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
133 *g += self.weight_decay * p;
134 }
135 }
136
137 for (v, &g) in self.v[i].iter_mut().zip(grad_data.iter()) {
139 *v = self.alpha * *v + (1.0 - self.alpha) * g * g;
140 }
141
142 if self.momentum > 0.0 {
143 for j in 0..n {
147 self.buf[i][j] = self.momentum * self.buf[i][j]
148 + grad_data[j] / (self.v[i][j].sqrt() + self.epsilon);
149 param_data[j] -= self.lr * self.buf[i][j];
150 }
151 } else {
152 for j in 0..n {
155 param_data[j] -= self.lr * grad_data[j] / (self.v[i][j].sqrt() + self.epsilon);
156 }
157 }
158
159 param.update_data_inplace(¶m_data)?;
160 new_params.push(param.clone());
161 }
162
163 Ok(new_params)
164 }
165
166 fn learning_rate(&self) -> f64 {
167 self.lr
168 }
169
170 fn set_learning_rate(&mut self, lr: f64) {
171 self.lr = lr;
172 }
173}
174
175impl<B: Backend> Stateful for RMSProp<B> {
178 fn state_dict(&self) -> OptimizerState {
179 let mut state = OptimizerState::new("RMSProp");
180
181 state.set_scalar("lr", self.lr);
182 state.set_scalar("alpha", self.alpha);
183 state.set_scalar("epsilon", self.epsilon);
184 state.set_scalar("momentum", self.momentum);
185 state.set_scalar("weight_decay", self.weight_decay);
186 state.set_scalar("n_params", self.v.len() as f64);
187
188 for (i, v) in self.v.iter().enumerate() {
189 state.set_buffer(format!("v.{i}"), v.clone());
190 }
191 for (i, b) in self.buf.iter().enumerate() {
192 if !b.is_empty() {
193 state.set_buffer(format!("buf.{i}"), b.clone());
194 }
195 }
196
197 state
198 }
199
200 fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
201 if state.optimizer_type != "RMSProp" {
202 return Err(shrew_core::Error::msg(format!(
203 "Cannot load {} state into RMSProp optimizer",
204 state.optimizer_type
205 )));
206 }
207
208 if let Some(lr) = state.get_scalar("lr") {
209 self.lr = lr;
210 }
211 if let Some(a) = state.get_scalar("alpha") {
212 self.alpha = a;
213 }
214 if let Some(eps) = state.get_scalar("epsilon") {
215 self.epsilon = eps;
216 }
217 if let Some(m) = state.get_scalar("momentum") {
218 self.momentum = m;
219 }
220 if let Some(wd) = state.get_scalar("weight_decay") {
221 self.weight_decay = wd;
222 }
223
224 let n = self.v.len();
225 for i in 0..n {
226 if let Some(buf) = state.get_buffer(&format!("v.{i}")) {
227 self.v[i] = buf.clone();
228 }
229 if let Some(buf) = state.get_buffer(&format!("buf.{i}")) {
230 self.buf[i] = buf.clone();
231 }
232 }
233
234 Ok(())
235 }
236}