shrew_nn/rmsnorm.rs
1// RMSNorm — Root Mean Square Layer Normalization
2//
3// RMSNorm is a simplification of LayerNorm that removes the mean centering.
4// It only normalizes by the RMS (root mean square) of the values.
5//
6// FORMULA:
7// x_norm = x / sqrt(mean(x²) + ε)
8// y = γ * x_norm
9//
10// WHY RMSNorm?
11//
12// RMSNorm is used in modern LLMs (LLaMA, Mistral, Gemma) because:
13// 1. It's computationally simpler (no mean subtraction)
14// 2. Empirically comparable performance to LayerNorm
15// 3. Slightly faster due to fewer operations
16//
17// SHAPES:
18// Input: [*, normalized_size]
19// Output: same shape as input
20// γ: [normalized_size]
21
22use shrew_core::backend::Backend;
23use shrew_core::dtype::DType;
24use shrew_core::error::Result;
25use shrew_core::tensor::Tensor;
26
27use crate::module::Module;
28
29/// RMS Normalization layer (used in LLaMA, Mistral, etc.).
30///
31/// Normalizes by root-mean-square without mean centering:
32/// y = x / sqrt(mean(x²) + ε) * γ
33///
34/// # Example
35/// ```ignore
36/// let rms = RMSNorm::<CpuBackend>::new(512, 1e-5, DType::F64, &dev)?;
37/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
38/// let y = rms.forward(&x)?;
39/// ```
40pub struct RMSNorm<B: Backend> {
41 /// Learnable scale parameter γ: [normalized_size]
42 weight: Tensor<B>,
43 normalized_size: usize,
44 eps: f64,
45}
46
47impl<B: Backend> RMSNorm<B> {
48 /// Create a new RMSNorm layer.
49 pub fn new(normalized_size: usize, eps: f64, dtype: DType, device: &B::Device) -> Result<Self> {
50 let weight = Tensor::<B>::ones(normalized_size, dtype, device)?.set_variable();
51 Ok(RMSNorm {
52 weight,
53 normalized_size,
54 eps,
55 })
56 }
57
58 pub fn normalized_size(&self) -> usize {
59 self.normalized_size
60 }
61 pub fn eps(&self) -> f64 {
62 self.eps
63 }
64}
65
66impl<B: Backend> Module<B> for RMSNorm<B> {
67 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
68 let rank = x.rank();
69 if rank == 0 {
70 return Err(shrew_core::Error::msg(
71 "RMSNorm: input must have at least 1 dimension",
72 ));
73 }
74 let last_dim = rank - 1;
75
76 // RMS = sqrt(mean(x²) + eps)
77 let x_sq = x.square()?;
78 let mean_sq = x_sq.mean(last_dim, true)?; // [..., 1]
79 let rms = mean_sq.affine(1.0, self.eps)?.sqrt()?; // [..., 1]
80
81 // Normalize
82 let x_norm = x.div(&rms)?; // broadcasting
83
84 // Scale by γ
85 x_norm.mul(&self.weight)
86 }
87
88 fn parameters(&self) -> Vec<Tensor<B>> {
89 vec![self.weight.clone()]
90 }
91
92 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
93 vec![("weight".to_string(), self.weight.clone())]
94 }
95}