1use shrew_core::backend::Backend;
20use shrew_core::dtype::DType;
21use shrew_core::error::Result;
22use shrew_core::shape::Shape;
23use shrew_core::tensor::Tensor;
24
25use crate::module::Module;
26
27pub struct GroupNorm<B: Backend> {
36 weight: Tensor<B>,
37 bias: Tensor<B>,
38 num_groups: usize,
39 num_channels: usize,
40 eps: f64,
41}
42
43impl<B: Backend> GroupNorm<B> {
44 pub fn new(
53 num_groups: usize,
54 num_channels: usize,
55 eps: f64,
56 dtype: DType,
57 device: &B::Device,
58 ) -> Result<Self> {
59 #[allow(clippy::manual_is_multiple_of)]
60 if num_channels % num_groups != 0 {
61 return Err(shrew_core::Error::msg(format!(
62 "GroupNorm: num_channels ({}) must be divisible by num_groups ({})",
63 num_channels, num_groups
64 )));
65 }
66 let weight = Tensor::<B>::ones(num_channels, dtype, device)?.set_variable();
67 let bias = Tensor::<B>::zeros(num_channels, dtype, device)?.set_variable();
68 Ok(GroupNorm {
69 weight,
70 bias,
71 num_groups,
72 num_channels,
73 eps,
74 })
75 }
76
77 pub fn num_groups(&self) -> usize {
78 self.num_groups
79 }
80 pub fn num_channels(&self) -> usize {
81 self.num_channels
82 }
83}
84
85impl<B: Backend> Module<B> for GroupNorm<B> {
86 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
87 let dims = x.dims();
88 if dims.len() < 2 {
89 return Err(shrew_core::Error::msg(
90 "GroupNorm: input must be at least 2D [N, C, ...]",
91 ));
92 }
93 let n = dims[0];
94 let c = dims[1];
95 if c != self.num_channels {
96 return Err(shrew_core::Error::msg(format!(
97 "GroupNorm: expected {} channels, got {}",
98 self.num_channels, c
99 )));
100 }
101 let channels_per_group = c / self.num_groups;
102
103 let spatial: usize = dims[2..].iter().product();
105 let group_size = channels_per_group * spatial;
106
107 let x_flat = x.reshape(Shape::new(vec![n, self.num_groups, group_size]))?;
109
110 let mu = x_flat.mean(2, true)?; let centered = x_flat.sub(&mu)?;
113 let var = centered.square()?.mean(2, true)?; let std = var.affine(1.0, self.eps)?.sqrt()?;
115 let x_norm = centered.div(&std)?;
116
117 let x_norm = x_norm.reshape(Shape::new(dims.to_vec()))?;
119
120 let mut w_shape = vec![1usize; dims.len()];
122 w_shape[1] = c;
123 let gamma = self.weight.reshape(Shape::new(w_shape.clone()))?;
124 let beta = self.bias.reshape(Shape::new(w_shape))?;
125
126 x_norm.mul(&gamma)?.add(&beta)
127 }
128
129 fn parameters(&self) -> Vec<Tensor<B>> {
130 vec![self.weight.clone(), self.bias.clone()]
131 }
132
133 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
134 vec![
135 ("weight".to_string(), self.weight.clone()),
136 ("bias".to_string(), self.bias.clone()),
137 ]
138 }
139}