shrew_optim/clip.rs
1// Gradient Clipping — Prevent exploding gradients during training
2//
3// Gradient clipping limits the magnitude of gradients before the optimizer
4// step, preventing catastrophically large updates.
5//
6// Two strategies:
7// 1. ClipByNorm: Scale all gradients so their global L2 norm ≤ max_norm
8// (used by GPT, BERT, and most modern architectures)
9// 2. ClipByValue: Clamp each gradient element to [-max_value, max_value]
10//
11// USAGE:
12// let grads = loss.backward()?;
13// let clipped = clip_grad_norm::<CpuBackend>(&grads, ¶ms, 1.0)?;
14// optimizer.step(&clipped)?;
15
16use shrew_core::backend::Backend;
17use shrew_core::backprop::GradStore;
18use shrew_core::error::Result;
19use shrew_core::tensor::Tensor;
20
21// Clip by global L2 norm
22
23/// Clip gradients by their global L2 norm.
24///
25/// If the total L2 norm of all gradients exceeds `max_norm`, all gradients
26/// are scaled down proportionally so the total norm equals `max_norm`.
27/// If the norm is already ≤ `max_norm`, gradients are returned unchanged.
28///
29/// Returns `(clipped_grads, total_norm)`.
30///
31/// # Arguments
32/// - `grads`: gradient store from `loss.backward()`
33/// - `params`: the parameters whose gradients to clip
34/// - `max_norm`: maximum allowed L2 norm
35///
36/// # Example
37/// ```ignore
38/// let grads = loss.backward()?;
39/// let (clipped, norm) = clip_grad_norm(&grads, ¶ms, 1.0)?;
40/// println!("Gradient norm: {norm:.4}");
41/// optimizer.step(&clipped)?;
42/// ```
43pub fn clip_grad_norm<B: Backend>(
44 grads: &GradStore<B>,
45 params: &[Tensor<B>],
46 max_norm: f64,
47) -> Result<(GradStore<B>, f64)> {
48 // 1. Compute global L2 norm: sqrt(sum of all grad elements squared)
49 let mut total_norm_sq = 0.0f64;
50 for param in params {
51 if let Some(grad) = grads.get(param) {
52 let data = grad.to_f64_vec()?;
53 for &v in &data {
54 total_norm_sq += v * v;
55 }
56 }
57 }
58 let total_norm = total_norm_sq.sqrt();
59
60 // 2. If norm ≤ max_norm, return as-is
61 if total_norm <= max_norm {
62 return Ok((grads.clone(), total_norm));
63 }
64
65 // 3. Scale factor: max_norm / total_norm
66 let scale = max_norm / (total_norm + 1e-6);
67
68 // 4. Build new GradStore with scaled gradients
69 let mut clipped = GradStore::<B>::new();
70 for param in params {
71 if let Some(grad) = grads.get(param) {
72 let scaled = grad.affine(scale, 0.0)?;
73 clipped.accumulate(param.id(), scaled)?;
74 }
75 }
76
77 Ok((clipped, total_norm))
78}
79
80/// Compute the global L2 norm of all gradients without clipping.
81///
82/// Useful for monitoring gradient magnitudes during training.
83pub fn grad_norm<B: Backend>(grads: &GradStore<B>, params: &[Tensor<B>]) -> Result<f64> {
84 let mut total_norm_sq = 0.0f64;
85 for param in params {
86 if let Some(grad) = grads.get(param) {
87 let data = grad.to_f64_vec()?;
88 for &v in &data {
89 total_norm_sq += v * v;
90 }
91 }
92 }
93 Ok(total_norm_sq.sqrt())
94}
95
96// Clip by value
97
98/// Clamp each gradient element to `[-max_value, max_value]`.
99///
100/// This is a simpler but less commonly used strategy than norm clipping.
101///
102/// Returns `clipped_grads`.
103pub fn clip_grad_value<B: Backend>(
104 grads: &GradStore<B>,
105 params: &[Tensor<B>],
106 max_value: f64,
107) -> Result<GradStore<B>> {
108 let mut clipped = GradStore::<B>::new();
109 for param in params {
110 if let Some(grad) = grads.get(param) {
111 let data = grad.to_f64_vec()?;
112 let clamped: Vec<f64> = data
113 .iter()
114 .map(|&v| v.max(-max_value).min(max_value))
115 .collect();
116 let clamped_tensor = Tensor::<B>::from_f64_slice(
117 &clamped,
118 grad.shape().clone(),
119 grad.dtype(),
120 grad.device(),
121 )?;
122 clipped.accumulate(param.id(), clamped_tensor)?;
123 }
124 }
125 Ok(clipped)
126}
127
128// Gradient Accumulation — Simulate larger batch sizes
129//
130// Gradient accumulation lets you simulate a larger effective batch size
131// without increasing memory usage. Instead of stepping the optimizer
132// after every batch, you accumulate gradients over N mini-batches
133// and then step once with the averaged gradient.
134//
135// This is essential when:
136// - Your GPU memory can only fit small batches
137// - You need large effective batch sizes (e.g., for Transformers)
138//
139// Without this helper, the pattern requires manually managing a GradStore
140// accumulator. This struct encapsulates that logic cleanly.
141
142/// Gradient accumulation helper.
143///
144/// Accumulates gradients over multiple mini-batches and provides
145/// the averaged gradient for the optimizer step.
146///
147/// # Example
148/// ```ignore
149/// let mut accum = GradAccumulator::<CpuBackend>::new(4); // 4 accumulation steps
150///
151/// for (i, batch) in batches.iter().enumerate() {
152/// let loss = model.forward(batch)?;
153/// let grads = loss.backward()?;
154///
155/// if let Some(avg_grads) = accum.step(&grads, ¶ms)? {
156/// optimizer.step(&avg_grads)?;
157/// }
158/// }
159/// ```
160pub struct GradAccumulator<B: Backend> {
161 /// Number of steps to accumulate before yielding
162 accum_steps: u64,
163 /// Current step within the accumulation window
164 current_step: u64,
165 /// Accumulated gradient sums (param_id → gradient tensor)
166 accumulated: Option<GradStore<B>>,
167}
168
169impl<B: Backend> GradAccumulator<B> {
170 /// Create a new gradient accumulator.
171 ///
172 /// # Arguments
173 /// - `accum_steps`: Number of mini-batches to accumulate before stepping
174 pub fn new(accum_steps: u64) -> Self {
175 assert!(accum_steps > 0, "accum_steps must be > 0");
176 GradAccumulator {
177 accum_steps,
178 current_step: 0,
179 accumulated: None,
180 }
181 }
182
183 /// Add gradients from one mini-batch.
184 ///
185 /// Returns `Some(averaged_grads)` when `accum_steps` batches have been
186 /// accumulated, otherwise returns `None`.
187 ///
188 /// The returned gradients are divided by `accum_steps` to produce
189 /// the average gradient.
190 pub fn step(
191 &mut self,
192 grads: &GradStore<B>,
193 params: &[Tensor<B>],
194 ) -> Result<Option<GradStore<B>>> {
195 self.current_step += 1;
196
197 // Accumulate
198 let acc = self.accumulated.get_or_insert_with(GradStore::new);
199 for param in params {
200 if let Some(grad) = grads.get(param) {
201 acc.accumulate(param.id(), grad.clone())?;
202 }
203 }
204
205 // Check if we've accumulated enough
206 if self.current_step >= self.accum_steps {
207 let accumulated = self.accumulated.take().unwrap();
208 self.current_step = 0;
209
210 // Average: divide each gradient by accum_steps
211 let scale = 1.0 / self.accum_steps as f64;
212 let mut averaged = GradStore::<B>::new();
213 for param in params {
214 if let Some(grad) = accumulated.get(param) {
215 let avg = grad.affine(scale, 0.0)?;
216 averaged.accumulate(param.id(), avg)?;
217 }
218 }
219
220 Ok(Some(averaged))
221 } else {
222 Ok(None)
223 }
224 }
225
226 /// Reset the accumulator, discarding any accumulated gradients.
227 pub fn reset(&mut self) {
228 self.current_step = 0;
229 self.accumulated = None;
230 }
231
232 /// Get the number of accumulation steps.
233 pub fn accum_steps(&self) -> u64 {
234 self.accum_steps
235 }
236
237 /// Get the current step within the accumulation window.
238 pub fn current_step(&self) -> u64 {
239 self.current_step
240 }
241}