shrew_nn/
rmsnorm.rs

1// RMSNorm — Root Mean Square Layer Normalization
2//
3// RMSNorm is a simplification of LayerNorm that removes the mean centering.
4// It only normalizes by the RMS (root mean square) of the values.
5//
6// FORMULA:
7//   x_norm = x / sqrt(mean(x²) + ε)
8//   y = γ * x_norm
9//
10// WHY RMSNorm?
11//
12// RMSNorm is used in modern LLMs (LLaMA, Mistral, Gemma) because:
13// 1. It's computationally simpler (no mean subtraction)
14// 2. Empirically comparable performance to LayerNorm
15// 3. Slightly faster due to fewer operations
16//
17// SHAPES:
18//   Input:  [*, normalized_size]
19//   Output: same shape as input
20//   γ:      [normalized_size]
21
22use shrew_core::backend::Backend;
23use shrew_core::dtype::DType;
24use shrew_core::error::Result;
25use shrew_core::tensor::Tensor;
26
27use crate::module::Module;
28
29/// RMS Normalization layer (used in LLaMA, Mistral, etc.).
30///
31/// Normalizes by root-mean-square without mean centering:
32///   y = x / sqrt(mean(x²) + ε) * γ
33///
34/// # Example
35/// ```ignore
36/// let rms = RMSNorm::<CpuBackend>::new(512, 1e-5, DType::F64, &dev)?;
37/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
38/// let y = rms.forward(&x)?;
39/// ```
40pub struct RMSNorm<B: Backend> {
41    /// Learnable scale parameter γ: [normalized_size]
42    weight: Tensor<B>,
43    normalized_size: usize,
44    eps: f64,
45}
46
47impl<B: Backend> RMSNorm<B> {
48    /// Create a new RMSNorm layer.
49    pub fn new(normalized_size: usize, eps: f64, dtype: DType, device: &B::Device) -> Result<Self> {
50        let weight = Tensor::<B>::ones(normalized_size, dtype, device)?.set_variable();
51        Ok(RMSNorm {
52            weight,
53            normalized_size,
54            eps,
55        })
56    }
57
58    pub fn normalized_size(&self) -> usize {
59        self.normalized_size
60    }
61    pub fn eps(&self) -> f64 {
62        self.eps
63    }
64}
65
66impl<B: Backend> Module<B> for RMSNorm<B> {
67    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
68        let rank = x.rank();
69        if rank == 0 {
70            return Err(shrew_core::Error::msg(
71                "RMSNorm: input must have at least 1 dimension",
72            ));
73        }
74        let last_dim = rank - 1;
75
76        // RMS = sqrt(mean(x²) + eps)
77        let x_sq = x.square()?;
78        let mean_sq = x_sq.mean(last_dim, true)?; // [..., 1]
79        let rms = mean_sq.affine(1.0, self.eps)?.sqrt()?; // [..., 1]
80
81        // Normalize
82        let x_norm = x.div(&rms)?; // broadcasting
83
84        // Scale by γ
85        x_norm.mul(&self.weight)
86    }
87
88    fn parameters(&self) -> Vec<Tensor<B>> {
89        vec![self.weight.clone()]
90    }
91
92    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
93        vec![("weight".to_string(), self.weight.clone())]
94    }
95}