shrew_nn/
transformer.rs

1// TransformerBlock — One layer of the Transformer architecture
2//
3// A TransformerBlock is the fundamental building block of models like
4// GPT, BERT, LLaMA, and virtually all modern language models.
5//
6// ARCHITECTURE (Pre-Norm style, used in GPT-2, LLaMA, etc.):
7//
8//   ┌───────────────────────────────┐
9//   │ Input: x [batch, seq, d_model]│
10//   └───────────────┬───────────────┘
11//                   │
12//          ┌────────┴────────┐
13//          │   LayerNorm 1   │
14//          │   ↓             │
15//          │   MHA           │  ← Multi-Head Self-Attention
16//          │   ↓             │
17//          └────────┬────────┘
18//                   │ + x        ← Residual connection
19//                   │
20//          ┌────────┴────────┐
21//          │   LayerNorm 2   │
22//          │   ↓             │
23//          │   FFN           │  ← Feed-Forward Network
24//          │   ↓             │
25//          └────────┬────────┘
26//                   │ + x        ← Residual connection
27//                   │
28//   ┌───────────────┴───────────────┐
29//   │ Output [batch, seq, d_model]  │
30//   └───────────────────────────────┘
31//
32// FEED-FORWARD NETWORK (FFN):
33//
34//   FFN(x) = Linear2(GELU(Linear1(x)))
35//   Linear1: d_model → d_ff (expand, typically 4×d_model)
36//   Linear2: d_ff → d_model (compress back)
37//
38// The FFN processes each position independently (same Transform applied to
39// every token). It's where much of the "knowledge" is stored in the model.
40//
41// RESIDUAL CONNECTIONS:
42//
43//   output = x + sublayer(x)
44//
45// These residual connections are crucial for training deep networks:
46// - They provide gradient highways (gradients flow directly through the + x)
47// - They allow the sublayer to learn a "correction" rather than the full mapping
48// - Without them, transformers with >6 layers would be very hard to train
49//
50// WHY PRE-NORM?
51//
52// Pre-norm (LayerNorm before attention/FFN) is more stable for training than
53// post-norm (LayerNorm after). Most modern models use pre-norm.
54
55use shrew_core::backend::Backend;
56use shrew_core::dtype::DType;
57use shrew_core::error::Result;
58use shrew_core::tensor::Tensor;
59
60use crate::attention::MultiHeadAttention;
61use crate::layernorm::LayerNorm;
62use crate::linear::Linear;
63use crate::module::Module;
64
65/// A single Transformer block (pre-norm style).
66///
67/// Contains:
68/// - Self-attention with multi-head attention
69/// - Feed-forward network (two linear layers with GELU)
70/// - Two layer normalizations
71/// - Residual connections around both sub-layers
72pub struct TransformerBlock<B: Backend> {
73    /// LayerNorm before attention
74    ln1: LayerNorm<B>,
75    /// Multi-head self-attention
76    attn: MultiHeadAttention<B>,
77    /// LayerNorm before FFN
78    ln2: LayerNorm<B>,
79    /// FFN first linear: d_model → d_ff
80    ff1: Linear<B>,
81    /// FFN second linear: d_ff → d_model
82    ff2: Linear<B>,
83    /// Model dimension
84    d_model: usize,
85}
86
87impl<B: Backend> TransformerBlock<B> {
88    /// Create a new TransformerBlock.
89    ///
90    /// # Arguments
91    /// - `d_model`: model dimension (embedding size)
92    /// - `num_heads`: number of attention heads
93    /// - `d_ff`: feed-forward inner dimension (typically 4 * d_model)
94    /// - `causal`: whether to use causal (autoregressive) attention mask
95    /// - `dtype`: data type
96    /// - `device`: compute device
97    pub fn new(
98        d_model: usize,
99        num_heads: usize,
100        d_ff: usize,
101        causal: bool,
102        dtype: DType,
103        device: &B::Device,
104    ) -> Result<Self> {
105        let ln1 = LayerNorm::new(d_model, 1e-5, dtype, device)?;
106        let attn = MultiHeadAttention::new(d_model, num_heads, dtype, device)?.with_causal(causal);
107        let ln2 = LayerNorm::new(d_model, 1e-5, dtype, device)?;
108        let ff1 = Linear::new(d_model, d_ff, true, dtype, device)?;
109        let ff2 = Linear::new(d_ff, d_model, true, dtype, device)?;
110
111        Ok(TransformerBlock {
112            ln1,
113            attn,
114            ln2,
115            ff1,
116            ff2,
117            d_model,
118        })
119    }
120
121    pub fn d_model(&self) -> usize {
122        self.d_model
123    }
124
125    /// Forward pass through the FFN: Linear → GELU → Linear
126    fn ffn(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
127        // Reshape [batch, seq, d_model] → [batch*seq, d_model] for Linear
128        let x_2d = x.reshape((batch * seq, self.d_model))?;
129        let h = self.ff1.forward(&x_2d)?.gelu()?;
130        // h is [batch*seq, d_ff] — reshape back for ff2
131        let out = self.ff2.forward(&h)?;
132        out.reshape((batch, seq, self.d_model))
133    }
134}
135
136impl<B: Backend> Module<B> for TransformerBlock<B> {
137    /// Forward pass (pre-norm):
138    ///   x = x + attention(layernorm(x))
139    ///   x = x + ffn(layernorm(x))
140    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
141        let dims = x.dims();
142        if dims.len() != 3 {
143            return Err(shrew_core::Error::msg(format!(
144                "TransformerBlock expects [batch, seq, d_model], got {:?}",
145                dims
146            )));
147        }
148        let batch = dims[0];
149        let seq = dims[1];
150
151        // Sub-layer 1: Self-attention with residual
152        let normed1 = self.ln1.forward(x)?;
153        let attn_out = self.attn.forward(&normed1)?;
154        let x = x.add(&attn_out)?; // Residual connection
155
156        // Sub-layer 2: FFN with residual
157        let normed2 = self.ln2.forward(&x)?;
158        let ffn_out = self.ffn(&normed2, batch, seq)?;
159        x.add(&ffn_out) // Residual connection
160    }
161
162    fn parameters(&self) -> Vec<Tensor<B>> {
163        let mut params = Vec::new();
164        params.extend(self.ln1.parameters());
165        params.extend(self.attn.parameters());
166        params.extend(self.ln2.parameters());
167        params.extend(self.ff1.parameters());
168        params.extend(self.ff2.parameters());
169        params
170    }
171
172    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
173        let mut named = Vec::new();
174        for (k, v) in self.ln1.named_parameters() {
175            named.push((format!("ln1.{k}"), v));
176        }
177        for (k, v) in self.attn.named_parameters() {
178            named.push((format!("attn.{k}"), v));
179        }
180        for (k, v) in self.ln2.named_parameters() {
181            named.push((format!("ln2.{k}"), v));
182        }
183        for (k, v) in self.ff1.named_parameters() {
184            named.push((format!("ff1.{k}"), v));
185        }
186        for (k, v) in self.ff2.named_parameters() {
187            named.push((format!("ff2.{k}"), v));
188        }
189        named
190    }
191}