tensor_accuracy

Function tensor_accuracy 

Source
pub fn tensor_accuracy<B>(
    logits: &Tensor<B>,
    targets: &Tensor<B>,
) -> Result<f64, Error>
where B: Backend,
Expand description

Compute accuracy directly from logit tensors and one-hot/class-index targets.

  • If targets has shape [batch, n_classes] (one-hot), takes argmax of both.
  • If targets has shape [batch] or [batch, 1], treats as class indices.