shrew_nn/
dropout.rs

1// Dropout — Regularization via random zeroing
2//
3// During training, Dropout randomly sets elements to zero with probability p.
4// The remaining elements are scaled by 1/(1-p) to preserve the expected value.
5// This prevents co-adaptation of neurons and improves generalization.
6//
7// During inference (eval mode), Dropout does nothing (identity function).
8//
9// The training flag uses Cell<bool> for interior mutability, so set_training
10// can be called through the Module trait's &self interface.
11
12use std::cell::Cell;
13
14use shrew_core::backend::Backend;
15use shrew_core::error::Result;
16use shrew_core::tensor::Tensor;
17
18use crate::module::Module;
19
20/// Applies dropout regularization.
21///
22/// During training: randomly zeros elements with probability `p`,
23/// scales remaining by `1/(1-p)`.
24///
25/// During eval: identity (no-op).
26pub struct Dropout {
27    /// Probability of an element being zeroed.
28    p: f64,
29    /// Whether we're in training mode (Cell for interior mutability).
30    training: Cell<bool>,
31}
32
33impl Dropout {
34    /// Create a new Dropout layer.
35    pub fn new(p: f64) -> Self {
36        assert!(
37            (0.0..1.0).contains(&p),
38            "Dropout probability must be in [0, 1)"
39        );
40        Dropout {
41            p,
42            training: Cell::new(true),
43        }
44    }
45
46    /// Set training/eval mode directly (works without specifying backend).
47    pub fn set_training(&self, training: bool) {
48        self.training.set(training);
49    }
50
51    /// Whether module is in training mode (works without specifying backend).
52    pub fn is_training(&self) -> bool {
53        self.training.get()
54    }
55
56    /// Apply dropout: randomly zero elements during training.
57    pub fn forward_t<B: Backend>(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
58        if !self.training.get() || self.p == 0.0 {
59            return Ok(x.clone());
60        }
61
62        let scale = 1.0 / (1.0 - self.p);
63
64        // Generate random mask on-device (no host round-trip)
65        let mask = Tensor::<B>::rand(x.shape().clone(), x.dtype(), x.device())?;
66        let threshold = Tensor::<B>::full(x.shape().clone(), self.p, x.dtype(), x.device())?;
67
68        // keep_mask: U8 tensor — 1 where mask >= p (keep), 0 where drop
69        let keep_mask = mask.ge(&threshold)?;
70
71        // Build zero tensor & scaled input
72        let zeros = Tensor::<B>::zeros(x.shape().clone(), x.dtype(), x.device())?;
73        let scaled_x = x.affine(scale, 0.0)?;
74
75        // Select: where keep → scaled_x, where drop → 0
76        Tensor::<B>::where_cond(&keep_mask, &scaled_x, &zeros)
77    }
78}
79
80// Module impl — note that Dropout has no trainable parameters.
81impl<B: Backend> Module<B> for Dropout {
82    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
83        self.forward_t(x)
84    }
85
86    fn parameters(&self) -> Vec<Tensor<B>> {
87        vec![] // No trainable parameters
88    }
89
90    fn set_training(&self, training: bool) {
91        self.training.set(training);
92    }
93
94    fn is_training(&self) -> bool {
95        self.training.get()
96    }
97}