shrew_ir/
lower.rs

1// Lowering — AST → Graph IR
2//
3// This module transforms a parsed AST (shrew_ir::ast::Program) into a
4// validated Graph IR (shrew_ir::graph::IrProgram). The lowering pass:
5//
6//   1. Resolves config values into a key-value map
7//   2. Resolves type aliases
8//   3. Lowers each @graph block into an IrGraph with typed nodes
9//   4. Resolves operations from expressions (matmul, add, layer_norm, etc.)
10//   5. Connects edges between nodes by name resolution
11//   6. Lowers @training / @inference into config structs
12//
13// ERRORS: Lowering can fail if names are undefined, types mismatch, etc.
14// We use the same Error type from error.rs.
15
16use std::collections::HashMap;
17
18use crate::ast::*;
19use crate::error::{Error, ErrorKind, Result};
20use crate::graph::*;
21use crate::token::Span;
22
23/// Lower a parsed AST program into a Graph IR program.
24pub fn lower(program: &Program) -> Result<IrProgram> {
25    let mut ctx = LowerCtx::new();
26    ctx.lower_program(program)?;
27    Ok(ctx.ir)
28}
29
30// Lowering context
31
32struct LowerCtx {
33    ir: IrProgram,
34}
35
36impl LowerCtx {
37    fn new() -> Self {
38        Self {
39            ir: IrProgram::new(),
40        }
41    }
42
43    fn lower_program(&mut self, program: &Program) -> Result<()> {
44        // First pass: collect metadata, config, types (order-independent info)
45        for item in &program.items {
46            match item {
47                TopLevel::Metadata(m) => self.lower_metadata(m)?,
48                TopLevel::Config(c) => self.lower_config(c)?,
49                TopLevel::Types(t) => self.lower_types(t)?,
50                _ => {}
51            }
52        }
53
54        // Second pass: lower graphs and other blocks
55        for item in &program.items {
56            match item {
57                TopLevel::Graph(g) => {
58                    let ir_graph = self.lower_graph(g)?;
59                    self.ir.graphs.push(ir_graph);
60                }
61                TopLevel::Training(t) => {
62                    self.ir.training = Some(self.lower_training(t)?);
63                }
64                TopLevel::Inference(i) => {
65                    self.ir.inference = Some(self.lower_inference(i)?);
66                }
67                // Metadata, Config, Types already handled in first pass
68                TopLevel::Metadata(_) | TopLevel::Config(_) | TopLevel::Types(_) => {}
69                // Import, CustomOp, Metrics, Logging, Visualization — stored
70                // but not deeply lowered yet (future work)
71                _ => {}
72            }
73        }
74
75        Ok(())
76    }
77
78    // @model
79
80    fn lower_metadata(&mut self, meta: &MetadataBlock) -> Result<()> {
81        for field in &meta.fields {
82            let value = match &field.value {
83                Literal::Str(s, _) => s.clone(),
84                Literal::Int(n, _) => n.to_string(),
85                Literal::Float(f, _) => f.to_string(),
86                Literal::Bool(b, _) => b.to_string(),
87                _ => format!("{:?}", field.value),
88            };
89            self.ir.metadata.insert(field.key.clone(), value);
90        }
91        Ok(())
92    }
93    
94    // @config
95
96    fn lower_config(&mut self, config: &ConfigBlock) -> Result<()> {
97        for field in &config.fields {
98            let value = self.eval_config_expr(&field.value)?;
99            self.ir.config.insert(field.key.clone(), value);
100        }
101        Ok(())
102    }
103
104    /// Evaluate a config expression to a ConfigValue.
105    /// Supports constant folding of arithmetic on integer/float literals.
106    fn eval_config_expr(&self, expr: &Expr) -> Result<ConfigValue> {
107        match expr {
108            Expr::Int(n, _) => Ok(ConfigValue::Int(*n)),
109            Expr::Float(f, _) => Ok(ConfigValue::Float(*f)),
110            Expr::Str(s, _) => Ok(ConfigValue::Str(s.clone())),
111            Expr::Bool(b, _) => Ok(ConfigValue::Bool(*b)),
112            Expr::List(items, _) => {
113                let vals: Result<Vec<_>> = items.iter().map(|e| self.eval_config_expr(e)).collect();
114                Ok(ConfigValue::List(vals?))
115            }
116            Expr::Ident(name, _) => {
117                // Reference to another config value
118                if let Some(val) = self.ir.config.get(name) {
119                    Ok(val.clone())
120                } else {
121                    // Treat as a string (e.g., symbolic dimension name)
122                    Ok(ConfigValue::Str(name.clone()))
123                }
124            }
125            Expr::Binary {
126                left,
127                op,
128                right,
129                span,
130            } => {
131                let l = self.eval_config_expr(left)?;
132                let r = self.eval_config_expr(right)?;
133                self.eval_binary_config(&l, *op, &r, *span)
134            }
135            Expr::Unary {
136                op: UnaryOp::Neg,
137                operand,
138                ..
139            } => {
140                let val = self.eval_config_expr(operand)?;
141                match val {
142                    ConfigValue::Int(n) => Ok(ConfigValue::Int(-n)),
143                    ConfigValue::Float(f) => Ok(ConfigValue::Float(-f)),
144                    _ => Ok(val),
145                }
146            }
147            _ => {
148                // For complex expressions, store as string representation
149                Ok(ConfigValue::Str(format!("{expr:?}")))
150            }
151        }
152    }
153
154    fn eval_binary_config(
155        &self,
156        left: &ConfigValue,
157        op: BinOp,
158        right: &ConfigValue,
159        span: Span,
160    ) -> Result<ConfigValue> {
161        match (left, right) {
162            (ConfigValue::Int(a), ConfigValue::Int(b)) => {
163                let result = match op {
164                    BinOp::Add => a + b,
165                    BinOp::Sub => a - b,
166                    BinOp::Mul => a * b,
167                    BinOp::Div => {
168                        if *b == 0 {
169                            return Err(Error::new(
170                                ErrorKind::Message("division by zero".into()),
171                                span,
172                            ));
173                        }
174                        a / b
175                    }
176                    BinOp::Mod => a % b,
177                    BinOp::Pow => a.pow(*b as u32),
178                    _ => return Ok(ConfigValue::Str(format!("{a} {op:?} {b}"))),
179                };
180                Ok(ConfigValue::Int(result))
181            }
182            (ConfigValue::Float(a), ConfigValue::Float(b)) => {
183                let result = match op {
184                    BinOp::Add => a + b,
185                    BinOp::Sub => a - b,
186                    BinOp::Mul => a * b,
187                    BinOp::Div => a / b,
188                    BinOp::Mod => a % b,
189                    BinOp::Pow => a.powf(*b),
190                    _ => return Ok(ConfigValue::Str(format!("{a} {op:?} {b}"))),
191                };
192                Ok(ConfigValue::Float(result))
193            }
194            (ConfigValue::Int(a), ConfigValue::Float(b)) => self.eval_binary_config(
195                &ConfigValue::Float(*a as f64),
196                op,
197                &ConfigValue::Float(*b),
198                span,
199            ),
200            (ConfigValue::Float(a), ConfigValue::Int(b)) => self.eval_binary_config(
201                &ConfigValue::Float(*a),
202                op,
203                &ConfigValue::Float(*b as f64),
204                span,
205            ),
206            _ => Ok(ConfigValue::Str(format!("{left:?} {op:?} {right:?}"))),
207        }
208    }
209
210    // @types
211
212    fn lower_types(&mut self, types: &TypesBlock) -> Result<()> {
213        for def in &types.defs {
214            let ty = self.lower_type_expr(&def.ty)?;
215            self.ir.type_aliases.insert(def.name.clone(), ty);
216        }
217        Ok(())
218    }
219
220    fn lower_type_expr(&self, ty: &TypeExpr) -> Result<IrType> {
221        match ty {
222            TypeExpr::Tensor { dims, dtype, .. } => {
223                let shape: Vec<Dim> = dims.iter().map(|d| self.lower_dim(d)).collect();
224                let dt = lower_dtype(dtype);
225                Ok(IrType::Tensor { shape, dtype: dt })
226            }
227            TypeExpr::Scalar(dt, _) => Ok(IrType::Scalar(lower_dtype(dt))),
228            TypeExpr::Named(name, _) => {
229                if let Some(resolved) = self.ir.type_aliases.get(name) {
230                    Ok(resolved.clone())
231                } else {
232                    // Forward reference or unknown — keep as unknown
233                    Ok(IrType::Unknown)
234                }
235            }
236            TypeExpr::Dynamic(_) => Ok(IrType::Unknown),
237            TypeExpr::IntDim(n, _) => {
238                // A single integer dimension used as a type
239                Ok(IrType::Tensor {
240                    shape: vec![Dim::Fixed(*n)],
241                    dtype: DType::F32,
242                })
243            }
244            _ => Ok(IrType::Unknown),
245        }
246    }
247
248    fn lower_dim(&self, dim: &Dimension) -> Dim {
249        match dim {
250            Dimension::Named(name, _) => {
251                // Try to resolve symbolic dim from config
252                if let Some(ConfigValue::Int(n)) = self.ir.config.get(name) {
253                    Dim::Fixed(*n)
254                } else {
255                    Dim::Symbolic(name.clone())
256                }
257            }
258            Dimension::Concrete(n, _) => Dim::Fixed(*n),
259            Dimension::Dynamic(_) => Dim::Dynamic,
260            Dimension::Inferred(_) => Dim::Dynamic,
261            Dimension::Computed(expr, _) => {
262                // Try to evaluate the expression as a constant integer dim
263                if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(expr) {
264                    Dim::Fixed(n)
265                } else if let Ok(ConfigValue::Float(f)) = self.eval_config_expr(expr) {
266                    Dim::Fixed(f as i64)
267                } else {
268                    Dim::Dynamic
269                }
270            }
271        }
272    }
273
274    // @graph
275
276    fn lower_graph(&self, graph: &GraphBlock) -> Result<IrGraph> {
277        let mut ir_graph = IrGraph::new(&graph.name);
278
279        // Scope: name → NodeId for resolving references within this graph
280        let mut scope: HashMap<String, NodeId> = HashMap::new();
281
282        for stmt in &graph.body {
283            match stmt {
284                GraphStmt::Input(input) => {
285                    let ty = self.lower_type_expr(&input.ty)?;
286                    let id = ir_graph.add_node(&input.name, OpKind::Identity, vec![], ty);
287                    ir_graph.inputs.push(id);
288                    scope.insert(input.name.clone(), id);
289                }
290                GraphStmt::Param(param) => {
291                    let ty = self.lower_type_expr(&param.ty)?;
292                    let id = ir_graph.add_node(&param.name, OpKind::Identity, vec![], ty.clone());
293                    scope.insert(param.name.clone(), id);
294
295                    // Parse init strategy
296                    let init = self.parse_init_strategy(&param.attrs);
297                    let frozen = self.parse_frozen(&param.attrs);
298
299                    ir_graph.params.push(IrParam {
300                        node_id: id,
301                        name: param.name.clone(),
302                        ty,
303                        init,
304                        frozen,
305                    });
306                }
307                GraphStmt::Node(node) => {
308                    let (op, inputs) =
309                        self.lower_node_body(&node.stmts, &mut ir_graph, &mut scope)?;
310                    let output_type = node
311                        .ty
312                        .as_ref()
313                        .map(|t| self.lower_type_expr(t))
314                        .transpose()?
315                        .unwrap_or(IrType::Unknown);
316
317                    let id = ir_graph.add_node(&node.name, op, inputs, output_type);
318
319                    // Transfer hints
320                    for stmt in &node.stmts {
321                        if let NodeStmt::Hint(hint, _) = stmt {
322                            let ir_hint = match hint {
323                                HintKind::RecomputeInBackward => IrHint::RecomputeInBackward,
324                                HintKind::MustPreserve => IrHint::MustPreserve,
325                                HintKind::InPlace => IrHint::InPlace,
326                                HintKind::NoGrad => IrHint::NoGrad,
327                                HintKind::Custom(s) => IrHint::Custom(s.clone()),
328                            };
329                            ir_graph.node_mut(id).hints.push(ir_hint);
330                        }
331                    }
332
333                    // Transfer extra attributes
334                    for stmt in &node.stmts {
335                        if let NodeStmt::Attr(key, val, _) = stmt {
336                            if let Some(attr) = self.expr_to_attr(val) {
337                                ir_graph.node_mut(id).attrs.insert(key.clone(), attr);
338                            }
339                        }
340                    }
341
342                    scope.insert(node.name.clone(), id);
343                }
344                GraphStmt::Output(output) => {
345                    // Derive a user-facing output name
346                    let out_name = output.name.clone().unwrap_or_else(|| {
347                        // Use the identifier name if it's a simple ident
348                        if let Expr::Ident(ref ident, _) = output.expr {
349                            ident.clone()
350                        } else {
351                            format!("__output_{}", ir_graph.outputs.len())
352                        }
353                    });
354
355                    // Resolve the output expression to a node id
356                    if let Some(id) = self.try_resolve_ident(&output.expr, &scope) {
357                        ir_graph.outputs.push(IrOutput {
358                            name: out_name,
359                            node_id: id,
360                        });
361                    } else {
362                        // Complex expression — lower it into a new node
363                        let (op, inputs) =
364                            self.lower_expr_to_op(&output.expr, &mut ir_graph, &mut scope)?;
365                        let id = ir_graph.add_node(out_name.clone(), op, inputs, IrType::Unknown);
366                        ir_graph.outputs.push(IrOutput {
367                            name: out_name,
368                            node_id: id,
369                        });
370                    }
371                }
372                GraphStmt::Assert(assert_stmt) => {
373                    ir_graph.asserts.push(IrAssert {
374                        message: assert_stmt.message.clone(),
375                        expr_text: format!("{:?}", assert_stmt.condition),
376                    });
377                }
378                GraphStmt::Check(_) => {
379                    // Check blocks contain multiple asserts — lower each
380                    // (Future: implement check block lowering)
381                }
382            }
383        }
384
385        Ok(ir_graph)
386    }
387
388    /// Parse param attributes to find the init strategy.
389    fn parse_init_strategy(&self, attrs: &[ParamAttr]) -> InitStrategy {
390        for attr in attrs {
391            if attr.key == "init" {
392                if let Expr::Str(s, _) = &attr.value {
393                    return match s.as_str() {
394                        "zeros" => InitStrategy::Zeros,
395                        "ones" => InitStrategy::Ones,
396                        "xavier_uniform" => InitStrategy::XavierUniform,
397                        "xavier_normal" => InitStrategy::XavierNormal,
398                        "kaiming_uniform" => InitStrategy::KaimingUniform,
399                        "kaiming_normal" => InitStrategy::KaimingNormal,
400                        s if s.starts_with("normal(") => {
401                            // Parse "normal(mean, std)"
402                            let inner = &s[7..s.len().saturating_sub(1)];
403                            let parts: Vec<&str> = inner.split(',').collect();
404                            if parts.len() == 2 {
405                                let mean = parts[0].trim().parse().unwrap_or(0.0);
406                                let std = parts[1].trim().parse().unwrap_or(0.02);
407                                InitStrategy::Normal { mean, std }
408                            } else {
409                                InitStrategy::Custom(s.to_string())
410                            }
411                        }
412                        s if s.starts_with("uniform(") => {
413                            let inner = &s[8..s.len().saturating_sub(1)];
414                            let parts: Vec<&str> = inner.split(',').collect();
415                            if parts.len() == 2 {
416                                let low = parts[0].trim().parse().unwrap_or(-1.0);
417                                let high = parts[1].trim().parse().unwrap_or(1.0);
418                                InitStrategy::Uniform { low, high }
419                            } else {
420                                InitStrategy::Custom(s.to_string())
421                            }
422                        }
423                        other => InitStrategy::Custom(other.to_string()),
424                    };
425                }
426            }
427        }
428        InitStrategy::Zeros // default
429    }
430
431    /// Parse param attributes to find frozen flag.
432    fn parse_frozen(&self, attrs: &[ParamAttr]) -> bool {
433        for attr in attrs {
434            if attr.key == "frozen" {
435                if let Expr::Bool(b, _) = &attr.value {
436                    return *b;
437                }
438            }
439        }
440        false
441    }
442
443    /// Lower the body of a node { op: ...; input: ...; } into (OpKind, inputs).
444    fn lower_node_body(
445        &self,
446        stmts: &[NodeStmt],
447        graph: &mut IrGraph,
448        scope: &mut HashMap<String, NodeId>,
449    ) -> Result<(OpKind, Vec<NodeId>)> {
450        for stmt in stmts {
451            if let NodeStmt::Op(expr, _) = stmt {
452                return self.lower_expr_to_op(expr, graph, scope);
453            }
454        }
455        // No op found — identity
456        Ok((OpKind::Identity, vec![]))
457    }
458
459    /// Lower an expression into an (OpKind, inputs) pair.
460    /// May create intermediate nodes in the graph for nested expressions.
461    fn lower_expr_to_op(
462        &self,
463        expr: &Expr,
464        graph: &mut IrGraph,
465        scope: &mut HashMap<String, NodeId>,
466    ) -> Result<(OpKind, Vec<NodeId>)> {
467        match expr {
468            // Function call: matmul(x, W), layer_norm(h, w, b, eps: 1e-5), etc.
469            Expr::Call { func, args, span } => self.lower_call(func, args, graph, scope, *span),
470            // Binary expression: tok_emb + pos_emb
471            Expr::Binary {
472                left, op, right, ..
473            } => {
474                let left_id = self.lower_expr_to_node(left, graph, scope)?;
475                let right_id = self.lower_expr_to_node(right, graph, scope)?;
476
477                let op_kind = match op {
478                    BinOp::Add => OpKind::Add,
479                    BinOp::Sub => OpKind::Sub,
480                    BinOp::Mul => OpKind::Mul,
481                    BinOp::Div => OpKind::Div,
482                    BinOp::Mod => OpKind::Mod,
483                    BinOp::Pow => OpKind::Pow,
484                    BinOp::Eq => OpKind::Equal,
485                    BinOp::Ne => OpKind::NotEqual,
486                    BinOp::Lt => OpKind::Less,
487                    BinOp::Gt => OpKind::Greater,
488                    BinOp::Le => OpKind::LessEqual,
489                    BinOp::Ge => OpKind::GreaterEqual,
490                    BinOp::And => OpKind::And,
491                    BinOp::Or => OpKind::Or,
492                    _ => OpKind::Custom {
493                        name: format!("{op:?}"),
494                        attrs: HashMap::new(),
495                    },
496                };
497
498                Ok((op_kind, vec![left_id, right_id]))
499            }
500            // Unary expression: -x, !cond
501            Expr::Unary { op, operand, .. } => {
502                let operand_id = self.lower_expr_to_node(operand, graph, scope)?;
503                let op_kind = match op {
504                    UnaryOp::Neg => OpKind::Neg,
505                    UnaryOp::Not => OpKind::Not,
506                    UnaryOp::BitNot => OpKind::Custom {
507                        name: "bitnot".into(),
508                        attrs: HashMap::new(),
509                    },
510                };
511                Ok((op_kind, vec![operand_id]))
512            }
513            // Identifier reference: just references another node
514            Expr::Ident(name, _) => {
515                if let Some(&id) = scope.get(name) {
516                    Ok((OpKind::Identity, vec![id]))
517                } else {
518                    Ok((OpKind::Identity, vec![]))
519                }
520            }
521            // Repeat expression: repeat(4) { transformer_block(h, ...) }
522            Expr::RepeatExpr { count, body, .. } => {
523                let n = match count.as_ref() {
524                    Expr::Int(n, _) => *n,
525                    _ => 1,
526                };
527                let (inner_op, inner_inputs) = self.lower_expr_to_op(body, graph, scope)?;
528                // Wrap the inner op in a Repeat, preserving the body operation
529                Ok((
530                    OpKind::Repeat {
531                        count: n,
532                        body_op: Box::new(inner_op),
533                    },
534                    inner_inputs,
535                ))
536            }
537            // Constants
538            Expr::Int(n, _) => Ok((OpKind::Constant(ConstantValue::Int(*n)), vec![])),
539            Expr::Float(f, _) => Ok((OpKind::Constant(ConstantValue::Float(*f)), vec![])),
540            Expr::Str(s, _) => Ok((OpKind::Constant(ConstantValue::Str(s.clone())), vec![])),
541            Expr::Bool(b, _) => Ok((OpKind::Constant(ConstantValue::Bool(*b)), vec![])),
542            Expr::Null(_) => Ok((OpKind::Constant(ConstantValue::Null), vec![])),
543            // Fallback
544            _ => Ok((
545                OpKind::Custom {
546                    name: format!("{expr:?}"),
547                    attrs: HashMap::new(),
548                },
549                vec![],
550            )),
551        }
552    }
553
554    /// Lower an expression to a single NodeId, creating intermediate nodes as needed.
555    fn lower_expr_to_node(
556        &self,
557        expr: &Expr,
558        graph: &mut IrGraph,
559        scope: &mut HashMap<String, NodeId>,
560    ) -> Result<NodeId> {
561        // Fast path: simple identifier — look up in scope
562        if let Expr::Ident(name, _) = expr {
563            if let Some(&id) = scope.get(name) {
564                return Ok(id);
565            }
566        }
567
568        // Complex expression: lower to (op, inputs) and create an anonymous node
569        let (op, inputs) = self.lower_expr_to_op(expr, graph, scope)?;
570        let name = format!("__anon_{}", graph.len());
571        let id = graph.add_node(name, op, inputs, IrType::Unknown);
572        Ok(id)
573    }
574
575    /// Lower a function call to (OpKind, inputs).
576    /// Positional args are lowered to nodes (creating intermediates for nested calls).
577    fn lower_call(
578        &self,
579        func: &str,
580        args: &[Arg],
581        graph: &mut IrGraph,
582        scope: &mut HashMap<String, NodeId>,
583        _span: Span,
584    ) -> Result<(OpKind, Vec<NodeId>)> {
585        // Resolve positional args to node IDs, creating intermediate nodes as needed
586        let input_ids: Vec<NodeId> = args
587            .iter()
588            .filter(|a| a.name.is_none())
589            .map(|a| self.lower_expr_to_node(&a.value, graph, scope))
590            .collect::<Result<Vec<_>>>()?;
591
592        // Collect named args
593        let named: HashMap<&str, &Expr> = args
594            .iter()
595            .filter_map(|a| a.name.as_deref().map(|name| (name, &a.value)))
596            .collect();
597
598        let op = match func {
599            //  Core ops 
600            "matmul" | "mm" => OpKind::MatMul,
601            "add" => OpKind::Add,
602            "sub" => OpKind::Sub,
603            "mul" => OpKind::Mul,
604            "div" => OpKind::Div,
605
606            //  Activations 
607            "relu" => OpKind::Relu,
608            "gelu" => OpKind::Gelu,
609            "silu" | "swish" => OpKind::Silu,
610            "sigmoid" => OpKind::Sigmoid,
611            "tanh" => OpKind::Tanh,
612            "exp" => OpKind::Exp,
613            "log" => OpKind::Log,
614            "sqrt" => OpKind::Sqrt,
615
616            //  Softmax 
617            "softmax" => {
618                let dim = self.get_named_int(&named, "dim").unwrap_or(-1);
619                OpKind::Softmax { dim }
620            }
621
622            //  Embedding 
623            "embedding" | "Embedding" => OpKind::Embedding,
624
625            //  Linear 
626            "linear" | "Linear" => {
627                let bias = named
628                    .get("bias")
629                    .is_none_or(|e| matches!(e, Expr::Bool(true, _)));
630                OpKind::Linear { bias }
631            }
632
633            //  Normalization 
634            "layer_norm" | "LayerNorm" => {
635                let eps = self.get_named_float(&named, "eps").unwrap_or(1e-5);
636                OpKind::LayerNorm { eps }
637            }
638            "batch_norm" | "BatchNorm" => {
639                let eps = self.get_named_float(&named, "eps").unwrap_or(1e-5);
640                OpKind::BatchNorm { eps }
641            }
642
643            //  Attention 
644            "multi_head_attention" | "MultiHeadAttention" => {
645                let n_heads = self.get_named_int(&named, "n_heads").unwrap_or(1);
646                OpKind::MultiHeadAttention { n_heads }
647            }
648            "transformer_block" | "TransformerBlock" => {
649                let n_heads = self.get_named_int(&named, "n_heads").unwrap_or(1);
650                OpKind::TransformerBlock { n_heads }
651            }
652
653            //  Reduction 
654            "sum" => {
655                let dim = self.get_named_int(&named, "dim").unwrap_or(-1);
656                let keepdim = self.get_named_bool(&named, "keepdim").unwrap_or(false);
657                OpKind::Sum {
658                    dims: vec![dim],
659                    keepdim,
660                }
661            }
662            "mean" => {
663                let dim = self.get_named_int(&named, "dim").unwrap_or(-1);
664                let keepdim = self.get_named_bool(&named, "keepdim").unwrap_or(false);
665                OpKind::Mean {
666                    dims: vec![dim],
667                    keepdim,
668                }
669            }
670
671            //  Shape ops 
672            "transpose" => OpKind::Transpose,
673            "concat" | "cat" => {
674                let dim = self.get_named_int(&named, "dim").unwrap_or(0);
675                OpKind::Concat { dim }
676            }
677
678            //  Dropout 
679            "dropout" | "Dropout" => {
680                let p = self.get_named_float(&named, "p").unwrap_or(0.0);
681                OpKind::Dropout { p }
682            }
683
684            //  Loss 
685            "cross_entropy" | "cross_entropy_loss" => OpKind::CrossEntropy,
686            "mse_loss" => OpKind::MseLoss,
687
688            //  Range 
689            "range" => OpKind::Range,
690
691            //  Fallback: custom/unknown op 
692            other => OpKind::Custom {
693                name: other.to_string(),
694                attrs: named
695                    .iter()
696                    .map(|(k, v)| {
697                        (
698                            k.to_string(),
699                            self.expr_to_attr(v)
700                                .unwrap_or(AttrValue::Str(format!("{v:?}"))),
701                        )
702                    })
703                    .collect(),
704            },
705        };
706
707        Ok((op, input_ids))
708    }
709
710    /// Try to resolve a simple identifier to a node ID (no node creation).
711    fn try_resolve_ident(&self, expr: &Expr, scope: &HashMap<String, NodeId>) -> Option<NodeId> {
712        match expr {
713            Expr::Ident(name, _) => scope.get(name).copied(),
714            _ => None,
715        }
716    }
717
718    /// Extract a named integer argument from a call.
719    fn get_named_int(&self, named: &HashMap<&str, &Expr>, key: &str) -> Option<i64> {
720        named.get(key).and_then(|e| match e {
721            Expr::Int(n, _) => Some(*n),
722            _ => None,
723        })
724    }
725
726    /// Extract a named float argument from a call.
727    fn get_named_float(&self, named: &HashMap<&str, &Expr>, key: &str) -> Option<f64> {
728        named.get(key).and_then(|e| match e {
729            Expr::Float(f, _) => Some(*f),
730            Expr::Int(n, _) => Some(*n as f64),
731            _ => None,
732        })
733    }
734
735    /// Extract a named boolean argument from a call.
736    fn get_named_bool(&self, named: &HashMap<&str, &Expr>, key: &str) -> Option<bool> {
737        named.get(key).and_then(|e| match e {
738            Expr::Bool(b, _) => Some(*b),
739            _ => None,
740        })
741    }
742
743    /// Convert an expression to an attribute value.
744    fn expr_to_attr(&self, expr: &Expr) -> Option<AttrValue> {
745        match expr {
746            Expr::Int(n, _) => Some(AttrValue::Int(*n)),
747            Expr::Float(f, _) => Some(AttrValue::Float(*f)),
748            Expr::Str(s, _) => Some(AttrValue::Str(s.clone())),
749            Expr::Bool(b, _) => Some(AttrValue::Bool(*b)),
750            Expr::List(items, _) => {
751                let vals: Vec<AttrValue> =
752                    items.iter().filter_map(|e| self.expr_to_attr(e)).collect();
753                Some(AttrValue::List(vals))
754            }
755            _ => None,
756        }
757    }
758
759    // @training
760
761    fn lower_training(&self, training: &TrainingBlock) -> Result<TrainingConfig> {
762        let mut model_graph = String::new();
763        let mut loss = String::new();
764        let mut optimizer = OptimizerConfig {
765            kind: "SGD".into(),
766            lr: 0.01,
767            extra: HashMap::new(),
768        };
769        let mut lr_schedule = None;
770        let mut grad_clip = None;
771        let mut precision = "fp32".to_string();
772        let mut epochs: i64 = 1;
773        let mut batch_size: i64 = 1;
774        let mut accumulation_steps: i64 = 1;
775
776        for field in &training.fields {
777            match field {
778                TrainingField::Model(name, _) => model_graph = name.clone(),
779                TrainingField::Loss(name, _) => loss = name.clone(),
780                TrainingField::Optimizer(fields, _) => {
781                    optimizer = self.lower_optimizer_config(fields)?;
782                }
783                TrainingField::LrSchedule(fields, _) => {
784                    lr_schedule = Some(self.lower_lr_schedule_config(fields)?);
785                }
786                TrainingField::GradClip(fields, _) => {
787                    grad_clip = Some(self.lower_grad_clip_config(fields)?);
788                }
789                TrainingField::Generic(f) => match f.key.as_str() {
790                    "precision" => {
791                        if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
792                            precision = s;
793                        }
794                    }
795                    "epochs" => {
796                        if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(&f.value) {
797                            epochs = n;
798                        }
799                    }
800                    "batch_size" => {
801                        if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(&f.value) {
802                            batch_size = n;
803                        }
804                    }
805                    "accumulation_steps" => {
806                        if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(&f.value) {
807                            accumulation_steps = n;
808                        }
809                    }
810                    _ => {}
811                },
812            }
813        }
814
815        Ok(TrainingConfig {
816            model_graph,
817            loss,
818            optimizer,
819            lr_schedule,
820            grad_clip,
821            precision,
822            epochs,
823            batch_size,
824            accumulation_steps,
825        })
826    }
827
828    fn lower_optimizer_config(&self, fields: &[ExprField]) -> Result<OptimizerConfig> {
829        let mut kind = "SGD".to_string();
830        let mut lr = 0.01;
831        let mut extra = HashMap::new();
832
833        for f in fields {
834            match f.key.as_str() {
835                "type" => {
836                    if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
837                        kind = s;
838                    }
839                }
840                "lr" | "learning_rate" => {
841                    if let Ok(val) = self.eval_config_expr(&f.value) {
842                        match &val {
843                            ConfigValue::Float(v) => lr = *v,
844                            ConfigValue::Int(n) => lr = *n as f64,
845                            _ => {}
846                        }
847                    }
848                }
849                other => {
850                    if let Ok(val) = self.eval_config_expr(&f.value) {
851                        extra.insert(other.to_string(), val);
852                    }
853                }
854            }
855        }
856
857        Ok(OptimizerConfig { kind, lr, extra })
858    }
859
860    fn lower_lr_schedule_config(&self, fields: &[ExprField]) -> Result<LrScheduleConfig> {
861        let mut kind = "constant".to_string();
862        let mut extra = HashMap::new();
863
864        for f in fields {
865            match f.key.as_str() {
866                "type" => {
867                    if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
868                        kind = s;
869                    }
870                }
871                other => {
872                    if let Ok(val) = self.eval_config_expr(&f.value) {
873                        extra.insert(other.to_string(), val);
874                    }
875                }
876            }
877        }
878
879        Ok(LrScheduleConfig { kind, extra })
880    }
881
882    fn lower_grad_clip_config(&self, fields: &[ExprField]) -> Result<GradClipConfig> {
883        let mut kind = "none".to_string();
884        let mut extra = HashMap::new();
885
886        for f in fields {
887            match f.key.as_str() {
888                "type" => {
889                    if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
890                        kind = s;
891                    }
892                }
893                other => {
894                    if let Ok(val) = self.eval_config_expr(&f.value) {
895                        extra.insert(other.to_string(), val);
896                    }
897                }
898            }
899        }
900
901        Ok(GradClipConfig { kind, extra })
902    }
903
904    // @inference
905
906    fn lower_inference(&self, inference: &InferenceBlock) -> Result<InferenceConfig> {
907        let mut model_graph = String::new();
908        let mut quantization = None;
909        let mut generation = None;
910
911        for field in &inference.fields {
912            match field {
913                InferenceField::Model(name, _) => model_graph = name.clone(),
914                InferenceField::Quantization(fields, _) => {
915                    let mut map = HashMap::new();
916                    for f in fields {
917                        if let Ok(val) = self.eval_config_expr(&f.value) {
918                            map.insert(f.key.clone(), val);
919                        }
920                    }
921                    quantization = Some(map);
922                }
923                InferenceField::Generation(fields, _) => {
924                    let mut map = HashMap::new();
925                    for f in fields {
926                        if let Ok(val) = self.eval_config_expr(&f.value) {
927                            map.insert(f.key.clone(), val);
928                        }
929                    }
930                    generation = Some(map);
931                }
932                _ => {}
933            }
934        }
935
936        Ok(InferenceConfig {
937            model_graph,
938            quantization,
939            generation,
940        })
941    }
942}
943
944// Helpers
945
946/// Convert AST DTypeKind to IR DType.
947fn lower_dtype(dt: &DTypeKind) -> DType {
948    match dt {
949        DTypeKind::F16 => DType::F16,
950        DTypeKind::F32 => DType::F32,
951        DTypeKind::F64 => DType::F64,
952        DTypeKind::Bf16 => DType::Bf16,
953        DTypeKind::I8 => DType::I8,
954        DTypeKind::I16 => DType::I16,
955        DTypeKind::I32 => DType::I32,
956        DTypeKind::I64 => DType::I64,
957        DTypeKind::U8 => DType::U8,
958        DTypeKind::U16 => DType::U16,
959        DTypeKind::U32 => DType::U32,
960        DTypeKind::U64 => DType::U64,
961        DTypeKind::Bool => DType::Bool,
962        DTypeKind::Complex64 => DType::Complex64,
963        DTypeKind::Complex128 => DType::Complex128,
964    }
965}