1use shrew_core::backend::Backend;
32use shrew_core::backprop::GradStore;
33use shrew_core::error::Result;
34use shrew_core::tensor::Tensor;
35
36use crate::optimizer::{Optimizer, OptimizerState, Stateful};
37
38pub struct Adam<B: Backend> {
42 params: Vec<Tensor<B>>,
43 lr: f64,
44 beta1: f64,
45 beta2: f64,
46 epsilon: f64,
47 weight_decay: f64,
48 decoupled_decay: bool,
50 t: u64,
52 m: Vec<Vec<f64>>,
54 v: Vec<Vec<f64>>,
56}
57
58impl<B: Backend> Adam<B> {
59 pub fn new(params: Vec<Tensor<B>>, lr: f64) -> Self {
61 let n = params.len();
62 Adam {
63 params,
64 lr,
65 beta1: 0.9,
66 beta2: 0.999,
67 epsilon: 1e-8,
68 weight_decay: 0.0,
69 decoupled_decay: false,
70 t: 0,
71 m: vec![Vec::new(); n],
72 v: vec![Vec::new(); n],
73 }
74 }
75
76 pub fn beta1(mut self, beta1: f64) -> Self {
78 self.beta1 = beta1;
79 self
80 }
81
82 pub fn beta2(mut self, beta2: f64) -> Self {
84 self.beta2 = beta2;
85 self
86 }
87
88 pub fn epsilon(mut self, eps: f64) -> Self {
90 self.epsilon = eps;
91 self
92 }
93
94 pub fn weight_decay(mut self, wd: f64) -> Self {
96 self.weight_decay = wd;
97 self
98 }
99
100 pub fn params(&self) -> &[Tensor<B>] {
102 &self.params
103 }
104
105 pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
107 &mut self.params
108 }
109
110 pub fn step_count(&self) -> u64 {
112 self.t
113 }
114
115 pub fn set_lr(&mut self, lr: f64) {
117 self.lr = lr;
118 }
119}
120
121pub struct AdamW<B: Backend>(pub Adam<B>);
127
128impl<B: Backend> AdamW<B> {
129 pub fn new(params: Vec<Tensor<B>>, lr: f64, weight_decay: f64) -> Self {
133 let mut adam = Adam::new(params, lr);
134 adam.decoupled_decay = true;
135 adam.weight_decay = weight_decay;
136 AdamW(adam)
137 }
138
139 pub fn weight_decay(mut self, wd: f64) -> Self {
141 self.0.weight_decay = wd;
142 self
143 }
144
145 pub fn beta1(mut self, beta1: f64) -> Self {
147 self.0.beta1 = beta1;
148 self
149 }
150
151 pub fn beta2(mut self, beta2: f64) -> Self {
153 self.0.beta2 = beta2;
154 self
155 }
156
157 pub fn params(&self) -> &[Tensor<B>] {
159 self.0.params()
160 }
161
162 pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
164 self.0.params_mut()
165 }
166
167 pub fn set_lr(&mut self, lr: f64) {
169 self.0.set_lr(lr);
170 }
171}
172
173impl<B: Backend> Optimizer<B> for Adam<B> {
174 #[allow(clippy::needless_range_loop)]
175 fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
176 self.t += 1;
177 let mut new_params = Vec::with_capacity(self.params.len());
178
179 for (i, param) in self.params.iter().enumerate() {
180 let grad = match grads.get(param) {
181 Some(g) => g,
182 None => {
183 new_params.push(param.clone());
184 continue;
185 }
186 };
187
188 let mut grad_data = grad.to_f64_vec()?;
189 let mut param_data = param.to_f64_vec()?;
190 let n = param_data.len();
191
192 if self.m[i].is_empty() {
194 self.m[i] = vec![0.0; n];
195 self.v[i] = vec![0.0; n];
196 }
197
198 if self.weight_decay != 0.0 && !self.decoupled_decay {
200 for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
201 *g += self.weight_decay * p;
202 }
203 }
204
205 for (m, &g) in self.m[i].iter_mut().zip(grad_data.iter()) {
207 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
208 }
209
210 for (v, &g) in self.v[i].iter_mut().zip(grad_data.iter()) {
212 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
213 }
214
215 let bc1 = 1.0 - self.beta1.powi(self.t as i32);
217 let bc2 = 1.0 - self.beta2.powi(self.t as i32);
218
219 for j in 0..n {
221 let m_hat = self.m[i][j] / bc1;
222 let v_hat = self.v[i][j] / bc2;
223
224 if self.weight_decay != 0.0 && self.decoupled_decay {
226 param_data[j] -= self.lr * self.weight_decay * param_data[j];
227 }
228
229 param_data[j] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
230 }
231
232 param.update_data_inplace(¶m_data)?;
234
235 new_params.push(param.clone());
236 }
237
238 Ok(new_params)
240 }
241
242 fn learning_rate(&self) -> f64 {
243 self.lr
244 }
245
246 fn set_learning_rate(&mut self, lr: f64) {
247 self.lr = lr;
248 }
249}
250
251impl<B: Backend> Optimizer<B> for AdamW<B> {
252 fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
253 self.0.step(grads)
254 }
255
256 fn learning_rate(&self) -> f64 {
257 self.0.learning_rate()
258 }
259
260 fn set_learning_rate(&mut self, lr: f64) {
261 self.0.set_learning_rate(lr);
262 }
263}
264
265impl<B: Backend> Stateful for Adam<B> {
268 fn state_dict(&self) -> OptimizerState {
269 let name = if self.decoupled_decay {
270 "AdamW"
271 } else {
272 "Adam"
273 };
274 let mut state = OptimizerState::new(name);
275
276 state.set_scalar("t", self.t as f64);
277 state.set_scalar("lr", self.lr);
278 state.set_scalar("beta1", self.beta1);
279 state.set_scalar("beta2", self.beta2);
280 state.set_scalar("epsilon", self.epsilon);
281 state.set_scalar("weight_decay", self.weight_decay);
282 state.set_scalar(
283 "decoupled_decay",
284 if self.decoupled_decay { 1.0 } else { 0.0 },
285 );
286 state.set_scalar("n_params", self.m.len() as f64);
287
288 for (i, m) in self.m.iter().enumerate() {
289 state.set_buffer(format!("m.{i}"), m.clone());
290 }
291 for (i, v) in self.v.iter().enumerate() {
292 state.set_buffer(format!("v.{i}"), v.clone());
293 }
294
295 state
296 }
297
298 fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
299 if state.optimizer_type != "Adam" && state.optimizer_type != "AdamW" {
300 return Err(shrew_core::Error::msg(format!(
301 "Cannot load {} state into Adam/AdamW optimizer",
302 state.optimizer_type
303 )));
304 }
305
306 self.t = state.get_scalar("t").unwrap_or(0.0) as u64;
307 if let Some(lr) = state.get_scalar("lr") {
308 self.lr = lr;
309 }
310 if let Some(b1) = state.get_scalar("beta1") {
311 self.beta1 = b1;
312 }
313 if let Some(b2) = state.get_scalar("beta2") {
314 self.beta2 = b2;
315 }
316 if let Some(eps) = state.get_scalar("epsilon") {
317 self.epsilon = eps;
318 }
319 if let Some(wd) = state.get_scalar("weight_decay") {
320 self.weight_decay = wd;
321 }
322 if let Some(dd) = state.get_scalar("decoupled_decay") {
323 self.decoupled_decay = dd != 0.0;
324 }
325
326 let n = self.m.len();
327 for i in 0..n {
328 if let Some(buf) = state.get_buffer(&format!("m.{i}")) {
329 self.m[i] = buf.clone();
330 }
331 if let Some(buf) = state.get_buffer(&format!("v.{i}")) {
332 self.v[i] = buf.clone();
333 }
334 }
335
336 Ok(())
337 }
338}
339
340impl<B: Backend> Stateful for AdamW<B> {
341 fn state_dict(&self) -> OptimizerState {
342 self.0.state_dict()
343 }
344
345 fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
346 self.0.load_state_dict(state)
347 }
348}