shrew_nn/
groupnorm.rs

1// GroupNorm — Group Normalization
2//
3// Group Normalization divides channels into groups and normalizes within each
4// group. It's a generalization of LayerNorm and InstanceNorm:
5//
6//   - GroupNorm(1, C)   = LayerNorm over channels
7//   - GroupNorm(C, C)   = InstanceNorm
8//   - GroupNorm(G, C)   = normalize each of G groups of C/G channels
9//
10// Unlike BatchNorm, GroupNorm is independent of batch size, making it
11// ideal for small-batch or single-sample training (e.g., object detection).
12//
13// SHAPES:
14//   Input:  [N, C, *] (any spatial dimensions after channels)
15//   Output: [N, C, *] (same shape)
16//   weight: [C], bias: [C]
17//   num_groups must divide C evenly.
18
19use shrew_core::backend::Backend;
20use shrew_core::dtype::DType;
21use shrew_core::error::Result;
22use shrew_core::shape::Shape;
23use shrew_core::tensor::Tensor;
24
25use crate::module::Module;
26
27/// Group Normalization layer.
28///
29/// # Examples
30/// ```ignore
31/// let gn = GroupNorm::<CpuBackend>::new(8, 32, 1e-5, DType::F64, &dev)?;
32/// let x = CpuTensor::rand((2, 32, 16, 16), DType::F64, &dev)?;
33/// let y = gn.forward(&x)?; // [2, 32, 16, 16]
34/// ```
35pub struct GroupNorm<B: Backend> {
36    weight: Tensor<B>,
37    bias: Tensor<B>,
38    num_groups: usize,
39    num_channels: usize,
40    eps: f64,
41}
42
43impl<B: Backend> GroupNorm<B> {
44    /// Create a new GroupNorm layer.
45    ///
46    /// # Arguments
47    /// - `num_groups`: number of groups to divide channels into
48    /// - `num_channels`: total number of channels (must be divisible by num_groups)
49    /// - `eps`: numerical stability constant
50    /// - `dtype`: data type
51    /// - `device`: compute device
52    pub fn new(
53        num_groups: usize,
54        num_channels: usize,
55        eps: f64,
56        dtype: DType,
57        device: &B::Device,
58    ) -> Result<Self> {
59        if !num_channels.is_multiple_of(num_groups) {
60            return Err(shrew_core::Error::msg(format!(
61                "GroupNorm: num_channels ({}) must be divisible by num_groups ({})",
62                num_channels, num_groups
63            )));
64        }
65        let weight = Tensor::<B>::ones(num_channels, dtype, device)?.set_variable();
66        let bias = Tensor::<B>::zeros(num_channels, dtype, device)?.set_variable();
67        Ok(GroupNorm {
68            weight,
69            bias,
70            num_groups,
71            num_channels,
72            eps,
73        })
74    }
75
76    pub fn num_groups(&self) -> usize {
77        self.num_groups
78    }
79    pub fn num_channels(&self) -> usize {
80        self.num_channels
81    }
82}
83
84impl<B: Backend> Module<B> for GroupNorm<B> {
85    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
86        let dims = x.dims();
87        if dims.len() < 2 {
88            return Err(shrew_core::Error::msg(
89                "GroupNorm: input must be at least 2D [N, C, ...]",
90            ));
91        }
92        let n = dims[0];
93        let c = dims[1];
94        if c != self.num_channels {
95            return Err(shrew_core::Error::msg(format!(
96                "GroupNorm: expected {} channels, got {}",
97                self.num_channels, c
98            )));
99        }
100        let channels_per_group = c / self.num_groups;
101
102        // Flatten spatial dims: [N, C, *] → product of spatial dims
103        let spatial: usize = dims[2..].iter().product();
104        let group_size = channels_per_group * spatial;
105
106        // Reshape to [N, G, channels_per_group * spatial]
107        let x_flat = x.reshape(Shape::new(vec![n, self.num_groups, group_size]))?;
108
109        // Mean and var within each group
110        let mu = x_flat.mean(2, true)?; // [N, G, 1]
111        let centered = x_flat.sub(&mu)?;
112        let var = centered.square()?.mean(2, true)?; // [N, G, 1]
113        let std = var.affine(1.0, self.eps)?.sqrt()?;
114        let x_norm = centered.div(&std)?;
115
116        // Reshape back to original shape
117        let x_norm = x_norm.reshape(Shape::new(dims.to_vec()))?;
118
119        // Scale and shift: build weight/bias to match [1, C, 1, 1, ...]
120        let mut w_shape = vec![1usize; dims.len()];
121        w_shape[1] = c;
122        let gamma = self.weight.reshape(Shape::new(w_shape.clone()))?;
123        let beta = self.bias.reshape(Shape::new(w_shape))?;
124
125        x_norm.mul(&gamma)?.add(&beta)
126    }
127
128    fn parameters(&self) -> Vec<Tensor<B>> {
129        vec![self.weight.clone(), self.bias.clone()]
130    }
131
132    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
133        vec![
134            ("weight".to_string(), self.weight.clone()),
135            ("bias".to_string(), self.bias.clone()),
136        ]
137    }
138}