shrew_nn/
batchnorm.rs

1// BatchNorm2d — 2D Batch Normalization
2//
3// Batch Normalization normalizes activations ACROSS the batch for each channel,
4// stabilizing and accelerating training of deep convolutional networks.
5//
6// FORMULA (training mode):
7//   x_hat = (x - mean_batch) / sqrt(var_batch + ε)
8//   y = γ * x_hat + β
9//
10// Where mean_batch and var_batch are computed per-channel over (N, H, W).
11//
12// RUNNING STATISTICS:
13//   During training, we maintain exponential moving averages:
14//     running_mean = (1 - momentum) * running_mean + momentum * mean_batch
15//     running_var  = (1 - momentum) * running_var  + momentum * var_batch
16//
17//   During eval, we use running_mean/running_var instead of batch stats.
18//
19// SHAPES:
20//   Input:  [N, C, H, W]
21//   Output: [N, C, H, W] (same shape)
22//   γ, β:  [C]
23//
24// WHY BatchNorm2d?
25//
26// In CNNs, BatchNorm dramatically helps training:
27// 1. Reduces internal covariate shift
28// 2. Allows higher learning rates
29// 3. Acts as a regularizer (reduces need for dropout)
30// 4. Makes networks less sensitive to weight initialization
31
32use shrew_core::backend::Backend;
33use shrew_core::dtype::DType;
34use shrew_core::error::Result;
35use shrew_core::shape::Shape;
36use shrew_core::tensor::Tensor;
37
38use crate::module::Module;
39
40/// 2D Batch Normalization layer for convolutional networks.
41///
42/// Normalizes each channel across the batch: for input `[N, C, H, W]`,
43/// mean and variance are computed over `(N, H, W)` for each of `C` channels.
44///
45/// # Examples
46/// ```ignore
47/// let bn = BatchNorm2d::<CpuBackend>::new(16, 1e-5, 0.1, DType::F64, &dev)?;
48/// let x: [batch, 16, H, W] tensor
49/// let y = bn.forward(&x)?; // normalized, same shape
50/// ```
51pub struct BatchNorm2d<B: Backend> {
52    /// Learnable scale (gamma): [C]
53    weight: Tensor<B>,
54    /// Learnable shift (beta): [C]
55    bias: Tensor<B>,
56    /// Running mean (not trainable): [C]
57    running_mean: std::cell::RefCell<Vec<f64>>,
58    /// Running variance (not trainable): [C]
59    running_var: std::cell::RefCell<Vec<f64>>,
60    /// Number of channels.
61    num_features: usize,
62    /// Numerical stability constant.
63    eps: f64,
64    /// Momentum for running statistics update.
65    momentum: f64,
66    /// Whether we're in training mode (use batch stats) or eval (use running stats).
67    training: std::cell::Cell<bool>,
68}
69
70impl<B: Backend> BatchNorm2d<B> {
71    /// Create a new BatchNorm2d layer.
72    ///
73    /// # Arguments
74    /// - `num_features`: number of channels (C)
75    /// - `eps`: numerical stability constant (typically 1e-5)
76    /// - `momentum`: EMA momentum for running stats (typically 0.1)
77    /// - `dtype`: data type for learnable parameters
78    /// - `device`: device
79    pub fn new(
80        num_features: usize,
81        eps: f64,
82        momentum: f64,
83        dtype: DType,
84        device: &B::Device,
85    ) -> Result<Self> {
86        // γ initialized to 1 (scale)
87        let weight =
88            Tensor::<B>::ones(Shape::new(vec![num_features]), dtype, device)?.set_variable();
89        // β initialized to 0 (shift)
90        let bias =
91            Tensor::<B>::zeros(Shape::new(vec![num_features]), dtype, device)?.set_variable();
92
93        Ok(BatchNorm2d {
94            weight,
95            bias,
96            running_mean: std::cell::RefCell::new(vec![0.0; num_features]),
97            running_var: std::cell::RefCell::new(vec![1.0; num_features]),
98            num_features,
99            eps,
100            momentum,
101            training: std::cell::Cell::new(true),
102        })
103    }
104
105    /// Set training mode (use batch statistics).
106    pub fn train(&self) {
107        self.training.set(true);
108    }
109
110    /// Set evaluation mode (use running statistics).
111    pub fn eval(&self) {
112        self.training.set(false);
113    }
114
115    /// Whether the module is in training mode.
116    pub fn is_training(&self) -> bool {
117        self.training.get()
118    }
119
120    pub fn num_features(&self) -> usize {
121        self.num_features
122    }
123
124    pub fn eps(&self) -> f64 {
125        self.eps
126    }
127
128    pub fn weight(&self) -> &Tensor<B> {
129        &self.weight
130    }
131
132    pub fn bias(&self) -> &Tensor<B> {
133        &self.bias
134    }
135
136    /// Create from existing weight and bias tensors (for executor/model loading).
137    /// Initializes running stats to mean=0, var=1.
138    pub fn from_tensors(weight: Tensor<B>, bias: Tensor<B>, eps: f64) -> Result<Self> {
139        let num_features = weight.elem_count();
140        Ok(BatchNorm2d {
141            weight: weight.set_variable(),
142            bias: bias.set_variable(),
143            running_mean: std::cell::RefCell::new(vec![0.0; num_features]),
144            running_var: std::cell::RefCell::new(vec![1.0; num_features]),
145            num_features,
146            eps,
147            momentum: 0.1,
148            training: std::cell::Cell::new(true),
149        })
150    }
151}
152
153impl<B: Backend> Module<B> for BatchNorm2d<B> {
154    /// Forward pass: batch-normalize each channel.
155    ///
156    /// Training:  use batch mean/var, update running stats.
157    /// Eval:      use running mean/var.
158    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
159        if x.rank() != 4 {
160            return Err(shrew_core::Error::msg(format!(
161                "BatchNorm2d: expected 4D input [N,C,H,W], got rank {}",
162                x.rank()
163            )));
164        }
165
166        let dims = x.dims();
167        let (n, c, h, w) = (dims[0], dims[1], dims[2], dims[3]);
168
169        if c != self.num_features {
170            return Err(shrew_core::Error::msg(format!(
171                "BatchNorm2d: expected {} channels, got {}",
172                self.num_features, c
173            )));
174        }
175
176        if self.training.get() {
177            // Compute per-channel mean and variance entirely on-device.
178            // Reshape [N,C,H,W] → [N,C,H*W] so we can reduce over dim=0 and dim=2.
179            let x_flat = x.reshape(Shape::new(vec![n, c, h * w]))?;
180
181            // Mean over N and spatial: first mean over dim=2 → [N,C], then mean over dim=0 → [C]
182            let mean_spatial = x_flat.mean(2, false)?; // [N, C]
183            let mean_batch = mean_spatial.mean(0, false)?; // [C]
184
185            // Variance: E[(x - mean)^2] — compute on device
186            let mean_bcast = mean_batch.reshape(Shape::new(vec![1, c, 1, 1]))?;
187            let diff = x.sub(&mean_bcast)?;
188            let diff_sq = diff.mul(&diff)?;
189            let diff_sq_flat = diff_sq.reshape(Shape::new(vec![n, c, h * w]))?;
190            let var_spatial = diff_sq_flat.mean(2, false)?; // [N, C]
191            let var_batch = var_spatial.mean(0, false)?; // [C]
192
193            // Update running statistics (small host transfer — only C floats)
194            {
195                let mean_host = mean_batch.to_f64_vec()?;
196                let var_host = var_batch.to_f64_vec()?;
197                let mut rm = self.running_mean.borrow_mut();
198                let mut rv = self.running_var.borrow_mut();
199                for ci in 0..c {
200                    rm[ci] = (1.0 - self.momentum) * rm[ci] + self.momentum * mean_host[ci];
201                    rv[ci] = (1.0 - self.momentum) * rv[ci] + self.momentum * var_host[ci];
202                }
203            }
204
205            self.apply_norm_tensors(x, &mean_batch, &var_batch, c)
206        } else {
207            // Eval mode: use running stats
208            let rm = self.running_mean.borrow();
209            let rv = self.running_var.borrow();
210            let mean_t =
211                Tensor::<B>::from_f64_slice(&rm, Shape::new(vec![c]), x.dtype(), x.device())?;
212            let var_t =
213                Tensor::<B>::from_f64_slice(&rv, Shape::new(vec![c]), x.dtype(), x.device())?;
214            self.apply_norm_tensors(x, &mean_t, &var_t, c)
215        }
216    }
217
218    fn parameters(&self) -> Vec<Tensor<B>> {
219        vec![self.weight.clone(), self.bias.clone()]
220    }
221
222    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
223        vec![
224            ("weight".to_string(), self.weight.clone()),
225            ("bias".to_string(), self.bias.clone()),
226        ]
227    }
228
229    fn set_training(&self, training: bool) {
230        self.training.set(training);
231    }
232
233    fn is_training(&self) -> bool {
234        self.training.get()
235    }
236}
237
238impl<B: Backend> BatchNorm2d<B> {
239    /// Apply normalization using on-device mean/var tensors (shape [C]).
240    ///
241    /// All computation stays on the device — no host round-trip.
242    ///   x_hat = (x - mean[1,C,1,1]) * invstd[1,C,1,1]
243    ///   y = gamma[1,C,1,1] * x_hat + beta[1,C,1,1]
244    fn apply_norm_tensors(
245        &self,
246        x: &Tensor<B>,
247        mean: &Tensor<B>,
248        var: &Tensor<B>,
249        c: usize,
250    ) -> Result<Tensor<B>> {
251        // Broadcast shapes: [C] → [1, C, 1, 1]
252        let mean_b = mean.reshape(Shape::new(vec![1, c, 1, 1]))?;
253        let var_b = var.reshape(Shape::new(vec![1, c, 1, 1]))?;
254
255        // invstd = 1 / sqrt(var + eps)
256        let var_eps = var_b.affine(1.0, self.eps)?; // var + eps
257        let invstd = var_eps.sqrt()?.powf(-1.0)?; // 1/sqrt(var+eps)
258
259        // x_hat = (x - mean) * invstd
260        let x_hat = x.sub(&mean_b)?.mul(&invstd)?;
261
262        // y = gamma * x_hat + beta
263        let gamma = self.weight.reshape(Shape::new(vec![1, c, 1, 1]))?;
264        let beta = self.bias.reshape(Shape::new(vec![1, c, 1, 1]))?;
265        x_hat.mul(&gamma)?.add(&beta)
266    }
267}