shrew_ir/
optimize.rs

1// Graph Optimization — Transformations that simplify and optimize the IR graph
2//
3// These passes run after validation and shape inference, transforming the graph
4// to reduce computation and memory usage. Each pass is idempotent and composable.
5//
6// Implemented passes:
7//   1. Dead Code Elimination (DCE) — remove nodes not reachable from outputs
8//   2. Identity Elimination — remove pass-through identity nodes
9//   3. Constant Folding — evaluate constant sub-expressions at compile time
10//   4. Common Sub-expression Elimination (CSE) — share identical computations
11//   5. Operator Fusion — merge compatible adjacent operations
12//
13// The optimize() function runs all passes in a fixed-point loop until
14// no more transformations apply.
15
16use crate::graph::*;
17use std::collections::{HashMap, HashSet};
18
19// Public API
20
21/// Run all optimization passes on every graph in the program.
22/// Returns the total number of transformations applied.
23pub fn optimize(program: &mut IrProgram) -> usize {
24    let mut total = 0;
25    for graph in &mut program.graphs {
26        total += optimize_graph(graph);
27    }
28    total
29}
30
31/// Run all optimization passes on a single graph.
32/// Runs passes in a loop until convergence (fixed point).
33pub fn optimize_graph(graph: &mut IrGraph) -> usize {
34    let mut total = 0;
35    loop {
36        let mut changed = 0;
37        changed += eliminate_dead_code(graph);
38        changed += eliminate_identities(graph);
39        changed += fold_constants(graph);
40        changed += eliminate_common_subexprs(graph);
41        changed += fuse_operators(graph);
42        if changed == 0 {
43            break;
44        }
45        total += changed;
46    }
47    total
48}
49
50// Pass 1: Dead Code Elimination
51//
52// Walk backward from output nodes; any node not reachable is dead.
53// Dead nodes are replaced with no-op markers and then compacted.
54
55/// Remove nodes not reachable from any output. Returns count of removed nodes.
56pub fn eliminate_dead_code(graph: &mut IrGraph) -> usize {
57    if graph.nodes.is_empty() {
58        return 0;
59    }
60
61    // Find all reachable nodes by walking backwards from outputs
62    let mut reachable = HashSet::new();
63    let mut stack: Vec<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
64
65    // Also keep param nodes alive
66    for param in &graph.params {
67        stack.push(param.node_id);
68    }
69
70    while let Some(id) = stack.pop() {
71        if !reachable.insert(id) {
72            continue;
73        }
74        if id.0 < graph.nodes.len() {
75            for &inp in &graph.nodes[id.0].inputs {
76                stack.push(inp);
77            }
78        }
79    }
80
81    let total = graph.nodes.len();
82    let dead_count = total - reachable.len();
83    if dead_count == 0 {
84        return 0;
85    }
86
87    // Build compaction map: old_id → new_id
88    let mut keep: Vec<bool> = vec![false; total];
89    for &id in &reachable {
90        keep[id.0] = true;
91    }
92
93    let mut old_to_new: Vec<Option<NodeId>> = vec![None; total];
94    let mut new_id = 0usize;
95    for old_id in 0..total {
96        if keep[old_id] {
97            old_to_new[old_id] = Some(NodeId(new_id));
98            new_id += 1;
99        }
100    }
101
102    // Compact nodes
103    let mut new_nodes = Vec::with_capacity(reachable.len());
104    for (old_id, node) in graph.nodes.drain(..).enumerate() {
105        if let Some(nid) = old_to_new[old_id] {
106            let mut node = node;
107            node.id = nid;
108            node.inputs = node
109                .inputs
110                .iter()
111                .filter_map(|&inp| old_to_new[inp.0])
112                .collect();
113            new_nodes.push(node);
114        }
115    }
116    graph.nodes = new_nodes;
117
118    // Remap inputs, outputs, params, name_to_id
119    graph.inputs = graph
120        .inputs
121        .iter()
122        .filter_map(|&id| old_to_new[id.0])
123        .collect();
124    graph.outputs.retain(|o| old_to_new[o.node_id.0].is_some());
125    for out in &mut graph.outputs {
126        if let Some(new) = old_to_new[out.node_id.0] {
127            out.node_id = new;
128        }
129    }
130
131    for param in &mut graph.params {
132        if let Some(new) = old_to_new[param.node_id.0] {
133            param.node_id = new;
134        }
135    }
136    graph.params.retain(|p| old_to_new[p.node_id.0].is_some());
137
138    // Rebuild name_to_id
139    graph.name_to_id.clear();
140    for node in &graph.nodes {
141        graph.name_to_id.insert(node.name.clone(), node.id);
142    }
143
144    dead_count
145}
146
147// Pass 2: Identity Elimination
148//
149// Identity nodes that aren't graph inputs or params can be collapsed: redirect
150// all consumers to use the identity's input directly.
151
152/// Eliminate redundant identity nodes. Returns count of identities removed.
153pub fn eliminate_identities(graph: &mut IrGraph) -> usize {
154    // Find identity nodes that can be removed
155    // Keep: input identities, param identities, identities without inputs
156    let input_set: HashSet<NodeId> = graph.inputs.iter().copied().collect();
157    let param_set: HashSet<NodeId> = graph.params.iter().map(|p| p.node_id).collect();
158
159    // Build identity map: node_id → its single input
160    let mut identity_map: HashMap<NodeId, NodeId> = HashMap::new();
161    for node in &graph.nodes {
162        if matches!(node.op, OpKind::Identity)
163            && node.inputs.len() == 1
164            && !input_set.contains(&node.id)
165            && !param_set.contains(&node.id)
166        {
167            identity_map.insert(node.id, node.inputs[0]);
168        }
169    }
170
171    if identity_map.is_empty() {
172        return 0;
173    }
174
175    // Resolve transitive chains: a→b→c → a→c
176    let mut resolved: HashMap<NodeId, NodeId> = HashMap::new();
177    for &id in identity_map.keys() {
178        let mut target = id;
179        let mut visited = HashSet::new();
180        while let Some(&next) = identity_map.get(&target) {
181            if !visited.insert(target) {
182                break; // cycle guard
183            }
184            target = next;
185        }
186        resolved.insert(id, target);
187    }
188
189    let count = resolved.len();
190
191    // Rewrite all node inputs
192    for node in &mut graph.nodes {
193        for inp in &mut node.inputs {
194            if let Some(&target) = resolved.get(inp) {
195                *inp = target;
196            }
197        }
198    }
199
200    // Rewrite graph outputs
201    for out in &mut graph.outputs {
202        if let Some(&target) = resolved.get(&out.node_id) {
203            out.node_id = target;
204        }
205    }
206
207    // Now run DCE to actually remove the unused identity nodes
208    eliminate_dead_code(graph);
209
210    count
211}
212
213// Pass 3: Constant Folding
214//
215// If a binary op has two constant inputs, evaluate the result at compile time.
216
217/// Fold constant expressions. Returns count of folded nodes.
218pub fn fold_constants(graph: &mut IrGraph) -> usize {
219    let mut folded = 0;
220
221    for i in 0..graph.nodes.len() {
222        let node = &graph.nodes[i];
223
224        // Only fold binary ops with two constant inputs
225        if node.inputs.len() != 2 {
226            continue;
227        }
228
229        let left_const = get_constant(&graph.nodes[node.inputs[0].0]);
230        let right_const = get_constant(&graph.nodes[node.inputs[1].0]);
231
232        let result = match (&node.op, left_const, right_const) {
233            (OpKind::Add, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) => {
234                Some(ConstantValue::Int(a + b))
235            }
236            (OpKind::Add, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b))) => {
237                Some(ConstantValue::Float(a + b))
238            }
239            (OpKind::Sub, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) => {
240                Some(ConstantValue::Int(a - b))
241            }
242            (OpKind::Sub, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b))) => {
243                Some(ConstantValue::Float(a - b))
244            }
245            (OpKind::Mul, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) => {
246                Some(ConstantValue::Int(a * b))
247            }
248            (OpKind::Mul, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b))) => {
249                Some(ConstantValue::Float(a * b))
250            }
251            (OpKind::Div, Some(ConstantValue::Int(a)), Some(ConstantValue::Int(b))) if b != 0 => {
252                Some(ConstantValue::Int(a / b))
253            }
254            (OpKind::Div, Some(ConstantValue::Float(a)), Some(ConstantValue::Float(b)))
255                if b != 0.0 =>
256            {
257                Some(ConstantValue::Float(a / b))
258            }
259            _ => None,
260        };
261
262        if let Some(val) = result {
263            graph.nodes[i].op = OpKind::Constant(val);
264            graph.nodes[i].inputs.clear();
265            folded += 1;
266        }
267    }
268
269    if folded > 0 {
270        eliminate_dead_code(graph);
271    }
272
273    folded
274}
275
276fn get_constant(node: &IrNode) -> Option<ConstantValue> {
277    match &node.op {
278        OpKind::Constant(v) => Some(v.clone()),
279        _ => None,
280    }
281}
282
283// Pass 4: Common Sub-expression Elimination (CSE)
284//
285// Two nodes with the same op and same inputs produce the same result.
286// Keep the first, redirect consumers of the second.
287
288/// Eliminate common sub-expressions. Returns count of CSE-eliminated nodes.
289pub fn eliminate_common_subexprs(graph: &mut IrGraph) -> usize {
290    // Collect input and param node IDs — these are semantically unique even
291    // when they share the same OpSignature (e.g. Identity with no inputs).
292    let protected: HashSet<NodeId> = graph
293        .inputs
294        .iter()
295        .copied()
296        .chain(graph.params.iter().map(|p| p.node_id))
297        .collect();
298
299    let mut canonical: HashMap<OpSignature, NodeId> = HashMap::new();
300    let mut redirect: HashMap<NodeId, NodeId> = HashMap::new();
301
302    for node in &graph.nodes {
303        // Don't CSE nodes with side effects or that are inputs/params
304        if has_side_effects(&node.op) || protected.contains(&node.id) {
305            continue;
306        }
307
308        let sig = OpSignature {
309            op: op_discriminant(&node.op),
310            inputs: node.inputs.clone(),
311        };
312
313        if let Some(&existing_id) = canonical.get(&sig) {
314            redirect.insert(node.id, existing_id);
315        } else {
316            canonical.insert(sig, node.id);
317        }
318    }
319
320    if redirect.is_empty() {
321        return 0;
322    }
323
324    let count = redirect.len();
325
326    // Rewrite inputs
327    for node in &mut graph.nodes {
328        for inp in &mut node.inputs {
329            if let Some(&target) = redirect.get(inp) {
330                *inp = target;
331            }
332        }
333    }
334
335    // Rewrite outputs
336    for out in &mut graph.outputs {
337        if let Some(&target) = redirect.get(&out.node_id) {
338            out.node_id = target;
339        }
340    }
341
342    eliminate_dead_code(graph);
343
344    count
345}
346
347/// A signature for CSE comparison.
348#[derive(Debug, Clone, PartialEq, Eq, Hash)]
349struct OpSignature {
350    op: String,
351    inputs: Vec<NodeId>,
352}
353
354/// Get a deterministic string key for an op (including parameters).
355fn op_discriminant(op: &OpKind) -> String {
356    match op {
357        OpKind::Add => "add".into(),
358        OpKind::Sub => "sub".into(),
359        OpKind::Mul => "mul".into(),
360        OpKind::Div => "div".into(),
361        OpKind::Mod => "mod".into(),
362        OpKind::Pow => "pow".into(),
363        OpKind::MatMul => "matmul".into(),
364        OpKind::Neg => "neg".into(),
365        OpKind::Relu => "relu".into(),
366        OpKind::Gelu => "gelu".into(),
367        OpKind::Silu => "silu".into(),
368        OpKind::Sigmoid => "sigmoid".into(),
369        OpKind::Tanh => "tanh".into(),
370        OpKind::Exp => "exp".into(),
371        OpKind::Log => "log".into(),
372        OpKind::Sqrt => "sqrt".into(),
373        OpKind::Transpose => "transpose".into(),
374        OpKind::Not => "not".into(),
375        OpKind::Identity => "identity".into(),
376        OpKind::Softmax { dim } => format!("softmax_{dim}"),
377        OpKind::LayerNorm { eps } => format!("layernorm_{eps}"),
378        OpKind::BatchNorm { eps } => format!("batchnorm_{eps}"),
379        OpKind::Sum { dims, keepdim } => format!("sum_{dims:?}_{keepdim}"),
380        OpKind::Mean { dims, keepdim } => format!("mean_{dims:?}_{keepdim}"),
381        OpKind::Max { dim, keepdim } => format!("max_{dim}_{keepdim}"),
382        OpKind::Min { dim, keepdim } => format!("min_{dim}_{keepdim}"),
383        OpKind::Variance { dims, keepdim } => format!("var_{dims:?}_{keepdim}"),
384        OpKind::Dropout { p } => format!("dropout_{p}"),
385        OpKind::Constant(v) => format!("const_{v}"),
386        // Ops with side effects or custom behavior shouldn't CSE
387        _ => format!("nocse_{op:?}"),
388    }
389}
390
391/// Check if an op has side effects (shouldn't be CSE'd).
392fn has_side_effects(op: &OpKind) -> bool {
393    matches!(
394        op,
395        OpKind::Dropout { .. }  // Dropout is random
396        | OpKind::Custom { .. } // Custom ops may have side effects
397        | OpKind::Call { .. } // Graph calls may have side effects
398    )
399}
400
401// Pass 5: Operator Fusion
402//
403// Fuse sequences of compatible operations into single fused ops:
404//
405//   MatMul + Add(bias)  → Linear { bias: true }
406//   Add + Relu          → FusedAddRelu (custom op)
407//   Sub + Relu          → FusedSubRelu
408//   MatMul + Relu       → FusedMatMulRelu
409//   Variance + Sqrt⁻¹ + Mul + Add  → LayerNorm (recognized pattern)
410//
411// Fusion reduces kernel launches and memory traffic on GPU backends.
412
413/// Fuse compatible operator sequences. Returns count of fusions applied.
414pub fn fuse_operators(graph: &mut IrGraph) -> usize {
415    let mut fused = 0;
416    fused += fuse_matmul_add(graph);
417    fused += fuse_add_relu(graph);
418    fused += fuse_matmul_relu(graph);
419    fused
420}
421
422/// Fuse MatMul(x, w) → Add(bias) into fused_matmul_add custom op.
423///
424/// Note: we do NOT fuse into OpKind::Linear because Linear::forward()
425/// transposes the weight (expects [out, in] layout), whereas raw matmul
426/// in .sw uses [in, out]. Using a custom op preserves the original semantics
427/// — a.matmul(b) + c with no implicit transpose.
428fn fuse_matmul_add(graph: &mut IrGraph) -> usize {
429    let mut fused = 0;
430    let output_nodes: HashSet<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
431
432    // Find Add nodes whose first input comes from a MatMul
433    let matmul_ids: HashSet<NodeId> = graph
434        .nodes
435        .iter()
436        .filter(|n| matches!(n.op, OpKind::MatMul))
437        .map(|n| n.id)
438        .collect();
439
440    // Build consumer count: how many nodes consume each node
441    let mut consumers: HashMap<NodeId, usize> = HashMap::new();
442    for node in &graph.nodes {
443        for &inp in &node.inputs {
444            *consumers.entry(inp).or_insert(0) += 1;
445        }
446    }
447
448    for i in 0..graph.nodes.len() {
449        let node = &graph.nodes[i];
450        if !matches!(node.op, OpKind::Add) || node.inputs.len() != 2 {
451            continue;
452        }
453
454        let first_inp = node.inputs[0];
455        let second_inp = node.inputs[1];
456
457        // Pattern: Add(MatMul(x, w), bias) where MatMul has only 1 consumer
458        if matmul_ids.contains(&first_inp)
459            && consumers.get(&first_inp).copied().unwrap_or(0) == 1
460            && !output_nodes.contains(&first_inp)
461        {
462            // Fuse: replace Add node -> fused_matmul_add custom op
463            // Inputs become: [x, w, bias] from MatMul's [x, w] + Add's second input
464            let matmul_inputs = graph.nodes[first_inp.0].inputs.clone();
465            graph.nodes[i].op = OpKind::Custom {
466                name: "fused_matmul_add".to_string(),
467                attrs: HashMap::new(),
468            };
469            graph.nodes[i].inputs = vec![matmul_inputs[0], matmul_inputs[1], second_inp];
470            graph.nodes[i].name = format!("{}_fused_matmul_add", graph.nodes[i].name);
471            fused += 1;
472        }
473    }
474
475    if fused > 0 {
476        eliminate_dead_code(graph);
477    }
478    fused
479}
480
481/// Fuse Add/Sub + Relu into a single fused custom op.
482fn fuse_add_relu(graph: &mut IrGraph) -> usize {
483    let mut fused = 0;
484    let output_nodes: HashSet<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
485
486    let add_sub_ids: HashSet<NodeId> = graph
487        .nodes
488        .iter()
489        .filter(|n| matches!(n.op, OpKind::Add | OpKind::Sub))
490        .map(|n| n.id)
491        .collect();
492
493    let mut consumers: HashMap<NodeId, usize> = HashMap::new();
494    for node in &graph.nodes {
495        for &inp in &node.inputs {
496            *consumers.entry(inp).or_insert(0) += 1;
497        }
498    }
499
500    for i in 0..graph.nodes.len() {
501        let node = &graph.nodes[i];
502        if !matches!(node.op, OpKind::Relu) || node.inputs.len() != 1 {
503            continue;
504        }
505
506        let inp = node.inputs[0];
507        if add_sub_ids.contains(&inp)
508            && consumers.get(&inp).copied().unwrap_or(0) == 1
509            && !output_nodes.contains(&inp)
510        {
511            let is_add = matches!(graph.nodes[inp.0].op, OpKind::Add);
512            let fused_name = if is_add {
513                "fused_add_relu"
514            } else {
515                "fused_sub_relu"
516            };
517            let prev_inputs = graph.nodes[inp.0].inputs.clone();
518
519            graph.nodes[i].op = OpKind::Custom {
520                name: fused_name.to_string(),
521                attrs: HashMap::new(),
522            };
523            graph.nodes[i].inputs = prev_inputs;
524            graph.nodes[i].name = format!("{}_fused", graph.nodes[i].name);
525            fused += 1;
526        }
527    }
528
529    if fused > 0 {
530        eliminate_dead_code(graph);
531    }
532    fused
533}
534
535/// Fuse MatMul + Relu into fused_matmul_relu.
536fn fuse_matmul_relu(graph: &mut IrGraph) -> usize {
537    let mut fused = 0;
538    let output_nodes: HashSet<NodeId> = graph.outputs.iter().map(|o| o.node_id).collect();
539
540    let matmul_ids: HashSet<NodeId> = graph
541        .nodes
542        .iter()
543        .filter(|n| matches!(n.op, OpKind::MatMul))
544        .map(|n| n.id)
545        .collect();
546
547    let mut consumers: HashMap<NodeId, usize> = HashMap::new();
548    for node in &graph.nodes {
549        for &inp in &node.inputs {
550            *consumers.entry(inp).or_insert(0) += 1;
551        }
552    }
553
554    for i in 0..graph.nodes.len() {
555        let node = &graph.nodes[i];
556        if !matches!(node.op, OpKind::Relu) || node.inputs.len() != 1 {
557            continue;
558        }
559
560        let inp = node.inputs[0];
561        if matmul_ids.contains(&inp)
562            && consumers.get(&inp).copied().unwrap_or(0) == 1
563            && !output_nodes.contains(&inp)
564        {
565            let prev_inputs = graph.nodes[inp.0].inputs.clone();
566            graph.nodes[i].op = OpKind::Custom {
567                name: "fused_matmul_relu".to_string(),
568                attrs: HashMap::new(),
569            };
570            graph.nodes[i].inputs = prev_inputs;
571            graph.nodes[i].name = format!("{}_fused", graph.nodes[i].name);
572            fused += 1;
573        }
574    }
575
576    if fused > 0 {
577        eliminate_dead_code(graph);
578    }
579    fused
580}
581
582// Convenience: Run specific passes
583
584/// Statistics from optimization.
585#[derive(Debug, Clone, Default)]
586pub struct OptStats {
587    pub dead_code_removed: usize,
588    pub identities_removed: usize,
589    pub constants_folded: usize,
590    pub cse_eliminated: usize,
591    pub ops_fused: usize,
592}
593
594impl std::fmt::Display for OptStats {
595    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        write!(
597            f,
598            "OptStats {{ dce: {}, identity: {}, const_fold: {}, cse: {}, fusion: {} }}",
599            self.dead_code_removed,
600            self.identities_removed,
601            self.constants_folded,
602            self.cse_eliminated,
603            self.ops_fused,
604        )
605    }
606}
607
608/// Run all passes with detailed statistics.
609pub fn optimize_graph_with_stats(graph: &mut IrGraph) -> OptStats {
610    let mut stats = OptStats::default();
611    loop {
612        let dce = eliminate_dead_code(graph);
613        let ident = eliminate_identities(graph);
614        let cf = fold_constants(graph);
615        let cse = eliminate_common_subexprs(graph);
616        let fus = fuse_operators(graph);
617
618        stats.dead_code_removed += dce;
619        stats.identities_removed += ident;
620        stats.constants_folded += cf;
621        stats.cse_eliminated += cse;
622        stats.ops_fused += fus;
623
624        if dce + ident + cf + cse + fus == 0 {
625            break;
626        }
627    }
628    stats
629}