shrew_nn/attention.rs
1// Multi-Head Attention — The core mechanism of the Transformer
2//
3// Multi-Head Attention (MHA) allows the model to jointly attend to information
4// from different representation subspaces at different positions.
5//
6// INTUITION:
7//
8// Imagine you're reading a sentence: "The cat sat on the mat."
9// For the word "sat", attention answers: "Which other words are relevant?"
10// - The query ("sat") asks: "What am I looking for?"
11// - The keys (all words) answer: "What do I contain?"
12// - The values (all words) say: "Here's what I provide."
13//
14// The attention score between query q and key k is: score = q · k / √d_k
15// High score = this key is relevant to this query.
16//
17// MULTIPLE HEADS:
18//
19// Instead of one big attention, we split into h heads:
20// 1. Project Q, K, V into h smaller subspaces (d_head = d_model / h)
21// 2. Each head computes attention independently
22// 3. Concatenate results and project back to d_model
23//
24// This lets different heads learn different types of relationships:
25// - Head 1 might learn syntactic dependencies
26// - Head 2 might learn semantic similarity
27// - Head 3 might learn positional patterns
28//
29// MATHEMATICS:
30//
31// Input: x of shape [batch, seq_len, d_model]
32//
33// 1. Q = x @ W_Q [batch, seq, d_model]
34// K = x @ W_K [batch, seq, d_model]
35// V = x @ W_V [batch, seq, d_model]
36//
37// 2. Reshape to heads: [batch, seq, h, d_head] → [batch, h, seq, d_head]
38//
39// 3. Attention per head:
40// scores = Q @ K^T / √d_head [batch, h, seq, seq]
41// weights = softmax(scores, dim=-1)
42// out = weights @ V [batch, h, seq, d_head]
43//
44// 4. Concatenate heads: [batch, seq, h * d_head] = [batch, seq, d_model]
45//
46// 5. Output projection: out @ W_O [batch, seq, d_model]
47//
48// CAUSAL MASK (for autoregressive models like GPT):
49//
50// To prevent attending to future tokens, we add a mask before softmax:
51// scores[i, j] = -∞ if j > i (position j is after position i)
52// This makes softmax(scores[i, j]) = 0 for future positions.
53
54use shrew_core::backend::Backend;
55use shrew_core::dtype::DType;
56use shrew_core::error::Result;
57use shrew_core::tensor::Tensor;
58
59use crate::linear::Linear;
60use crate::module::Module;
61
62/// Multi-Head Self-Attention module.
63///
64/// # Examples
65/// ```ignore
66/// let attn = MultiHeadAttention::<CpuBackend>::new(512, 8, DType::F64, &dev)?;
67/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
68/// let y = attn.forward(&x)?; // [2, 10, 512]
69/// ```
70pub struct MultiHeadAttention<B: Backend> {
71 /// Number of attention heads
72 num_heads: usize,
73 /// Dimension per head: d_head = d_model / num_heads
74 head_dim: usize,
75 /// Total model dimension
76 d_model: usize,
77 /// Query projection: [d_model, d_model]
78 w_q: Linear<B>,
79 /// Key projection: [d_model, d_model]
80 w_k: Linear<B>,
81 /// Value projection: [d_model, d_model]
82 w_v: Linear<B>,
83 /// Output projection: [d_model, d_model]
84 w_o: Linear<B>,
85 /// Scaling factor: 1/√d_head
86 scale: f64,
87 /// Whether to apply causal mask (for autoregressive models)
88 causal: bool,
89}
90
91impl<B: Backend> MultiHeadAttention<B> {
92 /// Create a new Multi-Head Attention module.
93 ///
94 /// # Arguments
95 /// - `d_model`: total model dimension (must be divisible by num_heads)
96 /// - `num_heads`: number of attention heads
97 /// - `dtype`: data type for parameters
98 /// - `device`: device to create parameters on
99 pub fn new(d_model: usize, num_heads: usize, dtype: DType, device: &B::Device) -> Result<Self> {
100 #[allow(clippy::manual_is_multiple_of)]
101 if d_model % num_heads != 0 {
102 return Err(shrew_core::Error::msg(format!(
103 "d_model ({}) must be divisible by num_heads ({})",
104 d_model, num_heads
105 )));
106 }
107 let head_dim = d_model / num_heads;
108
109 let w_q = Linear::new(d_model, d_model, false, dtype, device)?;
110 let w_k = Linear::new(d_model, d_model, false, dtype, device)?;
111 let w_v = Linear::new(d_model, d_model, false, dtype, device)?;
112 let w_o = Linear::new(d_model, d_model, false, dtype, device)?;
113
114 Ok(MultiHeadAttention {
115 num_heads,
116 head_dim,
117 d_model,
118 w_q,
119 w_k,
120 w_v,
121 w_o,
122 scale: 1.0 / (head_dim as f64).sqrt(),
123 causal: false,
124 })
125 }
126
127 /// Enable causal (autoregressive) masking.
128 pub fn with_causal(mut self, causal: bool) -> Self {
129 self.causal = causal;
130 self
131 }
132
133 pub fn num_heads(&self) -> usize {
134 self.num_heads
135 }
136
137 pub fn d_model(&self) -> usize {
138 self.d_model
139 }
140
141 pub fn head_dim(&self) -> usize {
142 self.head_dim
143 }
144
145 /// Reshape [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
146 /// then transpose to [batch, num_heads, seq, head_dim]
147 fn reshape_to_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
148 // [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
149 let reshaped = x.reshape((batch, seq, self.num_heads, self.head_dim))?;
150 // Transpose dims 1 and 2: [batch, num_heads, seq, head_dim]
151 reshaped.transpose(1, 2)?.contiguous()
152 }
153
154 /// Inverse of reshape_to_heads:
155 /// [batch, num_heads, seq, head_dim] → [batch, seq, d_model]
156 fn reshape_from_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
157 // [batch, num_heads, seq, head_dim] → [batch, seq, num_heads, head_dim]
158 let transposed = x.transpose(1, 2)?.contiguous()?;
159 // [batch, seq, num_heads * head_dim] = [batch, seq, d_model]
160 transposed.reshape((batch, seq, self.d_model))
161 }
162
163 /// Create a causal mask: upper-triangular matrix of -infinity.
164 /// Shape: [seq, seq] where mask[i][j] = -1e9 if j > i, else 0
165 fn create_causal_mask(
166 &self,
167 seq_len: usize,
168 dtype: DType,
169 device: &B::Device,
170 ) -> Result<Tensor<B>> {
171 let mut mask_data = vec![0.0f64; seq_len * seq_len];
172 for i in 0..seq_len {
173 for j in (i + 1)..seq_len {
174 mask_data[i * seq_len + j] = -1e9; // Large negative ≈ -infinity
175 }
176 }
177 Tensor::<B>::from_f64_slice(&mask_data, (seq_len, seq_len), dtype, device)
178 }
179}
180
181impl<B: Backend> Module<B> for MultiHeadAttention<B> {
182 /// Forward pass: self-attention on input x.
183 ///
184 /// Input: [batch, seq_len, d_model]
185 /// Output: [batch, seq_len, d_model]
186 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
187 let dims = x.dims();
188 if dims.len() != 3 {
189 return Err(shrew_core::Error::msg(format!(
190 "MultiHeadAttention expects 3D input [batch, seq, d_model], got {:?}",
191 dims
192 )));
193 }
194 let batch = dims[0];
195 let seq = dims[1];
196
197 // Step 1: Project to Q, K, V
198 // Reshape x from [batch, seq, d_model] to [batch*seq, d_model] for Linear
199 let x_2d = x.reshape((batch * seq, self.d_model))?;
200 let q_2d = self.w_q.forward(&x_2d)?; // [batch*seq, d_model]
201 let k_2d = self.w_k.forward(&x_2d)?;
202 let v_2d = self.w_v.forward(&x_2d)?;
203
204 // Reshape back to [batch, seq, d_model]
205 let q = q_2d.reshape((batch, seq, self.d_model))?;
206 let k = k_2d.reshape((batch, seq, self.d_model))?;
207 let v = v_2d.reshape((batch, seq, self.d_model))?;
208
209 // Step 2: Split into heads
210 // [batch, num_heads, seq, head_dim]
211 let q = self.reshape_to_heads(&q, batch, seq)?;
212 let k = self.reshape_to_heads(&k, batch, seq)?;
213 let v = self.reshape_to_heads(&v, batch, seq)?;
214
215 // Step 3: Compute attention scores
216 // scores = Q @ K^T / √d_k
217 // Q: [batch, h, seq, d_head], K^T: [batch, h, d_head, seq]
218 // scores: [batch, h, seq, seq]
219 let k_t = k.transpose(2, 3)?.contiguous()?;
220 let scores = q.matmul(&k_t)?.affine(self.scale, 0.0)?;
221
222 // Step 4: Apply causal mask (optional)
223 let scores = if self.causal {
224 let mask = self.create_causal_mask(seq, x.dtype(), x.device())?;
225 // mask is [seq, seq], scores are [batch, h, seq, seq]
226 // Broadcasting handles the batch and head dimensions
227 scores.add(&mask)?
228 } else {
229 scores
230 };
231
232 // Step 5: Softmax over last dimension (key positions)
233 let attn_weights = scores.softmax(3)?; // [batch, h, seq, seq]
234
235 // Step 6: Weighted sum of values
236 // [batch, h, seq, seq] @ [batch, h, seq, d_head] = [batch, h, seq, d_head]
237 let attn_output = attn_weights.matmul(&v)?;
238
239 // Step 7: Concatenate heads
240 // [batch, h, seq, d_head] → [batch, seq, d_model]
241 let concat = self.reshape_from_heads(&attn_output, batch, seq)?;
242
243 // Step 8: Output projection
244 let concat_2d = concat.reshape((batch * seq, self.d_model))?;
245 let output_2d = self.w_o.forward(&concat_2d)?;
246 output_2d.reshape((batch, seq, self.d_model))
247 }
248
249 fn parameters(&self) -> Vec<Tensor<B>> {
250 let mut params = Vec::new();
251 params.extend(self.w_q.parameters());
252 params.extend(self.w_k.parameters());
253 params.extend(self.w_v.parameters());
254 params.extend(self.w_o.parameters());
255 params
256 }
257
258 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
259 let mut named = Vec::new();
260 for (k, v) in self.w_q.named_parameters() {
261 named.push((format!("w_q.{k}"), v));
262 }
263 for (k, v) in self.w_k.named_parameters() {
264 named.push((format!("w_k.{k}"), v));
265 }
266 for (k, v) in self.w_v.named_parameters() {
267 named.push((format!("w_v.{k}"), v));
268 }
269 for (k, v) in self.w_o.named_parameters() {
270 named.push((format!("w_o.{k}"), v));
271 }
272 named
273 }
274}