shrew_nn/
groupnorm.rs

1// GroupNorm — Group Normalization
2//
3// Group Normalization divides channels into groups and normalizes within each
4// group. It's a generalization of LayerNorm and InstanceNorm:
5//
6//   - GroupNorm(1, C)   = LayerNorm over channels
7//   - GroupNorm(C, C)   = InstanceNorm
8//   - GroupNorm(G, C)   = normalize each of G groups of C/G channels
9//
10// Unlike BatchNorm, GroupNorm is independent of batch size, making it
11// ideal for small-batch or single-sample training (e.g., object detection).
12//
13// SHAPES:
14//   Input:  [N, C, *] (any spatial dimensions after channels)
15//   Output: [N, C, *] (same shape)
16//   weight: [C], bias: [C]
17//   num_groups must divide C evenly.
18
19use shrew_core::backend::Backend;
20use shrew_core::dtype::DType;
21use shrew_core::error::Result;
22use shrew_core::shape::Shape;
23use shrew_core::tensor::Tensor;
24
25use crate::module::Module;
26
27/// Group Normalization layer.
28///
29/// # Examples
30/// ```ignore
31/// let gn = GroupNorm::<CpuBackend>::new(8, 32, 1e-5, DType::F64, &dev)?;
32/// let x = CpuTensor::rand((2, 32, 16, 16), DType::F64, &dev)?;
33/// let y = gn.forward(&x)?; // [2, 32, 16, 16]
34/// ```
35pub struct GroupNorm<B: Backend> {
36    weight: Tensor<B>,
37    bias: Tensor<B>,
38    num_groups: usize,
39    num_channels: usize,
40    eps: f64,
41}
42
43impl<B: Backend> GroupNorm<B> {
44    /// Create a new GroupNorm layer.
45    ///
46    /// # Arguments
47    /// - `num_groups`: number of groups to divide channels into
48    /// - `num_channels`: total number of channels (must be divisible by num_groups)
49    /// - `eps`: numerical stability constant
50    /// - `dtype`: data type
51    /// - `device`: compute device
52    pub fn new(
53        num_groups: usize,
54        num_channels: usize,
55        eps: f64,
56        dtype: DType,
57        device: &B::Device,
58    ) -> Result<Self> {
59        #[allow(clippy::manual_is_multiple_of)]
60        if num_channels % num_groups != 0 {
61            return Err(shrew_core::Error::msg(format!(
62                "GroupNorm: num_channels ({}) must be divisible by num_groups ({})",
63                num_channels, num_groups
64            )));
65        }
66        let weight = Tensor::<B>::ones(num_channels, dtype, device)?.set_variable();
67        let bias = Tensor::<B>::zeros(num_channels, dtype, device)?.set_variable();
68        Ok(GroupNorm {
69            weight,
70            bias,
71            num_groups,
72            num_channels,
73            eps,
74        })
75    }
76
77    pub fn num_groups(&self) -> usize {
78        self.num_groups
79    }
80    pub fn num_channels(&self) -> usize {
81        self.num_channels
82    }
83}
84
85impl<B: Backend> Module<B> for GroupNorm<B> {
86    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
87        let dims = x.dims();
88        if dims.len() < 2 {
89            return Err(shrew_core::Error::msg(
90                "GroupNorm: input must be at least 2D [N, C, ...]",
91            ));
92        }
93        let n = dims[0];
94        let c = dims[1];
95        if c != self.num_channels {
96            return Err(shrew_core::Error::msg(format!(
97                "GroupNorm: expected {} channels, got {}",
98                self.num_channels, c
99            )));
100        }
101        let channels_per_group = c / self.num_groups;
102
103        // Flatten spatial dims: [N, C, *] → product of spatial dims
104        let spatial: usize = dims[2..].iter().product();
105        let group_size = channels_per_group * spatial;
106
107        // Reshape to [N, G, channels_per_group * spatial]
108        let x_flat = x.reshape(Shape::new(vec![n, self.num_groups, group_size]))?;
109
110        // Mean and var within each group
111        let mu = x_flat.mean(2, true)?; // [N, G, 1]
112        let centered = x_flat.sub(&mu)?;
113        let var = centered.square()?.mean(2, true)?; // [N, G, 1]
114        let std = var.affine(1.0, self.eps)?.sqrt()?;
115        let x_norm = centered.div(&std)?;
116
117        // Reshape back to original shape
118        let x_norm = x_norm.reshape(Shape::new(dims.to_vec()))?;
119
120        // Scale and shift: build weight/bias to match [1, C, 1, 1, ...]
121        let mut w_shape = vec![1usize; dims.len()];
122        w_shape[1] = c;
123        let gamma = self.weight.reshape(Shape::new(w_shape.clone()))?;
124        let beta = self.bias.reshape(Shape::new(w_shape))?;
125
126        x_norm.mul(&gamma)?.add(&beta)
127    }
128
129    fn parameters(&self) -> Vec<Tensor<B>> {
130        vec![self.weight.clone(), self.bias.clone()]
131    }
132
133    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
134        vec![
135            ("weight".to_string(), self.weight.clone()),
136            ("bias".to_string(), self.bias.clone()),
137        ]
138    }
139}