1use crate::graph::*;
17use std::collections::{HashMap, HashSet};
18
19pub fn optimize(program: &mut IrProgram) -> usize {
24 let mut total = 0;
25 for graph in &mut program.graphs {
26 total += optimize_graph(graph);
27 }
28 total
29}
30
31pub fn optimize_graph(graph: &mut IrGraph) -> usize {
34 let mut total = 0;
35 loop {
36 let mut changed = 0;
37 changed += eliminate_dead_code(graph);
38 changed += eliminate_identities(graph);
39 changed += fold_constants(graph);
40 changed += eliminate_common_subexprs(graph);
41 changed += fuse_operators(graph);
42 if changed == 0 {
43 break;
44 }
45 total += changed;
46 }
47 total
48}
49
50pub fn eliminate_dead_code(graph: &mut IrGraph) -> usize {
57 if graph.nodes.is_empty() {
58 return 0;
59 }
60
61 let mut reachable = HashSet::new();
63 let mut stack: Vec<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
64
65 for param in &graph.params {
67 stack.push(param.node_id);
68 }
69
70 while let Some(id) = stack.pop() {
71 if !reachable.insert(id) {
72 continue;
73 }
74 if id.0 < graph.nodes.len() {
75 for &inp in &graph.nodes[id.0].inputs {
76 stack.push(inp);
77 }
78 }
79 }
80
81 let total = graph.nodes.len();
82 let dead_count = total - reachable.len();
83 if dead_count == 0 {
84 return 0;
85 }
86
87 let mut keep: Vec<bool> = vec![false; total];
89 for &id in &reachable {
90 keep[id.0] = true;
91 }
92
93 let mut old_to_new: Vec<Option<NodeId>> = vec![None; total];
94 let mut new_id = 0usize;
95 for old_id in 0..total {
96 if keep[old_id] {
97 old_to_new[old_id] = Some(NodeId(new_id));
98 new_id += 1;
99 }
100 }
101
102 let mut new_nodes = Vec::with_capacity(reachable.len());
104 for (old_id, node) in graph.nodes.drain(..).enumerate() {
105 if let Some(nid) = old_to_new[old_id] {
106 let mut node = node;
107 node.id = nid;
108 node.inputs = node
109 .inputs
110 .iter()
111 .filter_map(|&inp| old_to_new[inp.0])
112 .collect();
113 new_nodes.push(node);
114 }
115 }
116 graph.nodes = new_nodes;
117
118 graph.inputs = graph
120 .inputs
121 .iter()
122 .filter_map(|&id| old_to_new[id.0])
123 .collect();
124 graph.outputs.retain(|o| old_to_new[o.node_id.0].is_some());
125 for out in &mut graph.outputs {
126 if let Some(new) = old_to_new[out.node_id.0] {
127 out.node_id = new;
128 }
129 }
130
131 for param in &mut graph.params {
132 if let Some(new) = old_to_new[param.node_id.0] {
133 param.node_id = new;
134 }
135 }
136 graph.params.retain(|p| old_to_new[p.node_id.0].is_some());
137
138 graph.name_to_id.clear();
140 for node in &graph.nodes {
141 graph.name_to_id.insert(node.name.clone(), node.id);
142 }
143
144 dead_count
145}
146
147pub fn eliminate_identities(graph: &mut IrGraph) -> usize {
154 let input_set: HashSet<NodeId> = graph.inputs.iter().copied().collect();
157 let param_set: HashSet<NodeId> = graph.params.iter().map(|p| p.node_id).collect();
158
159 let mut identity_map: HashMap<NodeId, NodeId> = HashMap::new();
161 for node in &graph.nodes {
162 if matches!(node.op, OpKind::Identity)
163 && node.inputs.len() == 1
164 && !input_set.contains(&node.id)
165 && !param_set.contains(&node.id)
166 {
167 identity_map.insert(node.id, node.inputs[0]);
168 }
169 }
170
171 if identity_map.is_empty() {
172 return 0;
173 }
174
175 let mut resolved: HashMap<NodeId, NodeId> = HashMap::new();
177 for &id in identity_map.keys() {
178 let mut target = id;
179 let mut visited = HashSet::new();
180 while let Some(&next) = identity_map.get(&target) {
181 if !visited.insert(target) {
182 break; }
184 target = next;
185 }
186 resolved.insert(id, target);
187 }
188
189 let count = resolved.len();
190
191 for node in &mut graph.nodes {
193 for inp in &mut node.inputs {
194 if let Some(&target) = resolved.get(inp) {
195 *inp = target;
196 }
197 }
198 }
199
200 for out in &mut graph.outputs {
202 if let Some(&target) = resolved.get(&out.node_id) {
203 out.node_id = target;
204 }
205 }
206
207 eliminate_dead_code(graph);
209
210 count
211}
212
213pub fn fold_constants(graph: &mut IrGraph) -> usize {
219 let mut folded = 0;
220
221 for i in 0..graph.nodes.len() {
222 let node = &graph.nodes[i];
223
224 if node.inputs.len() != 2 {
226 continue;
227 }
228
229 let left_const = get_constant(&graph.nodes[node.inputs[0].0]);
230 let right_const = get_constant(&graph.nodes[node.inputs[1].0]);
231
232 let result = match (&node.op, left_const, right_const) {
233 (OpKind::Add, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) => {
234 Some(ConstantValue::Int(a + b))
235 }
236 (OpKind::Add, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b))) => {
237 Some(ConstantValue::Float(a + b))
238 }
239 (OpKind::Sub, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) => {
240 Some(ConstantValue::Int(a - b))
241 }
242 (OpKind::Sub, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b))) => {
243 Some(ConstantValue::Float(a - b))
244 }
245 (OpKind::Mul, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) => {
246 Some(ConstantValue::Int(a * b))
247 }
248 (OpKind::Mul, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b))) => {
249 Some(ConstantValue::Float(a * b))
250 }
251 (OpKind::Div, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) if b != 0 => {
252 Some(ConstantValue::Int(a / b))
253 }
254 (OpKind::Div, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b)))
255 if b != 0.0 =>
256 {
257 Some(ConstantValue::Float(a / b))
258 }
259 _ => None,
260 };
261
262 if let Some(val) = result {
263 graph.nodes[i].op = OpKind::Constant(val);
264 graph.nodes[i].inputs.clear();
265 folded += 1;
266 }
267 }
268
269 if folded > 0 {
270 eliminate_dead_code(graph);
271 }
272
273 folded
274}
275
276fn get_constant(node: &IrNode) -> Option<ConstantValue> {
277 match &node.op {
278 OpKind::Constant(v) => Some(v.clone()),
279 _ => None,
280 }
281}
282
283pub fn eliminate_common_subexprs(graph: &mut IrGraph) -> usize {
290 let protected: HashSet<NodeId> = graph
293 .inputs
294 .iter()
295 .copied()
296 .chain(graph.params.iter().map(|p| p.node_id))
297 .collect();
298
299 let mut canonical: HashMap<OpSignature, NodeId> = HashMap::new();
300 let mut redirect: HashMap<NodeId, NodeId> = HashMap::new();
301
302 for node in &graph.nodes {
303 if has_side_effects(&node.op) || protected.contains(&node.id) {
305 continue;
306 }
307
308 let sig = OpSignature {
309 op: op_discriminant(&node.op),
310 inputs: node.inputs.clone(),
311 };
312
313 if let Some(&existing_id) = canonical.get(&sig) {
314 redirect.insert(node.id, existing_id);
315 } else {
316 canonical.insert(sig, node.id);
317 }
318 }
319
320 if redirect.is_empty() {
321 return 0;
322 }
323
324 let count = redirect.len();
325
326 for node in &mut graph.nodes {
328 for inp in &mut node.inputs {
329 if let Some(&target) = redirect.get(inp) {
330 *inp = target;
331 }
332 }
333 }
334
335 for out in &mut graph.outputs {
337 if let Some(&target) = redirect.get(&out.node_id) {
338 out.node_id = target;
339 }
340 }
341
342 eliminate_dead_code(graph);
343
344 count
345}
346
347#[derive(Debug, Clone, PartialEq, Eq, Hash)]
349struct OpSignature {
350 op: String,
351 inputs: Vec<NodeId>,
352}
353
354fn op_discriminant(op: &OpKind) -> String {
356 match op {
357 OpKind::Add => "add".into(),
358 OpKind::Sub => "sub".into(),
359 OpKind::Mul => "mul".into(),
360 OpKind::Div => "div".into(),
361 OpKind::Mod => "mod".into(),
362 OpKind::Pow => "pow".into(),
363 OpKind::MatMul => "matmul".into(),
364 OpKind::Neg => "neg".into(),
365 OpKind::Relu => "relu".into(),
366 OpKind::Gelu => "gelu".into(),
367 OpKind::Silu => "silu".into(),
368 OpKind::Sigmoid => "sigmoid".into(),
369 OpKind::Tanh => "tanh".into(),
370 OpKind::Exp => "exp".into(),
371 OpKind::Log => "log".into(),
372 OpKind::Sqrt => "sqrt".into(),
373 OpKind::Transpose => "transpose".into(),
374 OpKind::Not => "not".into(),
375 OpKind::Identity => "identity".into(),
376 OpKind::Softmax { dim } => format!("softmax_{dim}"),
377 OpKind::LayerNorm { eps } => format!("layernorm_{eps}"),
378 OpKind::BatchNorm { eps } => format!("batchnorm_{eps}"),
379 OpKind::Sum { dims, keepdim } => format!("sum_{dims:?}_{keepdim}"),
380 OpKind::Mean { dims, keepdim } => format!("mean_{dims:?}_{keepdim}"),
381 OpKind::Max { dim, keepdim } => format!("max_{dim}_{keepdim}"),
382 OpKind::Min { dim, keepdim } => format!("min_{dim}_{keepdim}"),
383 OpKind::Variance { dims, keepdim } => format!("var_{dims:?}_{keepdim}"),
384 OpKind::Dropout { p } => format!("dropout_{p}"),
385 OpKind::Constant(v) => format!("const_{v}"),
386 _ => format!("nocse_{op:?}"),
388 }
389}
390
391fn has_side_effects(op: &OpKind) -> bool {
393 matches!(
394 op,
395 OpKind::Dropout { .. } | OpKind::Custom { .. } | OpKind::Call { .. } )
399}
400
401pub fn fuse_operators(graph: &mut IrGraph) -> usize {
415 let mut fused = 0;
416 fused += fuse_matmul_add(graph);
417 fused += fuse_add_relu(graph);
418 fused += fuse_matmul_relu(graph);
419 fused
420}
421
422fn fuse_matmul_add(graph: &mut IrGraph) -> usize {
429 let mut fused = 0;
430 let output_nodes: HashSet<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
431
432 let matmul_ids: HashSet<NodeId> = graph
434 .nodes
435 .iter()
436 .filter(|n| matches!(n.op, OpKind::MatMul))
437 .map(|n| n.id)
438 .collect();
439
440 let mut consumers: HashMap<NodeId, usize> = HashMap::new();
442 for node in &graph.nodes {
443 for &inp in &node.inputs {
444 *consumers.entry(inp).or_insert(0) += 1;
445 }
446 }
447
448 for i in 0..graph.nodes.len() {
449 let node = &graph.nodes[i];
450 if !matches!(node.op, OpKind::Add) || node.inputs.len() != 2 {
451 continue;
452 }
453
454 let first_inp = node.inputs[0];
455 let second_inp = node.inputs[1];
456
457 if matmul_ids.contains(&first_inp)
459 && consumers.get(&first_inp).copied().unwrap_or(0) == 1
460 && !output_nodes.contains(&first_inp)
461 {
462 let matmul_inputs = graph.nodes[first_inp.0].inputs.clone();
465 graph.nodes[i].op = OpKind::Custom {
466 name: "fused_matmul_add".to_string(),
467 attrs: HashMap::new(),
468 };
469 graph.nodes[i].inputs = vec![matmul_inputs[0], matmul_inputs[1], second_inp];
470 graph.nodes[i].name = format!("{}_fused_matmul_add", graph.nodes[i].name);
471 fused += 1;
472 }
473 }
474
475 if fused > 0 {
476 eliminate_dead_code(graph);
477 }
478 fused
479}
480
481fn fuse_add_relu(graph: &mut IrGraph) -> usize {
483 let mut fused = 0;
484 let output_nodes: HashSet<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
485
486 let add_sub_ids: HashSet<NodeId> = graph
487 .nodes
488 .iter()
489 .filter(|n| matches!(n.op, OpKind::Add | OpKind::Sub))
490 .map(|n| n.id)
491 .collect();
492
493 let mut consumers: HashMap<NodeId, usize> = HashMap::new();
494 for node in &graph.nodes {
495 for &inp in &node.inputs {
496 *consumers.entry(inp).or_insert(0) += 1;
497 }
498 }
499
500 for i in 0..graph.nodes.len() {
501 let node = &graph.nodes[i];
502 if !matches!(node.op, OpKind::Relu) || node.inputs.len() != 1 {
503 continue;
504 }
505
506 let inp = node.inputs[0];
507 if add_sub_ids.contains(&inp)
508 && consumers.get(&inp).copied().unwrap_or(0) == 1
509 && !output_nodes.contains(&inp)
510 {
511 let is_add = matches!(graph.nodes[inp.0].op, OpKind::Add);
512 let fused_name = if is_add {
513 "fused_add_relu"
514 } else {
515 "fused_sub_relu"
516 };
517 let prev_inputs = graph.nodes[inp.0].inputs.clone();
518
519 graph.nodes[i].op = OpKind::Custom {
520 name: fused_name.to_string(),
521 attrs: HashMap::new(),
522 };
523 graph.nodes[i].inputs = prev_inputs;
524 graph.nodes[i].name = format!("{}_fused", graph.nodes[i].name);
525 fused += 1;
526 }
527 }
528
529 if fused > 0 {
530 eliminate_dead_code(graph);
531 }
532 fused
533}
534
535fn fuse_matmul_relu(graph: &mut IrGraph) -> usize {
537 let mut fused = 0;
538 let output_nodes: HashSet<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
539
540 let matmul_ids: HashSet<NodeId> = graph
541 .nodes
542 .iter()
543 .filter(|n| matches!(n.op, OpKind::MatMul))
544 .map(|n| n.id)
545 .collect();
546
547 let mut consumers: HashMap<NodeId, usize> = HashMap::new();
548 for node in &graph.nodes {
549 for &inp in &node.inputs {
550 *consumers.entry(inp).or_insert(0) += 1;
551 }
552 }
553
554 for i in 0..graph.nodes.len() {
555 let node = &graph.nodes[i];
556 if !matches!(node.op, OpKind::Relu) || node.inputs.len() != 1 {
557 continue;
558 }
559
560 let inp = node.inputs[0];
561 if matmul_ids.contains(&inp)
562 && consumers.get(&inp).copied().unwrap_or(0) == 1
563 && !output_nodes.contains(&inp)
564 {
565 let prev_inputs = graph.nodes[inp.0].inputs.clone();
566 graph.nodes[i].op = OpKind::Custom {
567 name: "fused_matmul_relu".to_string(),
568 attrs: HashMap::new(),
569 };
570 graph.nodes[i].inputs = prev_inputs;
571 graph.nodes[i].name = format!("{}_fused", graph.nodes[i].name);
572 fused += 1;
573 }
574 }
575
576 if fused > 0 {
577 eliminate_dead_code(graph);
578 }
579 fused
580}
581
582#[derive(Debug, Clone, Default)]
586pub struct OptStats {
587 pub dead_code_removed: usize,
588 pub identities_removed: usize,
589 pub constants_folded: usize,
590 pub cse_eliminated: usize,
591 pub ops_fused: usize,
592}
593
594impl std::fmt::Display for OptStats {
595 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 write!(
597 f,
598 "OptStats {{ dce: {}, identity: {}, const_fold: {}, cse: {}, fusion: {} }}",
599 self.dead_code_removed,
600 self.identities_removed,
601 self.constants_folded,
602 self.cse_eliminated,
603 self.ops_fused,
604 )
605 }
606}
607
608pub fn optimize_graph_with_stats(graph: &mut IrGraph) -> OptStats {
610 let mut stats = OptStats::default();
611 loop {
612 let dce = eliminate_dead_code(graph);
613 let ident = eliminate_identities(graph);
614 let cf = fold_constants(graph);
615 let cse = eliminate_common_subexprs(graph);
616 let fus = fuse_operators(graph);
617
618 stats.dead_code_removed += dce;
619 stats.identities_removed += ident;
620 stats.constants_folded += cf;
621 stats.cse_eliminated += cse;
622 stats.ops_fused += fus;
623
624 if dce + ident + cf + cse + fus == 0 {
625 break;
626 }
627 }
628 stats
629}