shrew_nn/loss.rs
1// Loss Functions
2//
3// Loss functions measure the difference between predictions and targets.
4// The loss is a scalar value that the optimizer tries to minimize.
5//
6// All loss functions return a SCALAR tensor so that backward() works directly.
7//
8// KEY LOSSES:
9//
10// 1. MSE (Mean Squared Error): mean((pred - target)²)
11// Used for regression tasks. Penalizes large errors quadratically.
12//
13// 2. Cross-Entropy: -mean(sum_classes(target * log(softmax(pred))))
14// Used for classification. The standard loss for neural networks that
15// output class probabilities.
16//
17// 3. L1 Loss (Mean Absolute Error): mean(|pred - target|)
18// More robust to outliers than MSE.
19//
20// 4. Smooth L1 / Huber Loss: smooth transition between L1 and L2 at beta.
21// Used in object detection (Faster R-CNN).
22//
23// 5. BCE (Binary Cross-Entropy): binary classification with probabilities.
24//
25// 6. BCE with Logits: combines sigmoid + BCE for numerical stability.
26//
27// 7. NLL Loss: negative log-likelihood with class indices.
28
29use shrew_core::backend::Backend;
30use shrew_core::error::Result;
31use shrew_core::tensor::Tensor;
32
33/// Reduction mode for loss functions.
34///
35/// Controls how the per-element losses are aggregated:
36/// - `Mean` (default): average over all elements
37/// - `Sum`: sum over all elements
38/// - `None`: return per-element losses (no reduction)
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum Reduction {
41 /// Return the mean of all per-element losses (default).
42 #[default]
43 Mean,
44 /// Return the sum of all per-element losses.
45 Sum,
46 /// Return per-element losses without reduction.
47 None,
48}
49
50/// Apply the reduction to a per-element loss tensor.
51fn apply_reduction<B: Backend>(loss: &Tensor<B>, reduction: Reduction) -> Result<Tensor<B>> {
52 match reduction {
53 Reduction::Mean => loss.mean_all(),
54 Reduction::Sum => loss.sum_all(),
55 Reduction::None => Ok(loss.clone()),
56 }
57}
58
59/// Mean Squared Error loss: mean((prediction - target)²)
60///
61/// Both prediction and target must have the same shape.
62/// Returns a scalar tensor.
63///
64/// # Example
65/// ```ignore
66/// let loss = mse_loss(&y_pred, &y_true)?;
67/// let grads = loss.backward()?;
68/// ```
69pub fn mse_loss<B: Backend>(prediction: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
70 let diff = prediction.sub(target)?;
71 let sq = diff.square()?;
72 sq.mean_all()
73}
74
75/// MSE Loss with configurable reduction.
76pub fn mse_loss_with_reduction<B: Backend>(
77 prediction: &Tensor<B>,
78 target: &Tensor<B>,
79 reduction: Reduction,
80) -> Result<Tensor<B>> {
81 let diff = prediction.sub(target)?;
82 let sq = diff.square()?;
83 apply_reduction(&sq, reduction)
84}
85
86/// Cross-entropy loss with log-softmax for numerical stability.
87///
88/// Computes: -mean( sum_over_classes( target * log_softmax(prediction) ) )
89///
90/// # Arguments
91/// - `logits`: raw scores [batch, num_classes] (NOT softmax-ed)
92/// - `target`: one-hot encoded targets [batch, num_classes]
93///
94/// # Numerical Stability
95/// Uses the tensor-level `log_softmax` which computes:
96/// log_softmax(x)_i = x_i - max(x) - log(sum(exp(x - max(x))))
97/// This is built entirely from differentiable tensor ops, so gradients
98/// flow back through logits automatically.
99pub fn cross_entropy_loss<B: Backend>(logits: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
100 let dims = logits.dims();
101 if dims.len() != 2 {
102 return Err(shrew_core::Error::msg(format!(
103 "cross_entropy expects 2D logits [batch, classes], got {:?}",
104 dims
105 )));
106 }
107
108 // log_softmax along class dimension (dim=1) — fully differentiable
109 let log_sm = logits.log_softmax(1)?;
110
111 // Cross-entropy = -mean(sum_classes(target * log_softmax))
112 let prod = target.mul(&log_sm)?;
113 let sum_classes = prod.sum(1, false)?; // [batch]
114 let mean_batch = sum_classes.mean_all()?; // scalar
115 mean_batch.neg()
116}
117
118/// L1 Loss (Mean Absolute Error): mean(|prediction - target|)
119///
120/// More robust to outliers than MSE because errors grow linearly, not
121/// quadratically. Commonly used in regression with noisy targets.
122///
123/// Both prediction and target must have the same shape.
124/// Returns a scalar tensor.
125pub fn l1_loss<B: Backend>(prediction: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
126 let diff = prediction.sub(target)?;
127 let abs_diff = diff.abs()?;
128 abs_diff.mean_all()
129}
130
131/// L1 Loss with configurable reduction.
132pub fn l1_loss_with_reduction<B: Backend>(
133 prediction: &Tensor<B>,
134 target: &Tensor<B>,
135 reduction: Reduction,
136) -> Result<Tensor<B>> {
137 let diff = prediction.sub(target)?;
138 let abs_diff = diff.abs()?;
139 apply_reduction(&abs_diff, reduction)
140}
141
142/// Smooth L1 Loss (Huber Loss):
143///
144/// ```text
145/// ⎧ 0.5 * (x)² / beta if |x| < beta
146/// loss(x) = ⎨
147/// ⎩ |x| - 0.5 * beta otherwise
148/// ```
149/// where x = prediction - target.
150///
151/// Transitions smoothly from L2 (near zero) to L1 (far from zero) at `beta`.
152/// Used in Faster R-CNN and SSD for bounding box regression.
153///
154/// # Arguments
155/// - `prediction`: predicted values (any shape)
156/// - `target`: ground truth values (same shape)
157/// - `beta`: threshold at which to switch from L2 to L1 (must be > 0)
158pub fn smooth_l1_loss<B: Backend>(
159 prediction: &Tensor<B>,
160 target: &Tensor<B>,
161 beta: f64,
162) -> Result<Tensor<B>> {
163 if beta <= 0.0 {
164 return Err(shrew_core::Error::msg("smooth_l1_loss: beta must be > 0"));
165 }
166 let diff = prediction.sub(target)?;
167 let abs_diff = diff.abs()?;
168
169 // L2 branch: 0.5 * x² / beta
170 let l2_part = diff.square()?.affine(0.5 / beta, 0.0)?;
171
172 // L1 branch: |x| - 0.5 * beta
173 let l1_part = abs_diff.affine(1.0, -0.5 * beta)?;
174
175 // Mask: 1 where |x| < beta, 0 otherwise
176 let beta_tensor = Tensor::<B>::full(
177 abs_diff.shape().clone(),
178 beta,
179 abs_diff.dtype(),
180 abs_diff.device(),
181 )?;
182 let mask = abs_diff.lt(&beta_tensor)?;
183
184 // Combine: where(mask, l2_part, l1_part)
185 let result = Tensor::<B>::where_cond(&mask, &l2_part, &l1_part)?;
186 result.mean_all()
187}
188
189/// Binary Cross-Entropy loss for probabilities in [0, 1].
190///
191/// Computes: -mean(target * log(pred) + (1 - target) * log(1 - pred))
192///
193/// # Arguments
194/// - `prediction`: predicted probabilities in (0, 1) — typically sigmoid output
195/// - `target`: binary targets in {0, 1}
196///
197/// # Warning
198/// Numerically unstable if prediction values are exactly 0 or 1.
199/// Use `bce_with_logits_loss` for better stability.
200pub fn bce_loss<B: Backend>(prediction: &Tensor<B>, target: &Tensor<B>) -> Result<Tensor<B>> {
201 // Clamp predictions to avoid log(0)
202 let eps = 1e-7;
203 let pred_clamped = prediction.clamp(eps, 1.0 - eps)?;
204
205 // -[target * log(pred) + (1-target) * log(1-pred)]
206 let log_pred = pred_clamped.log()?;
207
208 let ones = Tensor::<B>::ones(
209 prediction.shape().clone(),
210 prediction.dtype(),
211 prediction.device(),
212 )?;
213 let one_minus_pred = ones.sub(&pred_clamped)?;
214 let log_one_minus_pred = one_minus_pred.log()?;
215
216 let one_minus_target = ones.sub(target)?;
217
218 // target * log(pred) + (1-target) * log(1-pred)
219 let term1 = target.mul(&log_pred)?;
220 let term2 = one_minus_target.mul(&log_one_minus_pred)?;
221 let sum = term1.add(&term2)?;
222
223 sum.mean_all()?.neg()
224}
225
226/// Binary Cross-Entropy with Logits (numerically stable).
227///
228/// Combines sigmoid + BCE in a single formula:
229/// loss = mean(max(x, 0) - x*t + log(1 + exp(-|x|)))
230///
231/// This is numerically stable for any logit value.
232///
233/// # Arguments
234/// - `logits`: raw scores (before sigmoid), any shape
235/// - `target`: binary targets in {0, 1}, same shape
236pub fn bce_with_logits_loss<B: Backend>(
237 logits: &Tensor<B>,
238 target: &Tensor<B>,
239) -> Result<Tensor<B>> {
240 // Stable formula: max(x,0) - x*t + log(1 + exp(-|x|))
241 // = relu(x) - x*t + softplus(-|x|)
242 let relu_x = logits.relu()?;
243 let x_times_t = logits.mul(target)?;
244 let abs_x = logits.abs()?;
245 let neg_abs = abs_x.neg()?;
246 let exp_neg_abs = neg_abs.exp()?;
247
248 let ones = Tensor::<B>::ones(logits.shape().clone(), logits.dtype(), logits.device())?;
249 let one_plus_exp = ones.add(&exp_neg_abs)?;
250 let log_term = one_plus_exp.log()?;
251
252 // relu(x) - x*t + log(1 + exp(-|x|))
253 let loss = relu_x.sub(&x_times_t)?.add(&log_term)?;
254 loss.mean_all()
255}
256
257/// Negative Log-Likelihood Loss with integer class indices.
258///
259/// Computes: -mean(log_probs[i, target[i]]) for each sample i.
260///
261/// # Arguments
262/// - `log_probs`: log-probabilities [batch, num_classes] (output of log_softmax)
263/// - `targets`: class indices as f64 [batch] — each value in 0..num_classes
264///
265/// Typically used as: `nll_loss(&logits.log_softmax(1)?, &targets)`
266///
267/// Note: unlike cross_entropy_loss which takes one-hot targets,
268/// this takes integer class indices (more memory efficient).
269pub fn nll_loss<B: Backend>(log_probs: &Tensor<B>, targets: &Tensor<B>) -> Result<Tensor<B>> {
270 let dims = log_probs.dims();
271 if dims.len() != 2 {
272 return Err(shrew_core::Error::msg(format!(
273 "nll_loss expects 2D log_probs [batch, classes], got {:?}",
274 dims
275 )));
276 }
277 let batch = dims[0];
278 let num_classes = dims[1];
279
280 // Convert targets to one-hot, then use element-wise multiply
281 let target_vals = targets.to_f64_vec()?;
282 let mut one_hot = vec![0.0f64; batch * num_classes];
283 for i in 0..batch {
284 let cls = target_vals[i] as usize;
285 if cls >= num_classes {
286 return Err(shrew_core::Error::msg(format!(
287 "nll_loss: target index {} out of range for {} classes",
288 cls, num_classes
289 )));
290 }
291 one_hot[i * num_classes + cls] = 1.0;
292 }
293 let one_hot_tensor = Tensor::<B>::from_f64_slice(
294 &one_hot,
295 (batch, num_classes),
296 log_probs.dtype(),
297 log_probs.device(),
298 )?;
299
300 // -mean(sum_classes(one_hot * log_probs))
301 let prod = one_hot_tensor.mul(log_probs)?;
302 let sum_classes = prod.sum(1, false)?;
303 let mean_batch = sum_classes.mean_all()?;
304 mean_batch.neg()
305}