cross_entropy_loss

Function cross_entropy_loss 

Source
pub fn cross_entropy_loss<B>(
    logits: &Tensor<B>,
    target: &Tensor<B>,
) -> Result<Tensor<B>, Error>
where B: Backend,
Expand description

Cross-entropy loss with log-softmax for numerical stability.

Computes: -mean( sum_over_classes( target * log_softmax(prediction) ) )

§Arguments

  • logits: raw scores [batch, num_classes] (NOT softmax-ed)
  • target: one-hot encoded targets [batch, num_classes]

§Numerical Stability

Uses the tensor-level log_softmax which computes: log_softmax(x)_i = x_i - max(x) - log(sum(exp(x - max(x)))) This is built entirely from differentiable tensor ops, so gradients flow back through logits automatically.