shrew_nn/
attention.rs

1// Multi-Head Attention — The core mechanism of the Transformer
2//
3// Multi-Head Attention (MHA) allows the model to jointly attend to information
4// from different representation subspaces at different positions.
5//
6// INTUITION:
7//
8// Imagine you're reading a sentence: "The cat sat on the mat."
9// For the word "sat", attention answers: "Which other words are relevant?"
10//   - The query ("sat") asks: "What am I looking for?"
11//   - The keys (all words) answer: "What do I contain?"
12//   - The values (all words) say: "Here's what I provide."
13//
14// The attention score between query q and key k is: score = q · k / √d_k
15// High score = this key is relevant to this query.
16//
17// MULTIPLE HEADS:
18//
19// Instead of one big attention, we split into h heads:
20//   1. Project Q, K, V into h smaller subspaces (d_head = d_model / h)
21//   2. Each head computes attention independently
22//   3. Concatenate results and project back to d_model
23//
24// This lets different heads learn different types of relationships:
25//   - Head 1 might learn syntactic dependencies
26//   - Head 2 might learn semantic similarity
27//   - Head 3 might learn positional patterns
28//
29// MATHEMATICS:
30//
31//   Input: x of shape [batch, seq_len, d_model]
32//
33//   1. Q = x @ W_Q    [batch, seq, d_model]
34//      K = x @ W_K    [batch, seq, d_model]
35//      V = x @ W_V    [batch, seq, d_model]
36//
37//   2. Reshape to heads: [batch, seq, h, d_head] → [batch, h, seq, d_head]
38//
39//   3. Attention per head:
40//      scores = Q @ K^T / √d_head     [batch, h, seq, seq]
41//      weights = softmax(scores, dim=-1)
42//      out = weights @ V               [batch, h, seq, d_head]
43//
44//   4. Concatenate heads: [batch, seq, h * d_head] = [batch, seq, d_model]
45//
46//   5. Output projection: out @ W_O    [batch, seq, d_model]
47//
48// CAUSAL MASK (for autoregressive models like GPT):
49//
50//   To prevent attending to future tokens, we add a mask before softmax:
51//   scores[i, j] = -∞  if j > i  (position j is after position i)
52//   This makes softmax(scores[i, j]) = 0 for future positions.
53
54use shrew_core::backend::Backend;
55use shrew_core::dtype::DType;
56use shrew_core::error::Result;
57use shrew_core::tensor::Tensor;
58
59use crate::linear::Linear;
60use crate::module::Module;
61
62/// Multi-Head Self-Attention module.
63///
64/// # Examples
65/// ```ignore
66/// let attn = MultiHeadAttention::<CpuBackend>::new(512, 8, DType::F64, &dev)?;
67/// let x = CpuTensor::rand((2, 10, 512), DType::F64, &dev)?;
68/// let y = attn.forward(&x)?; // [2, 10, 512]
69/// ```
70pub struct MultiHeadAttention<B: Backend> {
71    /// Number of attention heads
72    num_heads: usize,
73    /// Dimension per head: d_head = d_model / num_heads
74    head_dim: usize,
75    /// Total model dimension
76    d_model: usize,
77    /// Query projection: [d_model, d_model]
78    w_q: Linear<B>,
79    /// Key projection: [d_model, d_model]
80    w_k: Linear<B>,
81    /// Value projection: [d_model, d_model]
82    w_v: Linear<B>,
83    /// Output projection: [d_model, d_model]
84    w_o: Linear<B>,
85    /// Scaling factor: 1/√d_head
86    scale: f64,
87    /// Whether to apply causal mask (for autoregressive models)
88    causal: bool,
89}
90
91impl<B: Backend> MultiHeadAttention<B> {
92    /// Create a new Multi-Head Attention module.
93    ///
94    /// # Arguments
95    /// - `d_model`: total model dimension (must be divisible by num_heads)
96    /// - `num_heads`: number of attention heads
97    /// - `dtype`: data type for parameters
98    /// - `device`: device to create parameters on
99    pub fn new(d_model: usize, num_heads: usize, dtype: DType, device: &B::Device) -> Result<Self> {
100        if !d_model.is_multiple_of(num_heads) {
101            return Err(shrew_core::Error::msg(format!(
102                "d_model ({}) must be divisible by num_heads ({})",
103                d_model, num_heads
104            )));
105        }
106        let head_dim = d_model / num_heads;
107
108        let w_q = Linear::new(d_model, d_model, false, dtype, device)?;
109        let w_k = Linear::new(d_model, d_model, false, dtype, device)?;
110        let w_v = Linear::new(d_model, d_model, false, dtype, device)?;
111        let w_o = Linear::new(d_model, d_model, false, dtype, device)?;
112
113        Ok(MultiHeadAttention {
114            num_heads,
115            head_dim,
116            d_model,
117            w_q,
118            w_k,
119            w_v,
120            w_o,
121            scale: 1.0 / (head_dim as f64).sqrt(),
122            causal: false,
123        })
124    }
125
126    /// Enable causal (autoregressive) masking.
127    pub fn with_causal(mut self, causal: bool) -> Self {
128        self.causal = causal;
129        self
130    }
131
132    pub fn num_heads(&self) -> usize {
133        self.num_heads
134    }
135
136    pub fn d_model(&self) -> usize {
137        self.d_model
138    }
139
140    pub fn head_dim(&self) -> usize {
141        self.head_dim
142    }
143
144    /// Reshape [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
145    /// then transpose to [batch, num_heads, seq, head_dim]
146    fn reshape_to_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
147        // [batch, seq, d_model] → [batch, seq, num_heads, head_dim]
148        let reshaped = x.reshape((batch, seq, self.num_heads, self.head_dim))?;
149        // Transpose dims 1 and 2: [batch, num_heads, seq, head_dim]
150        reshaped.transpose(1, 2)?.contiguous()
151    }
152
153    /// Inverse of reshape_to_heads:
154    /// [batch, num_heads, seq, head_dim] → [batch, seq, d_model]
155    fn reshape_from_heads(&self, x: &Tensor<B>, batch: usize, seq: usize) -> Result<Tensor<B>> {
156        // [batch, num_heads, seq, head_dim] → [batch, seq, num_heads, head_dim]
157        let transposed = x.transpose(1, 2)?.contiguous()?;
158        // [batch, seq, num_heads * head_dim] = [batch, seq, d_model]
159        transposed.reshape((batch, seq, self.d_model))
160    }
161
162    /// Create a causal mask: upper-triangular matrix of -infinity.
163    /// Shape: [seq, seq] where mask[i][j] = -1e9 if j > i, else 0
164    fn create_causal_mask(
165        &self,
166        seq_len: usize,
167        dtype: DType,
168        device: &B::Device,
169    ) -> Result<Tensor<B>> {
170        let mut mask_data = vec![0.0f64; seq_len * seq_len];
171        for i in 0..seq_len {
172            for j in (i + 1)..seq_len {
173                mask_data[i * seq_len + j] = -1e9; // Large negative ≈ -infinity
174            }
175        }
176        Tensor::<B>::from_f64_slice(&mask_data, (seq_len, seq_len), dtype, device)
177    }
178}
179
180impl<B: Backend> Module<B> for MultiHeadAttention<B> {
181    /// Forward pass: self-attention on input x.
182    ///
183    /// Input:  [batch, seq_len, d_model]
184    /// Output: [batch, seq_len, d_model]
185    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
186        let dims = x.dims();
187        if dims.len() != 3 {
188            return Err(shrew_core::Error::msg(format!(
189                "MultiHeadAttention expects 3D input [batch, seq, d_model], got {:?}",
190                dims
191            )));
192        }
193        let batch = dims[0];
194        let seq = dims[1];
195
196        // Step 1: Project to Q, K, V
197        // Reshape x from [batch, seq, d_model] to [batch*seq, d_model] for Linear
198        let x_2d = x.reshape((batch * seq, self.d_model))?;
199        let q_2d = self.w_q.forward(&x_2d)?; // [batch*seq, d_model]
200        let k_2d = self.w_k.forward(&x_2d)?;
201        let v_2d = self.w_v.forward(&x_2d)?;
202
203        // Reshape back to [batch, seq, d_model]
204        let q = q_2d.reshape((batch, seq, self.d_model))?;
205        let k = k_2d.reshape((batch, seq, self.d_model))?;
206        let v = v_2d.reshape((batch, seq, self.d_model))?;
207
208        // Step 2: Split into heads
209        // [batch, num_heads, seq, head_dim]
210        let q = self.reshape_to_heads(&q, batch, seq)?;
211        let k = self.reshape_to_heads(&k, batch, seq)?;
212        let v = self.reshape_to_heads(&v, batch, seq)?;
213
214        // Step 3: Compute attention scores
215        // scores = Q @ K^T / √d_k
216        // Q: [batch, h, seq, d_head], K^T: [batch, h, d_head, seq]
217        // scores: [batch, h, seq, seq]
218        let k_t = k.transpose(2, 3)?.contiguous()?;
219        let scores = q.matmul(&k_t)?.affine(self.scale, 0.0)?;
220
221        // Step 4: Apply causal mask (optional)
222        let scores = if self.causal {
223            let mask = self.create_causal_mask(seq, x.dtype(), x.device())?;
224            // mask is [seq, seq], scores are [batch, h, seq, seq]
225            // Broadcasting handles the batch and head dimensions
226            scores.add(&mask)?
227        } else {
228            scores
229        };
230
231        // Step 5: Softmax over last dimension (key positions)
232        let attn_weights = scores.softmax(3)?; // [batch, h, seq, seq]
233
234        // Step 6: Weighted sum of values
235        // [batch, h, seq, seq] @ [batch, h, seq, d_head] = [batch, h, seq, d_head]
236        let attn_output = attn_weights.matmul(&v)?;
237
238        // Step 7: Concatenate heads
239        // [batch, h, seq, d_head] → [batch, seq, d_model]
240        let concat = self.reshape_from_heads(&attn_output, batch, seq)?;
241
242        // Step 8: Output projection
243        let concat_2d = concat.reshape((batch * seq, self.d_model))?;
244        let output_2d = self.w_o.forward(&concat_2d)?;
245        output_2d.reshape((batch, seq, self.d_model))
246    }
247
248    fn parameters(&self) -> Vec<Tensor<B>> {
249        let mut params = Vec::new();
250        params.extend(self.w_q.parameters());
251        params.extend(self.w_k.parameters());
252        params.extend(self.w_v.parameters());
253        params.extend(self.w_o.parameters());
254        params
255    }
256
257    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
258        let mut named = Vec::new();
259        for (k, v) in self.w_q.named_parameters() {
260            named.push((format!("w_q.{k}"), v));
261        }
262        for (k, v) in self.w_k.named_parameters() {
263            named.push((format!("w_k.{k}"), v));
264        }
265        for (k, v) in self.w_v.named_parameters() {
266            named.push((format!("w_v.{k}"), v));
267        }
268        for (k, v) in self.w_o.named_parameters() {
269            named.push((format!("w_o.{k}"), v));
270        }
271        named
272    }
273}