shrew_nn/
loss.rs

1// Loss Functions
2//
3// Loss functions measure the difference between predictions and targets.
4// The loss is a scalar value that the optimizer tries to minimize.
5//
6// All loss functions return a SCALAR tensor so that backward() works directly.
7//
8// KEY LOSSES:
9//
10// 1. MSE (Mean Squared Error): mean((pred - target)²)
11//    Used for regression tasks. Penalizes large errors quadratically.
12//
13// 2. Cross-Entropy: -mean(sum_classes(target * log(softmax(pred))))
14//    Used for classification. The standard loss for neural networks that
15//    output class probabilities.
16//
17// 3. L1 Loss (Mean Absolute Error): mean(|pred - target|)
18//    More robust to outliers than MSE.
19//
20// 4. Smooth L1 / Huber Loss: smooth transition between L1 and L2 at beta.
21//    Used in object detection (Faster R-CNN).
22//
23// 5. BCE (Binary Cross-Entropy): binary classification with probabilities.
24//
25// 6. BCE with Logits: combines sigmoid + BCE for numerical stability.
26//
27// 7. NLL Loss: negative log-likelihood with class indices.
28
29use shrew_core::backend::Backend;
30use shrew_core::error::Result;
31use shrew_core::tensor::Tensor;
32
33/// Reduction mode for loss functions.
34///
35/// Controls how the per-element losses are aggregated:
36/// - `Mean` (default): average over all elements
37/// - `Sum`: sum over all elements
38/// - `None`: return per-element losses (no reduction)
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum Reduction {
41    /// Return the mean of all per-element losses (default).
42    #[default]
43    Mean,
44    /// Return the sum of all per-element losses.
45    Sum,
46    /// Return per-element losses without reduction.
47    None,
48}
49
50/// Apply the reduction to a per-element loss tensor.
51fn apply_reduction<B: Backend>(loss: &Tensor<B>, reduction: Reduction) -> Result<Tensor<B>> {
52    match reduction {
53        Reduction::Mean => loss.mean_all(),
54        Reduction::Sum => loss.sum_all(),
55        Reduction::None => Ok(loss.clone()),
56    }
57}
58
59/// Mean Squared Error loss: mean((prediction - target)²)
60///
61/// Both prediction and target must have the same shape.
62/// Returns a scalar tensor.
63///
64/// # Example
65/// ```ignore
66/// let loss = mse_loss(&y_pred, &y_true)?;
67/// let grads = loss.backward()?;
68/// ```
69pub fn mse_loss<B: Backend>(prediction: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
70    let diff = prediction.sub(target)?;
71    let sq = diff.square()?;
72    sq.mean_all()
73}
74
75/// MSE Loss with configurable reduction.
76pub fn mse_loss_with_reduction<B: Backend>(
77    prediction: &Tensor<B>,
78    target: &Tensor<B>,
79    reduction: Reduction,
80) -> Result<Tensor<B>> {
81    let diff = prediction.sub(target)?;
82    let sq = diff.square()?;
83    apply_reduction(&sq, reduction)
84}
85
86/// Cross-entropy loss with log-softmax for numerical stability.
87///
88/// Computes: -mean( sum_over_classes( target * log_softmax(prediction) ) )
89///
90/// # Arguments
91/// - `logits`: raw scores [batch, num_classes] (NOT softmax-ed)
92/// - `target`: one-hot encoded targets [batch, num_classes]
93///
94/// # Numerical Stability
95/// Uses the tensor-level `log_softmax` which computes:
96///   log_softmax(x)_i = x_i - max(x) - log(sum(exp(x - max(x))))
97/// This is built entirely from differentiable tensor ops, so gradients
98/// flow back through logits automatically.
99pub fn cross_entropy_loss<B: Backend>(logits: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
100    let dims = logits.dims();
101    if dims.len() != 2 {
102        return Err(shrew_core::Error::msg(format!(
103            "cross_entropy expects 2D logits [batch, classes], got {:?}",
104            dims
105        )));
106    }
107
108    // log_softmax along class dimension (dim=1) — fully differentiable
109    let log_sm = logits.log_softmax(1)?;
110
111    // Cross-entropy = -mean(sum_classes(target * log_softmax))
112    let prod = target.mul(&log_sm)?;
113    let sum_classes = prod.sum(1, false)?; // [batch]
114    let mean_batch = sum_classes.mean_all()?; // scalar
115    mean_batch.neg()
116}
117
118/// L1 Loss (Mean Absolute Error): mean(|prediction - target|)
119///
120/// More robust to outliers than MSE because errors grow linearly, not
121/// quadratically. Commonly used in regression with noisy targets.
122///
123/// Both prediction and target must have the same shape.
124/// Returns a scalar tensor.
125pub fn l1_loss<B: Backend>(prediction: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
126    let diff = prediction.sub(target)?;
127    let abs_diff = diff.abs()?;
128    abs_diff.mean_all()
129}
130
131/// L1 Loss with configurable reduction.
132pub fn l1_loss_with_reduction<B: Backend>(
133    prediction: &Tensor<B>,
134    target: &Tensor<B>,
135    reduction: Reduction,
136) -> Result<Tensor<B>> {
137    let diff = prediction.sub(target)?;
138    let abs_diff = diff.abs()?;
139    apply_reduction(&abs_diff, reduction)
140}
141
142/// Smooth L1 Loss (Huber Loss):
143///
144/// ```text
145///             ⎧ 0.5 * (x)² / beta   if |x| < beta
146/// loss(x) =  ⎨
147///             ⎩ |x| - 0.5 * beta     otherwise
148/// ```
149/// where x = prediction - target.
150///
151/// Transitions smoothly from L2 (near zero) to L1 (far from zero) at `beta`.
152/// Used in Faster R-CNN and SSD for bounding box regression.
153///
154/// # Arguments
155/// - `prediction`: predicted values (any shape)
156/// - `target`: ground truth values (same shape)
157/// - `beta`: threshold at which to switch from L2 to L1 (must be > 0)
158pub fn smooth_l1_loss<B: Backend>(
159    prediction: &Tensor<B>,
160    target: &Tensor<B>,
161    beta: f64,
162) -> Result<Tensor<B>> {
163    if beta <= 0.0 {
164        return Err(shrew_core::Error::msg("smooth_l1_loss: beta must be > 0"));
165    }
166    let diff = prediction.sub(target)?;
167    let abs_diff = diff.abs()?;
168
169    // L2 branch: 0.5 * x² / beta
170    let l2_part = diff.square()?.affine(0.5 / beta, 0.0)?;
171
172    // L1 branch: |x| - 0.5 * beta
173    let l1_part = abs_diff.affine(1.0, -0.5 * beta)?;
174
175    // Mask: 1 where |x| < beta, 0 otherwise
176    let beta_tensor = Tensor::<B>::full(
177        abs_diff.shape().clone(),
178        beta,
179        abs_diff.dtype(),
180        abs_diff.device(),
181    )?;
182    let mask = abs_diff.lt(&beta_tensor)?;
183
184    // Combine: where(mask, l2_part, l1_part)
185    let result = Tensor::<B>::where_cond(&mask, &l2_part, &l1_part)?;
186    result.mean_all()
187}
188
189/// Binary Cross-Entropy loss for probabilities in [0, 1].
190///
191/// Computes: -mean(target * log(pred) + (1 - target) * log(1 - pred))
192///
193/// # Arguments
194/// - `prediction`: predicted probabilities in (0, 1) — typically sigmoid output
195/// - `target`: binary targets in {0, 1}
196///
197/// # Warning
198/// Numerically unstable if prediction values are exactly 0 or 1.
199/// Use `bce_with_logits_loss` for better stability.
200pub fn bce_loss<B: Backend>(prediction: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
201    // Clamp predictions to avoid log(0)
202    let eps = 1e-7;
203    let pred_clamped = prediction.clamp(eps, 1.0 - eps)?;
204
205    // -[target * log(pred) + (1-target) * log(1-pred)]
206    let log_pred = pred_clamped.log()?;
207
208    let ones = Tensor::<B>::ones(
209        prediction.shape().clone(),
210        prediction.dtype(),
211        prediction.device(),
212    )?;
213    let one_minus_pred = ones.sub(&pred_clamped)?;
214    let log_one_minus_pred = one_minus_pred.log()?;
215
216    let one_minus_target = ones.sub(target)?;
217
218    // target * log(pred) + (1-target) * log(1-pred)
219    let term1 = target.mul(&log_pred)?;
220    let term2 = one_minus_target.mul(&log_one_minus_pred)?;
221    let sum = term1.add(&term2)?;
222
223    sum.mean_all()?.neg()
224}
225
226/// Binary Cross-Entropy with Logits (numerically stable).
227///
228/// Combines sigmoid + BCE in a single formula:
229///   loss = mean(max(x, 0) - x*t + log(1 + exp(-|x|)))
230///
231/// This is numerically stable for any logit value.
232///
233/// # Arguments
234/// - `logits`: raw scores (before sigmoid), any shape
235/// - `target`: binary targets in {0, 1}, same shape
236pub fn bce_with_logits_loss<B: Backend>(
237    logits: &Tensor<B>,
238    target: &Tensor<B>,
239) -> Result<Tensor<B>> {
240    // Stable formula: max(x,0) - x*t + log(1 + exp(-|x|))
241    // = relu(x) - x*t + softplus(-|x|)
242    let relu_x = logits.relu()?;
243    let x_times_t = logits.mul(target)?;
244    let abs_x = logits.abs()?;
245    let neg_abs = abs_x.neg()?;
246    let exp_neg_abs = neg_abs.exp()?;
247
248    let ones = Tensor::<B>::ones(logits.shape().clone(), logits.dtype(), logits.device())?;
249    let one_plus_exp = ones.add(&exp_neg_abs)?;
250    let log_term = one_plus_exp.log()?;
251
252    // relu(x) - x*t + log(1 + exp(-|x|))
253    let loss = relu_x.sub(&x_times_t)?.add(&log_term)?;
254    loss.mean_all()
255}
256
257/// Negative Log-Likelihood Loss with integer class indices.
258///
259/// Computes: -mean(log_probs[i, target[i]]) for each sample i.
260///
261/// # Arguments
262/// - `log_probs`: log-probabilities [batch, num_classes] (output of log_softmax)
263/// - `targets`: class indices as f64 [batch] — each value in 0..num_classes
264///
265/// Typically used as: `nll_loss(&logits.log_softmax(1)?, &targets)`
266///
267/// Note: unlike cross_entropy_loss which takes one-hot targets,
268/// this takes integer class indices (more memory efficient).
269pub fn nll_loss<B: Backend>(log_probs: &Tensor<B>, targets: &Tensor<B>) -> Result<Tensor<B>> {
270    let dims = log_probs.dims();
271    if dims.len() != 2 {
272        return Err(shrew_core::Error::msg(format!(
273            "nll_loss expects 2D log_probs [batch, classes], got {:?}",
274            dims
275        )));
276    }
277    let batch = dims[0];
278    let num_classes = dims[1];
279
280    // Convert targets to one-hot, then use element-wise multiply
281    let target_vals = targets.to_f64_vec()?;
282    let mut one_hot = vec![0.0f64; batch * num_classes];
283    for i in 0..batch {
284        let cls = target_vals[i] as usize;
285        if cls >= num_classes {
286            return Err(shrew_core::Error::msg(format!(
287                "nll_loss: target index {} out of range for {} classes",
288                cls, num_classes
289            )));
290        }
291        one_hot[i * num_classes + cls] = 1.0;
292    }
293    let one_hot_tensor = Tensor::<B>::from_f64_slice(
294        &one_hot,
295        (batch, num_classes),
296        log_probs.dtype(),
297        log_probs.device(),
298    )?;
299
300    // -mean(sum_classes(one_hot * log_probs))
301    let prod = one_hot_tensor.mul(log_probs)?;
302    let sum_classes = prod.sum(1, false)?;
303    let mean_batch = sum_classes.mean_all()?;
304    mean_batch.neg()
305}