shrew_nn/
metrics.rs

1// Evaluation Metrics
2//
3// Classification metrics:  accuracy, precision, recall, f1_score, confusion_matrix
4// Regression metrics:      r2_score, mae, rmse, mape
5// Language model metrics:  perplexity
6// Ranking metrics:         top_k_accuracy
7//
8// All functions operate on f64 slices or Tensor<B> for maximum flexibility.
9// For classification, we follow sklearn conventions:
10//   - predictions = class indices (argmax of logits)
11//   - targets = true class indices
12//
13// Multi-class averaging: macro (default), micro, weighted, per-class.
14
15use shrew_core::backend::Backend;
16use shrew_core::tensor::Tensor;
17use shrew_core::Result;
18
19// Averaging strategy
20
21/// How to average per-class metrics for multi-class problems.
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum Average {
24    /// Compute metric per class, then take unweighted mean.
25    Macro,
26    /// Compute globally: total TP / (total TP + total FP/FN).
27    Micro,
28    /// Per-class metric weighted by class support (number of true instances).
29    Weighted,
30}
31
32// Confusion matrix
33
34/// NxN confusion matrix. Entry [i][j] = count of samples with true class i
35/// predicted as class j.
36#[derive(Debug, Clone)]
37pub struct ConfusionMatrix {
38    pub matrix: Vec<Vec<u64>>,
39    pub n_classes: usize,
40}
41
42impl ConfusionMatrix {
43    /// Build a confusion matrix from predicted and true class indices.
44    pub fn from_predictions(predictions: &[usize], targets: &[usize], n_classes: usize) -> Self {
45        let mut matrix = vec![vec![0u64; n_classes]; n_classes];
46        for (&pred, &target) in predictions.iter().zip(targets.iter()) {
47            if target < n_classes && pred < n_classes {
48                matrix[target][pred] += 1;
49            }
50        }
51        ConfusionMatrix { matrix, n_classes }
52    }
53
54    /// True positives for class c.
55    pub fn tp(&self, c: usize) -> u64 {
56        self.matrix[c][c]
57    }
58
59    /// False positives for class c (predicted c but was not c).
60    pub fn fp(&self, c: usize) -> u64 {
61        (0..self.n_classes)
62            .map(|r| if r != c { self.matrix[r][c] } else { 0 })
63            .sum()
64    }
65
66    /// False negatives for class c (was c but predicted other).
67    pub fn fn_(&self, c: usize) -> u64 {
68        (0..self.n_classes)
69            .map(|col| if col != c { self.matrix[c][col] } else { 0 })
70            .sum()
71    }
72
73    /// True negatives for class c.
74    pub fn tn(&self, c: usize) -> u64 {
75        let total: u64 = self.matrix.iter().flat_map(|r| r.iter()).sum();
76        total - self.tp(c) - self.fp(c) - self.fn_(c)
77    }
78
79    /// Support (number of true instances) for class c.
80    pub fn support(&self, c: usize) -> u64 {
81        self.matrix[c].iter().sum()
82    }
83
84    /// Total number of samples.
85    pub fn total(&self) -> u64 {
86        self.matrix.iter().flat_map(|r| r.iter()).sum()
87    }
88
89    /// Pretty-print the confusion matrix.
90    pub fn to_string_table(&self) -> String {
91        let mut s = String::new();
92        s.push_str(&format!("{:>8}", ""));
93        for c in 0..self.n_classes {
94            s.push_str(&format!("{:>8}", format!("Pred {c}")));
95        }
96        s.push('\n');
97        for r in 0..self.n_classes {
98            s.push_str(&format!("{:>8}", format!("True {r}")));
99            for c in 0..self.n_classes {
100                s.push_str(&format!("{:>8}", self.matrix[r][c]));
101            }
102            s.push('\n');
103        }
104        s
105    }
106}
107
108// Classification metrics (from class indices)
109
110/// Classification accuracy: fraction of correct predictions.
111pub fn accuracy(predictions: &[usize], targets: &[usize]) -> f64 {
112    if predictions.is_empty() {
113        return 0.0;
114    }
115    let correct = predictions
116        .iter()
117        .zip(targets.iter())
118        .filter(|(p, t)| p == t)
119        .count();
120    correct as f64 / predictions.len() as f64
121}
122
123/// Precision for multi-class classification.
124///
125/// Precision = TP / (TP + FP) — how many selected items are relevant.
126pub fn precision(predictions: &[usize], targets: &[usize], n_classes: usize, avg: Average) -> f64 {
127    let cm = ConfusionMatrix::from_predictions(predictions, targets, n_classes);
128    match avg {
129        Average::Micro => {
130            let total_tp: u64 = (0..n_classes).map(|c| cm.tp(c)).sum();
131            let total_tp_fp: u64 = (0..n_classes).map(|c| cm.tp(c) + cm.fp(c)).sum();
132            if total_tp_fp == 0 {
133                0.0
134            } else {
135                total_tp as f64 / total_tp_fp as f64
136            }
137        }
138        Average::Macro => {
139            let precs: Vec<f64> = (0..n_classes)
140                .map(|c| {
141                    let denom = cm.tp(c) + cm.fp(c);
142                    if denom == 0 {
143                        0.0
144                    } else {
145                        cm.tp(c) as f64 / denom as f64
146                    }
147                })
148                .collect();
149            precs.iter().sum::<f64>() / n_classes as f64
150        }
151        Average::Weighted => {
152            let total = cm.total() as f64;
153            if total == 0.0 {
154                return 0.0;
155            }
156            (0..n_classes)
157                .map(|c| {
158                    let denom = cm.tp(c) + cm.fp(c);
159                    let p = if denom == 0 {
160                        0.0
161                    } else {
162                        cm.tp(c) as f64 / denom as f64
163                    };
164                    p * cm.support(c) as f64 / total
165                })
166                .sum()
167        }
168    }
169}
170
171/// Recall for multi-class classification.
172///
173/// Recall = TP / (TP + FN) — how many relevant items are selected.
174pub fn recall(predictions: &[usize], targets: &[usize], n_classes: usize, avg: Average) -> f64 {
175    let cm = ConfusionMatrix::from_predictions(predictions, targets, n_classes);
176    match avg {
177        Average::Micro => {
178            let total_tp: u64 = (0..n_classes).map(|c| cm.tp(c)).sum();
179            let total_tp_fn: u64 = (0..n_classes).map(|c| cm.tp(c) + cm.fn_(c)).sum();
180            if total_tp_fn == 0 {
181                0.0
182            } else {
183                total_tp as f64 / total_tp_fn as f64
184            }
185        }
186        Average::Macro => {
187            let recs: Vec<f64> = (0..n_classes)
188                .map(|c| {
189                    let denom = cm.tp(c) + cm.fn_(c);
190                    if denom == 0 {
191                        0.0
192                    } else {
193                        cm.tp(c) as f64 / denom as f64
194                    }
195                })
196                .collect();
197            recs.iter().sum::<f64>() / n_classes as f64
198        }
199        Average::Weighted => {
200            let total = cm.total() as f64;
201            if total == 0.0 {
202                return 0.0;
203            }
204            (0..n_classes)
205                .map(|c| {
206                    let denom = cm.tp(c) + cm.fn_(c);
207                    let r = if denom == 0 {
208                        0.0
209                    } else {
210                        cm.tp(c) as f64 / denom as f64
211                    };
212                    r * cm.support(c) as f64 / total
213                })
214                .sum()
215        }
216    }
217}
218
219/// F1 Score — harmonic mean of precision and recall.
220///
221/// F1 = 2 * (precision * recall) / (precision + recall)
222pub fn f1_score(predictions: &[usize], targets: &[usize], n_classes: usize, avg: Average) -> f64 {
223    let p = precision(predictions, targets, n_classes, avg);
224    let r = recall(predictions, targets, n_classes, avg);
225    if p + r == 0.0 {
226        0.0
227    } else {
228        2.0 * p * r / (p + r)
229    }
230}
231
232/// Per-class precision, recall, F1, and support — like sklearn's classification_report.
233pub fn classification_report(
234    predictions: &[usize],
235    targets: &[usize],
236    n_classes: usize,
237) -> Vec<ClassMetrics> {
238    let cm = ConfusionMatrix::from_predictions(predictions, targets, n_classes);
239    (0..n_classes)
240        .map(|c| {
241            let tp = cm.tp(c) as f64;
242            let fp = cm.fp(c) as f64;
243            let fn_ = cm.fn_(c) as f64;
244            let prec = if tp + fp == 0.0 { 0.0 } else { tp / (tp + fp) };
245            let rec = if tp + fn_ == 0.0 {
246                0.0
247            } else {
248                tp / (tp + fn_)
249            };
250            let f1 = if prec + rec == 0.0 {
251                0.0
252            } else {
253                2.0 * prec * rec / (prec + rec)
254            };
255            ClassMetrics {
256                class: c,
257                precision: prec,
258                recall: rec,
259                f1,
260                support: cm.support(c),
261            }
262        })
263        .collect()
264}
265
266/// Per-class metric report entry.
267#[derive(Debug, Clone)]
268pub struct ClassMetrics {
269    pub class: usize,
270    pub precision: f64,
271    pub recall: f64,
272    pub f1: f64,
273    pub support: u64,
274}
275
276/// Top-K accuracy: fraction of samples where the true class is in the top-K predictions.
277///
278/// `scores` is a flat array of shape [n_samples, n_classes] with raw logits/probs.
279pub fn top_k_accuracy(scores: &[f64], targets: &[usize], n_classes: usize, k: usize) -> f64 {
280    let n_samples = targets.len();
281    if n_samples == 0 {
282        return 0.0;
283    }
284    let mut correct = 0usize;
285    for i in 0..n_samples {
286        let row = &scores[i * n_classes..(i + 1) * n_classes];
287        // Get top-k indices
288        let mut indexed: Vec<(usize, f64)> = row.iter().copied().enumerate().collect();
289        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
290        let top_k: Vec<usize> = indexed.iter().take(k).map(|(idx, _)| *idx).collect();
291        if top_k.contains(&targets[i]) {
292            correct += 1;
293        }
294    }
295    correct as f64 / n_samples as f64
296}
297
298// Regression metrics
299
300/// R² (coefficient of determination).
301///
302/// R² = 1 - SS_res / SS_tot, where:
303///   SS_res = sum((y_true - y_pred)²)
304///   SS_tot = sum((y_true - mean(y_true))²)
305pub fn r2_score(predictions: &[f64], targets: &[f64]) -> f64 {
306    let n = targets.len() as f64;
307    if n == 0.0 {
308        return 0.0;
309    }
310    let mean_y: f64 = targets.iter().sum::<f64>() / n;
311    let ss_tot: f64 = targets.iter().map(|y| (y - mean_y).powi(2)).sum();
312    let ss_res: f64 = targets
313        .iter()
314        .zip(predictions.iter())
315        .map(|(y, p)| (y - p).powi(2))
316        .sum();
317    if ss_tot == 0.0 {
318        return 0.0;
319    }
320    1.0 - ss_res / ss_tot
321}
322
323/// Mean Absolute Error: mean(|y_true - y_pred|).
324pub fn mae(predictions: &[f64], targets: &[f64]) -> f64 {
325    let n = targets.len() as f64;
326    if n == 0.0 {
327        return 0.0;
328    }
329    targets
330        .iter()
331        .zip(predictions.iter())
332        .map(|(y, p)| (y - p).abs())
333        .sum::<f64>()
334        / n
335}
336
337/// Root Mean Squared Error: sqrt(mean((y_true - y_pred)²)).
338pub fn rmse(predictions: &[f64], targets: &[f64]) -> f64 {
339    let n = targets.len() as f64;
340    if n == 0.0 {
341        return 0.0;
342    }
343    let mse: f64 = targets
344        .iter()
345        .zip(predictions.iter())
346        .map(|(y, p)| (y - p).powi(2))
347        .sum::<f64>()
348        / n;
349    mse.sqrt()
350}
351
352/// Mean Absolute Percentage Error: mean(|y_true - y_pred| / |y_true|) * 100.
353pub fn mape(predictions: &[f64], targets: &[f64]) -> f64 {
354    let n = targets.len() as f64;
355    if n == 0.0 {
356        return 0.0;
357    }
358    targets
359        .iter()
360        .zip(predictions.iter())
361        .filter(|(y, _)| **y != 0.0)
362        .map(|(y, p)| ((y - p) / y).abs())
363        .sum::<f64>()
364        / n
365        * 100.0
366}
367
368// Language model metrics
369
370/// Perplexity from cross-entropy loss: exp(loss).
371///
372/// Lower perplexity = better language model.
373pub fn perplexity(cross_entropy_loss: f64) -> f64 {
374    cross_entropy_loss.exp()
375}
376
377/// Perplexity from a flat array of per-token log-probabilities.
378///
379/// PPL = exp(-1/N * sum(log_probs))
380pub fn perplexity_from_log_probs(log_probs: &[f64]) -> f64 {
381    let n = log_probs.len() as f64;
382    if n == 0.0 {
383        return f64::INFINITY;
384    }
385    let avg_neg_log_prob = -log_probs.iter().sum::<f64>() / n;
386    avg_neg_log_prob.exp()
387}
388
389// Tensor-level helpers
390
391/// Compute argmax along the last axis, returning class indices.
392///
393/// Input: [batch, n_classes] logits/probabilities.
394/// Output: Vec of length `batch` with predicted class indices.
395pub fn argmax_classes<B: Backend>(logits: &Tensor<B>) -> Result<Vec<usize>> {
396    let data = logits.to_f64_vec()?;
397    let dims = logits.dims();
398    let n_classes = *dims.last().unwrap_or(&1);
399    let batch = data.len() / n_classes;
400
401    let mut classes = Vec::with_capacity(batch);
402    for i in 0..batch {
403        let row = &data[i * n_classes..(i + 1) * n_classes];
404        let (max_idx, _) = row
405            .iter()
406            .enumerate()
407            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
408            .unwrap_or((0, &0.0));
409        classes.push(max_idx);
410    }
411    Ok(classes)
412}
413
414/// Compute accuracy directly from logit tensors and one-hot/class-index targets.
415///
416/// - If `targets` has shape [batch, n_classes] (one-hot), takes argmax of both.
417/// - If `targets` has shape [batch] or [batch, 1], treats as class indices.
418pub fn tensor_accuracy<B: Backend>(logits: &Tensor<B>, targets: &Tensor<B>) -> Result<f64> {
419    let pred_classes = argmax_classes(logits)?;
420    let target_data = targets.to_f64_vec()?;
421    let target_dims = targets.dims();
422    let logit_dims = logits.dims();
423
424    let target_classes: Vec<usize> =
425        if target_dims.len() >= 2 && target_dims.last() == logit_dims.last() {
426            // One-hot encoded targets — take argmax
427            let n_classes = *target_dims.last().unwrap_or(&1);
428            let batch = target_data.len() / n_classes;
429            (0..batch)
430                .map(|i| {
431                    let row = &target_data[i * n_classes..(i + 1) * n_classes];
432                    row.iter()
433                        .enumerate()
434                        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
435                        .map(|(idx, _)| idx)
436                        .unwrap_or(0)
437                })
438                .collect()
439        } else {
440            // Class indices
441            target_data.iter().map(|v| *v as usize).collect()
442        };
443
444    Ok(accuracy(&pred_classes, &target_classes))
445}
446
447// Tests
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_accuracy_perfect() {
455        assert_eq!(accuracy(&[0, 1, 2, 0], &[0, 1, 2, 0]), 1.0);
456    }
457
458    #[test]
459    fn test_accuracy_50_percent() {
460        assert_eq!(accuracy(&[0, 1, 0, 1], &[0, 0, 1, 1]), 0.5);
461    }
462
463    #[test]
464    fn test_confusion_matrix_binary() {
465        // TP=2, FP=1, FN=1, TN=2
466        let preds = [1, 1, 1, 0, 0, 0];
467        let targets = [1, 1, 0, 0, 0, 1];
468        let cm = ConfusionMatrix::from_predictions(&preds, &targets, 2);
469        assert_eq!(cm.tp(1), 2); // true positive for class 1
470        assert_eq!(cm.fp(1), 1); // false positive for class 1
471        assert_eq!(cm.fn_(1), 1); // false negative for class 1
472        assert_eq!(cm.tn(1), 2); // true negative for class 1
473    }
474
475    #[test]
476    fn test_precision_recall_f1_binary() {
477        let preds = [1, 1, 1, 0, 0, 0];
478        let targets = [1, 1, 0, 0, 0, 1];
479        let p = precision(&preds, &targets, 2, Average::Macro);
480        let r = recall(&preds, &targets, 2, Average::Macro);
481        let f = f1_score(&preds, &targets, 2, Average::Macro);
482        assert!((p - 0.6667).abs() < 0.01);
483        assert!((r - 0.6667).abs() < 0.01);
484        assert!((f - 0.6667).abs() < 0.01);
485    }
486
487    #[test]
488    fn test_precision_micro() {
489        let preds = [0, 1, 2, 0, 1, 2];
490        let targets = [0, 1, 2, 1, 0, 2];
491        // micro precision = accuracy for multi-class
492        let p = precision(&preds, &targets, 3, Average::Micro);
493        let a = accuracy(&preds, &targets);
494        assert!((p - a).abs() < 1e-10);
495    }
496
497    #[test]
498    fn test_classification_report() {
499        let preds = [0, 0, 1, 1, 2, 2];
500        let targets = [0, 1, 1, 2, 2, 0];
501        let report = classification_report(&preds, &targets, 3);
502        assert_eq!(report.len(), 3);
503        assert_eq!(report[0].support, 2); // class 0 has 2 true samples
504    }
505
506    #[test]
507    fn test_r2_perfect() {
508        let preds = [1.0, 2.0, 3.0, 4.0];
509        let targets = [1.0, 2.0, 3.0, 4.0];
510        assert!((r2_score(&preds, &targets) - 1.0).abs() < 1e-10);
511    }
512
513    #[test]
514    fn test_mae_rmse() {
515        let preds = [1.0, 2.0, 3.0];
516        let targets = [1.0, 3.0, 5.0];
517        assert!((mae(&preds, &targets) - 1.0).abs() < 1e-10);
518        assert!(
519            (rmse(&preds, &targets) - (2.0f64 / 1.0).sqrt() * (3.0f64 / 3.0).sqrt()).abs() < 0.2
520        );
521    }
522
523    #[test]
524    fn test_perplexity() {
525        assert!((perplexity(0.0) - 1.0).abs() < 1e-10);
526        assert!((perplexity(1.0) - std::f64::consts::E).abs() < 1e-10);
527    }
528
529    #[test]
530    fn test_top_k_accuracy() {
531        // 2 samples, 3 classes
532        let scores = [
533            0.1, 0.7, 0.2, // sample 0: pred class 1
534            0.8, 0.05, 0.15,
535        ]; // sample 1: pred class 0, 2nd is class 2
536        let targets = [1, 2];
537        // top-1: only sample 0 correct
538        assert!((top_k_accuracy(&scores, &targets, 3, 1) - 0.5).abs() < 1e-10);
539        // top-2: both correct (class 2 is 2nd highest for sample 1)
540        assert!((top_k_accuracy(&scores, &targets, 3, 2) - 1.0).abs() < 1e-10);
541    }
542
543    #[test]
544    fn test_argmax_classes() {
545        use shrew_cpu::{CpuBackend, CpuDevice};
546        let t = Tensor::<CpuBackend>::from_f64_slice(
547            &[
548                0.1, 0.9, 0.3, // class 1
549                0.8, 0.1, 0.1,
550            ], // class 0
551            vec![2, 3],
552            shrew_core::DType::F64,
553            &CpuDevice,
554        )
555        .unwrap();
556        let classes = argmax_classes(&t).unwrap();
557        assert_eq!(classes, vec![1, 0]);
558    }
559}