shrew_nn/dropout.rs
1// Dropout — Regularization via random zeroing
2//
3// During training, Dropout randomly sets elements to zero with probability p.
4// The remaining elements are scaled by 1/(1-p) to preserve the expected value.
5// This prevents co-adaptation of neurons and improves generalization.
6//
7// During inference (eval mode), Dropout does nothing (identity function).
8//
9// The training flag uses Cell<bool> for interior mutability, so set_training
10// can be called through the Module trait's &self interface.
11
12use std::cell::Cell;
13
14use shrew_core::backend::Backend;
15use shrew_core::error::Result;
16use shrew_core::tensor::Tensor;
17
18use crate::module::Module;
19
20/// Applies dropout regularization.
21///
22/// During training: randomly zeros elements with probability `p`,
23/// scales remaining by `1/(1-p)`.
24///
25/// During eval: identity (no-op).
26pub struct Dropout {
27 /// Probability of an element being zeroed.
28 p: f64,
29 /// Whether we're in training mode (Cell for interior mutability).
30 training: Cell<bool>,
31}
32
33impl Dropout {
34 /// Create a new Dropout layer.
35 pub fn new(p: f64) -> Self {
36 assert!(
37 (0.0..1.0).contains(&p),
38 "Dropout probability must be in [0, 1)"
39 );
40 Dropout {
41 p,
42 training: Cell::new(true),
43 }
44 }
45
46 /// Set training/eval mode directly (works without specifying backend).
47 pub fn set_training(&self, training: bool) {
48 self.training.set(training);
49 }
50
51 /// Whether module is in training mode (works without specifying backend).
52 pub fn is_training(&self) -> bool {
53 self.training.get()
54 }
55
56 /// Apply dropout: randomly zero elements during training.
57 pub fn forward_t<B: Backend>(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
58 if !self.training.get() || self.p == 0.0 {
59 return Ok(x.clone());
60 }
61
62 let scale = 1.0 / (1.0 - self.p);
63
64 // Generate random mask on-device (no host round-trip)
65 let mask = Tensor::<B>::rand(x.shape().clone(), x.dtype(), x.device())?;
66 let threshold = Tensor::<B>::full(x.shape().clone(), self.p, x.dtype(), x.device())?;
67
68 // keep_mask: U8 tensor — 1 where mask >= p (keep), 0 where drop
69 let keep_mask = mask.ge(&threshold)?;
70
71 // Build zero tensor & scaled input
72 let zeros = Tensor::<B>::zeros(x.shape().clone(), x.dtype(), x.device())?;
73 let scaled_x = x.affine(scale, 0.0)?;
74
75 // Select: where keep → scaled_x, where drop → 0
76 Tensor::<B>::where_cond(&keep_mask, &scaled_x, &zeros)
77 }
78}
79
80// Module impl — note that Dropout has no trainable parameters.
81impl<B: Backend> Module<B> for Dropout {
82 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
83 self.forward_t(x)
84 }
85
86 fn parameters(&self) -> Vec<Tensor<B>> {
87 vec![] // No trainable parameters
88 }
89
90 fn set_training(&self, training: bool) {
91 self.training.set(training);
92 }
93
94 fn is_training(&self) -> bool {
95 self.training.get()
96 }
97}