pub fn clip_grad_norm<B>(
grads: &GradStore<B>,
params: &[Tensor<B>],
max_norm: f64,
) -> Result<(GradStore<B>, f64), Error>where
B: Backend,Expand description
Clip gradients by their global L2 norm.
If the total L2 norm of all gradients exceeds max_norm, all gradients
are scaled down proportionally so the total norm equals max_norm.
If the norm is already ≤ max_norm, gradients are returned unchanged.
Returns (clipped_grads, total_norm).
§Arguments
grads: gradient store fromloss.backward()params: the parameters whose gradients to clipmax_norm: maximum allowed L2 norm
§Example
ⓘ
let grads = loss.backward()?;
let (clipped, norm) = clip_grad_norm(&grads, ¶ms, 1.0)?;
println!("Gradient norm: {norm:.4}");
optimizer.step(&clipped)?;