1use shrew_core::backend::Backend;
16use shrew_core::tensor::Tensor;
17use shrew_core::Result;
18
19#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum Average {
24 Macro,
26 Micro,
28 Weighted,
30}
31
32#[derive(Debug, Clone)]
37pub struct ConfusionMatrix {
38 pub matrix: Vec<Vec<u64>>,
39 pub n_classes: usize,
40}
41
42impl ConfusionMatrix {
43 pub fn from_predictions(predictions: &[usize], targets: &[usize], n_classes: usize) -> Self {
45 let mut matrix = vec![vec![0u64; n_classes]; n_classes];
46 for (&pred, &target) in predictions.iter().zip(targets.iter()) {
47 if target < n_classes && pred < n_classes {
48 matrix[target][pred] += 1;
49 }
50 }
51 ConfusionMatrix { matrix, n_classes }
52 }
53
54 pub fn tp(&self, c: usize) -> u64 {
56 self.matrix[c][c]
57 }
58
59 pub fn fp(&self, c: usize) -> u64 {
61 (0..self.n_classes)
62 .map(|r| if r != c { self.matrix[r][c] } else { 0 })
63 .sum()
64 }
65
66 pub fn fn_(&self, c: usize) -> u64 {
68 (0..self.n_classes)
69 .map(|col| if col != c { self.matrix[c][col] } else { 0 })
70 .sum()
71 }
72
73 pub fn tn(&self, c: usize) -> u64 {
75 let total: u64 = self.matrix.iter().flat_map(|r| r.iter()).sum();
76 total - self.tp(c) - self.fp(c) - self.fn_(c)
77 }
78
79 pub fn support(&self, c: usize) -> u64 {
81 self.matrix[c].iter().sum()
82 }
83
84 pub fn total(&self) -> u64 {
86 self.matrix.iter().flat_map(|r| r.iter()).sum()
87 }
88
89 pub fn to_string_table(&self) -> String {
91 let mut s = String::new();
92 s.push_str(&format!("{:>8}", ""));
93 for c in 0..self.n_classes {
94 s.push_str(&format!("{:>8}", format!("Pred {c}")));
95 }
96 s.push('\n');
97 for r in 0..self.n_classes {
98 s.push_str(&format!("{:>8}", format!("True {r}")));
99 for c in 0..self.n_classes {
100 s.push_str(&format!("{:>8}", self.matrix[r][c]));
101 }
102 s.push('\n');
103 }
104 s
105 }
106}
107
108pub fn accuracy(predictions: &[usize], targets: &[usize]) -> f64 {
112 if predictions.is_empty() {
113 return 0.0;
114 }
115 let correct = predictions
116 .iter()
117 .zip(targets.iter())
118 .filter(|(p, t)| p == t)
119 .count();
120 correct as f64 / predictions.len() as f64
121}
122
123pub fn precision(predictions: &[usize], targets: &[usize], n_classes: usize, avg: Average) -> f64 {
127 let cm = ConfusionMatrix::from_predictions(predictions, targets, n_classes);
128 match avg {
129 Average::Micro => {
130 let total_tp: u64 = (0..n_classes).map(|c| cm.tp(c)).sum();
131 let total_tp_fp: u64 = (0..n_classes).map(|c| cm.tp(c) + cm.fp(c)).sum();
132 if total_tp_fp == 0 {
133 0.0
134 } else {
135 total_tp as f64 / total_tp_fp as f64
136 }
137 }
138 Average::Macro => {
139 let precs: Vec<f64> = (0..n_classes)
140 .map(|c| {
141 let denom = cm.tp(c) + cm.fp(c);
142 if denom == 0 {
143 0.0
144 } else {
145 cm.tp(c) as f64 / denom as f64
146 }
147 })
148 .collect();
149 precs.iter().sum::<f64>() / n_classes as f64
150 }
151 Average::Weighted => {
152 let total = cm.total() as f64;
153 if total == 0.0 {
154 return 0.0;
155 }
156 (0..n_classes)
157 .map(|c| {
158 let denom = cm.tp(c) + cm.fp(c);
159 let p = if denom == 0 {
160 0.0
161 } else {
162 cm.tp(c) as f64 / denom as f64
163 };
164 p * cm.support(c) as f64 / total
165 })
166 .sum()
167 }
168 }
169}
170
171pub fn recall(predictions: &[usize], targets: &[usize], n_classes: usize, avg: Average) -> f64 {
175 let cm = ConfusionMatrix::from_predictions(predictions, targets, n_classes);
176 match avg {
177 Average::Micro => {
178 let total_tp: u64 = (0..n_classes).map(|c| cm.tp(c)).sum();
179 let total_tp_fn: u64 = (0..n_classes).map(|c| cm.tp(c) + cm.fn_(c)).sum();
180 if total_tp_fn == 0 {
181 0.0
182 } else {
183 total_tp as f64 / total_tp_fn as f64
184 }
185 }
186 Average::Macro => {
187 let recs: Vec<f64> = (0..n_classes)
188 .map(|c| {
189 let denom = cm.tp(c) + cm.fn_(c);
190 if denom == 0 {
191 0.0
192 } else {
193 cm.tp(c) as f64 / denom as f64
194 }
195 })
196 .collect();
197 recs.iter().sum::<f64>() / n_classes as f64
198 }
199 Average::Weighted => {
200 let total = cm.total() as f64;
201 if total == 0.0 {
202 return 0.0;
203 }
204 (0..n_classes)
205 .map(|c| {
206 let denom = cm.tp(c) + cm.fn_(c);
207 let r = if denom == 0 {
208 0.0
209 } else {
210 cm.tp(c) as f64 / denom as f64
211 };
212 r * cm.support(c) as f64 / total
213 })
214 .sum()
215 }
216 }
217}
218
219pub fn f1_score(predictions: &[usize], targets: &[usize], n_classes: usize, avg: Average) -> f64 {
223 let p = precision(predictions, targets, n_classes, avg);
224 let r = recall(predictions, targets, n_classes, avg);
225 if p + r == 0.0 {
226 0.0
227 } else {
228 2.0 * p * r / (p + r)
229 }
230}
231
232pub fn classification_report(
234 predictions: &[usize],
235 targets: &[usize],
236 n_classes: usize,
237) -> Vec<ClassMetrics> {
238 let cm = ConfusionMatrix::from_predictions(predictions, targets, n_classes);
239 (0..n_classes)
240 .map(|c| {
241 let tp = cm.tp(c) as f64;
242 let fp = cm.fp(c) as f64;
243 let fn_ = cm.fn_(c) as f64;
244 let prec = if tp + fp == 0.0 { 0.0 } else { tp / (tp + fp) };
245 let rec = if tp + fn_ == 0.0 {
246 0.0
247 } else {
248 tp / (tp + fn_)
249 };
250 let f1 = if prec + rec == 0.0 {
251 0.0
252 } else {
253 2.0 * prec * rec / (prec + rec)
254 };
255 ClassMetrics {
256 class: c,
257 precision: prec,
258 recall: rec,
259 f1,
260 support: cm.support(c),
261 }
262 })
263 .collect()
264}
265
266#[derive(Debug, Clone)]
268pub struct ClassMetrics {
269 pub class: usize,
270 pub precision: f64,
271 pub recall: f64,
272 pub f1: f64,
273 pub support: u64,
274}
275
276pub fn top_k_accuracy(scores: &[f64], targets: &[usize], n_classes: usize, k: usize) -> f64 {
280 let n_samples = targets.len();
281 if n_samples == 0 {
282 return 0.0;
283 }
284 let mut correct = 0usize;
285 for i in 0..n_samples {
286 let row = &scores[i * n_classes..(i + 1) * n_classes];
287 let mut indexed: Vec<(usize, f64)> = row.iter().copied().enumerate().collect();
289 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
290 let top_k: Vec<usize> = indexed.iter().take(k).map(|(idx, _)| *idx).collect();
291 if top_k.contains(&targets[i]) {
292 correct += 1;
293 }
294 }
295 correct as f64 / n_samples as f64
296}
297
298pub fn r2_score(predictions: &[f64], targets: &[f64]) -> f64 {
306 let n = targets.len() as f64;
307 if n == 0.0 {
308 return 0.0;
309 }
310 let mean_y: f64 = targets.iter().sum::<f64>() / n;
311 let ss_tot: f64 = targets.iter().map(|y| (y - mean_y).powi(2)).sum();
312 let ss_res: f64 = targets
313 .iter()
314 .zip(predictions.iter())
315 .map(|(y, p)| (y - p).powi(2))
316 .sum();
317 if ss_tot == 0.0 {
318 return 0.0;
319 }
320 1.0 - ss_res / ss_tot
321}
322
323pub fn mae(predictions: &[f64], targets: &[f64]) -> f64 {
325 let n = targets.len() as f64;
326 if n == 0.0 {
327 return 0.0;
328 }
329 targets
330 .iter()
331 .zip(predictions.iter())
332 .map(|(y, p)| (y - p).abs())
333 .sum::<f64>()
334 / n
335}
336
337pub fn rmse(predictions: &[f64], targets: &[f64]) -> f64 {
339 let n = targets.len() as f64;
340 if n == 0.0 {
341 return 0.0;
342 }
343 let mse: f64 = targets
344 .iter()
345 .zip(predictions.iter())
346 .map(|(y, p)| (y - p).powi(2))
347 .sum::<f64>()
348 / n;
349 mse.sqrt()
350}
351
352pub fn mape(predictions: &[f64], targets: &[f64]) -> f64 {
354 let n = targets.len() as f64;
355 if n == 0.0 {
356 return 0.0;
357 }
358 targets
359 .iter()
360 .zip(predictions.iter())
361 .filter(|(y, _)| **y != 0.0)
362 .map(|(y, p)| ((y - p) / y).abs())
363 .sum::<f64>()
364 / n
365 * 100.0
366}
367
368pub fn perplexity(cross_entropy_loss: f64) -> f64 {
374 cross_entropy_loss.exp()
375}
376
377pub fn perplexity_from_log_probs(log_probs: &[f64]) -> f64 {
381 let n = log_probs.len() as f64;
382 if n == 0.0 {
383 return f64::INFINITY;
384 }
385 let avg_neg_log_prob = -log_probs.iter().sum::<f64>() / n;
386 avg_neg_log_prob.exp()
387}
388
389pub fn argmax_classes<B: Backend>(logits: &Tensor<B>) -> Result<Vec<usize>> {
396 let data = logits.to_f64_vec()?;
397 let dims = logits.dims();
398 let n_classes = *dims.last().unwrap_or(&1);
399 let batch = data.len() / n_classes;
400
401 let mut classes = Vec::with_capacity(batch);
402 for i in 0..batch {
403 let row = &data[i * n_classes..(i + 1) * n_classes];
404 let (max_idx, _) = row
405 .iter()
406 .enumerate()
407 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
408 .unwrap_or((0, &0.0));
409 classes.push(max_idx);
410 }
411 Ok(classes)
412}
413
414pub fn tensor_accuracy<B: Backend>(logits: &Tensor<B>, targets: &Tensor<B>) -> Result<f64> {
419 let pred_classes = argmax_classes(logits)?;
420 let target_data = targets.to_f64_vec()?;
421 let target_dims = targets.dims();
422 let logit_dims = logits.dims();
423
424 let target_classes: Vec<usize> =
425 if target_dims.len() >= 2 && target_dims.last() == logit_dims.last() {
426 let n_classes = *target_dims.last().unwrap_or(&1);
428 let batch = target_data.len() / n_classes;
429 (0..batch)
430 .map(|i| {
431 let row = &target_data[i * n_classes..(i + 1) * n_classes];
432 row.iter()
433 .enumerate()
434 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
435 .map(|(idx, _)| idx)
436 .unwrap_or(0)
437 })
438 .collect()
439 } else {
440 target_data.iter().map(|v| *v as usize).collect()
442 };
443
444 Ok(accuracy(&pred_classes, &target_classes))
445}
446
447#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_accuracy_perfect() {
455 assert_eq!(accuracy(&[0, 1, 2, 0], &[0, 1, 2, 0]), 1.0);
456 }
457
458 #[test]
459 fn test_accuracy_50_percent() {
460 assert_eq!(accuracy(&[0, 1, 0, 1], &[0, 0, 1, 1]), 0.5);
461 }
462
463 #[test]
464 fn test_confusion_matrix_binary() {
465 let preds = [1, 1, 1, 0, 0, 0];
467 let targets = [1, 1, 0, 0, 0, 1];
468 let cm = ConfusionMatrix::from_predictions(&preds, &targets, 2);
469 assert_eq!(cm.tp(1), 2); assert_eq!(cm.fp(1), 1); assert_eq!(cm.fn_(1), 1); assert_eq!(cm.tn(1), 2); }
474
475 #[test]
476 fn test_precision_recall_f1_binary() {
477 let preds = [1, 1, 1, 0, 0, 0];
478 let targets = [1, 1, 0, 0, 0, 1];
479 let p = precision(&preds, &targets, 2, Average::Macro);
480 let r = recall(&preds, &targets, 2, Average::Macro);
481 let f = f1_score(&preds, &targets, 2, Average::Macro);
482 assert!((p - 0.6667).abs() < 0.01);
483 assert!((r - 0.6667).abs() < 0.01);
484 assert!((f - 0.6667).abs() < 0.01);
485 }
486
487 #[test]
488 fn test_precision_micro() {
489 let preds = [0, 1, 2, 0, 1, 2];
490 let targets = [0, 1, 2, 1, 0, 2];
491 let p = precision(&preds, &targets, 3, Average::Micro);
493 let a = accuracy(&preds, &targets);
494 assert!((p - a).abs() < 1e-10);
495 }
496
497 #[test]
498 fn test_classification_report() {
499 let preds = [0, 0, 1, 1, 2, 2];
500 let targets = [0, 1, 1, 2, 2, 0];
501 let report = classification_report(&preds, &targets, 3);
502 assert_eq!(report.len(), 3);
503 assert_eq!(report[0].support, 2); }
505
506 #[test]
507 fn test_r2_perfect() {
508 let preds = [1.0, 2.0, 3.0, 4.0];
509 let targets = [1.0, 2.0, 3.0, 4.0];
510 assert!((r2_score(&preds, &targets) - 1.0).abs() < 1e-10);
511 }
512
513 #[test]
514 fn test_mae_rmse() {
515 let preds = [1.0, 2.0, 3.0];
516 let targets = [1.0, 3.0, 5.0];
517 assert!((mae(&preds, &targets) - 1.0).abs() < 1e-10);
518 assert!(
519 (rmse(&preds, &targets) - (2.0f64 / 1.0).sqrt() * (3.0f64 / 3.0).sqrt()).abs() < 0.2
520 );
521 }
522
523 #[test]
524 fn test_perplexity() {
525 assert!((perplexity(0.0) - 1.0).abs() < 1e-10);
526 assert!((perplexity(1.0) - std::f64::consts::E).abs() < 1e-10);
527 }
528
529 #[test]
530 fn test_top_k_accuracy() {
531 let scores = [
533 0.1, 0.7, 0.2, 0.8, 0.05, 0.15,
535 ]; let targets = [1, 2];
537 assert!((top_k_accuracy(&scores, &targets, 3, 1) - 0.5).abs() < 1e-10);
539 assert!((top_k_accuracy(&scores, &targets, 3, 2) - 1.0).abs() < 1e-10);
541 }
542
543 #[test]
544 fn test_argmax_classes() {
545 use shrew_cpu::{CpuBackend, CpuDevice};
546 let t = Tensor::<CpuBackend>::from_f64_slice(
547 &[
548 0.1, 0.9, 0.3, 0.8, 0.1, 0.1,
550 ], vec![2, 3],
552 shrew_core::DType::F64,
553 &CpuDevice,
554 )
555 .unwrap();
556 let classes = argmax_classes(&t).unwrap();
557 assert_eq!(classes, vec![1, 0]);
558 }
559}