1use shrew_core::backend::Backend;
17use shrew_core::backprop::GradStore;
18use shrew_core::error::Result;
19use shrew_core::tensor::Tensor;
20
21use crate::optimizer::{Optimizer, OptimizerState, Stateful};
22
23pub struct SGD<B: Backend> {
30 params: Vec<Tensor<B>>,
31 lr: f64,
32 momentum: f64,
33 weight_decay: f64,
34 velocities: Vec<Option<Vec<f64>>>,
36}
37
38impl<B: Backend> SGD<B> {
39 pub fn new(params: Vec<Tensor<B>>, lr: f64, momentum: f64, weight_decay: f64) -> Self {
47 let n = params.len();
48 SGD {
49 params,
50 lr,
51 momentum,
52 weight_decay,
53 velocities: vec![None; n],
54 }
55 }
56
57 pub fn update_params(&mut self, new_params: Vec<Tensor<B>>) {
59 self.params = new_params;
60 }
61
62 pub fn params(&self) -> &[Tensor<B>] {
64 &self.params
65 }
66
67 pub fn params_mut(&mut self) -> &mut Vec<Tensor<B>> {
69 &mut self.params
70 }
71}
72
73impl<B: Backend> Optimizer<B> for SGD<B> {
74 fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>> {
75 let mut new_params = Vec::with_capacity(self.params.len());
76
77 for (i, param) in self.params.iter().enumerate() {
78 let grad = match grads.get(param) {
79 Some(g) => g,
80 None => {
81 new_params.push(param.clone());
83 continue;
84 }
85 };
86
87 let mut grad_data = grad.to_f64_vec()?;
88 let param_data = param.to_f64_vec()?;
89
90 if self.weight_decay != 0.0 {
92 for (g, &p) in grad_data.iter_mut().zip(param_data.iter()) {
93 *g += self.weight_decay * p;
94 }
95 }
96
97 if self.momentum != 0.0 {
99 let velocity = self.velocities[i].get_or_insert_with(|| vec![0.0; grad_data.len()]);
100 for (v, &g) in velocity.iter_mut().zip(grad_data.iter()) {
101 *v = self.momentum * *v + g;
102 }
103 grad_data = velocity.clone();
104 }
105
106 let updated: Vec<f64> = param_data
108 .iter()
109 .zip(grad_data.iter())
110 .map(|(&p, &g)| p - self.lr * g)
111 .collect();
112
113 param.update_data_inplace(&updated)?;
115
116 new_params.push(param.clone());
117 }
118
119 Ok(new_params)
120 }
121
122 fn learning_rate(&self) -> f64 {
123 self.lr
124 }
125
126 fn set_learning_rate(&mut self, lr: f64) {
127 self.lr = lr;
128 }
129}
130
131impl<B: Backend> Stateful for SGD<B> {
134 fn state_dict(&self) -> OptimizerState {
135 let mut state = OptimizerState::new("SGD");
136
137 state.set_scalar("lr", self.lr);
138 state.set_scalar("momentum", self.momentum);
139 state.set_scalar("weight_decay", self.weight_decay);
140 state.set_scalar("n_params", self.velocities.len() as f64);
141
142 for (i, vel) in self.velocities.iter().enumerate() {
143 if let Some(v) = vel {
144 state.set_buffer(format!("velocity.{i}"), v.clone());
145 }
146 }
147
148 state
149 }
150
151 fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()> {
152 if state.optimizer_type != "SGD" {
153 return Err(shrew_core::Error::msg(format!(
154 "Cannot load {} state into SGD optimizer",
155 state.optimizer_type
156 )));
157 }
158
159 if let Some(lr) = state.get_scalar("lr") {
160 self.lr = lr;
161 }
162 if let Some(m) = state.get_scalar("momentum") {
163 self.momentum = m;
164 }
165 if let Some(wd) = state.get_scalar("weight_decay") {
166 self.weight_decay = wd;
167 }
168
169 let n = self.velocities.len();
170 for i in 0..n {
171 if let Some(buf) = state.get_buffer(&format!("velocity.{i}")) {
172 self.velocities[i] = Some(buf.clone());
173 }
174 }
175
176 Ok(())
177 }
178}