shrew_nn/
attention.rs

1// Multi-Head Attention — The core mechanism of the Transformer
2//
3// Multi-Head Attention (MHA) allows the model to jointly attend to information
4// from different representation subspaces at different positions.
5//
6// INTUITION:
7//
8// Imagine you're reading a sentence: "The cat sat on the mat."
9// For the word "sat", attention answers: "Which other words are relevant?"
10//   - The query ("sat") asks: "What am I looking for?"
11//   - The keys (all words) answer: "What do I contain?"
12//   - The values (all words) say: "Here's what I provide."
13//
14// The attention score between query q and key k is: score = q · k / √d_k
15// High score = this key is relevant to this query.
16//
17// MULTIPLE HEADS:
18//
19// Instead of one big attention, we split into h heads:
20//   1. Project Q, K, V into h smaller subspaces (d_head = d_model / h)
21//   2. Each head computes attention independently
22//   3. Concatenate results and project back to d_model
23//
24// This lets different heads learn different types of relationships:
25//   - Head 1 might learn syntactic dependencies
26//   - Head 2 might learn semantic similarity
27//   - Head 3 might learn positional patterns
28//
29// MATHEMATICS:
30//
31//   Input: x of shape [batch, seq_len, d_model]
32//
33//   1. Q = x @ W_Q    [batch, seq, d_model]
34//      K = x @ W_K    [batch, seq, d_model]
35//      V = x @ W_V    [batch, seq, d_model]
36//
37//   2. Reshape to heads: [batch, seq, h, d_head] → [batch, h, seq, d_head]
38//
39//   3. Attention per head:
40//      scores = Q @ K^T / √d_head     [batch, h, seq, seq]
41//      weights = softmax(scores, dim=-1)
42//      out = weights @ V               [batch, h, seq, d_head]
43//
44//   4. Concatenate heads: [batch, seq, h * d_head] = [batch, seq, d_model]
45//
46//   5. Output projection: out @ W_O    [batch, seq, d_model]
47//
48// CAUSAL MASK (for autoregressive models like GPT):
49//
50//   To prevent attending to future tokens, we add a mask before softmax:
51//   scores[i, j] = -∞  if j > i  (position j is after position i)
52//   This makes softmax(scores[i, j]) = 0 for future positions.
53
54use shrew_core::backend::Backend;
55use shrew_core::dtype::DType;
56use shrew_core::error::Result;
57use shrew_core::tensor::Tensor;
58
59use crate::linear::Linear;
60use crate::module::Module;
61
62/// Multi-Head Self-Attention module.
63///
64/// # Examples
65/// ```ignore
66/// let attn = MultiHeadAttention::<CpuBackend>::new(512, 8, DType::F64, &dev)?;
67/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
68/// let y = attn.forward(&x)?; // [2, 10, 512]
69/// ```
70pub struct MultiHeadAttention<B: Backend> {
71    /// Number of attention heads
72    num_heads: usize,
73    /// Dimension per head: d_head = d_model / num_heads
74    head_dim: usize,
75    /// Total model dimension
76    d_model: usize,
77    /// Query projection: [d_model, d_model]
78    w_q: Linear<B>,
79    /// Key projection: [d_model, d_model]
80    w_k: Linear<B>,
81    /// Value projection: [d_model, d_model]
82    w_v: Linear<B>,
83    /// Output projection: [d_model, d_model]
84    w_o: Linear<B>,
85    /// Scaling factor: 1/√d_head
86    scale: f64,
87    /// Whether to apply causal mask (for autoregressive models)
88    causal: bool,
89}
90
91impl<B: Backend> MultiHeadAttention<B> {
92    /// Create a new Multi-Head Attention module.
93    ///
94    /// # Arguments
95    /// - `d_model`: total model dimension (must be divisible by num_heads)
96    /// - `num_heads`: number of attention heads
97    /// - `dtype`: data type for parameters
98    /// - `device`: device to create parameters on
99    pub fn new(d_model: usize, num_heads: usize, dtype: DType, device: &B::Device) -> Result<Self> {
100        #[allow(clippy::manual_is_multiple_of)]
101        if d_model % num_heads != 0 {
102            return Err(shrew_core::Error::msg(format!(
103                "d_model ({}) must be divisible by num_heads ({})",
104                d_model, num_heads
105            )));
106        }
107        let head_dim = d_model / num_heads;
108
109        let w_q = Linear::new(d_model, d_model, false, dtype, device)?;
110        let w_k = Linear::new(d_model, d_model, false, dtype, device)?;
111        let w_v = Linear::new(d_model, d_model, false, dtype, device)?;
112        let w_o = Linear::new(d_model, d_model, false, dtype, device)?;
113
114        Ok(MultiHeadAttention {
115            num_heads,
116            head_dim,
117            d_model,
118            w_q,
119            w_k,
120            w_v,
121            w_o,
122            scale: 1.0 / (head_dim as f64).sqrt(),
123            causal: false,
124        })
125    }
126
127    /// Enable causal (autoregressive) masking.
128    pub fn with_causal(mut self, causal: bool) -> Self {
129        self.causal = causal;
130        self
131    }
132
133    pub fn num_heads(&self) -> usize {
134        self.num_heads
135    }
136
137    pub fn d_model(&self) -> usize {
138        self.d_model
139    }
140
141    pub fn head_dim(&self) -> usize {
142        self.head_dim
143    }
144
145    /// Reshape [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
146    /// then transpose to [batch, num_heads, seq, head_dim]
147    fn reshape_to_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
148        // [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
149        let reshaped = x.reshape((batch, seq, self.num_heads, self.head_dim))?;
150        // Transpose dims 1 and 2: [batch, num_heads, seq, head_dim]
151        reshaped.transpose(1, 2)?.contiguous()
152    }
153
154    /// Inverse of reshape_to_heads:
155    /// [batch, num_heads, seq, head_dim] → [batch, seq, d_model]
156    fn reshape_from_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
157        // [batch, num_heads, seq, head_dim] → [batch, seq, num_heads, head_dim]
158        let transposed = x.transpose(1, 2)?.contiguous()?;
159        // [batch, seq, num_heads * head_dim] = [batch, seq, d_model]
160        transposed.reshape((batch, seq, self.d_model))
161    }
162
163    /// Create a causal mask: upper-triangular matrix of -infinity.
164    /// Shape: [seq, seq] where mask[i][j] = -1e9 if j > i, else 0
165    fn create_causal_mask(
166        &self,
167        seq_len: usize,
168        dtype: DType,
169        device: &B::Device,
170    ) -> Result<Tensor<B>> {
171        let mut mask_data = vec![0.0f64; seq_len * seq_len];
172        for i in 0..seq_len {
173            for j in (i + 1)..seq_len {
174                mask_data[i * seq_len + j] = -1e9; // Large negative ≈ -infinity
175            }
176        }
177        Tensor::<B>::from_f64_slice(&mask_data, (seq_len, seq_len), dtype, device)
178    }
179}
180
181impl<B: Backend> Module<B> for MultiHeadAttention<B> {
182    /// Forward pass: self-attention on input x.
183    ///
184    /// Input:  [batch, seq_len, d_model]
185    /// Output: [batch, seq_len, d_model]
186    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
187        let dims = x.dims();
188        if dims.len() != 3 {
189            return Err(shrew_core::Error::msg(format!(
190                "MultiHeadAttention expects 3D input [batch, seq, d_model], got {:?}",
191                dims
192            )));
193        }
194        let batch = dims[0];
195        let seq = dims[1];
196
197        // Step 1: Project to Q, K, V
198        // Reshape x from [batch, seq, d_model] to [batch*seq, d_model] for Linear
199        let x_2d = x.reshape((batch * seq, self.d_model))?;
200        let q_2d = self.w_q.forward(&x_2d)?; // [batch*seq, d_model]
201        let k_2d = self.w_k.forward(&x_2d)?;
202        let v_2d = self.w_v.forward(&x_2d)?;
203
204        // Reshape back to [batch, seq, d_model]
205        let q = q_2d.reshape((batch, seq, self.d_model))?;
206        let k = k_2d.reshape((batch, seq, self.d_model))?;
207        let v = v_2d.reshape((batch, seq, self.d_model))?;
208
209        // Step 2: Split into heads
210        // [batch, num_heads, seq, head_dim]
211        let q = self.reshape_to_heads(&q, batch, seq)?;
212        let k = self.reshape_to_heads(&k, batch, seq)?;
213        let v = self.reshape_to_heads(&v, batch, seq)?;
214
215        // Step 3: Compute attention scores
216        // scores = Q @ K^T / √d_k
217        // Q: [batch, h, seq, d_head], K^T: [batch, h, d_head, seq]
218        // scores: [batch, h, seq, seq]
219        let k_t = k.transpose(2, 3)?.contiguous()?;
220        let scores = q.matmul(&k_t)?.affine(self.scale, 0.0)?;
221
222        // Step 4: Apply causal mask (optional)
223        let scores = if self.causal {
224            let mask = self.create_causal_mask(seq, x.dtype(), x.device())?;
225            // mask is [seq, seq], scores are [batch, h, seq, seq]
226            // Broadcasting handles the batch and head dimensions
227            scores.add(&mask)?
228        } else {
229            scores
230        };
231
232        // Step 5: Softmax over last dimension (key positions)
233        let attn_weights = scores.softmax(3)?; // [batch, h, seq, seq]
234
235        // Step 6: Weighted sum of values
236        // [batch, h, seq, seq] @ [batch, h, seq, d_head] = [batch, h, seq, d_head]
237        let attn_output = attn_weights.matmul(&v)?;
238
239        // Step 7: Concatenate heads
240        // [batch, h, seq, d_head] → [batch, seq, d_model]
241        let concat = self.reshape_from_heads(&attn_output, batch, seq)?;
242
243        // Step 8: Output projection
244        let concat_2d = concat.reshape((batch * seq, self.d_model))?;
245        let output_2d = self.w_o.forward(&concat_2d)?;
246        output_2d.reshape((batch, seq, self.d_model))
247    }
248
249    fn parameters(&self) -> Vec<Tensor<B>> {
250        let mut params = Vec::new();
251        params.extend(self.w_q.parameters());
252        params.extend(self.w_k.parameters());
253        params.extend(self.w_v.parameters());
254        params.extend(self.w_o.parameters());
255        params
256    }
257
258    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
259        let mut named = Vec::new();
260        for (k, v) in self.w_q.named_parameters() {
261            named.push((format!("w_q.{k}"), v));
262        }
263        for (k, v) in self.w_k.named_parameters() {
264            named.push((format!("w_k.{k}"), v));
265        }
266        for (k, v) in self.w_v.named_parameters() {
267            named.push((format!("w_v.{k}"), v));
268        }
269        for (k, v) in self.w_o.named_parameters() {
270            named.push((format!("w_o.{k}"), v));
271        }
272        named
273    }
274}