shrew_nn/
layernorm.rs

1// LayerNorm — Layer Normalization
2//
3// Layer Normalization normalizes the activations WITHIN each sample
4// (across the feature dimension), independent of the batch.
5//
6// FORMULA:
7//   y = (x - mean(x)) / sqrt(var(x) + ε) * γ + β
8//
9// Where:
10//   - mean(x) and var(x) are computed over the last `normalized_shape` dims
11//   - γ (gamma/weight) and β (beta/bias) are learnable parameters
12//   - ε is a small constant for numerical stability (default 1e-5)
13//
14// WHY LayerNorm?
15//
16// In transformers, LayerNorm is used instead of BatchNorm because:
17// 1. It normalizes per-sample, not per-batch → works with variable batch sizes
18// 2. It doesn't need running statistics at inference time
19// 3. It's invariant to the batch composition
20//
21// Every transformer layer applies LayerNorm:
22//   x = x + Attention(LayerNorm(x))    ← pre-norm style (GPT-2, LLaMA)
23//   x = LayerNorm(x + Attention(x))    ← post-norm style (original Transformer)
24//
25// SHAPES:
26//   Input:  [*, normalized_shape]  (e.g., [batch, seq_len, d_model])
27//   Output: same shape as input
28//   γ, β:  [normalized_shape]     (e.g., [d_model])
29//
30// The normalization is applied over the last `len(normalized_shape)` dimensions.
31// For a typical transformer with d_model=512: LayerNorm(512) normalizes the
32// last dimension of each [batch, seq_len, 512] tensor.
33
34use shrew_core::backend::Backend;
35use shrew_core::dtype::DType;
36use shrew_core::error::Result;
37use shrew_core::tensor::Tensor;
38
39use crate::module::Module;
40
41/// Layer Normalization: normalizes over the last N dimensions.
42///
43/// # Example
44/// ```ignore
45/// let ln = LayerNorm::<CpuBackend>::new(512, 1e-5, DType::F64, &dev)?;
46/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
47/// let y = ln.forward(&x)?; // same shape, normalized
48/// ```
49pub struct LayerNorm<B: Backend> {
50    /// Learnable scale parameter γ: [normalized_size]
51    weight: Tensor<B>,
52    /// Learnable shift parameter β: [normalized_size]
53    bias: Tensor<B>,
54    /// Size of the last dimension to normalize over.
55    normalized_size: usize,
56    /// Small constant for numerical stability.
57    eps: f64,
58}
59
60impl<B: Backend> LayerNorm<B> {
61    /// Create a new LayerNorm layer.
62    ///
63    /// # Arguments
64    /// - `normalized_size`: size of the last dimension to normalize
65    /// - `eps`: numerical stability constant (typically 1e-5)
66    /// - `dtype`: data type for parameters
67    /// - `device`: device to create parameters on
68    pub fn new(normalized_size: usize, eps: f64, dtype: DType, device: &B::Device) -> Result<Self> {
69        // γ initialized to 1 (identity scale)
70        let weight = Tensor::<B>::ones(normalized_size, dtype, device)?.set_variable();
71        // β initialized to 0 (no shift)
72        let bias = Tensor::<B>::zeros(normalized_size, dtype, device)?.set_variable();
73
74        Ok(LayerNorm {
75            weight,
76            bias,
77            normalized_size,
78            eps,
79        })
80    }
81
82    /// Create from existing weight and bias tensors.
83    pub fn from_tensors(weight: Tensor<B>, bias: Tensor<B>, eps: f64) -> Result<Self> {
84        let normalized_size = weight.elem_count();
85        Ok(LayerNorm {
86            weight: weight.set_variable(),
87            bias: bias.set_variable(),
88            normalized_size,
89            eps,
90        })
91    }
92
93    pub fn eps(&self) -> f64 {
94        self.eps
95    }
96
97    pub fn normalized_size(&self) -> usize {
98        self.normalized_size
99    }
100}
101
102impl<B: Backend> Module<B> for LayerNorm<B> {
103    /// Forward pass: normalize over last dimension, then scale + shift.
104    ///
105    /// For input [batch, seq, d_model]:
106    ///   1. mean = mean(x, dim=-1, keepdim=true)   → [batch, seq, 1]
107    ///   2. var  = var(x, dim=-1, keepdim=true)     → [batch, seq, 1]
108    ///   3. x_norm = (x - mean) / sqrt(var + eps)   → [batch, seq, d_model]
109    ///   4. output = x_norm * γ + β                 → [batch, seq, d_model]
110    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
111        let rank = x.rank();
112        if rank == 0 {
113            return Err(shrew_core::Error::msg(
114                "LayerNorm: input must have at least 1 dimension",
115            ));
116        }
117        let last_dim = rank - 1;
118
119        // Step 1: Compute mean over last dimension
120        let mu = x.mean(last_dim, true)?; // [..., 1]
121
122        // Step 2: Compute variance over last dimension
123        // var(x) = mean((x - mean)²)
124        let centered = x.sub(&mu)?; // broadcasting: [..., D] - [..., 1] → [..., D]
125        let sq = centered.square()?;
126        let variance = sq.mean(last_dim, true)?; // [..., 1]
127
128        // Step 3: Normalize
129        // x_norm = centered / sqrt(var + eps)
130        let std = variance.affine(1.0, self.eps)?.sqrt()?; // [..., 1]
131        let x_norm = centered.div(&std)?; // broadcasting
132
133        // Step 4: Scale and shift
134        // Output = x_norm * γ + β
135        // γ and β are shape [D], x_norm is [..., D]
136        // Broadcasting handles this: [D] broadcasts to [..., D]
137        let output = x_norm.mul(&self.weight)?.add(&self.bias)?;
138
139        Ok(output)
140    }
141
142    fn parameters(&self) -> Vec<Tensor<B>> {
143        vec![self.weight.clone(), self.bias.clone()]
144    }
145
146    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
147        vec![
148            ("weight".to_string(), self.weight.clone()),
149            ("bias".to_string(), self.bias.clone()),
150        ]
151    }
152}