shrew_nn/batchnorm.rs
1// BatchNorm2d — 2D Batch Normalization
2//
3// Batch Normalization normalizes activations ACROSS the batch for each channel,
4// stabilizing and accelerating training of deep convolutional networks.
5//
6// FORMULA (training mode):
7// x_hat = (x - mean_batch) / sqrt(var_batch + ε)
8// y = γ * x_hat + β
9//
10// Where mean_batch and var_batch are computed per-channel over (N, H, W).
11//
12// RUNNING STATISTICS:
13// During training, we maintain exponential moving averages:
14// running_mean = (1 - momentum) * running_mean + momentum * mean_batch
15// running_var = (1 - momentum) * running_var + momentum * var_batch
16//
17// During eval, we use running_mean/running_var instead of batch stats.
18//
19// SHAPES:
20// Input: [N, C, H, W]
21// Output: [N, C, H, W] (same shape)
22// γ, β: [C]
23//
24// WHY BatchNorm2d?
25//
26// In CNNs, BatchNorm dramatically helps training:
27// 1. Reduces internal covariate shift
28// 2. Allows higher learning rates
29// 3. Acts as a regularizer (reduces need for dropout)
30// 4. Makes networks less sensitive to weight initialization
31
32use shrew_core::backend::Backend;
33use shrew_core::dtype::DType;
34use shrew_core::error::Result;
35use shrew_core::shape::Shape;
36use shrew_core::tensor::Tensor;
37
38use crate::module::Module;
39
40/// 2D Batch Normalization layer for convolutional networks.
41///
42/// Normalizes each channel across the batch: for input `[N, C, H, W]`,
43/// mean and variance are computed over `(N, H, W)` for each of `C` channels.
44///
45/// # Examples
46/// ```ignore
47/// let bn = BatchNorm2d::<CpuBackend>::new(16, 1e-5, 0.1, DType::F64, &dev)?;
48/// let x: [batch, 16, H, W] tensor
49/// let y = bn.forward(&x)?; // normalized, same shape
50/// ```
51pub struct BatchNorm2d<B: Backend> {
52 /// Learnable scale (gamma): [C]
53 weight: Tensor<B>,
54 /// Learnable shift (beta): [C]
55 bias: Tensor<B>,
56 /// Running mean (not trainable): [C]
57 running_mean: std::cell::RefCell<Vec<f64>>,
58 /// Running variance (not trainable): [C]
59 running_var: std::cell::RefCell<Vec<f64>>,
60 /// Number of channels.
61 num_features: usize,
62 /// Numerical stability constant.
63 eps: f64,
64 /// Momentum for running statistics update.
65 momentum: f64,
66 /// Whether we're in training mode (use batch stats) or eval (use running stats).
67 training: std::cell::Cell<bool>,
68}
69
70impl<B: Backend> BatchNorm2d<B> {
71 /// Create a new BatchNorm2d layer.
72 ///
73 /// # Arguments
74 /// - `num_features`: number of channels (C)
75 /// - `eps`: numerical stability constant (typically 1e-5)
76 /// - `momentum`: EMA momentum for running stats (typically 0.1)
77 /// - `dtype`: data type for learnable parameters
78 /// - `device`: device
79 pub fn new(
80 num_features: usize,
81 eps: f64,
82 momentum: f64,
83 dtype: DType,
84 device: &B::Device,
85 ) -> Result<Self> {
86 // γ initialized to 1 (scale)
87 let weight =
88 Tensor::<B>::ones(Shape::new(vec![num_features]), dtype, device)?.set_variable();
89 // β initialized to 0 (shift)
90 let bias =
91 Tensor::<B>::zeros(Shape::new(vec![num_features]), dtype, device)?.set_variable();
92
93 Ok(BatchNorm2d {
94 weight,
95 bias,
96 running_mean: std::cell::RefCell::new(vec![0.0; num_features]),
97 running_var: std::cell::RefCell::new(vec![1.0; num_features]),
98 num_features,
99 eps,
100 momentum,
101 training: std::cell::Cell::new(true),
102 })
103 }
104
105 /// Set training mode (use batch statistics).
106 pub fn train(&self) {
107 self.training.set(true);
108 }
109
110 /// Set evaluation mode (use running statistics).
111 pub fn eval(&self) {
112 self.training.set(false);
113 }
114
115 /// Whether the module is in training mode.
116 pub fn is_training(&self) -> bool {
117 self.training.get()
118 }
119
120 pub fn num_features(&self) -> usize {
121 self.num_features
122 }
123
124 pub fn eps(&self) -> f64 {
125 self.eps
126 }
127
128 pub fn weight(&self) -> &Tensor<B> {
129 &self.weight
130 }
131
132 pub fn bias(&self) -> &Tensor<B> {
133 &self.bias
134 }
135
136 /// Create from existing weight and bias tensors (for executor/model loading).
137 /// Initializes running stats to mean=0, var=1.
138 pub fn from_tensors(weight: Tensor<B>, bias: Tensor<B>, eps: f64) -> Result<Self> {
139 let num_features = weight.elem_count();
140 Ok(BatchNorm2d {
141 weight: weight.set_variable(),
142 bias: bias.set_variable(),
143 running_mean: std::cell::RefCell::new(vec![0.0; num_features]),
144 running_var: std::cell::RefCell::new(vec![1.0; num_features]),
145 num_features,
146 eps,
147 momentum: 0.1,
148 training: std::cell::Cell::new(true),
149 })
150 }
151}
152
153impl<B: Backend> Module<B> for BatchNorm2d<B> {
154 /// Forward pass: batch-normalize each channel.
155 ///
156 /// Training: use batch mean/var, update running stats.
157 /// Eval: use running mean/var.
158 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
159 if x.rank() != 4 {
160 return Err(shrew_core::Error::msg(format!(
161 "BatchNorm2d: expected 4D input [N,C,H,W], got rank {}",
162 x.rank()
163 )));
164 }
165
166 let dims = x.dims();
167 let (n, c, h, w) = (dims[0], dims[1], dims[2], dims[3]);
168
169 if c != self.num_features {
170 return Err(shrew_core::Error::msg(format!(
171 "BatchNorm2d: expected {} channels, got {}",
172 self.num_features, c
173 )));
174 }
175
176 if self.training.get() {
177 // Compute per-channel mean and variance entirely on-device.
178 // Reshape [N,C,H,W] → [N,C,H*W] so we can reduce over dim=0 and dim=2.
179 let x_flat = x.reshape(Shape::new(vec![n, c, h * w]))?;
180
181 // Mean over N and spatial: first mean over dim=2 → [N,C], then mean over dim=0 → [C]
182 let mean_spatial = x_flat.mean(2, false)?; // [N, C]
183 let mean_batch = mean_spatial.mean(0, false)?; // [C]
184
185 // Variance: E[(x - mean)^2] — compute on device
186 let mean_bcast = mean_batch.reshape(Shape::new(vec![1, c, 1, 1]))?;
187 let diff = x.sub(&mean_bcast)?;
188 let diff_sq = diff.mul(&diff)?;
189 let diff_sq_flat = diff_sq.reshape(Shape::new(vec![n, c, h * w]))?;
190 let var_spatial = diff_sq_flat.mean(2, false)?; // [N, C]
191 let var_batch = var_spatial.mean(0, false)?; // [C]
192
193 // Update running statistics (small host transfer — only C floats)
194 {
195 let mean_host = mean_batch.to_f64_vec()?;
196 let var_host = var_batch.to_f64_vec()?;
197 let mut rm = self.running_mean.borrow_mut();
198 let mut rv = self.running_var.borrow_mut();
199 for ci in 0..c {
200 rm[ci] = (1.0 - self.momentum) * rm[ci] + self.momentum * mean_host[ci];
201 rv[ci] = (1.0 - self.momentum) * rv[ci] + self.momentum * var_host[ci];
202 }
203 }
204
205 self.apply_norm_tensors(x, &mean_batch, &var_batch, c)
206 } else {
207 // Eval mode: use running stats
208 let rm = self.running_mean.borrow();
209 let rv = self.running_var.borrow();
210 let mean_t =
211 Tensor::<B>::from_f64_slice(&rm, Shape::new(vec![c]), x.dtype(), x.device())?;
212 let var_t =
213 Tensor::<B>::from_f64_slice(&rv, Shape::new(vec![c]), x.dtype(), x.device())?;
214 self.apply_norm_tensors(x, &mean_t, &var_t, c)
215 }
216 }
217
218 fn parameters(&self) -> Vec<Tensor<B>> {
219 vec![self.weight.clone(), self.bias.clone()]
220 }
221
222 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
223 vec![
224 ("weight".to_string(), self.weight.clone()),
225 ("bias".to_string(), self.bias.clone()),
226 ]
227 }
228
229 fn set_training(&self, training: bool) {
230 self.training.set(training);
231 }
232
233 fn is_training(&self) -> bool {
234 self.training.get()
235 }
236}
237
238impl<B: Backend> BatchNorm2d<B> {
239 /// Apply normalization using on-device mean/var tensors (shape [C]).
240 ///
241 /// All computation stays on the device — no host round-trip.
242 /// x_hat = (x - mean[1,C,1,1]) * invstd[1,C,1,1]
243 /// y = gamma[1,C,1,1] * x_hat + beta[1,C,1,1]
244 fn apply_norm_tensors(
245 &self,
246 x: &Tensor<B>,
247 mean: &Tensor<B>,
248 var: &Tensor<B>,
249 c: usize,
250 ) -> Result<Tensor<B>> {
251 // Broadcast shapes: [C] → [1, C, 1, 1]
252 let mean_b = mean.reshape(Shape::new(vec![1, c, 1, 1]))?;
253 let var_b = var.reshape(Shape::new(vec![1, c, 1, 1]))?;
254
255 // invstd = 1 / sqrt(var + eps)
256 let var_eps = var_b.affine(1.0, self.eps)?; // var + eps
257 let invstd = var_eps.sqrt()?.powf(-1.0)?; // 1/sqrt(var+eps)
258
259 // x_hat = (x - mean) * invstd
260 let x_hat = x.sub(&mean_b)?.mul(&invstd)?;
261
262 // y = gamma * x_hat + beta
263 let gamma = self.weight.reshape(Shape::new(vec![1, c, 1, 1]))?;
264 let beta = self.bias.reshape(Shape::new(vec![1, c, 1, 1]))?;
265 x_hat.mul(&gamma)?.add(&beta)
266 }
267}