1use shrew_core::backend::Backend;
20use shrew_core::dtype::DType;
21use shrew_core::error::Result;
22use shrew_core::shape::Shape;
23use shrew_core::tensor::Tensor;
24
25use crate::module::Module;
26
27pub struct GroupNorm<B: Backend> {
36 weight: Tensor<B>,
37 bias: Tensor<B>,
38 num_groups: usize,
39 num_channels: usize,
40 eps: f64,
41}
42
43impl<B: Backend> GroupNorm<B> {
44 pub fn new(
53 num_groups: usize,
54 num_channels: usize,
55 eps: f64,
56 dtype: DType,
57 device: &B::Device,
58 ) -> Result<Self> {
59 if !num_channels.is_multiple_of(num_groups) {
60 return Err(shrew_core::Error::msg(format!(
61 "GroupNorm: num_channels ({}) must be divisible by num_groups ({})",
62 num_channels, num_groups
63 )));
64 }
65 let weight = Tensor::<B>::ones(num_channels, dtype, device)?.set_variable();
66 let bias = Tensor::<B>::zeros(num_channels, dtype, device)?.set_variable();
67 Ok(GroupNorm {
68 weight,
69 bias,
70 num_groups,
71 num_channels,
72 eps,
73 })
74 }
75
76 pub fn num_groups(&self) -> usize {
77 self.num_groups
78 }
79 pub fn num_channels(&self) -> usize {
80 self.num_channels
81 }
82}
83
84impl<B: Backend> Module<B> for GroupNorm<B> {
85 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
86 let dims = x.dims();
87 if dims.len() < 2 {
88 return Err(shrew_core::Error::msg(
89 "GroupNorm: input must be at least 2D [N, C, ...]",
90 ));
91 }
92 let n = dims[0];
93 let c = dims[1];
94 if c != self.num_channels {
95 return Err(shrew_core::Error::msg(format!(
96 "GroupNorm: expected {} channels, got {}",
97 self.num_channels, c
98 )));
99 }
100 let channels_per_group = c / self.num_groups;
101
102 let spatial: usize = dims[2..].iter().product();
104 let group_size = channels_per_group * spatial;
105
106 let x_flat = x.reshape(Shape::new(vec![n, self.num_groups, group_size]))?;
108
109 let mu = x_flat.mean(2, true)?; let centered = x_flat.sub(&mu)?;
112 let var = centered.square()?.mean(2, true)?; let std = var.affine(1.0, self.eps)?.sqrt()?;
114 let x_norm = centered.div(&std)?;
115
116 let x_norm = x_norm.reshape(Shape::new(dims.to_vec()))?;
118
119 let mut w_shape = vec![1usize; dims.len()];
121 w_shape[1] = c;
122 let gamma = self.weight.reshape(Shape::new(w_shape.clone()))?;
123 let beta = self.bias.reshape(Shape::new(w_shape))?;
124
125 x_norm.mul(&gamma)?.add(&beta)
126 }
127
128 fn parameters(&self) -> Vec<Tensor<B>> {
129 vec![self.weight.clone(), self.bias.clone()]
130 }
131
132 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
133 vec![
134 ("weight".to_string(), self.weight.clone()),
135 ("bias".to_string(), self.bias.clone()),
136 ]
137 }
138}