shrew/exec/
engine.rs

1// =============================================================================
2// Engine — Core graph execution engine
3// =============================================================================
4//
5// Walks the IrGraph in topological order, dispatching each node to the
6// appropriate tensor operation. Manages parameter initialization and the
7// mapping from NodeId → live Tensor.
8
9use std::collections::HashMap;
10
11use shrew_core::backend::Backend;
12use shrew_core::dtype::DType as CoreDType;
13use shrew_core::error::Result;
14use shrew_core::tensor::Tensor;
15
16use shrew_ir::graph::{
17    ConfigValue, ConstantValue, DType as IrDType, Dim, InitStrategy, IrGraph, IrNode, IrProgram,
18    IrType, OpKind,
19};
20
21use shrew_nn::{
22    cross_entropy_loss, mse_loss, Dropout, Embedding, LayerNorm, Linear, Module, TransformerBlock,
23};
24
25// ─────────────────────────────────────────────────────────────────────────────
26// Runtime configuration
27// ─────────────────────────────────────────────────────────────────────────────
28
29/// Runtime configuration for resolving symbolic dimensions and execution mode.
30#[derive(Debug, Clone)]
31pub struct RuntimeConfig {
32    /// Maps symbolic dimension names to concrete values (e.g., "Batch" → 4).
33    pub dims: HashMap<String, usize>,
34    /// Default data type when unspecified (default: F32).
35    pub default_dtype: CoreDType,
36    /// Whether we're in training mode (affects dropout, etc.).
37    pub training: bool,
38}
39
40impl Default for RuntimeConfig {
41    fn default() -> Self {
42        Self {
43            dims: HashMap::new(),
44            default_dtype: CoreDType::F32,
45            training: false,
46        }
47    }
48}
49
50impl RuntimeConfig {
51    /// Set a symbolic dimension value.
52    pub fn set_dim(mut self, name: impl Into<String>, value: usize) -> Self {
53        self.dims.insert(name.into(), value);
54        self
55    }
56
57    /// Set training mode.
58    pub fn with_training(mut self, training: bool) -> Self {
59        self.training = training;
60        self
61    }
62
63    /// Set default dtype.
64    pub fn with_dtype(mut self, dtype: CoreDType) -> Self {
65        self.default_dtype = dtype;
66        self
67    }
68}
69
70// ─────────────────────────────────────────────────────────────────────────────
71// Execution result
72// ─────────────────────────────────────────────────────────────────────────────
73
74/// The result of executing a graph.
75#[derive(Debug)]
76pub struct ExecResult<B: Backend> {
77    /// Output tensors, keyed by node name.
78    pub outputs: HashMap<String, Tensor<B>>,
79    /// All intermediate values, keyed by NodeId.
80    pub values: HashMap<usize, Tensor<B>>,
81}
82
83impl<B: Backend> ExecResult<B> {
84    /// Get the first (or only) output tensor.
85    pub fn output(&self) -> Option<&Tensor<B>> {
86        self.outputs.values().next()
87    }
88
89    /// Get an output by node name.
90    pub fn get(&self, name: &str) -> Option<&Tensor<B>> {
91        self.outputs.get(name)
92    }
93}
94
95// ─────────────────────────────────────────────────────────────────────────────
96// Executor
97// ─────────────────────────────────────────────────────────────────────────────
98
99/// Executes IrProgram graphs on the Shrew tensor runtime.
100pub struct Executor<B: Backend> {
101    /// The lowered IR program.
102    program: IrProgram,
103    /// Runtime configuration (symbolic dims, dtype, training mode).
104    config: RuntimeConfig,
105    /// Device to execute on.
106    device: B::Device,
107    /// Initialized parameter tensors, keyed by (graph_name, param_name).
108    params: HashMap<(String, String), Tensor<B>>,
109}
110
111impl<B: Backend> Executor<B> {
112    /// Create a new executor. Initializes all parameters.
113    pub fn new(program: IrProgram, device: B::Device, config: RuntimeConfig) -> Result<Self> {
114        let mut exec = Self {
115            program,
116            config,
117            device,
118            params: HashMap::new(),
119        };
120        exec.init_all_params()?;
121        Ok(exec)
122    }
123
124    /// Get the underlying IR program.
125    pub fn program(&self) -> &IrProgram {
126        &self.program
127    }
128
129    /// Get a reference to the runtime config.
130    pub fn config(&self) -> &RuntimeConfig {
131        &self.config
132    }
133
134    /// Get a mutable reference to the runtime config.
135    pub fn config_mut(&mut self) -> &mut RuntimeConfig {
136        &mut self.config
137    }
138
139    /// Get all parameter tensors.
140    pub fn params(&self) -> &HashMap<(String, String), Tensor<B>> {
141        &self.params
142    }
143
144    /// Get flattened parameter list (all graphs).
145    pub fn all_params(&self) -> Vec<Tensor<B>> {
146        self.params.values().cloned().collect()
147    }
148
149    /// Get all parameters as `(key, tensor)` pairs, where key = `"graph/param"`.
150    pub fn named_params(&self) -> Vec<(String, Tensor<B>)> {
151        let mut pairs: Vec<(String, Tensor<B>)> = self
152            .params
153            .iter()
154            .map(|((g, p), t)| (format!("{g}/{p}"), t.clone()))
155            .collect();
156        pairs.sort_by(|a, b| a.0.cmp(&b.0));
157        pairs
158    }
159
160    /// Set a parameter by its `"graph/param"` key.  Returns true if found.
161    pub fn set_param_by_key(&mut self, key: &str, tensor: Tensor<B>) -> bool {
162        if let Some(pos) = key.find('/') {
163            let graph = &key[..pos];
164            let param = &key[pos + 1..];
165            let k = (graph.to_string(), param.to_string());
166            if let std::collections::hash_map::Entry::Occupied(mut e) = self.params.entry(k) {
167                e.insert(tensor.set_variable());
168                return true;
169            }
170        }
171        false
172    }
173
174    /// The device this executor is running on.
175    pub fn device(&self) -> &B::Device {
176        &self.device
177    }
178
179    /// Execute a named graph with given inputs.
180    pub fn run(
181        &self,
182        graph_name: &str,
183        inputs: &HashMap<String, Tensor<B>>,
184    ) -> Result<ExecResult<B>> {
185        let graph = self.program.get_graph(graph_name).ok_or_else(|| {
186            shrew_core::Error::msg(format!("Graph '{}' not found in program", graph_name))
187        })?;
188        self.execute_graph(graph, inputs)
189    }
190
191    /// Execute a graph, returning output tensors and all intermediate values.
192    fn execute_graph(
193        &self,
194        graph: &IrGraph,
195        inputs: &HashMap<String, Tensor<B>>,
196    ) -> Result<ExecResult<B>> {
197        let order = graph.topo_order();
198        let mut values: HashMap<usize, Tensor<B>> = HashMap::new();
199
200        // Map input nodes to their provided tensors
201        for &input_id in &graph.inputs {
202            let node = graph.node(input_id);
203            if let Some(tensor) = inputs.get(&node.name) {
204                values.insert(input_id.0, tensor.clone());
205            }
206        }
207
208        // Map parameter nodes to their initialized tensors
209        for param in &graph.params {
210            let key = (graph.name.clone(), param.name.clone());
211            if let Some(tensor) = self.params.get(&key) {
212                values.insert(param.node_id.0, tensor.clone());
213            }
214        }
215
216        // Execute each node in topological order
217        for &node_id in &order {
218            if values.contains_key(&node_id.0) {
219                continue; // Already initialized (input or param)
220            }
221            let node = graph.node(node_id);
222            let result = self.execute_node(graph, node, &values)?;
223            values.insert(node_id.0, result);
224        }
225
226        // Collect outputs
227        let mut outputs = HashMap::new();
228        for output in &graph.outputs {
229            if let Some(tensor) = values.get(&output.node_id.0) {
230                outputs.insert(output.name.clone(), tensor.clone());
231            }
232        }
233
234        Ok(ExecResult { outputs, values })
235    }
236
237    /// Execute a single node given its inputs' current values.
238    fn execute_node(
239        &self,
240        _graph: &IrGraph,
241        node: &IrNode,
242        values: &HashMap<usize, Tensor<B>>,
243    ) -> Result<Tensor<B>> {
244        // Collect input tensors for this node
245        let input_tensors: Vec<&Tensor<B>> = node
246            .inputs
247            .iter()
248            .filter_map(|id| values.get(&id.0))
249            .collect();
250
251        match &node.op {
252            // ── Identity: pass-through ──
253            OpKind::Identity => input_tensors.first().map(|t| (*t).clone()).ok_or_else(|| {
254                shrew_core::Error::msg(format!("Identity node '{}' has no input", node.name))
255            }),
256
257            // ── Unary ops ──
258            OpKind::Neg => unary(&input_tensors, &node.name, |t| t.neg()),
259            OpKind::Relu => unary(&input_tensors, &node.name, |t| t.relu()),
260            OpKind::Gelu => unary(&input_tensors, &node.name, |t| t.gelu()),
261            OpKind::Silu => unary(&input_tensors, &node.name, |t| t.silu()),
262            OpKind::Sigmoid => unary(&input_tensors, &node.name, |t| t.sigmoid()),
263            OpKind::Tanh => unary(&input_tensors, &node.name, |t| t.tanh()),
264            OpKind::Exp => unary(&input_tensors, &node.name, |t| t.exp()),
265            OpKind::Log => unary(&input_tensors, &node.name, |t| t.log()),
266            OpKind::Sqrt => unary(&input_tensors, &node.name, |t| t.sqrt()),
267
268            // ── Transpose ──
269            OpKind::Transpose => {
270                let t = require_input(&input_tensors, 0, &node.name)?;
271                let rank = t.rank();
272                if rank < 2 {
273                    return Err(shrew_core::Error::msg(format!(
274                        "Transpose requires rank >= 2, got {} for '{}'",
275                        rank, node.name
276                    )));
277                }
278                t.transpose(rank - 2, rank - 1)
279            }
280
281            // ── Binary ops ──
282            OpKind::Add => binary(&input_tensors, &node.name, |a, b| a.add(b)),
283            OpKind::Sub => binary(&input_tensors, &node.name, |a, b| a.sub(b)),
284            OpKind::Mul => binary(&input_tensors, &node.name, |a, b| a.mul(b)),
285            OpKind::Div => binary(&input_tensors, &node.name, |a, b| a.div(b)),
286            OpKind::MatMul => binary(&input_tensors, &node.name, |a, b| a.matmul(b)),
287
288            // ── Pow: x^y via exp(y * ln(x)) ──
289            OpKind::Pow => {
290                let base = require_input(&input_tensors, 0, &node.name)?;
291                let exp_t = require_input(&input_tensors, 1, &node.name)?;
292                // x^y = exp(y * ln(x))
293                base.log()?.mul(exp_t)?.exp()
294            }
295
296            // ── Mod: a - floor(a / b) * b ──
297            OpKind::Mod => {
298                let a = require_input(&input_tensors, 0, &node.name)?;
299                let b = require_input(&input_tensors, 1, &node.name)?;
300                let quotient = a.div(b)?.floor()?;
301                let product = quotient.mul(b)?;
302                a.sub(&product)
303            }
304
305            // ── Reduction ops ──
306            OpKind::Sum { dims, keepdim } => {
307                let t = require_input(&input_tensors, 0, &node.name)?;
308                if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
309                    t.sum_all()
310                } else {
311                    let dim = resolve_neg_dim(dims[0], t.rank());
312                    t.sum(dim, *keepdim)
313                }
314            }
315
316            OpKind::Mean { dims, keepdim } => {
317                let t = require_input(&input_tensors, 0, &node.name)?;
318                if dims.is_empty() || (dims.len() == 1 && dims[0] == -1) {
319                    t.mean_all()
320                } else {
321                    let dim = resolve_neg_dim(dims[0], t.rank());
322                    t.mean(dim, *keepdim)
323                }
324            }
325
326            OpKind::Max { dim, keepdim } => {
327                let t = require_input(&input_tensors, 0, &node.name)?;
328                let d = resolve_neg_dim(*dim, t.rank());
329                t.max(d, *keepdim)
330            }
331
332            OpKind::Min { dim, keepdim } => {
333                let t = require_input(&input_tensors, 0, &node.name)?;
334                let d = resolve_neg_dim(*dim, t.rank());
335                t.min(d, *keepdim)
336            }
337
338            OpKind::Variance { dims, keepdim } => {
339                let t = require_input(&input_tensors, 0, &node.name)?;
340                if dims.is_empty() {
341                    t.var(0, *keepdim)
342                } else {
343                    let dim = resolve_neg_dim(dims[0], t.rank());
344                    t.var(dim, *keepdim)
345                }
346            }
347
348            // ── Softmax ──
349            OpKind::Softmax { dim } => {
350                let t = require_input(&input_tensors, 0, &node.name)?;
351                let d = resolve_neg_dim(*dim, t.rank());
352                t.softmax(d)
353            }
354
355            // ── Shape ops ──
356            OpKind::Reshape { target_shape } | OpKind::View { target_shape } => {
357                let t = require_input(&input_tensors, 0, &node.name)?;
358                let shape = self.resolve_shape_vec(target_shape)?;
359                t.reshape(shape)
360            }
361
362            OpKind::Permute { dims: perm_dims } => {
363                let t = require_input(&input_tensors, 0, &node.name)?;
364                // Apply successive transpositions to achieve the permutation
365                let mut result = t.clone();
366                let mut current: Vec<usize> = (0..t.rank()).collect();
367                for i in 0..perm_dims.len() {
368                    let target = perm_dims[i] as usize;
369                    if current[i] != target {
370                        let j = current.iter().position(|&x| x == target).ok_or_else(|| {
371                            shrew_core::Error::msg(format!(
372                                "permute: dimension {} not found in current layout",
373                                target
374                            ))
375                        })?;
376                        result = result.transpose(i, j)?;
377                        current.swap(i, j);
378                    }
379                }
380                Ok(result)
381            }
382
383            OpKind::Expand { target_shape } => {
384                let t = require_input(&input_tensors, 0, &node.name)?;
385                let shape = self.resolve_shape_vec(target_shape)?;
386                t.expand(shape)
387            }
388
389            OpKind::Concat { dim } => {
390                if input_tensors.is_empty() {
391                    return Err(shrew_core::Error::msg(format!(
392                        "Concat node '{}' has no inputs",
393                        node.name
394                    )));
395                }
396                let owned: Vec<Tensor<B>> = input_tensors.iter().map(|t| (*t).clone()).collect();
397                Tensor::<B>::cat(&owned, *dim as usize)
398            }
399
400            // ── Embedding ──
401            // Convention: embedding(indices, weight_table)
402            OpKind::Embedding => {
403                let indices = require_input(&input_tensors, 0, &node.name)?;
404                let table = require_input(&input_tensors, 1, &node.name)?;
405                let emb = Embedding::<B>::from_tensor(table.clone())?;
406                emb.forward(indices)
407            }
408
409            // ── Linear ──
410            // Convention: linear(input, weight) or linear(input, weight, bias)
411            OpKind::Linear { bias } => {
412                let input = require_input(&input_tensors, 0, &node.name)?;
413                let weight = require_input(&input_tensors, 1, &node.name)?;
414                if *bias && input_tensors.len() >= 3 {
415                    let bias_t = require_input(&input_tensors, 2, &node.name)?;
416                    let lin = Linear::<B>::from_tensors(weight.clone(), Some(bias_t.clone()))?;
417                    lin.forward(input)
418                } else {
419                    let lin = Linear::<B>::from_tensors(weight.clone(), None)?;
420                    lin.forward(input)
421                }
422            }
423
424            // ── LayerNorm ──
425            // Convention: layer_norm(input, weight, bias)
426            OpKind::LayerNorm { eps } => {
427                let input = require_input(&input_tensors, 0, &node.name)?;
428                let weight = require_input(&input_tensors, 1, &node.name)?;
429                let bias_t = require_input(&input_tensors, 2, &node.name)?;
430                let ln = LayerNorm::<B>::from_tensors(weight.clone(), bias_t.clone(), *eps)?;
431                ln.forward(input)
432            }
433
434            // ── MultiHeadAttention ──
435            OpKind::MultiHeadAttention { n_heads } => {
436                let input = require_input(&input_tensors, 0, &node.name)?;
437                let d_model = *input
438                    .dims()
439                    .last()
440                    .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
441                let mha = shrew_nn::MultiHeadAttention::<B>::new(
442                    d_model,
443                    *n_heads as usize,
444                    input.dtype(),
445                    input.device(),
446                )?;
447                mha.forward(input)
448            }
449
450            // ── TransformerBlock ──
451            OpKind::TransformerBlock { n_heads } => {
452                let input = require_input(&input_tensors, 0, &node.name)?;
453                let dims = input.dims();
454                if dims.len() != 3 {
455                    return Err(shrew_core::Error::msg(format!(
456                        "TransformerBlock expects [batch, seq, d_model], got {:?}",
457                        dims
458                    )));
459                }
460                let d_model = dims[2];
461                let d_ff = d_model * 4;
462                let block = TransformerBlock::<B>::new(
463                    d_model,
464                    *n_heads as usize,
465                    d_ff,
466                    true, // causal by default
467                    input.dtype(),
468                    input.device(),
469                )?;
470                block.forward(input)
471            }
472
473            // ── Dropout ──
474            OpKind::Dropout { p } => {
475                let input = require_input(&input_tensors, 0, &node.name)?;
476                let dropout = Dropout::new(*p);
477                if self.config.training {
478                    dropout.forward_t(input)
479                } else {
480                    Ok(input.clone())
481                }
482            }
483
484            // ── Loss functions ──
485            OpKind::CrossEntropy => {
486                let predictions = require_input(&input_tensors, 0, &node.name)?;
487                let targets = require_input(&input_tensors, 1, &node.name)?;
488                cross_entropy_loss(predictions, targets)
489            }
490
491            OpKind::MseLoss => {
492                let predictions = require_input(&input_tensors, 0, &node.name)?;
493                let targets = require_input(&input_tensors, 1, &node.name)?;
494                mse_loss(predictions, targets)
495            }
496
497            // ── Comparison ops ──
498            OpKind::Equal
499            | OpKind::NotEqual
500            | OpKind::Less
501            | OpKind::Greater
502            | OpKind::LessEqual
503            | OpKind::GreaterEqual => {
504                let lhs = require_input(&input_tensors, 0, &node.name)?;
505                let rhs = require_input(&input_tensors, 1, &node.name)?;
506                match &node.op {
507                    OpKind::Equal => lhs.eq(rhs),
508                    OpKind::NotEqual => lhs.ne(rhs),
509                    OpKind::Less => lhs.lt(rhs),
510                    OpKind::Greater => lhs.gt(rhs),
511                    OpKind::LessEqual => lhs.le(rhs),
512                    OpKind::GreaterEqual => lhs.ge(rhs),
513                    _ => unreachable!(),
514                }
515            }
516
517            // ── Constants ──
518            OpKind::Constant(val) => self.materialize_constant(val, &node.output_type),
519
520            // ── Repeat: execute body_op N times in sequence ──
521            OpKind::Repeat { count, body_op } => {
522                let input = require_input(&input_tensors, 0, &node.name)?;
523                let mut current = input.clone();
524                for _ in 0..*count {
525                    current = self.execute_body_op(body_op, &current)?;
526                }
527                Ok(current)
528            }
529
530            // ── Call: execute another graph ──
531            OpKind::Call { graph_name } => {
532                // Build inputs for the sub-graph
533                let sub_graph = self.program.get_graph(graph_name).ok_or_else(|| {
534                    shrew_core::Error::msg(format!("Called graph '{}' not found", graph_name))
535                })?;
536                let mut sub_inputs = HashMap::new();
537                for (i, &input_id) in sub_graph.inputs.iter().enumerate() {
538                    let input_node = sub_graph.node(input_id);
539                    if let Some(tensor) = input_tensors.get(i) {
540                        sub_inputs.insert(input_node.name.clone(), (*tensor).clone());
541                    }
542                }
543                let result = self.execute_graph(sub_graph, &sub_inputs)?;
544                result.output().cloned().ok_or_else(|| {
545                    shrew_core::Error::msg(format!(
546                        "Called graph '{}' produced no output",
547                        graph_name
548                    ))
549                })
550            }
551
552            // ── Range ──
553            OpKind::Range => {
554                // range(start, end) → 1D tensor [start, start+1, ..., end-1]
555                let (start, end) = if input_tensors.len() >= 2 {
556                    let s = input_tensors[0].to_scalar_f64()?;
557                    let e = input_tensors[1].to_scalar_f64()?;
558                    (s as i64, e as i64)
559                } else if input_tensors.len() == 1 {
560                    (0i64, input_tensors[0].to_scalar_f64()? as i64)
561                } else {
562                    // Try resolving from output type shape
563                    match &node.output_type {
564                        IrType::Tensor { shape, .. } => {
565                            if let Some(Dim::Fixed(n)) = shape.first() {
566                                (0, *n)
567                            } else if let Some(Dim::Symbolic(name)) = shape.first() {
568                                let n = self.resolve_symbolic(name)? as i64;
569                                (0, n)
570                            } else {
571                                (0, 1)
572                            }
573                        }
574                        _ => (0, 1),
575                    }
576                };
577                let data: Vec<f64> = (start..end).map(|i| i as f64).collect();
578                let len = data.len();
579                Tensor::<B>::from_f64_slice(&data, len, CoreDType::I64, &self.device)
580            }
581
582            // ── BatchNorm ──
583            // Convention: batch_norm(input, weight, bias)
584            OpKind::BatchNorm { eps } => {
585                let input = require_input(&input_tensors, 0, &node.name)?;
586                if input_tensors.len() >= 3 {
587                    let weight = require_input(&input_tensors, 1, &node.name)?;
588                    let bias_t = require_input(&input_tensors, 2, &node.name)?;
589                    let bn = shrew_nn::BatchNorm2d::<B>::from_tensors(
590                        weight.clone(),
591                        bias_t.clone(),
592                        *eps,
593                    )?;
594                    bn.forward(input)
595                } else {
596                    // No weight/bias provided — create default BatchNorm from channels
597                    let dims = input.dims();
598                    if dims.len() != 4 {
599                        return Err(shrew_core::Error::msg(format!(
600                            "BatchNorm expects 4D input [N,C,H,W], got {:?}",
601                            dims
602                        )));
603                    }
604                    let c = dims[1];
605                    let bn =
606                        shrew_nn::BatchNorm2d::<B>::new(c, *eps, 0.1, input.dtype(), &self.device)?;
607                    bn.forward(input)
608                }
609            }
610
611            // ── Split ──
612            OpKind::Split { dim, chunks } => {
613                let input = require_input(&input_tensors, 0, &node.name)?;
614                let d = resolve_neg_dim(*dim, input.rank());
615                let result = input.chunk(*chunks as usize, d)?;
616                // Return first chunk (Split in IR produces a single node)
617                result
618                    .into_iter()
619                    .next()
620                    .ok_or_else(|| shrew_core::Error::msg("Split produced no chunks"))
621            }
622
623            // ── Logical ops (on comparison results) ──
624            OpKind::And => {
625                let lhs = require_input(&input_tensors, 0, &node.name)?;
626                let rhs = require_input(&input_tensors, 1, &node.name)?;
627                // a AND b = (a != 0) & (b != 0) → element-wise min
628                let a_data = lhs.to_f64_vec()?;
629                let b_data = rhs.to_f64_vec()?;
630                let result: Vec<f64> = a_data
631                    .iter()
632                    .zip(b_data.iter())
633                    .map(|(&a, &b)| if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 })
634                    .collect();
635                let n = result.len();
636                Tensor::<B>::from_f64_slice(&result, n, CoreDType::U8, &self.device)
637            }
638            OpKind::Or => {
639                let lhs = require_input(&input_tensors, 0, &node.name)?;
640                let rhs = require_input(&input_tensors, 1, &node.name)?;
641                let a_data = lhs.to_f64_vec()?;
642                let b_data = rhs.to_f64_vec()?;
643                let result: Vec<f64> = a_data
644                    .iter()
645                    .zip(b_data.iter())
646                    .map(|(&a, &b)| if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 })
647                    .collect();
648                let n = result.len();
649                Tensor::<B>::from_f64_slice(&result, n, CoreDType::U8, &self.device)
650            }
651            OpKind::Not => {
652                let input = require_input(&input_tensors, 0, &node.name)?;
653                let data = input.to_f64_vec()?;
654                let result: Vec<f64> = data
655                    .iter()
656                    .map(|&v| if v == 0.0 { 1.0 } else { 0.0 })
657                    .collect();
658                let n = result.len();
659                Tensor::<B>::from_f64_slice(&result, n, CoreDType::U8, &self.device)
660            }
661
662            // ── Custom op ──
663            OpKind::Custom { name, .. } => {
664                match name.as_str() {
665                    // Fused matmul + add: a.matmul(b) + c (no weight transpose)
666                    "fused_matmul_add" => {
667                        let a = require_input(&input_tensors, 0, &node.name)?;
668                        let b = require_input(&input_tensors, 1, &node.name)?;
669                        let c = require_input(&input_tensors, 2, &node.name)?;
670                        a.matmul(b)?.add(c)
671                    }
672                    // Fused add + relu
673                    "fused_add_relu" => {
674                        let a = require_input(&input_tensors, 0, &node.name)?;
675                        let b = require_input(&input_tensors, 1, &node.name)?;
676                        a.add(b)?.relu()
677                    }
678                    // Fused sub + relu
679                    "fused_sub_relu" => {
680                        let a = require_input(&input_tensors, 0, &node.name)?;
681                        let b = require_input(&input_tensors, 1, &node.name)?;
682                        a.sub(b)?.relu()
683                    }
684                    // Fused matmul + relu
685                    "fused_matmul_relu" => {
686                        let a = require_input(&input_tensors, 0, &node.name)?;
687                        let b = require_input(&input_tensors, 1, &node.name)?;
688                        a.matmul(b)?.relu()
689                    }
690                    _ => Err(shrew_core::Error::msg(format!(
691                        "Custom op '{}' is not implemented in the executor",
692                        name
693                    ))),
694                }
695            }
696        }
697    }
698
699    /// Execute a body op (used inside Repeat).
700    fn execute_body_op(&self, op: &OpKind, input: &Tensor<B>) -> Result<Tensor<B>> {
701        match op {
702            OpKind::TransformerBlock { n_heads } => {
703                let dims = input.dims();
704                if dims.len() != 3 {
705                    return Err(shrew_core::Error::msg(format!(
706                        "TransformerBlock expects [batch, seq, d_model], got {:?}",
707                        dims
708                    )));
709                }
710                let d_model = dims[2];
711                let d_ff = d_model * 4;
712                let block = TransformerBlock::<B>::new(
713                    d_model,
714                    *n_heads as usize,
715                    d_ff,
716                    true,
717                    input.dtype(),
718                    input.device(),
719                )?;
720                block.forward(input)
721            }
722            OpKind::MultiHeadAttention { n_heads } => {
723                let d_model = *input
724                    .dims()
725                    .last()
726                    .ok_or_else(|| shrew_core::Error::msg("MHA input has no dimensions"))?;
727                let mha = shrew_nn::MultiHeadAttention::<B>::new(
728                    d_model,
729                    *n_heads as usize,
730                    input.dtype(),
731                    input.device(),
732                )?;
733                mha.forward(input)
734            }
735            // For other repeated ops, dispatch through the main execute_node
736            // infrastructure by returning an error so the caller knows.
737            _ => Err(shrew_core::Error::msg(format!(
738                "Unsupported op in Repeat body: {:?}. \
739                 Only TransformerBlock and MultiHeadAttention are supported.",
740                op
741            ))),
742        }
743    }
744
745    // ─────────────────────────────────────────────────────────────────────
746    // Parameter initialization
747    // ─────────────────────────────────────────────────────────────────────
748
749    /// Initialize all parameters across all graphs.
750    fn init_all_params(&mut self) -> Result<()> {
751        let graphs: Vec<(String, Vec<_>)> = self
752            .program
753            .graphs
754            .iter()
755            .map(|g| {
756                (
757                    g.name.clone(),
758                    g.params
759                        .iter()
760                        .map(|p| (p.name.clone(), p.ty.clone(), p.init.clone(), p.frozen))
761                        .collect::<Vec<_>>(),
762                )
763            })
764            .collect();
765
766        for (graph_name, params) in &graphs {
767            for (param_name, ty, init, frozen) in params {
768                let tensor = self.init_param(ty, init, *frozen)?;
769                self.params
770                    .insert((graph_name.clone(), param_name.clone()), tensor);
771            }
772        }
773        Ok(())
774    }
775
776    /// Initialize a single parameter tensor based on its type and init strategy.
777    fn init_param(&self, ty: &IrType, init: &InitStrategy, frozen: bool) -> Result<Tensor<B>> {
778        let (shape, dtype) = self.resolve_type(ty)?;
779        let tensor = match init {
780            InitStrategy::Zeros => Tensor::<B>::zeros(shape, dtype, &self.device)?,
781            InitStrategy::Ones => Tensor::<B>::ones(shape, dtype, &self.device)?,
782            InitStrategy::Normal { mean, std } => {
783                Tensor::<B>::randn(shape, dtype, &self.device)?.affine(*std, *mean)?
784            }
785            InitStrategy::Uniform { low, high } => {
786                let range = high - low;
787                Tensor::<B>::rand(shape, dtype, &self.device)?.affine(range, *low)?
788            }
789            InitStrategy::XavierUniform => {
790                // Xavier uniform: U(-a, a) where a = sqrt(6 / (fan_in + fan_out))
791                let (fan_in, fan_out) = compute_fans(&shape);
792                let a = (6.0_f64 / (fan_in + fan_out) as f64).sqrt();
793                Tensor::<B>::rand(shape, dtype, &self.device)?.affine(2.0 * a, -a)?
794            }
795            InitStrategy::XavierNormal => {
796                // Xavier normal: N(0, std) where std = sqrt(2 / (fan_in + fan_out))
797                let (fan_in, fan_out) = compute_fans(&shape);
798                let std = (2.0_f64 / (fan_in + fan_out) as f64).sqrt();
799                Tensor::<B>::randn(shape, dtype, &self.device)?.affine(std, 0.0)?
800            }
801            InitStrategy::KaimingUniform => {
802                // Kaiming uniform: U(-bound, bound) where bound = sqrt(3 / fan_in)
803                let (fan_in, _) = compute_fans(&shape);
804                let bound = (3.0_f64 / fan_in as f64).sqrt();
805                Tensor::<B>::rand(shape, dtype, &self.device)?.affine(2.0 * bound, -bound)?
806            }
807            InitStrategy::KaimingNormal => {
808                // Kaiming normal: N(0, std) where std = sqrt(2 / fan_in)
809                let (fan_in, _) = compute_fans(&shape);
810                let std = (2.0_f64 / fan_in as f64).sqrt();
811                Tensor::<B>::randn(shape, dtype, &self.device)?.affine(std, 0.0)?
812            }
813            InitStrategy::Custom(_) => Tensor::<B>::randn(shape, dtype, &self.device)?,
814        };
815
816        if frozen {
817            Ok(tensor)
818        } else {
819            Ok(tensor.set_variable())
820        }
821    }
822
823    /// Update parameters after an optimizer step.
824    pub fn update_params(&mut self, graph_name: &str, new_params: &[Tensor<B>]) {
825        let param_names: Vec<String> = self
826            .params
827            .keys()
828            .filter(|(g, _)| g == graph_name)
829            .map(|(_, n)| n.clone())
830            .collect();
831
832        for (name, tensor) in param_names.into_iter().zip(new_params.iter()) {
833            self.params
834                .insert((graph_name.to_string(), name), tensor.clone());
835        }
836    }
837
838    /// Collect parameters for a specific graph (for optimizer).
839    pub fn graph_params(&self, graph_name: &str) -> Vec<Tensor<B>> {
840        self.params
841            .iter()
842            .filter(|((g, _), _)| g == graph_name)
843            .map(|(_, t)| t.clone())
844            .collect()
845    }
846
847    // ─────────────────────────────────────────────────────────────────────
848    // Helpers
849    // ─────────────────────────────────────────────────────────────────────
850
851    /// Resolve a Dim to a concrete usize.
852    fn resolve_dim(&self, dim: &Dim) -> Result<usize> {
853        match dim {
854            Dim::Fixed(n) => Ok(*n as usize),
855            Dim::Symbolic(name) => self.resolve_symbolic(name),
856            Dim::Dynamic => Err(shrew_core::Error::msg(
857                "Cannot resolve dynamic dimension at runtime",
858            )),
859        }
860    }
861
862    /// Resolve a symbolic dimension name.
863    fn resolve_symbolic(&self, name: &str) -> Result<usize> {
864        // Try runtime config
865        if let Some(&val) = self.config.dims.get(name) {
866            return Ok(val);
867        }
868        // Try program config
869        if let Some(ConfigValue::Int(n)) = self.program.config.get(name) {
870            return Ok(*n as usize);
871        }
872        Err(shrew_core::Error::msg(format!(
873            "Unresolved symbolic dimension: '{}'. Set it via RuntimeConfig::set_dim()",
874            name
875        )))
876    }
877
878    /// Resolve an IrType to a concrete (Shape, CoreDType).
879    fn resolve_type(&self, ty: &IrType) -> Result<(shrew_core::Shape, CoreDType)> {
880        match ty {
881            IrType::Tensor { shape, dtype } => {
882                let dims: Vec<usize> = shape
883                    .iter()
884                    .map(|d| self.resolve_dim(d))
885                    .collect::<Result<Vec<_>>>()?;
886                let core_dtype = ir_dtype_to_core(*dtype)?;
887                Ok((shrew_core::Shape::new(dims), core_dtype))
888            }
889            IrType::Scalar(dtype) => {
890                let core_dtype = ir_dtype_to_core(*dtype)?;
891                Ok((shrew_core::Shape::new(vec![1]), core_dtype))
892            }
893            IrType::Int => Ok((shrew_core::Shape::new(vec![1]), CoreDType::I64)),
894            _ => Ok((shrew_core::Shape::new(vec![1]), self.config.default_dtype)),
895        }
896    }
897
898    /// Resolve a Vec<Dim> to a concrete shape tuple.
899    fn resolve_shape_vec(&self, dims: &[Dim]) -> Result<Vec<usize>> {
900        dims.iter().map(|d| self.resolve_dim(d)).collect()
901    }
902
903    /// Materialize a constant value as a tensor.
904    fn materialize_constant(&self, val: &ConstantValue, ty: &IrType) -> Result<Tensor<B>> {
905        match val {
906            ConstantValue::Int(n) => {
907                Tensor::<B>::from_f64_slice(&[*n as f64], 1, CoreDType::I64, &self.device)
908            }
909            ConstantValue::Float(f) => Tensor::<B>::from_f64_slice(
910                &[*f],
911                1,
912                ir_type_dtype(ty, self.config.default_dtype)?,
913                &self.device,
914            ),
915            ConstantValue::Bool(b) => Tensor::<B>::from_f64_slice(
916                &[if *b { 1.0 } else { 0.0 }],
917                1,
918                CoreDType::U8,
919                &self.device,
920            ),
921            ConstantValue::Str(_) => {
922                // Strings can't be tensors — return a dummy scalar
923                Tensor::<B>::zeros(1, self.config.default_dtype, &self.device)
924            }
925            ConstantValue::Null => Tensor::<B>::zeros(1, self.config.default_dtype, &self.device),
926        }
927    }
928}
929
930// ─────────────────────────────────────────────────────────────────────────────
931// Free helpers
932// ─────────────────────────────────────────────────────────────────────────────
933
934/// Convert IR DType to core DType.
935pub fn ir_dtype_to_core(dt: IrDType) -> Result<CoreDType> {
936    match dt {
937        IrDType::F32 => Ok(CoreDType::F32),
938        IrDType::F64 => Ok(CoreDType::F64),
939        IrDType::U8 => Ok(CoreDType::U8),
940        IrDType::U32 => Ok(CoreDType::U32),
941        IrDType::I64 => Ok(CoreDType::I64),
942        // Map unsupported types to closest supported
943        IrDType::F16 | IrDType::Bf16 => Ok(CoreDType::F32),
944        IrDType::I8 | IrDType::I16 | IrDType::I32 => Ok(CoreDType::I64),
945        IrDType::U16 => Ok(CoreDType::U32),
946        IrDType::U64 => Ok(CoreDType::U32),
947        IrDType::Bool => Ok(CoreDType::U8),
948        _ => Err(shrew_core::Error::msg(format!(
949            "Unsupported IR dtype: {dt}"
950        ))),
951    }
952}
953
954/// Extract dtype from IrType, with a fallback default.
955fn ir_type_dtype(ty: &IrType, default: CoreDType) -> Result<CoreDType> {
956    match ty {
957        IrType::Tensor { dtype, .. } => ir_dtype_to_core(*dtype),
958        IrType::Scalar(dtype) => ir_dtype_to_core(*dtype),
959        _ => Ok(default),
960    }
961}
962
963/// Resolve a negative dimension index.
964fn resolve_neg_dim(dim: i64, rank: usize) -> usize {
965    if dim < 0 {
966        (rank as i64 + dim) as usize
967    } else {
968        dim as usize
969    }
970}
971
972/// Require an input at a given index.
973fn require_input<'a, B: Backend>(
974    inputs: &[&'a Tensor<B>],
975    idx: usize,
976    node_name: &str,
977) -> Result<&'a Tensor<B>> {
978    inputs.get(idx).copied().ok_or_else(|| {
979        shrew_core::Error::msg(format!(
980            "Node '{}' expected input at index {}, but only {} inputs available",
981            node_name,
982            idx,
983            inputs.len()
984        ))
985    })
986}
987
988/// Execute a unary op.
989fn unary<B: Backend>(
990    inputs: &[&Tensor<B>],
991    node_name: &str,
992    f: impl FnOnce(&Tensor<B>) -> Result<Tensor<B>>,
993) -> Result<Tensor<B>> {
994    let t = require_input(inputs, 0, node_name)?;
995    f(t)
996}
997
998/// Execute a binary op.
999fn binary<B: Backend>(
1000    inputs: &[&Tensor<B>],
1001    node_name: &str,
1002    f: impl FnOnce(&Tensor<B>, &Tensor<B>) -> Result<Tensor<B>>,
1003) -> Result<Tensor<B>> {
1004    let a = require_input(inputs, 0, node_name)?;
1005    let b = require_input(inputs, 1, node_name)?;
1006    f(a, b)
1007}
1008
1009/// Compute (fan_in, fan_out) from a parameter shape.
1010///
1011/// Follows PyTorch conventions:
1012/// - 1-D (bias): fan_in = fan_out = shape[0]
1013/// - 2-D (linear weight): fan_in = shape[1], fan_out = shape[0]
1014/// - 3-D+ (conv weight): fan_in = shape[1] * receptive, fan_out = shape[0] * receptive
1015fn compute_fans(shape: &shrew_core::Shape) -> (usize, usize) {
1016    let dims = shape.dims();
1017    match dims.len() {
1018        0 => (1, 1),
1019        1 => (dims[0], dims[0]),
1020        2 => (dims[1], dims[0]),
1021        _ => {
1022            // Conv: [out_channels, in_channels, *kernel_size]
1023            let receptive: usize = dims[2..].iter().product();
1024            let fan_in = dims[1] * receptive;
1025            let fan_out = dims[0] * receptive;
1026            (fan_in, fan_out)
1027        }
1028    }
1029}