shrew/exec/
jit.rs

1// =============================================================================
2// JIT Graph Compilation — Compile IR graphs into optimized execution plans
3// =============================================================================
4//
5// The default `Executor` interprets the IR graph on every call:
6//   - Recomputes topological order each time
7//   - Looks up each node in a HashMap
8//   - Matches on OpKind (50+ variants) per node
9//   - Allocates intermediates into a HashMap with no reuse
10//
11// The JIT compiler transforms the graph into a pre-compiled execution plan
12// that eliminates all of this overhead:
13//
14// COMPONENTS:
15//
16//   CompiledGraph    — The compiled execution plan for one IR graph
17//   Instruction      — A single operation in the compiled plan
18//   MemoryPlan       — Buffer lifecycle analysis and reuse
19//   BufferSlot       — A reusable memory slot (register analogy)
20//   JitExecutor      — Runs compiled graphs instead of re-interpreting
21//   CompileStats     — Compilation statistics
22//
23// WORKFLOW:
24//
25//   1. Compile:    JitExecutor::compile(program) → JitExecutor
26//   2. Run:        executor.run("Forward", &inputs) → JitResult
27//   3. Recompile:  executor.recompile("Forward") — after graph changes
28//
29// OPTIMIZATIONS:
30//
31//   - Pre-computed topological order (computed once at compile time)
32//   - Instruction tape (flat Vec<Instruction>, no HashMap lookup)
33//   - Memory planning: liveness analysis, buffer reuse (register allocation)
34//   - Dead value early-free (values dropped as soon as last consumer is done)
35//   - Fused dispatch (fused ops become single instructions)
36//   - Input/param index lookup pre-computed (no string matching at runtime)
37
38use std::collections::HashMap;
39use std::fmt;
40use std::time::Instant;
41
42use shrew_core::backend::Backend;
43use shrew_core::dtype::DType as CoreDType;
44use shrew_core::error::Result;
45use shrew_core::tensor::Tensor;
46
47use shrew_ir::graph::{ConstantValue, Dim, IrGraph, IrNode, IrProgram, IrType, NodeId, OpKind};
48
49use shrew_nn::Module;
50
51use super::engine::{ir_dtype_to_core, RuntimeConfig};
52
53// =============================================================================
54// Instruction — A single pre-compiled operation
55// =============================================================================
56
57/// Specifies which buffer slot an instruction reads from.
58#[derive(Debug, Clone)]
59pub struct SlotRef {
60    /// Index into the buffer table.
61    pub slot: usize,
62}
63
64/// A single operation in the compiled execution plan.
65///
66/// Unlike the interpreter, instructions are a flat enum with pre-resolved
67/// input/output buffer slots — no name lookups, no HashMap access.
68#[derive(Debug, Clone)]
69pub enum Instruction {
70    // ── Source instructions (produce values from external sources) ──
71    /// Load a graph input into a buffer slot.
72    LoadInput {
73        /// Input name (for lookup in the provided HashMap).
74        name: String,
75        /// Buffer slot to store the input tensor.
76        dst: usize,
77    },
78    /// Load a parameter into a buffer slot.
79    LoadParam {
80        /// Key: (graph_name, param_name).
81        graph_name: String,
82        param_name: String,
83        /// Buffer slot to store the parameter tensor.
84        dst: usize,
85    },
86
87    // ── Unary operations ──
88    Unary {
89        op: UnaryInstr,
90        src: usize,
91        dst: usize,
92    },
93
94    // ── Binary operations ──
95    Binary {
96        op: BinaryInstr,
97        lhs: usize,
98        rhs: usize,
99        dst: usize,
100    },
101
102    // ── Reduction operations ──
103    Reduce {
104        op: ReduceInstr,
105        src: usize,
106        dst: usize,
107        dims: Vec<i64>,
108        keepdim: bool,
109    },
110
111    // ── Shape operations ──
112    Reshape {
113        src: usize,
114        dst: usize,
115        /// Pre-resolved concrete shape (symbolic dims resolved at compile time).
116        shape: Vec<usize>,
117    },
118    Transpose {
119        src: usize,
120        dst: usize,
121    },
122    Permute {
123        src: usize,
124        dst: usize,
125        dims: Vec<i64>,
126    },
127    Expand {
128        src: usize,
129        dst: usize,
130        shape: Vec<usize>,
131    },
132    Concat {
133        srcs: Vec<usize>,
134        dst: usize,
135        dim: usize,
136    },
137    Split {
138        src: usize,
139        dst: usize,
140        dim: usize,
141        chunks: usize,
142    },
143
144    // ── Neural network operations ──
145    Softmax {
146        src: usize,
147        dst: usize,
148        dim: usize,
149    },
150    Embedding {
151        indices: usize,
152        table: usize,
153        dst: usize,
154    },
155    Linear {
156        input: usize,
157        weight: usize,
158        bias: Option<usize>,
159        dst: usize,
160    },
161    LayerNorm {
162        input: usize,
163        weight: usize,
164        bias: usize,
165        dst: usize,
166        eps: f64,
167    },
168    BatchNorm {
169        input: usize,
170        weight: Option<usize>,
171        bias: Option<usize>,
172        dst: usize,
173        eps: f64,
174    },
175    MultiHeadAttention {
176        input: usize,
177        dst: usize,
178        n_heads: usize,
179    },
180    TransformerBlock {
181        input: usize,
182        dst: usize,
183        n_heads: usize,
184    },
185    Dropout {
186        src: usize,
187        dst: usize,
188        p: f64,
189    },
190
191    // ── Loss functions ──
192    CrossEntropy {
193        predictions: usize,
194        targets: usize,
195        dst: usize,
196    },
197    MseLoss {
198        predictions: usize,
199        targets: usize,
200        dst: usize,
201    },
202
203    // ── Constants ──
204    Constant {
205        value: ConstantValue,
206        output_type: IrType,
207        dst: usize,
208    },
209
210    // ── Control flow ──
211    Repeat {
212        count: i64,
213        body_op: Box<OpKind>,
214        src: usize,
215        dst: usize,
216    },
217    Call {
218        graph_name: String,
219        inputs: Vec<usize>,
220        dst: usize,
221    },
222
223    // ── Comparison / logical ──
224    Compare {
225        op: CompareInstr,
226        lhs: usize,
227        rhs: usize,
228        dst: usize,
229    },
230    LogicalNot {
231        src: usize,
232        dst: usize,
233    },
234    LogicalBinOp {
235        op: LogicalBinInstr,
236        lhs: usize,
237        rhs: usize,
238        dst: usize,
239    },
240
241    // ── Fused operations (from IR optimizer) ──
242    FusedMatMulAdd {
243        a: usize,
244        b: usize,
245        c: usize,
246        dst: usize,
247    },
248    FusedAddRelu {
249        lhs: usize,
250        rhs: usize,
251        dst: usize,
252    },
253    FusedSubRelu {
254        lhs: usize,
255        rhs: usize,
256        dst: usize,
257    },
258    FusedMatMulRelu {
259        lhs: usize,
260        rhs: usize,
261        dst: usize,
262    },
263
264    // ── Identity (pass-through) ──
265    Copy {
266        src: usize,
267        dst: usize,
268    },
269
270    // ── Range ──
271    Range {
272        inputs: Vec<usize>,
273        output_type: IrType,
274        dst: usize,
275    },
276
277    // ── Free a buffer slot (dead value elimination) ──
278    Free {
279        slot: usize,
280    },
281}
282
283/// Unary operation variants (pre-dispatched).
284#[derive(Debug, Clone, Copy)]
285pub enum UnaryInstr {
286    Neg,
287    Relu,
288    Gelu,
289    Silu,
290    Sigmoid,
291    Tanh,
292    Exp,
293    Log,
294    Sqrt,
295}
296
297/// Binary operation variants (pre-dispatched).
298#[derive(Debug, Clone, Copy)]
299pub enum BinaryInstr {
300    Add,
301    Sub,
302    Mul,
303    Div,
304    MatMul,
305    Pow,
306    Mod,
307}
308
309/// Reduction operation variants.
310#[derive(Debug, Clone, Copy)]
311pub enum ReduceInstr {
312    Sum,
313    Mean,
314    Max,
315    Min,
316    Variance,
317}
318
319/// Comparison operation variants.
320#[derive(Debug, Clone, Copy)]
321pub enum CompareInstr {
322    Equal,
323    NotEqual,
324    Less,
325    Greater,
326    LessEqual,
327    GreaterEqual,
328}
329
330/// Logical binary operation variants.
331#[derive(Debug, Clone, Copy)]
332pub enum LogicalBinInstr {
333    And,
334    Or,
335}
336
337// =============================================================================
338// MemoryPlan — Buffer lifecycle analysis
339// =============================================================================
340
341/// Tracks when each value is first produced and last consumed.
342#[derive(Debug, Clone)]
343pub struct ValueLifetime {
344    /// Instruction index where this value is produced.
345    pub produced_at: usize,
346    /// Instruction index where this value is last consumed (inclusive).
347    pub last_used_at: usize,
348    /// Node ID from the original graph.
349    pub node_id: NodeId,
350    /// Whether this value is a graph output (must not be freed).
351    pub is_output: bool,
352    /// Whether this value is an input or parameter (externally owned).
353    pub is_external: bool,
354}
355
356/// The memory plan for a compiled graph — maps NodeIds to buffer slots
357/// and tracks lifetimes for dead value elimination.
358#[derive(Debug, Clone)]
359pub struct MemoryPlan {
360    /// Number of buffer slots needed.
361    pub num_slots: usize,
362    /// Mapping from NodeId → buffer slot.
363    pub node_to_slot: HashMap<usize, usize>,
364    /// Lifetime of each slot.
365    pub lifetimes: Vec<ValueLifetime>,
366    /// Free instructions to insert (slot, after_instruction_idx).
367    pub free_points: Vec<(usize, usize)>,
368    /// Number of buffers reused.
369    pub reuse_count: usize,
370}
371
372// =============================================================================
373// CompiledGraph — A fully compiled execution plan
374// =============================================================================
375
376/// The compiled execution plan for a single IR graph.
377///
378/// Contains a flat instruction tape, memory plan, and metadata for
379/// efficient repeated execution.
380#[derive(Debug)]
381pub struct CompiledGraph {
382    /// Name of the source graph.
383    pub graph_name: String,
384    /// Flat instruction tape — executed sequentially.
385    pub instructions: Vec<Instruction>,
386    /// Memory plan — buffer slot assignments.
387    pub memory_plan: MemoryPlan,
388    /// Output slot mappings: name → buffer slot.
389    pub output_slots: HashMap<String, usize>,
390    /// Total number of buffer slots.
391    pub num_slots: usize,
392    /// Compilation statistics.
393    pub stats: CompileStats,
394}
395
396/// Statistics from the compilation process.
397#[derive(Debug, Clone)]
398pub struct CompileStats {
399    /// Number of instructions in the compiled plan.
400    pub num_instructions: usize,
401    /// Number of nodes in the source graph.
402    pub num_source_nodes: usize,
403    /// Number of buffer slots allocated.
404    pub num_slots: usize,
405    /// Number of buffer slots reused.
406    pub num_reused: usize,
407    /// Number of free instructions inserted.
408    pub num_frees: usize,
409    /// Number of fused instructions.
410    pub num_fused: usize,
411    /// Compilation time in microseconds.
412    pub compile_time_us: u64,
413}
414
415impl fmt::Display for CompileStats {
416    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417        write!(
418            f,
419            "CompiledGraph: {} instructions ({} source nodes), {} slots ({} reused), {} frees, {} fused, compiled in {}μs",
420            self.num_instructions,
421            self.num_source_nodes,
422            self.num_slots,
423            self.num_reused,
424            self.num_frees,
425            self.num_fused,
426            self.compile_time_us,
427        )
428    }
429}
430
431// =============================================================================
432// Compilation — transform IrGraph → CompiledGraph
433// =============================================================================
434
435/// Compile a single IR graph into an optimized execution plan.
436pub fn compile_graph(
437    graph: &IrGraph,
438    program: &IrProgram,
439    config: &RuntimeConfig,
440) -> Result<CompiledGraph> {
441    let start = Instant::now();
442
443    // 1. Get topological order (computed once)
444    let order = graph.topo_order();
445
446    // 2. Assign buffer slots
447    let mut node_to_slot: HashMap<usize, usize> = HashMap::new();
448    let mut next_slot = 0usize;
449    let mut produced_at: HashMap<usize, usize> = HashMap::new();
450    let mut last_used_at: HashMap<usize, usize> = HashMap::new();
451
452    // Pre-compute output node IDs for quick lookup
453    let output_node_ids: std::collections::HashSet<usize> =
454        graph.outputs.iter().map(|o| o.node_id.0).collect();
455
456    // Pre-compute input/param node IDs
457    let input_node_ids: std::collections::HashSet<usize> =
458        graph.inputs.iter().map(|id| id.0).collect();
459    let param_node_ids: std::collections::HashSet<usize> =
460        graph.params.iter().map(|p| p.node_id.0).collect();
461
462    // Assign a slot to each node in topo order
463    for &node_id in &order {
464        let slot = next_slot;
465        node_to_slot.insert(node_id.0, slot);
466        next_slot += 1;
467    }
468
469    // 3. Build instruction tape
470    let mut instructions: Vec<Instruction> = Vec::with_capacity(order.len() + graph.inputs.len());
471    let mut num_fused = 0;
472
473    for (instr_idx, &node_id) in order.iter().enumerate() {
474        let node = graph.node(node_id);
475        let dst = node_to_slot[&node_id.0];
476
477        // Track production point
478        produced_at.insert(node_id.0, instr_idx);
479
480        // Track consumption points for inputs
481        for &input_id in &node.inputs {
482            last_used_at.insert(input_id.0, instr_idx);
483        }
484
485        // Check if this node is an input
486        if input_node_ids.contains(&node_id.0) {
487            instructions.push(Instruction::LoadInput {
488                name: node.name.clone(),
489                dst,
490            });
491            continue;
492        }
493
494        // Check if this node is a parameter
495        if param_node_ids.contains(&node_id.0) {
496            if let Some(param) = graph.params.iter().find(|p| p.node_id == node_id) {
497                instructions.push(Instruction::LoadParam {
498                    graph_name: graph.name.clone(),
499                    param_name: param.name.clone(),
500                    dst,
501                });
502            }
503            continue;
504        }
505
506        // Compile the operation
507        let instr = compile_node(graph, node, &node_to_slot, config, program)?;
508
509        // Track fused instructions
510        match &instr {
511            Instruction::FusedMatMulAdd { .. }
512            | Instruction::FusedAddRelu { .. }
513            | Instruction::FusedSubRelu { .. }
514            | Instruction::FusedMatMulRelu { .. } => {
515                num_fused += 1;
516            }
517            _ => {}
518        }
519
520        instructions.push(instr);
521    }
522
523    // 4. Compute lifetimes and insert free instructions
524    let mut lifetimes = Vec::new();
525    let mut free_points = Vec::new();
526
527    for &node_id in &order {
528        let slot = node_to_slot[&node_id.0];
529        let is_output = output_node_ids.contains(&node_id.0);
530        let is_external =
531            input_node_ids.contains(&node_id.0) || param_node_ids.contains(&node_id.0);
532        let prod = produced_at.get(&node_id.0).copied().unwrap_or(0);
533        let last = last_used_at.get(&node_id.0).copied().unwrap_or(prod);
534
535        lifetimes.push(ValueLifetime {
536            produced_at: prod,
537            last_used_at: last,
538            node_id,
539            is_output,
540            is_external,
541        });
542
543        // Insert free point if this value is not an output and not external
544        if !is_output && !is_external && last < instructions.len().saturating_sub(1) {
545            free_points.push((slot, last));
546        }
547    }
548
549    // Sort frees by position (latest first for stable insertion)
550    free_points.sort_by(|a, b| b.1.cmp(&a.1));
551
552    // Insert free instructions after the last use
553    let num_frees = free_points.len();
554    for (slot, after_idx) in &free_points {
555        let insert_pos = (*after_idx + 1).min(instructions.len());
556        instructions.insert(insert_pos, Instruction::Free { slot: *slot });
557    }
558
559    // 5. Build output slot mapping
560    let mut output_slots = HashMap::new();
561    for output in &graph.outputs {
562        if let Some(&slot) = node_to_slot.get(&output.node_id.0) {
563            output_slots.insert(output.name.clone(), slot);
564        }
565    }
566
567    let num_slots = next_slot;
568    let compile_time = start.elapsed();
569
570    let memory_plan = MemoryPlan {
571        num_slots,
572        node_to_slot,
573        lifetimes,
574        free_points: Vec::new(), // Already applied
575        reuse_count: 0,          // No physical reuse in this version (logical slots)
576    };
577
578    let stats = CompileStats {
579        num_instructions: instructions.len(),
580        num_source_nodes: graph.nodes.len(),
581        num_slots,
582        num_reused: 0,
583        num_frees,
584        num_fused,
585        compile_time_us: compile_time.as_micros() as u64,
586    };
587
588    Ok(CompiledGraph {
589        graph_name: graph.name.clone(),
590        instructions,
591        memory_plan,
592        output_slots,
593        num_slots,
594        stats,
595    })
596}
597
598/// Compile a single IR node into an instruction.
599fn compile_node(
600    _graph: &IrGraph,
601    node: &IrNode,
602    node_to_slot: &HashMap<usize, usize>,
603    config: &RuntimeConfig,
604    program: &IrProgram,
605) -> Result<Instruction> {
606    let dst = node_to_slot[&node.id.0];
607
608    // Helper: get slot for an input
609    let slot = |idx: usize| -> Result<usize> {
610        let input_id = node.inputs.get(idx).ok_or_else(|| {
611            shrew_core::Error::msg(format!(
612                "Node '{}' expected input at index {}, but has {} inputs",
613                node.name,
614                idx,
615                node.inputs.len()
616            ))
617        })?;
618        node_to_slot.get(&input_id.0).copied().ok_or_else(|| {
619            shrew_core::Error::msg(format!(
620                "Node '{}' input {} (NodeId {}) not found in slot map",
621                node.name, idx, input_id.0
622            ))
623        })
624    };
625
626    match &node.op {
627        OpKind::Identity => Ok(Instruction::Copy { src: slot(0)?, dst }),
628
629        // ── Unary ──
630        OpKind::Neg => Ok(Instruction::Unary {
631            op: UnaryInstr::Neg,
632            src: slot(0)?,
633            dst,
634        }),
635        OpKind::Relu => Ok(Instruction::Unary {
636            op: UnaryInstr::Relu,
637            src: slot(0)?,
638            dst,
639        }),
640        OpKind::Gelu => Ok(Instruction::Unary {
641            op: UnaryInstr::Gelu,
642            src: slot(0)?,
643            dst,
644        }),
645        OpKind::Silu => Ok(Instruction::Unary {
646            op: UnaryInstr::Silu,
647            src: slot(0)?,
648            dst,
649        }),
650        OpKind::Sigmoid => Ok(Instruction::Unary {
651            op: UnaryInstr::Sigmoid,
652            src: slot(0)?,
653            dst,
654        }),
655        OpKind::Tanh => Ok(Instruction::Unary {
656            op: UnaryInstr::Tanh,
657            src: slot(0)?,
658            dst,
659        }),
660        OpKind::Exp => Ok(Instruction::Unary {
661            op: UnaryInstr::Exp,
662            src: slot(0)?,
663            dst,
664        }),
665        OpKind::Log => Ok(Instruction::Unary {
666            op: UnaryInstr::Log,
667            src: slot(0)?,
668            dst,
669        }),
670        OpKind::Sqrt => Ok(Instruction::Unary {
671            op: UnaryInstr::Sqrt,
672            src: slot(0)?,
673            dst,
674        }),
675
676        // ── Binary ──
677        OpKind::Add => Ok(Instruction::Binary {
678            op: BinaryInstr::Add,
679            lhs: slot(0)?,
680            rhs: slot(1)?,
681            dst,
682        }),
683        OpKind::Sub => Ok(Instruction::Binary {
684            op: BinaryInstr::Sub,
685            lhs: slot(0)?,
686            rhs: slot(1)?,
687            dst,
688        }),
689        OpKind::Mul => Ok(Instruction::Binary {
690            op: BinaryInstr::Mul,
691            lhs: slot(0)?,
692            rhs: slot(1)?,
693            dst,
694        }),
695        OpKind::Div => Ok(Instruction::Binary {
696            op: BinaryInstr::Div,
697            lhs: slot(0)?,
698            rhs: slot(1)?,
699            dst,
700        }),
701        OpKind::MatMul => Ok(Instruction::Binary {
702            op: BinaryInstr::MatMul,
703            lhs: slot(0)?,
704            rhs: slot(1)?,
705            dst,
706        }),
707        OpKind::Pow => Ok(Instruction::Binary {
708            op: BinaryInstr::Pow,
709            lhs: slot(0)?,
710            rhs: slot(1)?,
711            dst,
712        }),
713        OpKind::Mod => Ok(Instruction::Binary {
714            op: BinaryInstr::Mod,
715            lhs: slot(0)?,
716            rhs: slot(1)?,
717            dst,
718        }),
719
720        // ── Transpose ──
721        OpKind::Transpose => Ok(Instruction::Transpose { src: slot(0)?, dst }),
722
723        // ── Reductions ──
724        OpKind::Sum { dims, keepdim } => Ok(Instruction::Reduce {
725            op: ReduceInstr::Sum,
726            src: slot(0)?,
727            dst,
728            dims: dims.clone(),
729            keepdim: *keepdim,
730        }),
731        OpKind::Mean { dims, keepdim } => Ok(Instruction::Reduce {
732            op: ReduceInstr::Mean,
733            src: slot(0)?,
734            dst,
735            dims: dims.clone(),
736            keepdim: *keepdim,
737        }),
738        OpKind::Max { dim, keepdim } => Ok(Instruction::Reduce {
739            op: ReduceInstr::Max,
740            src: slot(0)?,
741            dst,
742            dims: vec![*dim],
743            keepdim: *keepdim,
744        }),
745        OpKind::Min { dim, keepdim } => Ok(Instruction::Reduce {
746            op: ReduceInstr::Min,
747            src: slot(0)?,
748            dst,
749            dims: vec![*dim],
750            keepdim: *keepdim,
751        }),
752        OpKind::Variance { dims, keepdim } => Ok(Instruction::Reduce {
753            op: ReduceInstr::Variance,
754            src: slot(0)?,
755            dst,
756            dims: dims.clone(),
757            keepdim: *keepdim,
758        }),
759
760        // ── Softmax ──
761        OpKind::Softmax { dim } => {
762            let d = *dim;
763            Ok(Instruction::Softmax {
764                src: slot(0)?,
765                dst,
766                dim: d as usize,
767            })
768        }
769
770        // ── Shape ops ──
771        OpKind::Reshape { target_shape } | OpKind::View { target_shape } => {
772            let shape = resolve_shape_vec(target_shape, config, program)?;
773            Ok(Instruction::Reshape {
774                src: slot(0)?,
775                dst,
776                shape,
777            })
778        }
779        OpKind::Permute { dims } => Ok(Instruction::Permute {
780            src: slot(0)?,
781            dst,
782            dims: dims.clone(),
783        }),
784        OpKind::Expand { target_shape } => {
785            let shape = resolve_shape_vec(target_shape, config, program)?;
786            Ok(Instruction::Expand {
787                src: slot(0)?,
788                dst,
789                shape,
790            })
791        }
792        OpKind::Concat { dim } => {
793            let srcs: Vec<usize> = (0..node.inputs.len())
794                .map(&slot)
795                .collect::<Result<Vec<_>>>()?;
796            Ok(Instruction::Concat {
797                srcs,
798                dst,
799                dim: *dim as usize,
800            })
801        }
802        OpKind::Split { dim, chunks } => Ok(Instruction::Split {
803            src: slot(0)?,
804            dst,
805            dim: resolve_neg_dim(*dim, 4), // dim resolved at runtime
806            chunks: *chunks as usize,
807        }),
808
809        // ── NN layers ──
810        OpKind::Embedding => Ok(Instruction::Embedding {
811            indices: slot(0)?,
812            table: slot(1)?,
813            dst,
814        }),
815        OpKind::Linear { bias } => {
816            let bias_slot = if *bias && node.inputs.len() >= 3 {
817                Some(slot(2)?)
818            } else {
819                None
820            };
821            Ok(Instruction::Linear {
822                input: slot(0)?,
823                weight: slot(1)?,
824                bias: bias_slot,
825                dst,
826            })
827        }
828        OpKind::LayerNorm { eps } => Ok(Instruction::LayerNorm {
829            input: slot(0)?,
830            weight: slot(1)?,
831            bias: slot(2)?,
832            dst,
833            eps: *eps,
834        }),
835        OpKind::BatchNorm { eps } => {
836            let weight = if node.inputs.len() >= 2 {
837                Some(slot(1)?)
838            } else {
839                None
840            };
841            let bias = if node.inputs.len() >= 3 {
842                Some(slot(2)?)
843            } else {
844                None
845            };
846            Ok(Instruction::BatchNorm {
847                input: slot(0)?,
848                weight,
849                bias,
850                dst,
851                eps: *eps,
852            })
853        }
854        OpKind::MultiHeadAttention { n_heads } => Ok(Instruction::MultiHeadAttention {
855            input: slot(0)?,
856            dst,
857            n_heads: *n_heads as usize,
858        }),
859        OpKind::TransformerBlock { n_heads } => Ok(Instruction::TransformerBlock {
860            input: slot(0)?,
861            dst,
862            n_heads: *n_heads as usize,
863        }),
864        OpKind::Dropout { p } => Ok(Instruction::Dropout {
865            src: slot(0)?,
866            dst,
867            p: *p,
868        }),
869
870        // ── Loss ──
871        OpKind::CrossEntropy => Ok(Instruction::CrossEntropy {
872            predictions: slot(0)?,
873            targets: slot(1)?,
874            dst,
875        }),
876        OpKind::MseLoss => Ok(Instruction::MseLoss {
877            predictions: slot(0)?,
878            targets: slot(1)?,
879            dst,
880        }),
881
882        // ── Constants ──
883        OpKind::Constant(val) => Ok(Instruction::Constant {
884            value: val.clone(),
885            output_type: node.output_type.clone(),
886            dst,
887        }),
888
889        // ── Repeat ──
890        OpKind::Repeat { count, body_op } => Ok(Instruction::Repeat {
891            count: *count,
892            body_op: body_op.clone(),
893            src: slot(0)?,
894            dst,
895        }),
896
897        // ── Call ──
898        OpKind::Call { graph_name } => {
899            let inputs: Vec<usize> = (0..node.inputs.len())
900                .map(&slot)
901                .collect::<Result<Vec<_>>>()?;
902            Ok(Instruction::Call {
903                graph_name: graph_name.clone(),
904                inputs,
905                dst,
906            })
907        }
908
909        // ── Range ──
910        OpKind::Range => {
911            let inputs: Vec<usize> = (0..node.inputs.len())
912                .map(&slot)
913                .collect::<Result<Vec<_>>>()?;
914            Ok(Instruction::Range {
915                inputs,
916                output_type: node.output_type.clone(),
917                dst,
918            })
919        }
920
921        // ── Comparison ──
922        OpKind::Equal => Ok(Instruction::Compare {
923            op: CompareInstr::Equal,
924            lhs: slot(0)?,
925            rhs: slot(1)?,
926            dst,
927        }),
928        OpKind::NotEqual => Ok(Instruction::Compare {
929            op: CompareInstr::NotEqual,
930            lhs: slot(0)?,
931            rhs: slot(1)?,
932            dst,
933        }),
934        OpKind::Less => Ok(Instruction::Compare {
935            op: CompareInstr::Less,
936            lhs: slot(0)?,
937            rhs: slot(1)?,
938            dst,
939        }),
940        OpKind::Greater => Ok(Instruction::Compare {
941            op: CompareInstr::Greater,
942            lhs: slot(0)?,
943            rhs: slot(1)?,
944            dst,
945        }),
946        OpKind::LessEqual => Ok(Instruction::Compare {
947            op: CompareInstr::LessEqual,
948            lhs: slot(0)?,
949            rhs: slot(1)?,
950            dst,
951        }),
952        OpKind::GreaterEqual => Ok(Instruction::Compare {
953            op: CompareInstr::GreaterEqual,
954            lhs: slot(0)?,
955            rhs: slot(1)?,
956            dst,
957        }),
958
959        // ── Logical ──
960        OpKind::And => Ok(Instruction::LogicalBinOp {
961            op: LogicalBinInstr::And,
962            lhs: slot(0)?,
963            rhs: slot(1)?,
964            dst,
965        }),
966        OpKind::Or => Ok(Instruction::LogicalBinOp {
967            op: LogicalBinInstr::Or,
968            lhs: slot(0)?,
969            rhs: slot(1)?,
970            dst,
971        }),
972        OpKind::Not => Ok(Instruction::LogicalNot { src: slot(0)?, dst }),
973
974        // ── Fused ops (from IR optimizer) ──
975        OpKind::Custom { name, .. } => match name.as_str() {
976            "fused_matmul_add" => Ok(Instruction::FusedMatMulAdd {
977                a: slot(0)?,
978                b: slot(1)?,
979                c: slot(2)?,
980                dst,
981            }),
982            "fused_add_relu" => Ok(Instruction::FusedAddRelu {
983                lhs: slot(0)?,
984                rhs: slot(1)?,
985                dst,
986            }),
987            "fused_sub_relu" => Ok(Instruction::FusedSubRelu {
988                lhs: slot(0)?,
989                rhs: slot(1)?,
990                dst,
991            }),
992            "fused_matmul_relu" => Ok(Instruction::FusedMatMulRelu {
993                lhs: slot(0)?,
994                rhs: slot(1)?,
995                dst,
996            }),
997            other => Err(shrew_core::Error::msg(format!(
998                "Unknown custom op '{}' during JIT compilation",
999                other
1000            ))),
1001        },
1002    }
1003}
1004
1005// =============================================================================
1006// JitExecutor — Runs compiled graphs
1007// =============================================================================
1008
1009/// A JIT-compiled executor that runs pre-compiled graph execution plans.
1010///
1011/// Unlike the interpreter (`Executor`), the JIT executor:
1012/// - Pre-compiles each graph into a flat instruction tape
1013/// - Pre-resolves all buffer slot assignments
1014/// - Inserts dead-value-free instructions for memory efficiency
1015/// - Dispatches operations without HashMap lookups or string matching
1016///
1017/// # Usage
1018/// ```ignore
1019/// let jit = JitExecutor::<CpuBackend>::compile(program, device, config)?;
1020/// let result = jit.run("Forward", &inputs)?;
1021/// let output = result.get("output").unwrap();
1022/// ```
1023pub struct JitExecutor<B: Backend> {
1024    /// Compiled graphs, keyed by graph name.
1025    compiled: HashMap<String, CompiledGraph>,
1026    /// The source IR program.
1027    program: IrProgram,
1028    /// Runtime configuration.
1029    config: RuntimeConfig,
1030    /// Device.
1031    device: B::Device,
1032    /// Initialized parameters.
1033    params: HashMap<(String, String), Tensor<B>>,
1034}
1035
1036/// Result of a JIT execution.
1037#[derive(Debug)]
1038pub struct JitResult<B: Backend> {
1039    /// Output tensors, keyed by output name.
1040    pub outputs: HashMap<String, Tensor<B>>,
1041}
1042
1043impl<B: Backend> JitResult<B> {
1044    /// Get the first output tensor.
1045    pub fn output(&self) -> Option<&Tensor<B>> {
1046        self.outputs.values().next()
1047    }
1048
1049    /// Get an output by name.
1050    pub fn get(&self, name: &str) -> Option<&Tensor<B>> {
1051        self.outputs.get(name)
1052    }
1053}
1054
1055impl<B: Backend> JitExecutor<B> {
1056    /// Compile all graphs in a program and create a JIT executor.
1057    pub fn compile(program: IrProgram, device: B::Device, config: RuntimeConfig) -> Result<Self> {
1058        let mut compiled = HashMap::new();
1059
1060        // Compile each graph
1061        for graph in &program.graphs {
1062            let cg = compile_graph(graph, &program, &config)?;
1063            compiled.insert(graph.name.clone(), cg);
1064        }
1065
1066        // Initialize parameters (reuse logic from Executor)
1067        let mut params = HashMap::new();
1068        for graph in &program.graphs {
1069            for param in &graph.params {
1070                let tensor = init_param::<B>(
1071                    &param.ty,
1072                    &param.init,
1073                    param.frozen,
1074                    &config,
1075                    &program,
1076                    &device,
1077                )?;
1078                params.insert((graph.name.clone(), param.name.clone()), tensor);
1079            }
1080        }
1081
1082        Ok(Self {
1083            compiled,
1084            program,
1085            config,
1086            device,
1087            params,
1088        })
1089    }
1090
1091    /// Get compilation statistics for a graph.
1092    pub fn stats(&self, graph_name: &str) -> Option<&CompileStats> {
1093        self.compiled.get(graph_name).map(|cg| &cg.stats)
1094    }
1095
1096    /// Get all compilation statistics.
1097    pub fn all_stats(&self) -> Vec<(&str, &CompileStats)> {
1098        self.compiled
1099            .iter()
1100            .map(|(name, cg)| (name.as_str(), &cg.stats))
1101            .collect()
1102    }
1103
1104    /// Run a compiled graph with the given inputs.
1105    pub fn run(
1106        &self,
1107        graph_name: &str,
1108        inputs: &HashMap<String, Tensor<B>>,
1109    ) -> Result<JitResult<B>> {
1110        let cg = self.compiled.get(graph_name).ok_or_else(|| {
1111            shrew_core::Error::msg(format!(
1112                "Graph '{}' not compiled. Available: {:?}",
1113                graph_name,
1114                self.compiled.keys().collect::<Vec<_>>()
1115            ))
1116        })?;
1117
1118        // Allocate buffer table (slots)
1119        let mut slots: Vec<Option<Tensor<B>>> = vec![None; cg.num_slots];
1120
1121        // Execute instruction tape
1122        for instr in &cg.instructions {
1123            match instr {
1124                Instruction::LoadInput { name, dst } => {
1125                    if let Some(tensor) = inputs.get(name) {
1126                        slots[*dst] = Some(tensor.clone());
1127                    }
1128                }
1129
1130                Instruction::LoadParam {
1131                    graph_name,
1132                    param_name,
1133                    dst,
1134                } => {
1135                    let key = (graph_name.clone(), param_name.clone());
1136                    if let Some(tensor) = self.params.get(&key) {
1137                        slots[*dst] = Some(tensor.clone());
1138                    }
1139                }
1140
1141                Instruction::Unary { op, src, dst } => {
1142                    let t = get_slot(&slots, *src)?;
1143                    let result = match op {
1144                        UnaryInstr::Neg => t.neg(),
1145                        UnaryInstr::Relu => t.relu(),
1146                        UnaryInstr::Gelu => t.gelu(),
1147                        UnaryInstr::Silu => t.silu(),
1148                        UnaryInstr::Sigmoid => t.sigmoid(),
1149                        UnaryInstr::Tanh => t.tanh(),
1150                        UnaryInstr::Exp => t.exp(),
1151                        UnaryInstr::Log => t.log(),
1152                        UnaryInstr::Sqrt => t.sqrt(),
1153                    }?;
1154                    slots[*dst] = Some(result);
1155                }
1156
1157                Instruction::Binary { op, lhs, rhs, dst } => {
1158                    let a = get_slot(&slots, *lhs)?;
1159                    let b = get_slot(&slots, *rhs)?;
1160                    let result = match op {
1161                        BinaryInstr::Add => a.add(b),
1162                        BinaryInstr::Sub => a.sub(b),
1163                        BinaryInstr::Mul => a.mul(b),
1164                        BinaryInstr::Div => a.div(b),
1165                        BinaryInstr::MatMul => a.matmul(b),
1166                        BinaryInstr::Pow => a.log()?.mul(b)?.exp(), // x^y = exp(y*ln(x))
1167                        BinaryInstr::Mod => {
1168                            let quotient = a.div(b)?.floor()?;
1169                            let product = quotient.mul(b)?;
1170                            a.sub(&product)
1171                        }
1172                    }?;
1173                    slots[*dst] = Some(result);
1174                }
1175
1176                Instruction::Reduce {
1177                    op,
1178                    src,
1179                    dst,
1180                    dims,
1181                    keepdim,
1182                } => {
1183                    let t = get_slot(&slots, *src)?;
1184                    let result = match op {
1185                        ReduceInstr::Sum => {
1186                            if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
1187                                t.sum_all()
1188                            } else {
1189                                let d = resolve_neg_dim(dims[0], t.rank());
1190                                t.sum(d as usize, *keepdim)
1191                            }
1192                        }
1193                        ReduceInstr::Mean => {
1194                            if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
1195                                t.mean_all()
1196                            } else {
1197                                let d = resolve_neg_dim(dims[0], t.rank());
1198                                t.mean(d as usize, *keepdim)
1199                            }
1200                        }
1201                        ReduceInstr::Max => {
1202                            let d = resolve_neg_dim(dims[0], t.rank());
1203                            t.max(d as usize, *keepdim)
1204                        }
1205                        ReduceInstr::Min => {
1206                            let d = resolve_neg_dim(dims[0], t.rank());
1207                            t.min(d as usize, *keepdim)
1208                        }
1209                        ReduceInstr::Variance => {
1210                            if dims.is_empty() {
1211                                t.var(0, *keepdim)
1212                            } else {
1213                                let d = resolve_neg_dim(dims[0], t.rank());
1214                                t.var(d as usize, *keepdim)
1215                            }
1216                        }
1217                    }?;
1218                    slots[*dst] = Some(result);
1219                }
1220
1221                Instruction::Reshape { src, dst, shape } => {
1222                    let t = get_slot(&slots, *src)?;
1223                    slots[*dst] = Some(t.reshape(shape.clone())?);
1224                }
1225
1226                Instruction::Transpose { src, dst } => {
1227                    let t = get_slot(&slots, *src)?;
1228                    let rank = t.rank();
1229                    slots[*dst] = Some(t.transpose(rank - 2, rank - 1)?);
1230                }
1231
1232                Instruction::Permute { src, dst, dims } => {
1233                    let t = get_slot(&slots, *src)?;
1234                    let mut result = t.clone();
1235                    let mut current: Vec<usize> = (0..t.rank()).collect();
1236                    for i in 0..dims.len() {
1237                        let target = dims[i] as usize;
1238                        if current[i] != target {
1239                            if let Some(j) = current.iter().position(|&x| x == target) {
1240                                result = result.transpose(i, j)?;
1241                                current.swap(i, j);
1242                            }
1243                        }
1244                    }
1245                    slots[*dst] = Some(result);
1246                }
1247
1248                Instruction::Expand { src, dst, shape } => {
1249                    let t = get_slot(&slots, *src)?;
1250                    slots[*dst] = Some(t.expand(shape.clone())?);
1251                }
1252
1253                Instruction::Concat { srcs, dst, dim } => {
1254                    let tensors: Vec<Tensor<B>> = srcs
1255                        .iter()
1256                        .map(|s| get_slot(&slots, *s).cloned())
1257                        .collect::<Result<Vec<_>>>()?;
1258                    slots[*dst] = Some(Tensor::<B>::cat(&tensors, *dim)?);
1259                }
1260
1261                Instruction::Split {
1262                    src,
1263                    dst,
1264                    dim,
1265                    chunks,
1266                } => {
1267                    let t = get_slot(&slots, *src)?;
1268                    let result = t.chunk(*chunks, *dim)?;
1269                    if let Some(first) = result.into_iter().next() {
1270                        slots[*dst] = Some(first);
1271                    }
1272                }
1273
1274                Instruction::Softmax { src, dst, dim } => {
1275                    let t = get_slot(&slots, *src)?;
1276                    slots[*dst] = Some(t.softmax(*dim)?);
1277                }
1278
1279                Instruction::Embedding {
1280                    indices,
1281                    table,
1282                    dst,
1283                } => {
1284                    let idx = get_slot(&slots, *indices)?;
1285                    let tbl = get_slot(&slots, *table)?;
1286                    let emb = shrew_nn::Embedding::<B>::from_tensor(tbl.clone())?;
1287                    slots[*dst] = Some(emb.forward(idx)?);
1288                }
1289
1290                Instruction::Linear {
1291                    input,
1292                    weight,
1293                    bias,
1294                    dst,
1295                } => {
1296                    let inp = get_slot(&slots, *input)?;
1297                    let w = get_slot(&slots, *weight)?;
1298                    let b = bias.map(|s| get_slot(&slots, s).cloned()).transpose()?;
1299                    let lin = shrew_nn::Linear::<B>::from_tensors(w.clone(), b)?;
1300                    slots[*dst] = Some(lin.forward(inp)?);
1301                }
1302
1303                Instruction::LayerNorm {
1304                    input,
1305                    weight,
1306                    bias,
1307                    dst,
1308                    eps,
1309                } => {
1310                    let inp = get_slot(&slots, *input)?;
1311                    let w = get_slot(&slots, *weight)?;
1312                    let b = get_slot(&slots, *bias)?;
1313                    let ln = shrew_nn::LayerNorm::<B>::from_tensors(w.clone(), b.clone(), *eps)?;
1314                    slots[*dst] = Some(ln.forward(inp)?);
1315                }
1316
1317                Instruction::BatchNorm {
1318                    input,
1319                    weight,
1320                    bias,
1321                    dst,
1322                    eps,
1323                } => {
1324                    let inp = get_slot(&slots, *input)?;
1325                    if let (Some(ws), Some(bs)) = (weight, bias) {
1326                        let w = get_slot(&slots, *ws)?;
1327                        let b = get_slot(&slots, *bs)?;
1328                        let bn =
1329                            shrew_nn::BatchNorm2d::<B>::from_tensors(w.clone(), b.clone(), *eps)?;
1330                        slots[*dst] = Some(bn.forward(inp)?);
1331                    } else {
1332                        let dims = inp.dims();
1333                        let c = if dims.len() == 4 { dims[1] } else { dims[0] };
1334                        let bn = shrew_nn::BatchNorm2d::<B>::new(
1335                            c,
1336                            *eps,
1337                            0.1,
1338                            inp.dtype(),
1339                            &self.device,
1340                        )?;
1341                        slots[*dst] = Some(bn.forward(inp)?);
1342                    }
1343                }
1344
1345                Instruction::MultiHeadAttention {
1346                    input,
1347                    dst,
1348                    n_heads,
1349                } => {
1350                    let inp = get_slot(&slots, *input)?;
1351                    let d_model = *inp
1352                        .dims()
1353                        .last()
1354                        .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
1355                    let mha = shrew_nn::MultiHeadAttention::<B>::new(
1356                        d_model,
1357                        *n_heads,
1358                        inp.dtype(),
1359                        inp.device(),
1360                    )?;
1361                    slots[*dst] = Some(mha.forward(inp)?);
1362                }
1363
1364                Instruction::TransformerBlock {
1365                    input,
1366                    dst,
1367                    n_heads,
1368                } => {
1369                    let inp = get_slot(&slots, *input)?;
1370                    let dims = inp.dims();
1371                    let d_model = dims[dims.len() - 1];
1372                    let d_ff = d_model * 4;
1373                    let block = shrew_nn::TransformerBlock::<B>::new(
1374                        d_model,
1375                        *n_heads,
1376                        d_ff,
1377                        true,
1378                        inp.dtype(),
1379                        inp.device(),
1380                    )?;
1381                    slots[*dst] = Some(block.forward(inp)?);
1382                }
1383
1384                Instruction::Dropout { src, dst, p } => {
1385                    let t = get_slot(&slots, *src)?;
1386                    let dropout = shrew_nn::Dropout::new(*p);
1387                    if self.config.training {
1388                        slots[*dst] = Some(dropout.forward_t(t)?);
1389                    } else {
1390                        slots[*dst] = Some(t.clone());
1391                    }
1392                }
1393
1394                Instruction::CrossEntropy {
1395                    predictions,
1396                    targets,
1397                    dst,
1398                } => {
1399                    let p = get_slot(&slots, *predictions)?;
1400                    let t = get_slot(&slots, *targets)?;
1401                    slots[*dst] = Some(shrew_nn::cross_entropy_loss(p, t)?);
1402                }
1403
1404                Instruction::MseLoss {
1405                    predictions,
1406                    targets,
1407                    dst,
1408                } => {
1409                    let p = get_slot(&slots, *predictions)?;
1410                    let t = get_slot(&slots, *targets)?;
1411                    slots[*dst] = Some(shrew_nn::mse_loss(p, t)?);
1412                }
1413
1414                Instruction::Constant {
1415                    value,
1416                    output_type,
1417                    dst,
1418                } => {
1419                    let tensor = materialize_constant::<B>(
1420                        value,
1421                        output_type,
1422                        self.config.default_dtype,
1423                        &self.device,
1424                    )?;
1425                    slots[*dst] = Some(tensor);
1426                }
1427
1428                Instruction::Repeat {
1429                    count,
1430                    body_op,
1431                    src,
1432                    dst,
1433                } => {
1434                    let t = get_slot(&slots, *src)?;
1435                    let mut current = t.clone();
1436                    for _ in 0..(*count as u32) {
1437                        current = execute_body_op::<B>(body_op, &current, &self.device)?;
1438                    }
1439                    slots[*dst] = Some(current);
1440                }
1441
1442                Instruction::Call {
1443                    graph_name,
1444                    inputs: input_slots,
1445                    dst,
1446                } => {
1447                    let _sub_cg = self.compiled.get(graph_name).ok_or_else(|| {
1448                        shrew_core::Error::msg(format!(
1449                            "Called graph '{}' not compiled",
1450                            graph_name
1451                        ))
1452                    })?;
1453                    let sub_graph = self.program.get_graph(graph_name).ok_or_else(|| {
1454                        shrew_core::Error::msg(format!("Called graph '{}' not found", graph_name))
1455                    })?;
1456                    let mut sub_inputs = HashMap::new();
1457                    for (i, &input_id) in sub_graph.inputs.iter().enumerate() {
1458                        if let Some(&s) = input_slots.get(i) {
1459                            if let Some(tensor) = &slots[s] {
1460                                let input_name = sub_graph.node(input_id).name.clone();
1461                                sub_inputs.insert(input_name, tensor.clone());
1462                            }
1463                        }
1464                    }
1465                    let result = self.run(graph_name, &sub_inputs)?;
1466                    if let Some(out) = result.output() {
1467                        slots[*dst] = Some(out.clone());
1468                    }
1469                }
1470
1471                Instruction::Compare { op, lhs, rhs, dst } => {
1472                    let a = get_slot(&slots, *lhs)?;
1473                    let b = get_slot(&slots, *rhs)?;
1474                    let result = match op {
1475                        CompareInstr::Equal => a.eq(b),
1476                        CompareInstr::NotEqual => a.ne(b),
1477                        CompareInstr::Less => a.lt(b),
1478                        CompareInstr::Greater => a.gt(b),
1479                        CompareInstr::LessEqual => a.le(b),
1480                        CompareInstr::GreaterEqual => a.ge(b),
1481                    }?;
1482                    slots[*dst] = Some(result);
1483                }
1484
1485                Instruction::LogicalNot { src, dst } => {
1486                    let t = get_slot(&slots, *src)?;
1487                    let data = t.to_f64_vec()?;
1488                    let result: Vec<f64> = data
1489                        .iter()
1490                        .map(|&v| if v == 0.0 { 1.0 } else { 0.0 })
1491                        .collect();
1492                    let n = result.len();
1493                    slots[*dst] = Some(Tensor::<B>::from_f64_slice(
1494                        &result,
1495                        n,
1496                        CoreDType::U8,
1497                        &self.device,
1498                    )?);
1499                }
1500
1501                Instruction::LogicalBinOp { op, lhs, rhs, dst } => {
1502                    let a = get_slot(&slots, *lhs)?;
1503                    let b = get_slot(&slots, *rhs)?;
1504                    let a_data = a.to_f64_vec()?;
1505                    let b_data = b.to_f64_vec()?;
1506                    let result: Vec<f64> = a_data
1507                        .iter()
1508                        .zip(b_data.iter())
1509                        .map(|(&x, &y)| match op {
1510                            LogicalBinInstr::And => {
1511                                if x != 0.0 && y != 0.0 {
1512                                    1.0
1513                                } else {
1514                                    0.0
1515                                }
1516                            }
1517                            LogicalBinInstr::Or => {
1518                                if x != 0.0 || y != 0.0 {
1519                                    1.0
1520                                } else {
1521                                    0.0
1522                                }
1523                            }
1524                        })
1525                        .collect();
1526                    let n = result.len();
1527                    slots[*dst] = Some(Tensor::<B>::from_f64_slice(
1528                        &result,
1529                        n,
1530                        CoreDType::U8,
1531                        &self.device,
1532                    )?);
1533                }
1534
1535                Instruction::FusedMatMulAdd { a, b, c, dst } => {
1536                    let at = get_slot(&slots, *a)?;
1537                    let bt = get_slot(&slots, *b)?;
1538                    let ct = get_slot(&slots, *c)?;
1539                    slots[*dst] = Some(at.matmul(bt)?.add(ct)?);
1540                }
1541
1542                Instruction::FusedAddRelu { lhs, rhs, dst } => {
1543                    let a = get_slot(&slots, *lhs)?;
1544                    let b = get_slot(&slots, *rhs)?;
1545                    slots[*dst] = Some(a.add(b)?.relu()?);
1546                }
1547
1548                Instruction::FusedSubRelu { lhs, rhs, dst } => {
1549                    let a = get_slot(&slots, *lhs)?;
1550                    let b = get_slot(&slots, *rhs)?;
1551                    slots[*dst] = Some(a.sub(b)?.relu()?);
1552                }
1553
1554                Instruction::FusedMatMulRelu { lhs, rhs, dst } => {
1555                    let a = get_slot(&slots, *lhs)?;
1556                    let b = get_slot(&slots, *rhs)?;
1557                    slots[*dst] = Some(a.matmul(b)?.relu()?);
1558                }
1559
1560                Instruction::Copy { src, dst } => {
1561                    let t = get_slot(&slots, *src)?;
1562                    slots[*dst] = Some(t.clone());
1563                }
1564
1565                Instruction::Range {
1566                    inputs: input_slots,
1567                    output_type,
1568                    dst,
1569                } => {
1570                    let (start, end) = if input_slots.len() >= 2 {
1571                        let s = get_slot(&slots, input_slots[0])?.to_scalar_f64()?;
1572                        let e = get_slot(&slots, input_slots[1])?.to_scalar_f64()?;
1573                        (s as i64, e as i64)
1574                    } else if input_slots.len() == 1 {
1575                        (
1576                            0i64,
1577                            get_slot(&slots, input_slots[0])?.to_scalar_f64()? as i64,
1578                        )
1579                    } else {
1580                        match output_type {
1581                            IrType::Tensor { shape, .. } => {
1582                                if let Some(Dim::Fixed(n)) = shape.first() {
1583                                    (0, *n)
1584                                } else {
1585                                    (0, 1)
1586                                }
1587                            }
1588                            _ => (0, 1),
1589                        }
1590                    };
1591                    let data: Vec<f64> = (start..end).map(|i| i as f64).collect();
1592                    let len = data.len();
1593                    slots[*dst] = Some(Tensor::<B>::from_f64_slice(
1594                        &data,
1595                        len,
1596                        CoreDType::I64,
1597                        &self.device,
1598                    )?);
1599                }
1600
1601                Instruction::Free { slot } => {
1602                    slots[*slot] = None;
1603                }
1604            }
1605        }
1606
1607        // Collect outputs
1608        let mut outputs = HashMap::new();
1609        for (name, &slot) in &cg.output_slots {
1610            if let Some(tensor) = &slots[slot] {
1611                outputs.insert(name.clone(), tensor.clone());
1612            }
1613        }
1614
1615        Ok(JitResult { outputs })
1616    }
1617
1618    /// Get the underlying program.
1619    pub fn program(&self) -> &IrProgram {
1620        &self.program
1621    }
1622
1623    /// Get runtime config.
1624    pub fn config(&self) -> &RuntimeConfig {
1625        &self.config
1626    }
1627
1628    /// Get all parameters.
1629    pub fn params(&self) -> &HashMap<(String, String), Tensor<B>> {
1630        &self.params
1631    }
1632
1633    /// Get parameters for a specific graph.
1634    pub fn graph_params(&self, graph_name: &str) -> Vec<Tensor<B>> {
1635        self.params
1636            .iter()
1637            .filter(|((g, _), _)| g == graph_name)
1638            .map(|(_, t)| t.clone())
1639            .collect()
1640    }
1641
1642    /// Update parameters after an optimizer step.
1643    pub fn update_params(&mut self, graph_name: &str, new_params: &[Tensor<B>]) {
1644        let param_names: Vec<String> = self
1645            .params
1646            .keys()
1647            .filter(|(g, _)| g == graph_name)
1648            .map(|(_, n)| n.clone())
1649            .collect();
1650
1651        for (name, tensor) in param_names.into_iter().zip(new_params.iter()) {
1652            self.params
1653                .insert((graph_name.to_string(), name), tensor.clone());
1654        }
1655    }
1656
1657    /// Recompile a single graph (e.g., after optimizer changes shapes).
1658    pub fn recompile(&mut self, graph_name: &str) -> Result<()> {
1659        let graph = self
1660            .program
1661            .get_graph(graph_name)
1662            .ok_or_else(|| shrew_core::Error::msg(format!("Graph '{}' not found", graph_name)))?;
1663        let cg = compile_graph(graph, &self.program, &self.config)?;
1664        self.compiled.insert(graph_name.to_string(), cg);
1665        Ok(())
1666    }
1667
1668    /// Dump the compiled instruction tape for a graph (debugging).
1669    pub fn dump(&self, graph_name: &str) -> Option<String> {
1670        let cg = self.compiled.get(graph_name)?;
1671        let mut out = format!("=== JIT Compiled: {} ===\n", cg.graph_name);
1672        out.push_str(&format!("{}\n\n", cg.stats));
1673        for (i, instr) in cg.instructions.iter().enumerate() {
1674            out.push_str(&format!("  [{:>3}] {:?}\n", i, instr));
1675        }
1676        out.push_str(&format!("\nOutputs: {:?}\n", cg.output_slots));
1677        Some(out)
1678    }
1679}
1680
1681// =============================================================================
1682// Helper functions
1683// =============================================================================
1684
1685/// Get a tensor from a buffer slot.
1686fn get_slot<B: Backend>(slots: &[Option<Tensor<B>>], idx: usize) -> Result<&Tensor<B>> {
1687    slots.get(idx).and_then(|s| s.as_ref()).ok_or_else(|| {
1688        shrew_core::Error::msg(format!(
1689            "Buffer slot {} is empty (value was freed or never produced)",
1690            idx
1691        ))
1692    })
1693}
1694
1695/// Resolve a negative dimension index.
1696fn resolve_neg_dim(dim: i64, rank: usize) -> usize {
1697    if dim < 0 {
1698        (rank as i64 + dim) as usize
1699    } else {
1700        dim as usize
1701    }
1702}
1703
1704/// Resolve a Vec<Dim> to concrete shape.
1705fn resolve_shape_vec(
1706    dims: &[Dim],
1707    config: &RuntimeConfig,
1708    program: &IrProgram,
1709) -> Result<Vec<usize>> {
1710    dims.iter()
1711        .map(|d| resolve_dim(d, config, program))
1712        .collect()
1713}
1714
1715/// Resolve a single Dim.
1716fn resolve_dim(dim: &Dim, config: &RuntimeConfig, program: &IrProgram) -> Result<usize> {
1717    match dim {
1718        Dim::Fixed(n) => Ok(*n as usize),
1719        Dim::Symbolic(name) => {
1720            if let Some(&val) = config.dims.get(name) {
1721                return Ok(val);
1722            }
1723            if let Some(shrew_ir::graph::ConfigValue::Int(n)) = program.config.get(name) {
1724                return Ok(*n as usize);
1725            }
1726            Err(shrew_core::Error::msg(format!(
1727                "Unresolved symbolic dimension: '{}'",
1728                name
1729            )))
1730        }
1731        Dim::Dynamic => Err(shrew_core::Error::msg(
1732            "Cannot resolve dynamic dimension at compile time",
1733        )),
1734    }
1735}
1736
1737/// Initialize a parameter tensor.
1738fn init_param<B: Backend>(
1739    ty: &IrType,
1740    init: &shrew_ir::graph::InitStrategy,
1741    frozen: bool,
1742    config: &RuntimeConfig,
1743    program: &IrProgram,
1744    device: &B::Device,
1745) -> Result<Tensor<B>> {
1746    let (shape, dtype) = resolve_type(ty, config, program)?;
1747
1748    let tensor = match init {
1749        shrew_ir::graph::InitStrategy::Zeros => Tensor::<B>::zeros(shape, dtype, device)?,
1750        shrew_ir::graph::InitStrategy::Ones => Tensor::<B>::ones(shape, dtype, device)?,
1751        shrew_ir::graph::InitStrategy::Normal { mean, std } => {
1752            Tensor::<B>::randn(shape, dtype, device)?.affine(*std, *mean)?
1753        }
1754        shrew_ir::graph::InitStrategy::Uniform { low, high } => {
1755            Tensor::<B>::rand(shape, dtype, device)?.affine(*high - *low, *low)?
1756        }
1757        shrew_ir::graph::InitStrategy::XavierUniform => {
1758            let (fan_in, fan_out) = compute_fans(&shape);
1759            let a = (6.0_f64 / (fan_in + fan_out) as f64).sqrt();
1760            Tensor::<B>::rand(shape, dtype, device)?.affine(2.0 * a, -a)?
1761        }
1762        shrew_ir::graph::InitStrategy::XavierNormal => {
1763            let (fan_in, fan_out) = compute_fans(&shape);
1764            let std = (2.0_f64 / (fan_in + fan_out) as f64).sqrt();
1765            Tensor::<B>::randn(shape, dtype, device)?.affine(std, 0.0)?
1766        }
1767        shrew_ir::graph::InitStrategy::KaimingUniform => {
1768            let (fan_in, _) = compute_fans(&shape);
1769            let bound = (3.0_f64 / fan_in as f64).sqrt();
1770            Tensor::<B>::rand(shape, dtype, device)?.affine(2.0 * bound, -bound)?
1771        }
1772        shrew_ir::graph::InitStrategy::KaimingNormal => {
1773            let (fan_in, _) = compute_fans(&shape);
1774            let std = (2.0_f64 / fan_in as f64).sqrt();
1775            Tensor::<B>::randn(shape, dtype, device)?.affine(std, 0.0)?
1776        }
1777        shrew_ir::graph::InitStrategy::Custom(_) => Tensor::<B>::randn(shape, dtype, device)?,
1778    };
1779
1780    if frozen {
1781        Ok(tensor)
1782    } else {
1783        Ok(tensor.set_variable())
1784    }
1785}
1786
1787/// Resolve IrType to (Shape, CoreDType).
1788fn resolve_type(
1789    ty: &IrType,
1790    config: &RuntimeConfig,
1791    program: &IrProgram,
1792) -> Result<(shrew_core::Shape, CoreDType)> {
1793    match ty {
1794        IrType::Tensor { shape, dtype } => {
1795            let dims: Vec<usize> = shape
1796                .iter()
1797                .map(|d| resolve_dim(d, config, program))
1798                .collect::<Result<Vec<_>>>()?;
1799            let core_dtype = ir_dtype_to_core(*dtype)?;
1800            Ok((shrew_core::Shape::new(dims), core_dtype))
1801        }
1802        IrType::Scalar(dtype) => {
1803            let core_dtype = ir_dtype_to_core(*dtype)?;
1804            Ok((shrew_core::Shape::new(vec![1]), core_dtype))
1805        }
1806        IrType::Int => Ok((shrew_core::Shape::new(vec![1]), CoreDType::I64)),
1807        _ => Ok((shrew_core::Shape::new(vec![1]), config.default_dtype)),
1808    }
1809}
1810
1811/// Materialize a constant as a tensor.
1812fn materialize_constant<B: Backend>(
1813    val: &ConstantValue,
1814    ty: &IrType,
1815    default_dtype: CoreDType,
1816    device: &B::Device,
1817) -> Result<Tensor<B>> {
1818    match val {
1819        ConstantValue::Int(n) => {
1820            Tensor::<B>::from_f64_slice(&[*n as f64], 1, CoreDType::I64, device)
1821        }
1822        ConstantValue::Float(f) => {
1823            let dtype = match ty {
1824                IrType::Tensor { dtype, .. } => ir_dtype_to_core(*dtype)?,
1825                IrType::Scalar(dtype) => ir_dtype_to_core(*dtype)?,
1826                _ => default_dtype,
1827            };
1828            Tensor::<B>::from_f64_slice(&[*f], 1, dtype, device)
1829        }
1830        ConstantValue::Bool(b) => {
1831            Tensor::<B>::from_f64_slice(&[if *b { 1.0 } else { 0.0 }], 1, CoreDType::U8, device)
1832        }
1833        ConstantValue::Str(_) => Tensor::<B>::zeros(1, default_dtype, device),
1834        ConstantValue::Null => Tensor::<B>::zeros(1, default_dtype, device),
1835    }
1836}
1837
1838/// Execute a body op (for Repeat instruction).
1839fn execute_body_op<B: Backend>(
1840    op: &OpKind,
1841    input: &Tensor<B>,
1842    _device: &B::Device,
1843) -> Result<Tensor<B>> {
1844    match op {
1845        OpKind::TransformerBlock { n_heads } => {
1846            let dims = input.dims();
1847            let d_model = dims[dims.len() - 1];
1848            let d_ff = d_model * 4;
1849            let block = shrew_nn::TransformerBlock::<B>::new(
1850                d_model,
1851                *n_heads as usize,
1852                d_ff,
1853                true,
1854                input.dtype(),
1855                input.device(),
1856            )?;
1857            block.forward(input)
1858        }
1859        OpKind::MultiHeadAttention { n_heads } => {
1860            let d_model = *input
1861                .dims()
1862                .last()
1863                .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
1864            let mha = shrew_nn::MultiHeadAttention::<B>::new(
1865                d_model,
1866                *n_heads as usize,
1867                input.dtype(),
1868                input.device(),
1869            )?;
1870            mha.forward(input)
1871        }
1872        _ => Err(shrew_core::Error::msg(format!(
1873            "Unsupported op in Repeat body: {:?}",
1874            op
1875        ))),
1876    }
1877}
1878
1879/// Compute (fan_in, fan_out) from a parameter shape.
1880fn compute_fans(shape: &shrew_core::Shape) -> (usize, usize) {
1881    let dims = shape.dims();
1882    match dims.len() {
1883        0 => (1, 1),
1884        1 => (dims[0], dims[0]),
1885        2 => (dims[1], dims[0]),
1886        _ => {
1887            let receptive: usize = dims[2..].iter().product();
1888            let fan_in = dims[1] * receptive;
1889            let fan_out = dims[0] * receptive;
1890            (fan_in, fan_out)
1891        }
1892    }
1893}
1894
1895// =============================================================================
1896// Convenience: parse → lower → validate → optimize → JIT compile
1897// =============================================================================
1898
1899/// Parse, lower, validate, optimize, and JIT-compile a `.sw` program.
1900///
1901/// This is the recommended entry point for production use — it produces
1902/// a JIT executor that runs graphs faster than the interpreter.
1903///
1904/// # Example
1905/// ```ignore
1906/// let jit = load_jit::<CpuBackend>(source, CpuDevice, RuntimeConfig::default())?;
1907/// let result = jit.run("Forward", &inputs)?;
1908/// ```
1909pub fn load_jit<B: Backend>(
1910    source: &str,
1911    device: B::Device,
1912    config: RuntimeConfig,
1913) -> Result<JitExecutor<B>> {
1914    let ast =
1915        shrew_ir::parse(source).map_err(|e| shrew_core::Error::msg(format!("Parse error: {e}")))?;
1916    let mut ir = shrew_ir::lower(&ast)
1917        .map_err(|e| shrew_core::Error::msg(format!("Lowering error: {e}")))?;
1918
1919    if let Err(errors) = shrew_ir::validate(&ir) {
1920        let msg = errors
1921            .iter()
1922            .map(|e| e.to_string())
1923            .collect::<Vec<_>>()
1924            .join("\n");
1925        return Err(shrew_core::Error::msg(format!("Validation errors:\n{msg}")));
1926    }
1927
1928    shrew_ir::infer_shapes(&mut ir);
1929    shrew_ir::optimize(&mut ir);
1930
1931    JitExecutor::<B>::compile(ir, device, config)
1932}