shrew_optim/
clip.rs

1// Gradient Clipping — Prevent exploding gradients during training
2//
3// Gradient clipping limits the magnitude of gradients before the optimizer
4// step, preventing catastrophically large updates.
5//
6// Two strategies:
7//   1. ClipByNorm: Scale all gradients so their global L2 norm ≤ max_norm
8//      (used by GPT, BERT, and most modern architectures)
9//   2. ClipByValue: Clamp each gradient element to [-max_value, max_value]
10//
11// USAGE:
12//   let grads = loss.backward()?;
13//   let clipped = clip_grad_norm::<CpuBackend>(&grads, &params, 1.0)?;
14//   optimizer.step(&clipped)?;
15
16use shrew_core::backend::Backend;
17use shrew_core::backprop::GradStore;
18use shrew_core::error::Result;
19use shrew_core::tensor::Tensor;
20
21// Clip by global L2 norm
22
23/// Clip gradients by their global L2 norm.
24///
25/// If the total L2 norm of all gradients exceeds `max_norm`, all gradients
26/// are scaled down proportionally so the total norm equals `max_norm`.
27/// If the norm is already ≤ `max_norm`, gradients are returned unchanged.
28///
29/// Returns `(clipped_grads, total_norm)`.
30///
31/// # Arguments
32/// - `grads`: gradient store from `loss.backward()`
33/// - `params`: the parameters whose gradients to clip
34/// - `max_norm`: maximum allowed L2 norm
35///
36/// # Example
37/// ```ignore
38/// let grads = loss.backward()?;
39/// let (clipped, norm) = clip_grad_norm(&grads, &params, 1.0)?;
40/// println!("Gradient norm: {norm:.4}");
41/// optimizer.step(&clipped)?;
42/// ```
43pub fn clip_grad_norm<B: Backend>(
44    grads: &GradStore<B>,
45    params: &[Tensor<B>],
46    max_norm: f64,
47) -> Result<(GradStore<B>, f64)> {
48    // 1. Compute global L2 norm: sqrt(sum of all grad elements squared)
49    let mut total_norm_sq = 0.0f64;
50    for param in params {
51        if let Some(grad) = grads.get(param) {
52            let data = grad.to_f64_vec()?;
53            for &v in &data {
54                total_norm_sq += v * v;
55            }
56        }
57    }
58    let total_norm = total_norm_sq.sqrt();
59
60    // 2. If norm ≤ max_norm, return as-is
61    if total_norm <= max_norm {
62        return Ok((grads.clone(), total_norm));
63    }
64
65    // 3. Scale factor: max_norm / total_norm
66    let scale = max_norm / (total_norm + 1e-6);
67
68    // 4. Build new GradStore with scaled gradients
69    let mut clipped = GradStore::<B>::new();
70    for param in params {
71        if let Some(grad) = grads.get(param) {
72            let scaled = grad.affine(scale, 0.0)?;
73            clipped.accumulate(param.id(), scaled)?;
74        }
75    }
76
77    Ok((clipped, total_norm))
78}
79
80/// Compute the global L2 norm of all gradients without clipping.
81///
82/// Useful for monitoring gradient magnitudes during training.
83pub fn grad_norm<B: Backend>(grads: &GradStore<B>, params: &[Tensor<B>]) -> Result<f64> {
84    let mut total_norm_sq = 0.0f64;
85    for param in params {
86        if let Some(grad) = grads.get(param) {
87            let data = grad.to_f64_vec()?;
88            for &v in &data {
89                total_norm_sq += v * v;
90            }
91        }
92    }
93    Ok(total_norm_sq.sqrt())
94}
95
96// Clip by value
97
98/// Clamp each gradient element to `[-max_value, max_value]`.
99///
100/// This is a simpler but less commonly used strategy than norm clipping.
101///
102/// Returns `clipped_grads`.
103pub fn clip_grad_value<B: Backend>(
104    grads: &GradStore<B>,
105    params: &[Tensor<B>],
106    max_value: f64,
107) -> Result<GradStore<B>> {
108    let mut clipped = GradStore::<B>::new();
109    for param in params {
110        if let Some(grad) = grads.get(param) {
111            let data = grad.to_f64_vec()?;
112            let clamped: Vec<f64> = data
113                .iter()
114                .map(|&v| v.max(-max_value).min(max_value))
115                .collect();
116            let clamped_tensor = Tensor::<B>::from_f64_slice(
117                &clamped,
118                grad.shape().clone(),
119                grad.dtype(),
120                grad.device(),
121            )?;
122            clipped.accumulate(param.id(), clamped_tensor)?;
123        }
124    }
125    Ok(clipped)
126}
127
128// Gradient Accumulation — Simulate larger batch sizes
129//
130// Gradient accumulation lets you simulate a larger effective batch size
131// without increasing memory usage. Instead of stepping the optimizer
132// after every batch, you accumulate gradients over N mini-batches
133// and then step once with the averaged gradient.
134//
135// This is essential when:
136//   - Your GPU memory can only fit small batches
137//   - You need large effective batch sizes (e.g., for Transformers)
138//
139// Without this helper, the pattern requires manually managing a GradStore
140// accumulator. This struct encapsulates that logic cleanly.
141
142/// Gradient accumulation helper.
143///
144/// Accumulates gradients over multiple mini-batches and provides
145/// the averaged gradient for the optimizer step.
146///
147/// # Example
148/// ```ignore
149/// let mut accum = GradAccumulator::<CpuBackend>::new(4); // 4 accumulation steps
150///
151/// for (i, batch) in batches.iter().enumerate() {
152///     let loss = model.forward(batch)?;
153///     let grads = loss.backward()?;
154///     
155///     if let Some(avg_grads) = accum.step(&grads, &params)? {
156///         optimizer.step(&avg_grads)?;
157///     }
158/// }
159/// ```
160pub struct GradAccumulator<B: Backend> {
161    /// Number of steps to accumulate before yielding
162    accum_steps: u64,
163    /// Current step within the accumulation window
164    current_step: u64,
165    /// Accumulated gradient sums (param_id → gradient tensor)
166    accumulated: Option<GradStore<B>>,
167}
168
169impl<B: Backend> GradAccumulator<B> {
170    /// Create a new gradient accumulator.
171    ///
172    /// # Arguments
173    /// - `accum_steps`: Number of mini-batches to accumulate before stepping
174    pub fn new(accum_steps: u64) -> Self {
175        assert!(accum_steps > 0, "accum_steps must be > 0");
176        GradAccumulator {
177            accum_steps,
178            current_step: 0,
179            accumulated: None,
180        }
181    }
182
183    /// Add gradients from one mini-batch.
184    ///
185    /// Returns `Some(averaged_grads)` when `accum_steps` batches have been
186    /// accumulated, otherwise returns `None`.
187    ///
188    /// The returned gradients are divided by `accum_steps` to produce
189    /// the average gradient.
190    pub fn step(
191        &mut self,
192        grads: &GradStore<B>,
193        params: &[Tensor<B>],
194    ) -> Result<Option<GradStore<B>>> {
195        self.current_step += 1;
196
197        // Accumulate
198        let acc = self.accumulated.get_or_insert_with(GradStore::new);
199        for param in params {
200            if let Some(grad) = grads.get(param) {
201                acc.accumulate(param.id(), grad.clone())?;
202            }
203        }
204
205        // Check if we've accumulated enough
206        if self.current_step >= self.accum_steps {
207            let accumulated = self.accumulated.take().unwrap();
208            self.current_step = 0;
209
210            // Average: divide each gradient by accum_steps
211            let scale = 1.0 / self.accum_steps as f64;
212            let mut averaged = GradStore::<B>::new();
213            for param in params {
214                if let Some(grad) = accumulated.get(param) {
215                    let avg = grad.affine(scale, 0.0)?;
216                    averaged.accumulate(param.id(), avg)?;
217                }
218            }
219
220            Ok(Some(averaged))
221        } else {
222            Ok(None)
223        }
224    }
225
226    /// Reset the accumulator, discarding any accumulated gradients.
227    pub fn reset(&mut self) {
228        self.current_step = 0;
229        self.accumulated = None;
230    }
231
232    /// Get the number of accumulation steps.
233    pub fn accum_steps(&self) -> u64 {
234        self.accum_steps
235    }
236
237    /// Get the current step within the accumulation window.
238    pub fn current_step(&self) -> u64 {
239        self.current_step
240    }
241}