1use crate::graph::*;
29
30pub fn infer_shapes(program: &mut IrProgram) {
35 for graph in &mut program.graphs {
36 infer_graph_shapes(graph);
37 }
38}
39
40pub fn infer_graph_shapes(graph: &mut IrGraph) {
42 let order = graph.topo_order();
43 for id in order {
44 let inferred = infer_node_type(graph, id);
45 if let Some(ty) = inferred {
46 graph.node_mut(id).output_type = ty;
47 }
48 }
49}
50
51fn infer_node_type(graph: &IrGraph, id: NodeId) -> Option<IrType> {
54 let node = graph.node(id);
55
56 if !matches!(node.output_type, IrType::Unknown) {
58 return None;
59 }
60
61 let inputs: Vec<&IrType> = node
62 .inputs
63 .iter()
64 .map(|&i| &graph.node(i).output_type)
65 .collect();
66
67 match &node.op {
68 OpKind::Identity => inputs.first().map(|t| (*t).clone()),
70
71 OpKind::Neg
73 | OpKind::Relu
74 | OpKind::Gelu
75 | OpKind::Silu
76 | OpKind::Sigmoid
77 | OpKind::Tanh
78 | OpKind::Exp
79 | OpKind::Log
80 | OpKind::Sqrt
81 | OpKind::Not => inputs.first().map(|t| (*t).clone()),
82
83 OpKind::Softmax { .. } => inputs.first().map(|t| (*t).clone()),
85
86 OpKind::Dropout { .. } => inputs.first().map(|t| (*t).clone()),
88
89 OpKind::LayerNorm { .. } | OpKind::BatchNorm { .. } => inputs.first().map(|t| (*t).clone()),
91
92 OpKind::Add | OpKind::Sub | OpKind::Mul | OpKind::Div | OpKind::Mod | OpKind::Pow => {
94 if inputs.len() == 2 {
95 broadcast_shapes(inputs[0], inputs[1])
96 } else {
97 None
98 }
99 }
100
101 OpKind::Equal
103 | OpKind::NotEqual
104 | OpKind::Less
105 | OpKind::Greater
106 | OpKind::LessEqual
107 | OpKind::GreaterEqual => {
108 if inputs.len() == 2 {
109 if let Some(IrType::Tensor { shape, .. }) = broadcast_shapes(inputs[0], inputs[1]) {
110 Some(IrType::Tensor {
111 shape,
112 dtype: DType::Bool,
113 })
114 } else {
115 None
116 }
117 } else {
118 None
119 }
120 }
121
122 OpKind::And | OpKind::Or => {
124 if inputs.len() == 2 {
125 if let Some(IrType::Tensor { shape, .. }) = broadcast_shapes(inputs[0], inputs[1]) {
126 Some(IrType::Tensor {
127 shape,
128 dtype: DType::Bool,
129 })
130 } else {
131 None
132 }
133 } else {
134 None
135 }
136 }
137
138 OpKind::MatMul => {
140 if inputs.len() == 2 {
141 infer_matmul(inputs[0], inputs[1])
142 } else {
143 None
144 }
145 }
146
147 OpKind::Transpose => {
149 if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
150 if shape.len() >= 2 {
151 let mut new_shape = shape.clone();
152 let n = new_shape.len();
153 new_shape.swap(n - 1, n - 2);
154 Some(IrType::Tensor {
155 shape: new_shape,
156 dtype: *dtype,
157 })
158 } else {
159 Some(IrType::Tensor {
160 shape: shape.clone(),
161 dtype: *dtype,
162 })
163 }
164 } else {
165 None
166 }
167 }
168
169 OpKind::Permute { dims } => {
171 if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
172 let new_shape: Vec<Dim> = dims
173 .iter()
174 .map(|&d| {
175 let idx = if d < 0 { shape.len() as i64 + d } else { d } as usize;
176 shape.get(idx).cloned().unwrap_or(Dim::Dynamic)
177 })
178 .collect();
179 Some(IrType::Tensor {
180 shape: new_shape,
181 dtype: *dtype,
182 })
183 } else {
184 None
185 }
186 }
187
188 OpKind::Reshape { target_shape } | OpKind::View { target_shape } => {
190 if let Some(IrType::Tensor { dtype, .. }) = inputs.first() {
191 Some(IrType::Tensor {
192 shape: target_shape.clone(),
193 dtype: *dtype,
194 })
195 } else {
196 None
197 }
198 }
199
200 OpKind::Expand { target_shape } => {
202 if let Some(IrType::Tensor { dtype, .. }) = inputs.first() {
203 Some(IrType::Tensor {
204 shape: target_shape.clone(),
205 dtype: *dtype,
206 })
207 } else {
208 None
209 }
210 }
211
212 OpKind::Sum { dims, keepdim }
214 | OpKind::Mean { dims, keepdim }
215 | OpKind::Variance { dims, keepdim } => {
216 if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
217 Some(infer_reduction(shape, dims, *keepdim, *dtype))
218 } else {
219 None
220 }
221 }
222
223 OpKind::Max { dim, keepdim } | OpKind::Min { dim, keepdim } => {
224 if let Some(IrType::Tensor { shape, dtype }) = inputs.first() {
225 Some(infer_reduction(shape, &[*dim], *keepdim, *dtype))
226 } else {
227 None
228 }
229 }
230
231 OpKind::Concat { dim } => infer_concat(&inputs, *dim),
233
234 OpKind::Embedding => {
236 if inputs.len() >= 2 {
238 if let (
239 IrType::Tensor {
240 shape: table_shape,
241 dtype,
242 },
243 IrType::Tensor {
244 shape: idx_shape, ..
245 },
246 ) = (inputs[0], inputs[1])
247 {
248 if let Some(embed_dim) = table_shape.last() {
249 let mut out_shape = idx_shape.clone();
250 out_shape.push(embed_dim.clone());
251 return Some(IrType::Tensor {
252 shape: out_shape,
253 dtype: *dtype,
254 });
255 }
256 }
257 }
258 None
259 }
260
261 OpKind::Linear { .. } => {
263 if inputs.len() >= 2 {
265 if let (
266 IrType::Tensor {
267 shape: in_shape,
268 dtype,
269 },
270 IrType::Tensor { shape: w_shape, .. },
271 ) = (inputs[0], inputs[1])
272 {
273 if !in_shape.is_empty() && w_shape.len() == 2 {
274 let mut out_shape = in_shape[..in_shape.len() - 1].to_vec();
275 out_shape.push(w_shape[0].clone());
276 return Some(IrType::Tensor {
277 shape: out_shape,
278 dtype: *dtype,
279 });
280 }
281 }
282 }
283 None
284 }
285
286 OpKind::CrossEntropy | OpKind::MseLoss => {
288 if let Some(IrType::Tensor { dtype, .. }) = inputs.first() {
289 Some(IrType::Scalar(*dtype))
290 } else {
291 None
292 }
293 }
294
295 OpKind::MultiHeadAttention { .. } | OpKind::TransformerBlock { .. } => {
297 inputs.first().map(|t| (*t).clone())
298 }
299
300 OpKind::Repeat { .. } => inputs.first().map(|t| (*t).clone()),
302
303 OpKind::Constant(_) => None,
305
306 OpKind::Custom { .. } | OpKind::Call { .. } => None,
308
309 OpKind::Split { .. } | OpKind::Range => None,
311 }
312}
313
314fn broadcast_shapes(left: &IrType, right: &IrType) -> Option<IrType> {
318 if let (
319 IrType::Tensor {
320 shape: ls,
321 dtype: ld,
322 },
323 IrType::Tensor {
324 shape: rs,
325 dtype: _rd,
326 },
327 ) = (left, right)
328 {
329 let dtype = *ld;
331 let max_rank = ls.len().max(rs.len());
332 let mut result = Vec::with_capacity(max_rank);
333
334 for i in 0..max_rank {
335 let l_idx = if i < ls.len() {
336 Some(&ls[ls.len() - 1 - i])
337 } else {
338 None
339 };
340 let r_idx = if i < rs.len() {
341 Some(&rs[rs.len() - 1 - i])
342 } else {
343 None
344 };
345
346 let dim = match (l_idx, r_idx) {
347 (Some(l), None) => l.clone(),
348 (None, Some(r)) => r.clone(),
349 (Some(l), Some(r)) => broadcast_dim(l, r)?,
350 (None, None) => unreachable!(),
351 };
352 result.push(dim);
353 }
354
355 result.reverse();
356 Some(IrType::Tensor {
357 shape: result,
358 dtype,
359 })
360 } else {
361 None
362 }
363}
364
365fn broadcast_dim(a: &Dim, b: &Dim) -> Option<Dim> {
367 match (a, b) {
368 (Dim::Fixed(1), other) | (other, Dim::Fixed(1)) => Some(other.clone()),
369 (Dim::Fixed(x), Dim::Fixed(y)) if x == y => Some(Dim::Fixed(*x)),
370 (Dim::Fixed(_), Dim::Fixed(_)) => None, (Dim::Symbolic(s), Dim::Symbolic(t)) if s == t => Some(Dim::Symbolic(s.clone())),
372 (Dim::Dynamic, _) | (_, Dim::Dynamic) => Some(Dim::Dynamic),
373 (Dim::Symbolic(s), _) | (_, Dim::Symbolic(s)) => Some(Dim::Symbolic(s.clone())),
374 }
375}
376
377fn infer_matmul(left: &IrType, right: &IrType) -> Option<IrType> {
379 if let (IrType::Tensor { shape: ls, dtype }, IrType::Tensor { shape: rs, .. }) = (left, right) {
380 if ls.len() < 2 || rs.len() < 2 {
381 return Some(IrType::Tensor {
383 shape: vec![Dim::Dynamic],
384 dtype: *dtype,
385 });
386 }
387
388 let l_batch = &ls[..ls.len() - 2];
390 let r_batch = &rs[..rs.len() - 2];
391 let batch = broadcast_batch_dims(l_batch, r_batch);
392
393 let m = ls[ls.len() - 2].clone();
394 let n = rs[rs.len() - 1].clone();
395
396 let mut shape = batch;
397 shape.push(m);
398 shape.push(n);
399
400 Some(IrType::Tensor {
401 shape,
402 dtype: *dtype,
403 })
404 } else {
405 None
406 }
407}
408
409fn broadcast_batch_dims(a: &[Dim], b: &[Dim]) -> Vec<Dim> {
410 let max_len = a.len().max(b.len());
411 let mut result = Vec::with_capacity(max_len);
412 for i in 0..max_len {
413 let l = if i < a.len() {
414 Some(&a[a.len() - 1 - i])
415 } else {
416 None
417 };
418 let r = if i < b.len() {
419 Some(&b[b.len() - 1 - i])
420 } else {
421 None
422 };
423 let dim = match (l, r) {
424 (Some(l), None) => l.clone(),
425 (None, Some(r)) => r.clone(),
426 (Some(l), Some(r)) => broadcast_dim(l, r).unwrap_or(Dim::Dynamic),
427 (None, None) => unreachable!(),
428 };
429 result.push(dim);
430 }
431 result.reverse();
432 result
433}
434
435fn infer_reduction(shape: &[Dim], dims: &[i64], keepdim: bool, dtype: DType) -> IrType {
437 let rank = shape.len();
438 let normalized: Vec<usize> = dims
439 .iter()
440 .map(|&d| {
441 if d < 0 {
442 (rank as i64 + d) as usize
443 } else {
444 d as usize
445 }
446 })
447 .collect();
448
449 if keepdim {
450 let new_shape: Vec<Dim> = shape
451 .iter()
452 .enumerate()
453 .map(|(i, d)| {
454 if normalized.contains(&i) {
455 Dim::Fixed(1)
456 } else {
457 d.clone()
458 }
459 })
460 .collect();
461 IrType::Tensor {
462 shape: new_shape,
463 dtype,
464 }
465 } else {
466 let new_shape: Vec<Dim> = shape
467 .iter()
468 .enumerate()
469 .filter(|(i, _)| !normalized.contains(i))
470 .map(|(_, d)| d.clone())
471 .collect();
472 if new_shape.is_empty() {
473 IrType::Scalar(dtype)
474 } else {
475 IrType::Tensor {
476 shape: new_shape,
477 dtype,
478 }
479 }
480 }
481}
482
483fn infer_concat(inputs: &[&IrType], dim: i64) -> Option<IrType> {
485 if inputs.is_empty() {
486 return None;
487 }
488
489 if let IrType::Tensor {
491 shape: first_shape,
492 dtype,
493 } = inputs[0]
494 {
495 let rank = first_shape.len();
496 let d = if dim < 0 {
497 (rank as i64 + dim) as usize
498 } else {
499 dim as usize
500 };
501
502 if d >= rank {
503 return None;
504 }
505
506 let mut result_shape = first_shape.clone();
507
508 let mut total = dim_value(&first_shape[d]);
510 for &input in &inputs[1..] {
511 if let IrType::Tensor { shape, .. } = input {
512 if shape.len() == rank {
513 total = match (total, dim_value(&shape[d])) {
514 (Some(a), Some(b)) => Some(a + b),
515 _ => None,
516 };
517 }
518 }
519 }
520
521 result_shape[d] = match total {
522 Some(n) => Dim::Fixed(n),
523 None => Dim::Dynamic,
524 };
525
526 Some(IrType::Tensor {
527 shape: result_shape,
528 dtype: *dtype,
529 })
530 } else {
531 None
532 }
533}
534
535fn dim_value(dim: &Dim) -> Option<i64> {
536 match dim {
537 Dim::Fixed(n) => Some(*n),
538 _ => None,
539 }
540}