shrew_nn/transformer.rs
1// TransformerBlock — One layer of the Transformer architecture
2//
3// A TransformerBlock is the fundamental building block of models like
4// GPT, BERT, LLaMA, and virtually all modern language models.
5//
6// ARCHITECTURE (Pre-Norm style, used in GPT-2, LLaMA, etc.):
7//
8// ┌───────────────────────────────┐
9// │ Input: x [batch, seq, d_model]│
10// └───────────────┬───────────────┘
11// │
12// ┌────────┴────────┐
13// │ LayerNorm 1 │
14// │ ↓ │
15// │ MHA │ ← Multi-Head Self-Attention
16// │ ↓ │
17// └────────┬────────┘
18// │ + x ← Residual connection
19// │
20// ┌────────┴────────┐
21// │ LayerNorm 2 │
22// │ ↓ │
23// │ FFN │ ← Feed-Forward Network
24// │ ↓ │
25// └────────┬────────┘
26// │ + x ← Residual connection
27// │
28// ┌───────────────┴───────────────┐
29// │ Output [batch, seq, d_model] │
30// └───────────────────────────────┘
31//
32// FEED-FORWARD NETWORK (FFN):
33//
34// FFN(x) = Linear2(GELU(Linear1(x)))
35// Linear1: d_model → d_ff (expand, typically 4×d_model)
36// Linear2: d_ff → d_model (compress back)
37//
38// The FFN processes each position independently (same Transform applied to
39// every token). It's where much of the "knowledge" is stored in the model.
40//
41// RESIDUAL CONNECTIONS:
42//
43// output = x + sublayer(x)
44//
45// These residual connections are crucial for training deep networks:
46// - They provide gradient highways (gradients flow directly through the + x)
47// - They allow the sublayer to learn a "correction" rather than the full mapping
48// - Without them, transformers with >6 layers would be very hard to train
49//
50// WHY PRE-NORM?
51//
52// Pre-norm (LayerNorm before attention/FFN) is more stable for training than
53// post-norm (LayerNorm after). Most modern models use pre-norm.
54
55use shrew_core::backend::Backend;
56use shrew_core::dtype::DType;
57use shrew_core::error::Result;
58use shrew_core::tensor::Tensor;
59
60use crate::attention::MultiHeadAttention;
61use crate::layernorm::LayerNorm;
62use crate::linear::Linear;
63use crate::module::Module;
64
65/// A single Transformer block (pre-norm style).
66///
67/// Contains:
68/// - Self-attention with multi-head attention
69/// - Feed-forward network (two linear layers with GELU)
70/// - Two layer normalizations
71/// - Residual connections around both sub-layers
72pub struct TransformerBlock<B: Backend> {
73 /// LayerNorm before attention
74 ln1: LayerNorm<B>,
75 /// Multi-head self-attention
76 attn: MultiHeadAttention<B>,
77 /// LayerNorm before FFN
78 ln2: LayerNorm<B>,
79 /// FFN first linear: d_model → d_ff
80 ff1: Linear<B>,
81 /// FFN second linear: d_ff → d_model
82 ff2: Linear<B>,
83 /// Model dimension
84 d_model: usize,
85}
86
87impl<B: Backend> TransformerBlock<B> {
88 /// Create a new TransformerBlock.
89 ///
90 /// # Arguments
91 /// - `d_model`: model dimension (embedding size)
92 /// - `num_heads`: number of attention heads
93 /// - `d_ff`: feed-forward inner dimension (typically 4 * d_model)
94 /// - `causal`: whether to use causal (autoregressive) attention mask
95 /// - `dtype`: data type
96 /// - `device`: compute device
97 pub fn new(
98 d_model: usize,
99 num_heads: usize,
100 d_ff: usize,
101 causal: bool,
102 dtype: DType,
103 device: &B::Device,
104 ) -> Result<Self> {
105 let ln1 = LayerNorm::new(d_model, 1e-5, dtype, device)?;
106 let attn = MultiHeadAttention::new(d_model, num_heads, dtype, device)?.with_causal(causal);
107 let ln2 = LayerNorm::new(d_model, 1e-5, dtype, device)?;
108 let ff1 = Linear::new(d_model, d_ff, true, dtype, device)?;
109 let ff2 = Linear::new(d_ff, d_model, true, dtype, device)?;
110
111 Ok(TransformerBlock {
112 ln1,
113 attn,
114 ln2,
115 ff1,
116 ff2,
117 d_model,
118 })
119 }
120
121 pub fn d_model(&self) -> usize {
122 self.d_model
123 }
124
125 /// Forward pass through the FFN: Linear → GELU → Linear
126 fn ffn(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
127 // Reshape [batch, seq, d_model] → [batch*seq, d_model] for Linear
128 let x_2d = x.reshape((batch * seq, self.d_model))?;
129 let h = self.ff1.forward(&x_2d)?.gelu()?;
130 // h is [batch*seq, d_ff] — reshape back for ff2
131 let out = self.ff2.forward(&h)?;
132 out.reshape((batch, seq, self.d_model))
133 }
134}
135
136impl<B: Backend> Module<B> for TransformerBlock<B> {
137 /// Forward pass (pre-norm):
138 /// x = x + attention(layernorm(x))
139 /// x = x + ffn(layernorm(x))
140 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
141 let dims = x.dims();
142 if dims.len() != 3 {
143 return Err(shrew_core::Error::msg(format!(
144 "TransformerBlock expects [batch, seq, d_model], got {:?}",
145 dims
146 )));
147 }
148 let batch = dims[0];
149 let seq = dims[1];
150
151 // Sub-layer 1: Self-attention with residual
152 let normed1 = self.ln1.forward(x)?;
153 let attn_out = self.attn.forward(&normed1)?;
154 let x = x.add(&attn_out)?; // Residual connection
155
156 // Sub-layer 2: FFN with residual
157 let normed2 = self.ln2.forward(&x)?;
158 let ffn_out = self.ffn(&normed2, batch, seq)?;
159 x.add(&ffn_out) // Residual connection
160 }
161
162 fn parameters(&self) -> Vec<Tensor<B>> {
163 let mut params = Vec::new();
164 params.extend(self.ln1.parameters());
165 params.extend(self.attn.parameters());
166 params.extend(self.ln2.parameters());
167 params.extend(self.ff1.parameters());
168 params.extend(self.ff2.parameters());
169 params
170 }
171
172 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
173 let mut named = Vec::new();
174 for (k, v) in self.ln1.named_parameters() {
175 named.push((format!("ln1.{k}"), v));
176 }
177 for (k, v) in self.attn.named_parameters() {
178 named.push((format!("attn.{k}"), v));
179 }
180 for (k, v) in self.ln2.named_parameters() {
181 named.push((format!("ln2.{k}"), v));
182 }
183 for (k, v) in self.ff1.named_parameters() {
184 named.push((format!("ff1.{k}"), v));
185 }
186 for (k, v) in self.ff2.named_parameters() {
187 named.push((format!("ff2.{k}"), v));
188 }
189 named
190 }
191}