argmax_classes

Function argmax_classes 

Source
pub fn argmax_classes<B>(logits: &Tensor<B>) -> Result<Vec<usize>, Error>
where B: Backend,
Expand description

Compute argmax along the last axis, returning class indices.

Input: [batch, n_classes] logits/probabilities. Output: Vec of length batch with predicted class indices.