clip_grad_norm

Function clip_grad_norm 

Source
pub fn clip_grad_norm<B>(
    grads: &GradStore<B>,
    params: &[Tensor<B>],
    max_norm: f64,
) -> Result<(GradStore<B>, f64), Error>
where B: Backend,
Expand description

Clip gradients by their global L2 norm.

If the total L2 norm of all gradients exceeds max_norm, all gradients are scaled down proportionally so the total norm equals max_norm. If the norm is already ≤ max_norm, gradients are returned unchanged.

Returns (clipped_grads, total_norm).

§Arguments

  • grads: gradient store from loss.backward()
  • params: the parameters whose gradients to clip
  • max_norm: maximum allowed L2 norm

§Example

let grads = loss.backward()?;
let (clipped, norm) = clip_grad_norm(&grads, &params, 1.0)?;
println!("Gradient norm: {norm:.4}");
optimizer.step(&clipped)?;