shrew_nn/attention.rs
1// Multi-Head Attention — The core mechanism of the Transformer
2//
3// Multi-Head Attention (MHA) allows the model to jointly attend to information
4// from different representation subspaces at different positions.
5//
6// INTUITION:
7//
8// Imagine you're reading a sentence: "The cat sat on the mat."
9// For the word "sat", attention answers: "Which other words are relevant?"
10// - The query ("sat") asks: "What am I looking for?"
11// - The keys (all words) answer: "What do I contain?"
12// - The values (all words) say: "Here's what I provide."
13//
14// The attention score between query q and key k is: score = q · k / √d_k
15// High score = this key is relevant to this query.
16//
17// MULTIPLE HEADS:
18//
19// Instead of one big attention, we split into h heads:
20// 1. Project Q, K, V into h smaller subspaces (d_head = d_model / h)
21// 2. Each head computes attention independently
22// 3. Concatenate results and project back to d_model
23//
24// This lets different heads learn different types of relationships:
25// - Head 1 might learn syntactic dependencies
26// - Head 2 might learn semantic similarity
27// - Head 3 might learn positional patterns
28//
29// MATHEMATICS:
30//
31// Input: x of shape [batch, seq_len, d_model]
32//
33// 1. Q = x @ W_Q [batch, seq, d_model]
34// K = x @ W_K [batch, seq, d_model]
35// V = x @ W_V [batch, seq, d_model]
36//
37// 2. Reshape to heads: [batch, seq, h, d_head] → [batch, h, seq, d_head]
38//
39// 3. Attention per head:
40// scores = Q @ K^T / √d_head [batch, h, seq, seq]
41// weights = softmax(scores, dim=-1)
42// out = weights @ V [batch, h, seq, d_head]
43//
44// 4. Concatenate heads: [batch, seq, h * d_head] = [batch, seq, d_model]
45//
46// 5. Output projection: out @ W_O [batch, seq, d_model]
47//
48// CAUSAL MASK (for autoregressive models like GPT):
49//
50// To prevent attending to future tokens, we add a mask before softmax:
51// scores[i, j] = -∞ if j > i (position j is after position i)
52// This makes softmax(scores[i, j]) = 0 for future positions.
53
54use shrew_core::backend::Backend;
55use shrew_core::dtype::DType;
56use shrew_core::error::Result;
57use shrew_core::tensor::Tensor;
58
59use crate::linear::Linear;
60use crate::module::Module;
61
62/// Multi-Head Self-Attention module.
63///
64/// # Examples
65/// ```ignore
66/// let attn = MultiHeadAttention::<CpuBackend>::new(512, 8, DType::F64, &dev)?;
67/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
68/// let y = attn.forward(&x)?; // [2, 10, 512]
69/// ```
70pub struct MultiHeadAttention<B: Backend> {
71 /// Number of attention heads
72 num_heads: usize,
73 /// Dimension per head: d_head = d_model / num_heads
74 head_dim: usize,
75 /// Total model dimension
76 d_model: usize,
77 /// Query projection: [d_model, d_model]
78 w_q: Linear<B>,
79 /// Key projection: [d_model, d_model]
80 w_k: Linear<B>,
81 /// Value projection: [d_model, d_model]
82 w_v: Linear<B>,
83 /// Output projection: [d_model, d_model]
84 w_o: Linear<B>,
85 /// Scaling factor: 1/√d_head
86 scale: f64,
87 /// Whether to apply causal mask (for autoregressive models)
88 causal: bool,
89}
90
91impl<B: Backend> MultiHeadAttention<B> {
92 /// Create a new Multi-Head Attention module.
93 ///
94 /// # Arguments
95 /// - `d_model`: total model dimension (must be divisible by num_heads)
96 /// - `num_heads`: number of attention heads
97 /// - `dtype`: data type for parameters
98 /// - `device`: device to create parameters on
99 pub fn new(d_model: usize, num_heads: usize, dtype: DType, device: &B::Device) -> Result<Self> {
100 if !d_model.is_multiple_of(num_heads) {
101 return Err(shrew_core::Error::msg(format!(
102 "d_model ({}) must be divisible by num_heads ({})",
103 d_model, num_heads
104 )));
105 }
106 let head_dim = d_model / num_heads;
107
108 let w_q = Linear::new(d_model, d_model, false, dtype, device)?;
109 let w_k = Linear::new(d_model, d_model, false, dtype, device)?;
110 let w_v = Linear::new(d_model, d_model, false, dtype, device)?;
111 let w_o = Linear::new(d_model, d_model, false, dtype, device)?;
112
113 Ok(MultiHeadAttention {
114 num_heads,
115 head_dim,
116 d_model,
117 w_q,
118 w_k,
119 w_v,
120 w_o,
121 scale: 1.0 / (head_dim as f64).sqrt(),
122 causal: false,
123 })
124 }
125
126 /// Enable causal (autoregressive) masking.
127 pub fn with_causal(mut self, causal: bool) -> Self {
128 self.causal = causal;
129 self
130 }
131
132 pub fn num_heads(&self) -> usize {
133 self.num_heads
134 }
135
136 pub fn d_model(&self) -> usize {
137 self.d_model
138 }
139
140 pub fn head_dim(&self) -> usize {
141 self.head_dim
142 }
143
144 /// Reshape [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
145 /// then transpose to [batch, num_heads, seq, head_dim]
146 fn reshape_to_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
147 // [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
148 let reshaped = x.reshape((batch, seq, self.num_heads, self.head_dim))?;
149 // Transpose dims 1 and 2: [batch, num_heads, seq, head_dim]
150 reshaped.transpose(1, 2)?.contiguous()
151 }
152
153 /// Inverse of reshape_to_heads:
154 /// [batch, num_heads, seq, head_dim] → [batch, seq, d_model]
155 fn reshape_from_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
156 // [batch, num_heads, seq, head_dim] → [batch, seq, num_heads, head_dim]
157 let transposed = x.transpose(1, 2)?.contiguous()?;
158 // [batch, seq, num_heads * head_dim] = [batch, seq, d_model]
159 transposed.reshape((batch, seq, self.d_model))
160 }
161
162 /// Create a causal mask: upper-triangular matrix of -infinity.
163 /// Shape: [seq, seq] where mask[i][j] = -1e9 if j > i, else 0
164 fn create_causal_mask(
165 &self,
166 seq_len: usize,
167 dtype: DType,
168 device: &B::Device,
169 ) -> Result<Tensor<B>> {
170 let mut mask_data = vec![0.0f64; seq_len * seq_len];
171 for i in 0..seq_len {
172 for j in (i + 1)..seq_len {
173 mask_data[i * seq_len + j] = -1e9; // Large negative ≈ -infinity
174 }
175 }
176 Tensor::<B>::from_f64_slice(&mask_data, (seq_len, seq_len), dtype, device)
177 }
178}
179
180impl<B: Backend> Module<B> for MultiHeadAttention<B> {
181 /// Forward pass: self-attention on input x.
182 ///
183 /// Input: [batch, seq_len, d_model]
184 /// Output: [batch, seq_len, d_model]
185 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
186 let dims = x.dims();
187 if dims.len() != 3 {
188 return Err(shrew_core::Error::msg(format!(
189 "MultiHeadAttention expects 3D input [batch, seq, d_model], got {:?}",
190 dims
191 )));
192 }
193 let batch = dims[0];
194 let seq = dims[1];
195
196 // Step 1: Project to Q, K, V
197 // Reshape x from [batch, seq, d_model] to [batch*seq, d_model] for Linear
198 let x_2d = x.reshape((batch * seq, self.d_model))?;
199 let q_2d = self.w_q.forward(&x_2d)?; // [batch*seq, d_model]
200 let k_2d = self.w_k.forward(&x_2d)?;
201 let v_2d = self.w_v.forward(&x_2d)?;
202
203 // Reshape back to [batch, seq, d_model]
204 let q = q_2d.reshape((batch, seq, self.d_model))?;
205 let k = k_2d.reshape((batch, seq, self.d_model))?;
206 let v = v_2d.reshape((batch, seq, self.d_model))?;
207
208 // Step 2: Split into heads
209 // [batch, num_heads, seq, head_dim]
210 let q = self.reshape_to_heads(&q, batch, seq)?;
211 let k = self.reshape_to_heads(&k, batch, seq)?;
212 let v = self.reshape_to_heads(&v, batch, seq)?;
213
214 // Step 3: Compute attention scores
215 // scores = Q @ K^T / √d_k
216 // Q: [batch, h, seq, d_head], K^T: [batch, h, d_head, seq]
217 // scores: [batch, h, seq, seq]
218 let k_t = k.transpose(2, 3)?.contiguous()?;
219 let scores = q.matmul(&k_t)?.affine(self.scale, 0.0)?;
220
221 // Step 4: Apply causal mask (optional)
222 let scores = if self.causal {
223 let mask = self.create_causal_mask(seq, x.dtype(), x.device())?;
224 // mask is [seq, seq], scores are [batch, h, seq, seq]
225 // Broadcasting handles the batch and head dimensions
226 scores.add(&mask)?
227 } else {
228 scores
229 };
230
231 // Step 5: Softmax over last dimension (key positions)
232 let attn_weights = scores.softmax(3)?; // [batch, h, seq, seq]
233
234 // Step 6: Weighted sum of values
235 // [batch, h, seq, seq] @ [batch, h, seq, d_head] = [batch, h, seq, d_head]
236 let attn_output = attn_weights.matmul(&v)?;
237
238 // Step 7: Concatenate heads
239 // [batch, h, seq, d_head] → [batch, seq, d_model]
240 let concat = self.reshape_from_heads(&attn_output, batch, seq)?;
241
242 // Step 8: Output projection
243 let concat_2d = concat.reshape((batch * seq, self.d_model))?;
244 let output_2d = self.w_o.forward(&concat_2d)?;
245 output_2d.reshape((batch, seq, self.d_model))
246 }
247
248 fn parameters(&self) -> Vec<Tensor<B>> {
249 let mut params = Vec::new();
250 params.extend(self.w_q.parameters());
251 params.extend(self.w_k.parameters());
252 params.extend(self.w_v.parameters());
253 params.extend(self.w_o.parameters());
254 params
255 }
256
257 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
258 let mut named = Vec::new();
259 for (k, v) in self.w_q.named_parameters() {
260 named.push((format!("w_q.{k}"), v));
261 }
262 for (k, v) in self.w_k.named_parameters() {
263 named.push((format!("w_k.{k}"), v));
264 }
265 for (k, v) in self.w_v.named_parameters() {
266 named.push((format!("w_v.{k}"), v));
267 }
268 for (k, v) in self.w_o.named_parameters() {
269 named.push((format!("w_o.{k}"), v));
270 }
271 named
272 }
273}