shrew_nn/layernorm.rs
1// LayerNorm — Layer Normalization
2//
3// Layer Normalization normalizes the activations WITHIN each sample
4// (across the feature dimension), independent of the batch.
5//
6// FORMULA:
7// y = (x - mean(x)) / sqrt(var(x) + ε) * γ + β
8//
9// Where:
10// - mean(x) and var(x) are computed over the last `normalized_shape` dims
11// - γ (gamma/weight) and β (beta/bias) are learnable parameters
12// - ε is a small constant for numerical stability (default 1e-5)
13//
14// WHY LayerNorm?
15//
16// In transformers, LayerNorm is used instead of BatchNorm because:
17// 1. It normalizes per-sample, not per-batch → works with variable batch sizes
18// 2. It doesn't need running statistics at inference time
19// 3. It's invariant to the batch composition
20//
21// Every transformer layer applies LayerNorm:
22// x = x + Attention(LayerNorm(x)) ← pre-norm style (GPT-2, LLaMA)
23// x = LayerNorm(x + Attention(x)) ← post-norm style (original Transformer)
24//
25// SHAPES:
26// Input: [*, normalized_shape] (e.g., [batch, seq_len, d_model])
27// Output: same shape as input
28// γ, β: [normalized_shape] (e.g., [d_model])
29//
30// The normalization is applied over the last `len(normalized_shape)` dimensions.
31// For a typical transformer with d_model=512: LayerNorm(512) normalizes the
32// last dimension of each [batch, seq_len, 512] tensor.
33
34use shrew_core::backend::Backend;
35use shrew_core::dtype::DType;
36use shrew_core::error::Result;
37use shrew_core::tensor::Tensor;
38
39use crate::module::Module;
40
41/// Layer Normalization: normalizes over the last N dimensions.
42///
43/// # Example
44/// ```ignore
45/// let ln = LayerNorm::<CpuBackend>::new(512, 1e-5, DType::F64, &dev)?;
46/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
47/// let y = ln.forward(&x)?; // same shape, normalized
48/// ```
49pub struct LayerNorm<B: Backend> {
50 /// Learnable scale parameter γ: [normalized_size]
51 weight: Tensor<B>,
52 /// Learnable shift parameter β: [normalized_size]
53 bias: Tensor<B>,
54 /// Size of the last dimension to normalize over.
55 normalized_size: usize,
56 /// Small constant for numerical stability.
57 eps: f64,
58}
59
60impl<B: Backend> LayerNorm<B> {
61 /// Create a new LayerNorm layer.
62 ///
63 /// # Arguments
64 /// - `normalized_size`: size of the last dimension to normalize
65 /// - `eps`: numerical stability constant (typically 1e-5)
66 /// - `dtype`: data type for parameters
67 /// - `device`: device to create parameters on
68 pub fn new(normalized_size: usize, eps: f64, dtype: DType, device: &B::Device) -> Result<Self> {
69 // γ initialized to 1 (identity scale)
70 let weight = Tensor::<B>::ones(normalized_size, dtype, device)?.set_variable();
71 // β initialized to 0 (no shift)
72 let bias = Tensor::<B>::zeros(normalized_size, dtype, device)?.set_variable();
73
74 Ok(LayerNorm {
75 weight,
76 bias,
77 normalized_size,
78 eps,
79 })
80 }
81
82 /// Create from existing weight and bias tensors.
83 pub fn from_tensors(weight: Tensor<B>, bias: Tensor<B>, eps: f64) -> Result<Self> {
84 let normalized_size = weight.elem_count();
85 Ok(LayerNorm {
86 weight: weight.set_variable(),
87 bias: bias.set_variable(),
88 normalized_size,
89 eps,
90 })
91 }
92
93 pub fn eps(&self) -> f64 {
94 self.eps
95 }
96
97 pub fn normalized_size(&self) -> usize {
98 self.normalized_size
99 }
100}
101
102impl<B: Backend> Module<B> for LayerNorm<B> {
103 /// Forward pass: normalize over last dimension, then scale + shift.
104 ///
105 /// For input [batch, seq, d_model]:
106 /// 1. mean = mean(x, dim=-1, keepdim=true) → [batch, seq, 1]
107 /// 2. var = var(x, dim=-1, keepdim=true) → [batch, seq, 1]
108 /// 3. x_norm = (x - mean) / sqrt(var + eps) → [batch, seq, d_model]
109 /// 4. output = x_norm * γ + β → [batch, seq, d_model]
110 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
111 let rank = x.rank();
112 if rank == 0 {
113 return Err(shrew_core::Error::msg(
114 "LayerNorm: input must have at least 1 dimension",
115 ));
116 }
117 let last_dim = rank - 1;
118
119 // Step 1: Compute mean over last dimension
120 let mu = x.mean(last_dim, true)?; // [..., 1]
121
122 // Step 2: Compute variance over last dimension
123 // var(x) = mean((x - mean)²)
124 let centered = x.sub(&mu)?; // broadcasting: [..., D] - [..., 1] → [..., D]
125 let sq = centered.square()?;
126 let variance = sq.mean(last_dim, true)?; // [..., 1]
127
128 // Step 3: Normalize
129 // x_norm = centered / sqrt(var + eps)
130 let std = variance.affine(1.0, self.eps)?.sqrt()?; // [..., 1]
131 let x_norm = centered.div(&std)?; // broadcasting
132
133 // Step 4: Scale and shift
134 // Output = x_norm * γ + β
135 // γ and β are shape [D], x_norm is [..., D]
136 // Broadcasting handles this: [D] broadcasts to [..., D]
137 let output = x_norm.mul(&self.weight)?.add(&self.bias)?;
138
139 Ok(output)
140 }
141
142 fn parameters(&self) -> Vec<Tensor<B>> {
143 vec![self.weight.clone(), self.bias.clone()]
144 }
145
146 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
147 vec![
148 ("weight".to_string(), self.weight.clone()),
149 ("bias".to_string(), self.bias.clone()),
150 ]
151 }
152}