shrew_ir/
shapes.rs

1// Shape Inference — Propagate tensor shapes through the computation graph
2//
3// Shape inference fills in Unknown output types by propagating shapes forward
4// through the graph in topological order. This enables:
5//
6//   1. Early detection of shape mismatches
7//   2. Memory planning (allocate exact buffer sizes)
8//   3. Kernel selection (choose optimal implementation for shapes)
9//
10// After inference, every node output_type should ideally be a concrete
11// IrType::Tensor (except for truly dynamic/symbolic shapes).
12//
13// RULES:
14//   - Identity: output = input shape
15//   - Add/Sub/Mul/Div: shapes must broadcast, output = broadcast shape
16//   - MatMul: [.., M, K] × [.., K, N] → [.., M, N]
17//   - Relu/Gelu/etc: shape preserved
18//   - Softmax: shape preserved
19//   - LayerNorm: shape preserved
20//   - Transpose: last two dims swapped
21//   - Sum/Mean: dims removed (or kept if keepdim)
22//   - Reshape: target shape
23//   - Concat: sum along concat dim
24//   - Dropout: shape preserved
25//   - Embedding: indices_shape + [embed_dim]
26//   - Linear: [.., in_features] → [.., out_features]
27
28use crate::graph::*;
29
30// Public API
31
32/// Run shape inference on all graphs in the program.
33/// Modifies node output_type in place.
34pub fn infer_shapes(program: &mut IrProgram) {
35    for graph in &mut program.graphs {
36        infer_graph_shapes(graph);
37    }
38}
39
40/// Run shape inference on a single graph.
41pub fn infer_graph_shapes(graph: &mut IrGraph) {
42    let order = graph.topo_order();
43    for id in order {
44        let inferred = infer_node_type(graph, id);
45        if let Some(ty) = inferred {
46            graph.node_mut(id).output_type = ty;
47        }
48    }
49}
50
51// Per-node inference
52
53fn infer_node_type(graph: &IrGraph, id: NodeId) -> Option<IrType> {
54    let node = graph.node(id);
55
56    // If already a concrete type, keep it
57    if !matches!(node.output_type, IrType::Unknown) {
58        return None;
59    }
60
61    let inputs: Vec<&IrType> = node
62        .inputs
63        .iter()
64        .map(|&i| &graph.node(i).output_type)
65        .collect();
66
67    match &node.op {
68        //  Identity: propagate from single input 
69        OpKind::Identity => inputs.first().map(|t| (*t).clone()),
70
71        //  Unary element-wise: shape preserved 
72        OpKind::Neg
73        | OpKind::Relu
74        | OpKind::Gelu
75        | OpKind::Silu
76        | OpKind::Sigmoid
77        | OpKind::Tanh
78        | OpKind::Exp
79        | OpKind::Log
80        | OpKind::Sqrt
81        | OpKind::Not => inputs.first().map(|t| (*t).clone()),
82
83        //  Softmax: shape preserved 
84        OpKind::Softmax { .. } => inputs.first().map(|t| (*t).clone()),
85
86        //  Dropout: shape preserved 
87        OpKind::Dropout { .. } => inputs.first().map(|t| (*t).clone()),
88
89        //  LayerNorm / BatchNorm: shape preserved 
90        OpKind::LayerNorm { .. } | OpKind::BatchNorm { .. } => inputs.first().map(|t| (*t).clone()),
91
92        //  Binary element-wise: broadcast 
93        OpKind::Add | OpKind::Sub | OpKind::Mul | OpKind::Div | OpKind::Mod | OpKind::Pow => {
94            if inputs.len() == 2 {
95                broadcast_shapes(inputs[0], inputs[1])
96            } else {
97                None
98            }
99        }
100
101        //  Comparison: same shape, bool dtype 
102        OpKind::Equal
103        | OpKind::NotEqual
104        | OpKind::Less
105        | OpKind::Greater
106        | OpKind::LessEqual
107        | OpKind::GreaterEqual => {
108            if inputs.len() == 2 {
109                if let Some(IrType::Tensor { shape, .. }) = broadcast_shapes(inputs[0], inputs[1]) {
110                    Some(IrType::Tensor {
111                        shape,
112                        dtype: DType::Bool,
113                    })
114                } else {
115                    None
116                }
117            } else {
118                None
119            }
120        }
121
122        //  Logical: same shape, bool 
123        OpKind::And | OpKind::Or => {
124            if inputs.len() == 2 {
125                if let Some(IrType::Tensor { shape, .. }) = broadcast_shapes(inputs[0], inputs[1]) {
126                    Some(IrType::Tensor {
127                        shape,
128                        dtype: DType::Bool,
129                    })
130                } else {
131                    None
132                }
133            } else {
134                None
135            }
136        }
137
138        //  MatMul: [.., M, K] × [.., K, N] → [.., M, N] 
139        OpKind::MatMul => {
140            if inputs.len() == 2 {
141                infer_matmul(inputs[0], inputs[1])
142            } else {
143                None
144            }
145        }
146
147        //  Transpose: swap last two dims 
148        OpKind::Transpose => {
149            if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
150                if shape.len() >= 2 {
151                    let mut new_shape = shape.clone();
152                    let n = new_shape.len();
153                    new_shape.swap(n - 1, n - 2);
154                    Some(IrType::Tensor {
155                        shape: new_shape,
156                        dtype: *dtype,
157                    })
158                } else {
159                    Some(IrType::Tensor {
160                        shape: shape.clone(),
161                        dtype: *dtype,
162                    })
163                }
164            } else {
165                None
166            }
167        }
168
169        //  Permute: reorder dimensions 
170        OpKind::Permute { dims } => {
171            if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
172                let new_shape: Vec<Dim> = dims
173                    .iter()
174                    .map(|&d| {
175                        let idx = if d < 0 { shape.len() as i64 + d } else { d } as usize;
176                        shape.get(idx).cloned().unwrap_or(Dim::Dynamic)
177                    })
178                    .collect();
179                Some(IrType::Tensor {
180                    shape: new_shape,
181                    dtype: *dtype,
182                })
183            } else {
184                None
185            }
186        }
187
188        //  Reshape / View: target shape (resolve -1 if possible) 
189        OpKind::Reshape { target_shape } | OpKind::View { target_shape } => {
190            if let Some(IrType::Tensor { dtype, .. }) = inputs.first() {
191                Some(IrType::Tensor {
192                    shape: target_shape.clone(),
193                    dtype: *dtype,
194                })
195            } else {
196                None
197            }
198        }
199
200        //  Expand: target shape 
201        OpKind::Expand { target_shape } => {
202            if let Some(IrType::Tensor { dtype, .. }) = inputs.first() {
203                Some(IrType::Tensor {
204                    shape: target_shape.clone(),
205                    dtype: *dtype,
206                })
207            } else {
208                None
209            }
210        }
211
212        //  Reduction ops 
213        OpKind::Sum { dims, keepdim }
214        | OpKind::Mean { dims, keepdim }
215        | OpKind::Variance { dims, keepdim } => {
216            if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
217                Some(infer_reduction(shape, dims, *keepdim, *dtype))
218            } else {
219                None
220            }
221        }
222
223        OpKind::Max { dim, keepdim } | OpKind::Min { dim, keepdim } => {
224            if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
225                Some(infer_reduction(shape, &[*dim], *keepdim, *dtype))
226            } else {
227                None
228            }
229        }
230
231        //  Concat: sum along dim 
232        OpKind::Concat { dim } => infer_concat(&inputs, *dim),
233
234        //  Embedding: indices → [.., embed_dim] 
235        OpKind::Embedding => {
236            // If we have table and indices: table=[V, D], indices=[..] → [.., D]
237            if inputs.len() >= 2 {
238                if let (
239                    IrType::Tensor {
240                        shape: table_shape,
241                        dtype,
242                    },
243                    IrType::Tensor {
244                        shape: idx_shape, ..
245                    },
246                ) = (inputs[0], inputs[1])
247                {
248                    if let Some(embed_dim) = table_shape.last() {
249                        let mut out_shape = idx_shape.clone();
250                        out_shape.push(embed_dim.clone());
251                        return Some(IrType::Tensor {
252                            shape: out_shape,
253                            dtype: *dtype,
254                        });
255                    }
256                }
257            }
258            None
259        }
260
261        //  Linear 
262        OpKind::Linear { .. } => {
263            // input=[.., in_features], weight=[out, in] → [.., out_features]
264            if inputs.len() >= 2 {
265                if let (
266                    IrType::Tensor {
267                        shape: in_shape,
268                        dtype,
269                    },
270                    IrType::Tensor { shape: w_shape, .. },
271                ) = (inputs[0], inputs[1])
272                {
273                    if !in_shape.is_empty() && w_shape.len() == 2 {
274                        let mut out_shape = in_shape[..in_shape.len() - 1].to_vec();
275                        out_shape.push(w_shape[0].clone());
276                        return Some(IrType::Tensor {
277                            shape: out_shape,
278                            dtype: *dtype,
279                        });
280                    }
281                }
282            }
283            None
284        }
285
286        //  Loss functions: output is scalar 
287        OpKind::CrossEntropy | OpKind::MseLoss => {
288            if let Some(IrType::Tensor { dtype, .. }) = inputs.first() {
289                Some(IrType::Scalar(*dtype))
290            } else {
291                None
292            }
293        }
294
295        //  Attention / Transformer: shape preserved (simplified) 
296        OpKind::MultiHeadAttention { .. } | OpKind::TransformerBlock { .. } => {
297            inputs.first().map(|t| (*t).clone())
298        }
299
300        //  Repeat: shape preserved 
301        OpKind::Repeat { .. } => inputs.first().map(|t| (*t).clone()),
302
303        //  Constants: already typed at creation 
304        OpKind::Constant(_) => None,
305
306        //  Custom / Call: can't infer 
307        OpKind::Custom { .. } | OpKind::Call { .. } => None,
308
309        //  Split / Range: complex, skip 
310        OpKind::Split { .. } | OpKind::Range => None,
311    }
312}
313
314// Shape helpers
315
316/// Broadcast two tensor types. Returns the output type if compatible.
317fn broadcast_shapes(left: &IrType, right: &IrType) -> Option<IrType> {
318    if let (
319        IrType::Tensor {
320            shape: ls,
321            dtype: ld,
322        },
323        IrType::Tensor {
324            shape: rs,
325            dtype: _rd,
326        },
327    ) = (left, right)
328    {
329        // dtype of output = left dtype (assume matching)
330        let dtype = *ld;
331        let max_rank = ls.len().max(rs.len());
332        let mut result = Vec::with_capacity(max_rank);
333
334        for i in 0..max_rank {
335            let l_idx = if i < ls.len() {
336                Some(&ls[ls.len() - 1 - i])
337            } else {
338                None
339            };
340            let r_idx = if i < rs.len() {
341                Some(&rs[rs.len() - 1 - i])
342            } else {
343                None
344            };
345
346            let dim = match (l_idx, r_idx) {
347                (Some(l), None) => l.clone(),
348                (None, Some(r)) => r.clone(),
349                (Some(l), Some(r)) => broadcast_dim(l, r)?,
350                (None, None) => unreachable!(),
351            };
352            result.push(dim);
353        }
354
355        result.reverse();
356        Some(IrType::Tensor {
357            shape: result,
358            dtype,
359        })
360    } else {
361        None
362    }
363}
364
365/// Broadcast a single dimension pair.
366fn broadcast_dim(a: &Dim, b: &Dim) -> Option<Dim> {
367    match (a, b) {
368        (Dim::Fixed(1), other) | (other, Dim::Fixed(1)) => Some(other.clone()),
369        (Dim::Fixed(x), Dim::Fixed(y)) if x == y => Some(Dim::Fixed(*x)),
370        (Dim::Fixed(_), Dim::Fixed(_)) => None, // incompatible
371        (Dim::Symbolic(s), Dim::Symbolic(t)) if s == t => Some(Dim::Symbolic(s.clone())),
372        (Dim::Dynamic, _) | (_, Dim::Dynamic) => Some(Dim::Dynamic),
373        (Dim::Symbolic(s), _) | (_, Dim::Symbolic(s)) => Some(Dim::Symbolic(s.clone())),
374    }
375}
376
377/// MatMul shape: [.., M, K] × [.., K, N] → [.., M, N]
378fn infer_matmul(left: &IrType, right: &IrType) -> Option<IrType> {
379    if let (IrType::Tensor { shape: ls, dtype }, IrType::Tensor { shape: rs, .. }) = (left, right) {
380        if ls.len() < 2 || rs.len() < 2 {
381            // 1D matmul: not fully handled, return Dynamic
382            return Some(IrType::Tensor {
383                shape: vec![Dim::Dynamic],
384                dtype: *dtype,
385            });
386        }
387
388        // Batch dims = broadcast of everything except last 2
389        let l_batch = &ls[..ls.len() - 2];
390        let r_batch = &rs[..rs.len() - 2];
391        let batch = broadcast_batch_dims(l_batch, r_batch);
392
393        let m = ls[ls.len() - 2].clone();
394        let n = rs[rs.len() - 1].clone();
395
396        let mut shape = batch;
397        shape.push(m);
398        shape.push(n);
399
400        Some(IrType::Tensor {
401            shape,
402            dtype: *dtype,
403        })
404    } else {
405        None
406    }
407}
408
409fn broadcast_batch_dims(a: &[Dim], b: &[Dim]) -> Vec<Dim> {
410    let max_len = a.len().max(b.len());
411    let mut result = Vec::with_capacity(max_len);
412    for i in 0..max_len {
413        let l = if i < a.len() {
414            Some(&a[a.len() - 1 - i])
415        } else {
416            None
417        };
418        let r = if i < b.len() {
419            Some(&b[b.len() - 1 - i])
420        } else {
421            None
422        };
423        let dim = match (l, r) {
424            (Some(l), None) => l.clone(),
425            (None, Some(r)) => r.clone(),
426            (Some(l), Some(r)) => broadcast_dim(l, r).unwrap_or(Dim::Dynamic),
427            (None, None) => unreachable!(),
428        };
429        result.push(dim);
430    }
431    result.reverse();
432    result
433}
434
435/// Infer the result of a reduction op.
436fn infer_reduction(shape: &[Dim], dims: &[i64], keepdim: bool, dtype: DType) -> IrType {
437    let rank = shape.len();
438    let normalized: Vec<usize> = dims
439        .iter()
440        .map(|&d| {
441            if d < 0 {
442                (rank as i64 + d) as usize
443            } else {
444                d as usize
445            }
446        })
447        .collect();
448
449    if keepdim {
450        let new_shape: Vec<Dim> = shape
451            .iter()
452            .enumerate()
453            .map(|(i, d)| {
454                if normalized.contains(&i) {
455                    Dim::Fixed(1)
456                } else {
457                    d.clone()
458                }
459            })
460            .collect();
461        IrType::Tensor {
462            shape: new_shape,
463            dtype,
464        }
465    } else {
466        let new_shape: Vec<Dim> = shape
467            .iter()
468            .enumerate()
469            .filter(|(i, _)| !normalized.contains(i))
470            .map(|(_, d)| d.clone())
471            .collect();
472        if new_shape.is_empty() {
473            IrType::Scalar(dtype)
474        } else {
475            IrType::Tensor {
476                shape: new_shape,
477                dtype,
478            }
479        }
480    }
481}
482
483/// Infer concat shape.
484fn infer_concat(inputs: &[&IrType], dim: i64) -> Option<IrType> {
485    if inputs.is_empty() {
486        return None;
487    }
488
489    // Use the first input as a template
490    if let IrType::Tensor {
491        shape: first_shape,
492        dtype,
493    } = inputs[0]
494    {
495        let rank = first_shape.len();
496        let d = if dim < 0 {
497            (rank as i64 + dim) as usize
498        } else {
499            dim as usize
500        };
501
502        if d >= rank {
503            return None;
504        }
505
506        let mut result_shape = first_shape.clone();
507
508        // Sum the concat dimension across all inputs
509        let mut total = dim_value(&first_shape[d]);
510        for &input in &inputs[1..] {
511            if let IrType::Tensor { shape, .. } = input {
512                if shape.len() == rank {
513                    total = match (total, dim_value(&shape[d])) {
514                        (Some(a), Some(b)) => Some(a + b),
515                        _ => None,
516                    };
517                }
518            }
519        }
520
521        result_shape[d] = match total {
522            Some(n) => Dim::Fixed(n),
523            None => Dim::Dynamic,
524        };
525
526        Some(IrType::Tensor {
527            shape: result_shape,
528            dtype: *dtype,
529        })
530    } else {
531        None
532    }
533}
534
535fn dim_value(dim: &Dim) -> Option<i64> {
536    match dim {
537        Dim::Fixed(n) => Some(*n),
538        _ => None,
539    }
540}