1use shrew_core::backend::Backend;
38use shrew_core::backprop::GradStore;
39use shrew_core::error::Result;
40use shrew_core::tensor::Tensor;
41
42use crate::optimizer::{Optimizer, OptimizerState, Stateful};
43
44pub struct RAdam<B: Backend> {
50 params: Vec<Tensor<B>>,
51 lr: f64,
52 beta1: f64,
53 beta2: f64,
54 epsilon: f64,
55 weight_decay: f64,
56 t: u64,
58 m: Vec<Vec<f64>>,
60 v: Vec<Vec<f64>>,
62 rho_inf: f64,
64}
65
66impl<B: Backend> RAdam<B> {
67 pub fn new(params: Vec<Tensor<B>>, lr: f64) -> Self {
69 let n = params.len();
70 let beta2 = 0.999;
71 RAdam {
72 params,
73 lr,
74 beta1: 0.9,
75 beta2,
76 epsilon: 1e-8,
77 weight_decay: 0.0,
78 t: 0,
79 m: vec![Vec::new(); n],
80 v: vec![Vec::new(); n],
81 rho_inf: 2.0 / (1.0 - beta2) - 1.0,
82 }
83 }
84
85 pub fn beta1(mut self, beta1: f64) -> Self {
87 self.beta1 = beta1;
88 self
89 }
90
91 pub fn beta2(mut self, beta2: f64) -> Self {
93 self.beta2 = beta2;
94 self.rho_inf = 2.0 / (1.0 - beta2) - 1.0;
95 self
96 }
97
98 pub fn epsilon(mut self, eps: f64) -> Self {
100 self.epsilon = eps;
101 self
102 }
103
104 pub fn weight_decay(mut self, wd: f64) -> Self {
106 self.weight_decay = wd;
107 self
108 }
109
110 pub fn params(&self) -> &[Tensor<B>] {
112 &self.params
113 }
114
115 pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
117 &mut self.params
118 }
119
120 pub fn step_count(&self) -> u64 {
122 self.t
123 }
124}
125
126impl<B: Backend> Optimizer<B> for RAdam<B> {
127 #[allow(clippy::needless_range_loop)]
128 fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
129 self.t += 1;
130 let mut new_params = Vec::with_capacity(self.params.len());
131
132 for (i, param) in self.params.iter().enumerate() {
133 let grad = match grads.get(param) {
134 Some(g) => g,
135 None => {
136 new_params.push(param.clone());
137 continue;
138 }
139 };
140
141 let grad_data = grad.to_f64_vec()?;
142 let mut param_data = param.to_f64_vec()?;
143 let n = param_data.len();
144
145 if self.m[i].is_empty() {
147 self.m[i] = vec![0.0; n];
148 self.v[i] = vec![0.0; n];
149 }
150
151 if self.weight_decay != 0.0 {
153 for j in 0..n {
154 param_data[j] -= self.lr * self.weight_decay * param_data[j];
155 }
156 }
157
158 for (m, &g) in self.m[i].iter_mut().zip(grad_data.iter()) {
160 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
161 }
162 for (v, &g) in self.v[i].iter_mut().zip(grad_data.iter()) {
163 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
164 }
165
166 let bc1 = 1.0 - self.beta1.powi(self.t as i32);
168
169 let beta2_t = self.beta2.powi(self.t as i32);
171 let bc2 = 1.0 - beta2_t;
172 let rho_t = self.rho_inf - 2.0 * self.t as f64 * beta2_t / bc2;
173
174 if rho_t > 5.0 {
175 let r = ((rho_t - 4.0) * (rho_t - 2.0) * self.rho_inf
177 / ((self.rho_inf - 4.0) * (self.rho_inf - 2.0) * rho_t))
178 .sqrt();
179
180 for j in 0..n {
181 let m_hat = self.m[i][j] / bc1;
182 let v_hat = self.v[i][j] / bc2;
183 param_data[j] -= self.lr * r * m_hat / (v_hat.sqrt() + self.epsilon);
184 }
185 } else {
186 for j in 0..n {
188 let m_hat = self.m[i][j] / bc1;
189 param_data[j] -= self.lr * m_hat;
190 }
191 }
192
193 param.update_data_inplace(¶m_data)?;
194 new_params.push(param.clone());
195 }
196
197 Ok(new_params)
198 }
199
200 fn learning_rate(&self) -> f64 {
201 self.lr
202 }
203
204 fn set_learning_rate(&mut self, lr: f64) {
205 self.lr = lr;
206 }
207}
208
209impl<B: Backend> Stateful for RAdam<B> {
212 fn state_dict(&self) -> OptimizerState {
213 let mut state = OptimizerState::new("RAdam");
214
215 state.set_scalar("t", self.t as f64);
216 state.set_scalar("lr", self.lr);
217 state.set_scalar("beta1", self.beta1);
218 state.set_scalar("beta2", self.beta2);
219 state.set_scalar("epsilon", self.epsilon);
220 state.set_scalar("weight_decay", self.weight_decay);
221 state.set_scalar("rho_inf", self.rho_inf);
222 state.set_scalar("n_params", self.m.len() as f64);
223
224 for (i, m) in self.m.iter().enumerate() {
225 state.set_buffer(format!("m.{i}"), m.clone());
226 }
227 for (i, v) in self.v.iter().enumerate() {
228 state.set_buffer(format!("v.{i}"), v.clone());
229 }
230
231 state
232 }
233
234 fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
235 if state.optimizer_type != "RAdam" {
236 return Err(shrew_core::Error::msg(format!(
237 "Cannot load {} state into RAdam optimizer",
238 state.optimizer_type
239 )));
240 }
241
242 self.t = state.get_scalar("t").unwrap_or(0.0) as u64;
243 if let Some(lr) = state.get_scalar("lr") {
244 self.lr = lr;
245 }
246 if let Some(b1) = state.get_scalar("beta1") {
247 self.beta1 = b1;
248 }
249 if let Some(b2) = state.get_scalar("beta2") {
250 self.beta2 = b2;
251 }
252 if let Some(eps) = state.get_scalar("epsilon") {
253 self.epsilon = eps;
254 }
255 if let Some(wd) = state.get_scalar("weight_decay") {
256 self.weight_decay = wd;
257 }
258 if let Some(ri) = state.get_scalar("rho_inf") {
259 self.rho_inf = ri;
260 }
261
262 let n = self.m.len();
263 for i in 0..n {
264 if let Some(buf) = state.get_buffer(&format!("m.{i}")) {
265 self.m[i] = buf.clone();
266 }
267 if let Some(buf) = state.get_buffer(&format!("v.{i}")) {
268 self.v[i] = buf.clone();
269 }
270 }
271
272 Ok(())
273 }
274}