1use std::collections::HashMap;
17
18use crate::ast::*;
19use crate::error::{Error, ErrorKind, Result};
20use crate::graph::*;
21use crate::token::Span;
22
23pub fn lower(program: &Program) -> Result<IrProgram> {
25 let mut ctx = LowerCtx::new();
26 ctx.lower_program(program)?;
27 Ok(ctx.ir)
28}
29
30struct LowerCtx {
33 ir: IrProgram,
34}
35
36impl LowerCtx {
37 fn new() -> Self {
38 Self {
39 ir: IrProgram::new(),
40 }
41 }
42
43 fn lower_program(&mut self, program: &Program) -> Result<()> {
44 for item in &program.items {
46 match item {
47 TopLevel::Metadata(m) => self.lower_metadata(m)?,
48 TopLevel::Config(c) => self.lower_config(c)?,
49 TopLevel::Types(t) => self.lower_types(t)?,
50 _ => {}
51 }
52 }
53
54 for item in &program.items {
56 match item {
57 TopLevel::Graph(g) => {
58 let ir_graph = self.lower_graph(g)?;
59 self.ir.graphs.push(ir_graph);
60 }
61 TopLevel::Training(t) => {
62 self.ir.training = Some(self.lower_training(t)?);
63 }
64 TopLevel::Inference(i) => {
65 self.ir.inference = Some(self.lower_inference(i)?);
66 }
67 TopLevel::Metadata(_) | TopLevel::Config(_) | TopLevel::Types(_) => {}
69 _ => {}
72 }
73 }
74
75 Ok(())
76 }
77
78 fn lower_metadata(&mut self, meta: &MetadataBlock) -> Result<()> {
81 for field in &meta.fields {
82 let value = match &field.value {
83 Literal::Str(s, _) => s.clone(),
84 Literal::Int(n, _) => n.to_string(),
85 Literal::Float(f, _) => f.to_string(),
86 Literal::Bool(b, _) => b.to_string(),
87 _ => format!("{:?}", field.value),
88 };
89 self.ir.metadata.insert(field.key.clone(), value);
90 }
91 Ok(())
92 }
93
94 fn lower_config(&mut self, config: &ConfigBlock) -> Result<()> {
97 for field in &config.fields {
98 let value = self.eval_config_expr(&field.value)?;
99 self.ir.config.insert(field.key.clone(), value);
100 }
101 Ok(())
102 }
103
104 fn eval_config_expr(&self, expr: &Expr) -> Result<ConfigValue> {
107 match expr {
108 Expr::Int(n, _) => Ok(ConfigValue::Int(*n)),
109 Expr::Float(f, _) => Ok(ConfigValue::Float(*f)),
110 Expr::Str(s, _) => Ok(ConfigValue::Str(s.clone())),
111 Expr::Bool(b, _) => Ok(ConfigValue::Bool(*b)),
112 Expr::List(items, _) => {
113 let vals: Result<Vec<_>> = items.iter().map(|e| self.eval_config_expr(e)).collect();
114 Ok(ConfigValue::List(vals?))
115 }
116 Expr::Ident(name, _) => {
117 if let Some(val) = self.ir.config.get(name) {
119 Ok(val.clone())
120 } else {
121 Ok(ConfigValue::Str(name.clone()))
123 }
124 }
125 Expr::Binary {
126 left,
127 op,
128 right,
129 span,
130 } => {
131 let l = self.eval_config_expr(left)?;
132 let r = self.eval_config_expr(right)?;
133 self.eval_binary_config(&l, *op, &r, *span)
134 }
135 Expr::Unary {
136 op: UnaryOp::Neg,
137 operand,
138 ..
139 } => {
140 let val = self.eval_config_expr(operand)?;
141 match val {
142 ConfigValue::Int(n) => Ok(ConfigValue::Int(-n)),
143 ConfigValue::Float(f) => Ok(ConfigValue::Float(-f)),
144 _ => Ok(val),
145 }
146 }
147 _ => {
148 Ok(ConfigValue::Str(format!("{expr:?}")))
150 }
151 }
152 }
153
154 fn eval_binary_config(
155 &self,
156 left: &ConfigValue,
157 op: BinOp,
158 right: &ConfigValue,
159 span: Span,
160 ) -> Result<ConfigValue> {
161 match (left, right) {
162 (ConfigValue::Int(a), ConfigValue::Int(b)) => {
163 let result = match op {
164 BinOp::Add => a + b,
165 BinOp::Sub => a - b,
166 BinOp::Mul => a * b,
167 BinOp::Div => {
168 if *b == 0 {
169 return Err(Error::new(
170 ErrorKind::Message("division by zero".into()),
171 span,
172 ));
173 }
174 a / b
175 }
176 BinOp::Mod => a % b,
177 BinOp::Pow => a.pow(*b as u32),
178 _ => return Ok(ConfigValue::Str(format!("{a} {op:?} {b}"))),
179 };
180 Ok(ConfigValue::Int(result))
181 }
182 (ConfigValue::Float(a), ConfigValue::Float(b)) => {
183 let result = match op {
184 BinOp::Add => a + b,
185 BinOp::Sub => a - b,
186 BinOp::Mul => a * b,
187 BinOp::Div => a / b,
188 BinOp::Mod => a % b,
189 BinOp::Pow => a.powf(*b),
190 _ => return Ok(ConfigValue::Str(format!("{a} {op:?} {b}"))),
191 };
192 Ok(ConfigValue::Float(result))
193 }
194 (ConfigValue::Int(a), ConfigValue::Float(b)) => self.eval_binary_config(
195 &ConfigValue::Float(*a as f64),
196 op,
197 &ConfigValue::Float(*b),
198 span,
199 ),
200 (ConfigValue::Float(a), ConfigValue::Int(b)) => self.eval_binary_config(
201 &ConfigValue::Float(*a),
202 op,
203 &ConfigValue::Float(*b as f64),
204 span,
205 ),
206 _ => Ok(ConfigValue::Str(format!("{left:?} {op:?} {right:?}"))),
207 }
208 }
209
210 fn lower_types(&mut self, types: &TypesBlock) -> Result<()> {
213 for def in &types.defs {
214 let ty = self.lower_type_expr(&def.ty)?;
215 self.ir.type_aliases.insert(def.name.clone(), ty);
216 }
217 Ok(())
218 }
219
220 fn lower_type_expr(&self, ty: &TypeExpr) -> Result<IrType> {
221 match ty {
222 TypeExpr::Tensor { dims, dtype, .. } => {
223 let shape: Vec<Dim> = dims.iter().map(|d| self.lower_dim(d)).collect();
224 let dt = lower_dtype(dtype);
225 Ok(IrType::Tensor { shape, dtype: dt })
226 }
227 TypeExpr::Scalar(dt, _) => Ok(IrType::Scalar(lower_dtype(dt))),
228 TypeExpr::Named(name, _) => {
229 if let Some(resolved) = self.ir.type_aliases.get(name) {
230 Ok(resolved.clone())
231 } else {
232 Ok(IrType::Unknown)
234 }
235 }
236 TypeExpr::Dynamic(_) => Ok(IrType::Unknown),
237 TypeExpr::IntDim(n, _) => {
238 Ok(IrType::Tensor {
240 shape: vec![Dim::Fixed(*n)],
241 dtype: DType::F32,
242 })
243 }
244 _ => Ok(IrType::Unknown),
245 }
246 }
247
248 fn lower_dim(&self, dim: &Dimension) -> Dim {
249 match dim {
250 Dimension::Named(name, _) => {
251 if let Some(ConfigValue::Int(n)) = self.ir.config.get(name) {
253 Dim::Fixed(*n)
254 } else {
255 Dim::Symbolic(name.clone())
256 }
257 }
258 Dimension::Concrete(n, _) => Dim::Fixed(*n),
259 Dimension::Dynamic(_) => Dim::Dynamic,
260 Dimension::Inferred(_) => Dim::Dynamic,
261 Dimension::Computed(expr, _) => {
262 if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(expr) {
264 Dim::Fixed(n)
265 } else if let Ok(ConfigValue::Float(f)) = self.eval_config_expr(expr) {
266 Dim::Fixed(f as i64)
267 } else {
268 Dim::Dynamic
269 }
270 }
271 }
272 }
273
274 fn lower_graph(&self, graph: &GraphBlock) -> Result<IrGraph> {
277 let mut ir_graph = IrGraph::new(&graph.name);
278
279 let mut scope: HashMap<String, NodeId> = HashMap::new();
281
282 for stmt in &graph.body {
283 match stmt {
284 GraphStmt::Input(input) => {
285 let ty = self.lower_type_expr(&input.ty)?;
286 let id = ir_graph.add_node(&input.name, OpKind::Identity, vec![], ty);
287 ir_graph.inputs.push(id);
288 scope.insert(input.name.clone(), id);
289 }
290 GraphStmt::Param(param) => {
291 let ty = self.lower_type_expr(¶m.ty)?;
292 let id = ir_graph.add_node(¶m.name, OpKind::Identity, vec![], ty.clone());
293 scope.insert(param.name.clone(), id);
294
295 let init = self.parse_init_strategy(¶m.attrs);
297 let frozen = self.parse_frozen(¶m.attrs);
298
299 ir_graph.params.push(IrParam {
300 node_id: id,
301 name: param.name.clone(),
302 ty,
303 init,
304 frozen,
305 });
306 }
307 GraphStmt::Node(node) => {
308 let (op, inputs) =
309 self.lower_node_body(&node.stmts, &mut ir_graph, &mut scope)?;
310 let output_type = node
311 .ty
312 .as_ref()
313 .map(|t| self.lower_type_expr(t))
314 .transpose()?
315 .unwrap_or(IrType::Unknown);
316
317 let id = ir_graph.add_node(&node.name, op, inputs, output_type);
318
319 for stmt in &node.stmts {
321 if let NodeStmt::Hint(hint, _) = stmt {
322 let ir_hint = match hint {
323 HintKind::RecomputeInBackward => IrHint::RecomputeInBackward,
324 HintKind::MustPreserve => IrHint::MustPreserve,
325 HintKind::InPlace => IrHint::InPlace,
326 HintKind::NoGrad => IrHint::NoGrad,
327 HintKind::Custom(s) => IrHint::Custom(s.clone()),
328 };
329 ir_graph.node_mut(id).hints.push(ir_hint);
330 }
331 }
332
333 for stmt in &node.stmts {
335 if let NodeStmt::Attr(key, val, _) = stmt {
336 if let Some(attr) = self.expr_to_attr(val) {
337 ir_graph.node_mut(id).attrs.insert(key.clone(), attr);
338 }
339 }
340 }
341
342 scope.insert(node.name.clone(), id);
343 }
344 GraphStmt::Output(output) => {
345 let out_name = output.name.clone().unwrap_or_else(|| {
347 if let Expr::Ident(ref ident, _) = output.expr {
349 ident.clone()
350 } else {
351 format!("__output_{}", ir_graph.outputs.len())
352 }
353 });
354
355 if let Some(id) = self.try_resolve_ident(&output.expr, &scope) {
357 ir_graph.outputs.push(IrOutput {
358 name: out_name,
359 node_id: id,
360 });
361 } else {
362 let (op, inputs) =
364 self.lower_expr_to_op(&output.expr, &mut ir_graph, &mut scope)?;
365 let id = ir_graph.add_node(out_name.clone(), op, inputs, IrType::Unknown);
366 ir_graph.outputs.push(IrOutput {
367 name: out_name,
368 node_id: id,
369 });
370 }
371 }
372 GraphStmt::Assert(assert_stmt) => {
373 ir_graph.asserts.push(IrAssert {
374 message: assert_stmt.message.clone(),
375 expr_text: format!("{:?}", assert_stmt.condition),
376 });
377 }
378 GraphStmt::Check(_) => {
379 }
382 }
383 }
384
385 Ok(ir_graph)
386 }
387
388 fn parse_init_strategy(&self, attrs: &[ParamAttr]) -> InitStrategy {
390 for attr in attrs {
391 if attr.key == "init" {
392 if let Expr::Str(s, _) = &attr.value {
393 return match s.as_str() {
394 "zeros" => InitStrategy::Zeros,
395 "ones" => InitStrategy::Ones,
396 "xavier_uniform" => InitStrategy::XavierUniform,
397 "xavier_normal" => InitStrategy::XavierNormal,
398 "kaiming_uniform" => InitStrategy::KaimingUniform,
399 "kaiming_normal" => InitStrategy::KaimingNormal,
400 s if s.starts_with("normal(") => {
401 let inner = &s[7..s.len().saturating_sub(1)];
403 let parts: Vec<&str> = inner.split(',').collect();
404 if parts.len() == 2 {
405 let mean = parts[0].trim().parse().unwrap_or(0.0);
406 let std = parts[1].trim().parse().unwrap_or(0.02);
407 InitStrategy::Normal { mean, std }
408 } else {
409 InitStrategy::Custom(s.to_string())
410 }
411 }
412 s if s.starts_with("uniform(") => {
413 let inner = &s[8..s.len().saturating_sub(1)];
414 let parts: Vec<&str> = inner.split(',').collect();
415 if parts.len() == 2 {
416 let low = parts[0].trim().parse().unwrap_or(-1.0);
417 let high = parts[1].trim().parse().unwrap_or(1.0);
418 InitStrategy::Uniform { low, high }
419 } else {
420 InitStrategy::Custom(s.to_string())
421 }
422 }
423 other => InitStrategy::Custom(other.to_string()),
424 };
425 }
426 }
427 }
428 InitStrategy::Zeros }
430
431 fn parse_frozen(&self, attrs: &[ParamAttr]) -> bool {
433 for attr in attrs {
434 if attr.key == "frozen" {
435 if let Expr::Bool(b, _) = &attr.value {
436 return *b;
437 }
438 }
439 }
440 false
441 }
442
443 fn lower_node_body(
445 &self,
446 stmts: &[NodeStmt],
447 graph: &mut IrGraph,
448 scope: &mut HashMap<String, NodeId>,
449 ) -> Result<(OpKind, Vec<NodeId>)> {
450 for stmt in stmts {
451 if let NodeStmt::Op(expr, _) = stmt {
452 return self.lower_expr_to_op(expr, graph, scope);
453 }
454 }
455 Ok((OpKind::Identity, vec![]))
457 }
458
459 fn lower_expr_to_op(
462 &self,
463 expr: &Expr,
464 graph: &mut IrGraph,
465 scope: &mut HashMap<String, NodeId>,
466 ) -> Result<(OpKind, Vec<NodeId>)> {
467 match expr {
468 Expr::Call { func, args, span } => self.lower_call(func, args, graph, scope, *span),
470 Expr::Binary {
472 left, op, right, ..
473 } => {
474 let left_id = self.lower_expr_to_node(left, graph, scope)?;
475 let right_id = self.lower_expr_to_node(right, graph, scope)?;
476
477 let op_kind = match op {
478 BinOp::Add => OpKind::Add,
479 BinOp::Sub => OpKind::Sub,
480 BinOp::Mul => OpKind::Mul,
481 BinOp::Div => OpKind::Div,
482 BinOp::Mod => OpKind::Mod,
483 BinOp::Pow => OpKind::Pow,
484 BinOp::Eq => OpKind::Equal,
485 BinOp::Ne => OpKind::NotEqual,
486 BinOp::Lt => OpKind::Less,
487 BinOp::Gt => OpKind::Greater,
488 BinOp::Le => OpKind::LessEqual,
489 BinOp::Ge => OpKind::GreaterEqual,
490 BinOp::And => OpKind::And,
491 BinOp::Or => OpKind::Or,
492 _ => OpKind::Custom {
493 name: format!("{op:?}"),
494 attrs: HashMap::new(),
495 },
496 };
497
498 Ok((op_kind, vec![left_id, right_id]))
499 }
500 Expr::Unary { op, operand, .. } => {
502 let operand_id = self.lower_expr_to_node(operand, graph, scope)?;
503 let op_kind = match op {
504 UnaryOp::Neg => OpKind::Neg,
505 UnaryOp::Not => OpKind::Not,
506 UnaryOp::BitNot => OpKind::Custom {
507 name: "bitnot".into(),
508 attrs: HashMap::new(),
509 },
510 };
511 Ok((op_kind, vec![operand_id]))
512 }
513 Expr::Ident(name, _) => {
515 if let Some(&id) = scope.get(name) {
516 Ok((OpKind::Identity, vec![id]))
517 } else {
518 Ok((OpKind::Identity, vec![]))
519 }
520 }
521 Expr::RepeatExpr { count, body, .. } => {
523 let n = match count.as_ref() {
524 Expr::Int(n, _) => *n,
525 _ => 1,
526 };
527 let (inner_op, inner_inputs) = self.lower_expr_to_op(body, graph, scope)?;
528 Ok((
530 OpKind::Repeat {
531 count: n,
532 body_op: Box::new(inner_op),
533 },
534 inner_inputs,
535 ))
536 }
537 Expr::Int(n, _) => Ok((OpKind::Constant(ConstantValue::Int(*n)), vec![])),
539 Expr::Float(f, _) => Ok((OpKind::Constant(ConstantValue::Float(*f)), vec![])),
540 Expr::Str(s, _) => Ok((OpKind::Constant(ConstantValue::Str(s.clone())), vec![])),
541 Expr::Bool(b, _) => Ok((OpKind::Constant(ConstantValue::Bool(*b)), vec![])),
542 Expr::Null(_) => Ok((OpKind::Constant(ConstantValue::Null), vec![])),
543 _ => Ok((
545 OpKind::Custom {
546 name: format!("{expr:?}"),
547 attrs: HashMap::new(),
548 },
549 vec![],
550 )),
551 }
552 }
553
554 fn lower_expr_to_node(
556 &self,
557 expr: &Expr,
558 graph: &mut IrGraph,
559 scope: &mut HashMap<String, NodeId>,
560 ) -> Result<NodeId> {
561 if let Expr::Ident(name, _) = expr {
563 if let Some(&id) = scope.get(name) {
564 return Ok(id);
565 }
566 }
567
568 let (op, inputs) = self.lower_expr_to_op(expr, graph, scope)?;
570 let name = format!("__anon_{}", graph.len());
571 let id = graph.add_node(name, op, inputs, IrType::Unknown);
572 Ok(id)
573 }
574
575 fn lower_call(
578 &self,
579 func: &str,
580 args: &[Arg],
581 graph: &mut IrGraph,
582 scope: &mut HashMap<String, NodeId>,
583 _span: Span,
584 ) -> Result<(OpKind, Vec<NodeId>)> {
585 let input_ids: Vec<NodeId> = args
587 .iter()
588 .filter(|a| a.name.is_none())
589 .map(|a| self.lower_expr_to_node(&a.value, graph, scope))
590 .collect::<Result<Vec<_>>>()?;
591
592 let named: HashMap<&str, &Expr> = args
594 .iter()
595 .filter_map(|a| a.name.as_deref().map(|name| (name, &a.value)))
596 .collect();
597
598 let op = match func {
599 "matmul" | "mm" => OpKind::MatMul,
601 "add" => OpKind::Add,
602 "sub" => OpKind::Sub,
603 "mul" => OpKind::Mul,
604 "div" => OpKind::Div,
605
606 "relu" => OpKind::Relu,
608 "gelu" => OpKind::Gelu,
609 "silu" | "swish" => OpKind::Silu,
610 "sigmoid" => OpKind::Sigmoid,
611 "tanh" => OpKind::Tanh,
612 "exp" => OpKind::Exp,
613 "log" => OpKind::Log,
614 "sqrt" => OpKind::Sqrt,
615
616 "softmax" => {
618 let dim = self.get_named_int(&named, "dim").unwrap_or(-1);
619 OpKind::Softmax { dim }
620 }
621
622 "embedding" | "Embedding" => OpKind::Embedding,
624
625 "linear" | "Linear" => {
627 let bias = named
628 .get("bias")
629 .is_none_or(|e| matches!(e, Expr::Bool(true, _)));
630 OpKind::Linear { bias }
631 }
632
633 "layer_norm" | "LayerNorm" => {
635 let eps = self.get_named_float(&named, "eps").unwrap_or(1e-5);
636 OpKind::LayerNorm { eps }
637 }
638 "batch_norm" | "BatchNorm" => {
639 let eps = self.get_named_float(&named, "eps").unwrap_or(1e-5);
640 OpKind::BatchNorm { eps }
641 }
642
643 "multi_head_attention" | "MultiHeadAttention" => {
645 let n_heads = self.get_named_int(&named, "n_heads").unwrap_or(1);
646 OpKind::MultiHeadAttention { n_heads }
647 }
648 "transformer_block" | "TransformerBlock" => {
649 let n_heads = self.get_named_int(&named, "n_heads").unwrap_or(1);
650 OpKind::TransformerBlock { n_heads }
651 }
652
653 "sum" => {
655 let dim = self.get_named_int(&named, "dim").unwrap_or(-1);
656 let keepdim = self.get_named_bool(&named, "keepdim").unwrap_or(false);
657 OpKind::Sum {
658 dims: vec![dim],
659 keepdim,
660 }
661 }
662 "mean" => {
663 let dim = self.get_named_int(&named, "dim").unwrap_or(-1);
664 let keepdim = self.get_named_bool(&named, "keepdim").unwrap_or(false);
665 OpKind::Mean {
666 dims: vec![dim],
667 keepdim,
668 }
669 }
670
671 "transpose" => OpKind::Transpose,
673 "concat" | "cat" => {
674 let dim = self.get_named_int(&named, "dim").unwrap_or(0);
675 OpKind::Concat { dim }
676 }
677
678 "dropout" | "Dropout" => {
680 let p = self.get_named_float(&named, "p").unwrap_or(0.0);
681 OpKind::Dropout { p }
682 }
683
684 "cross_entropy" | "cross_entropy_loss" => OpKind::CrossEntropy,
686 "mse_loss" => OpKind::MseLoss,
687
688 "range" => OpKind::Range,
690
691 other => OpKind::Custom {
693 name: other.to_string(),
694 attrs: named
695 .iter()
696 .map(|(k, v)| {
697 (
698 k.to_string(),
699 self.expr_to_attr(v)
700 .unwrap_or(AttrValue::Str(format!("{v:?}"))),
701 )
702 })
703 .collect(),
704 },
705 };
706
707 Ok((op, input_ids))
708 }
709
710 fn try_resolve_ident(&self, expr: &Expr, scope: &HashMap<String, NodeId>) -> Option<NodeId> {
712 match expr {
713 Expr::Ident(name, _) => scope.get(name).copied(),
714 _ => None,
715 }
716 }
717
718 fn get_named_int(&self, named: &HashMap<&str, &Expr>, key: &str) -> Option<i64> {
720 named.get(key).and_then(|e| match e {
721 Expr::Int(n, _) => Some(*n),
722 _ => None,
723 })
724 }
725
726 fn get_named_float(&self, named: &HashMap<&str, &Expr>, key: &str) -> Option<f64> {
728 named.get(key).and_then(|e| match e {
729 Expr::Float(f, _) => Some(*f),
730 Expr::Int(n, _) => Some(*n as f64),
731 _ => None,
732 })
733 }
734
735 fn get_named_bool(&self, named: &HashMap<&str, &Expr>, key: &str) -> Option<bool> {
737 named.get(key).and_then(|e| match e {
738 Expr::Bool(b, _) => Some(*b),
739 _ => None,
740 })
741 }
742
743 fn expr_to_attr(&self, expr: &Expr) -> Option<AttrValue> {
745 match expr {
746 Expr::Int(n, _) => Some(AttrValue::Int(*n)),
747 Expr::Float(f, _) => Some(AttrValue::Float(*f)),
748 Expr::Str(s, _) => Some(AttrValue::Str(s.clone())),
749 Expr::Bool(b, _) => Some(AttrValue::Bool(*b)),
750 Expr::List(items, _) => {
751 let vals: Vec<AttrValue> =
752 items.iter().filter_map(|e| self.expr_to_attr(e)).collect();
753 Some(AttrValue::List(vals))
754 }
755 _ => None,
756 }
757 }
758
759 fn lower_training(&self, training: &TrainingBlock) -> Result<TrainingConfig> {
762 let mut model_graph = String::new();
763 let mut loss = String::new();
764 let mut optimizer = OptimizerConfig {
765 kind: "SGD".into(),
766 lr: 0.01,
767 extra: HashMap::new(),
768 };
769 let mut lr_schedule = None;
770 let mut grad_clip = None;
771 let mut precision = "fp32".to_string();
772 let mut epochs: i64 = 1;
773 let mut batch_size: i64 = 1;
774 let mut accumulation_steps: i64 = 1;
775
776 for field in &training.fields {
777 match field {
778 TrainingField::Model(name, _) => model_graph = name.clone(),
779 TrainingField::Loss(name, _) => loss = name.clone(),
780 TrainingField::Optimizer(fields, _) => {
781 optimizer = self.lower_optimizer_config(fields)?;
782 }
783 TrainingField::LrSchedule(fields, _) => {
784 lr_schedule = Some(self.lower_lr_schedule_config(fields)?);
785 }
786 TrainingField::GradClip(fields, _) => {
787 grad_clip = Some(self.lower_grad_clip_config(fields)?);
788 }
789 TrainingField::Generic(f) => match f.key.as_str() {
790 "precision" => {
791 if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
792 precision = s;
793 }
794 }
795 "epochs" => {
796 if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(&f.value) {
797 epochs = n;
798 }
799 }
800 "batch_size" => {
801 if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(&f.value) {
802 batch_size = n;
803 }
804 }
805 "accumulation_steps" => {
806 if let Ok(ConfigValue::Int(n)) = self.eval_config_expr(&f.value) {
807 accumulation_steps = n;
808 }
809 }
810 _ => {}
811 },
812 }
813 }
814
815 Ok(TrainingConfig {
816 model_graph,
817 loss,
818 optimizer,
819 lr_schedule,
820 grad_clip,
821 precision,
822 epochs,
823 batch_size,
824 accumulation_steps,
825 })
826 }
827
828 fn lower_optimizer_config(&self, fields: &[ExprField]) -> Result<OptimizerConfig> {
829 let mut kind = "SGD".to_string();
830 let mut lr = 0.01;
831 let mut extra = HashMap::new();
832
833 for f in fields {
834 match f.key.as_str() {
835 "type" => {
836 if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
837 kind = s;
838 }
839 }
840 "lr" | "learning_rate" => {
841 if let Ok(val) = self.eval_config_expr(&f.value) {
842 match &val {
843 ConfigValue::Float(v) => lr = *v,
844 ConfigValue::Int(n) => lr = *n as f64,
845 _ => {}
846 }
847 }
848 }
849 other => {
850 if let Ok(val) = self.eval_config_expr(&f.value) {
851 extra.insert(other.to_string(), val);
852 }
853 }
854 }
855 }
856
857 Ok(OptimizerConfig { kind, lr, extra })
858 }
859
860 fn lower_lr_schedule_config(&self, fields: &[ExprField]) -> Result<LrScheduleConfig> {
861 let mut kind = "constant".to_string();
862 let mut extra = HashMap::new();
863
864 for f in fields {
865 match f.key.as_str() {
866 "type" => {
867 if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
868 kind = s;
869 }
870 }
871 other => {
872 if let Ok(val) = self.eval_config_expr(&f.value) {
873 extra.insert(other.to_string(), val);
874 }
875 }
876 }
877 }
878
879 Ok(LrScheduleConfig { kind, extra })
880 }
881
882 fn lower_grad_clip_config(&self, fields: &[ExprField]) -> Result<GradClipConfig> {
883 let mut kind = "none".to_string();
884 let mut extra = HashMap::new();
885
886 for f in fields {
887 match f.key.as_str() {
888 "type" => {
889 if let Ok(ConfigValue::Str(s)) = self.eval_config_expr(&f.value) {
890 kind = s;
891 }
892 }
893 other => {
894 if let Ok(val) = self.eval_config_expr(&f.value) {
895 extra.insert(other.to_string(), val);
896 }
897 }
898 }
899 }
900
901 Ok(GradClipConfig { kind, extra })
902 }
903
904 fn lower_inference(&self, inference: &InferenceBlock) -> Result<InferenceConfig> {
907 let mut model_graph = String::new();
908 let mut quantization = None;
909 let mut generation = None;
910
911 for field in &inference.fields {
912 match field {
913 InferenceField::Model(name, _) => model_graph = name.clone(),
914 InferenceField::Quantization(fields, _) => {
915 let mut map = HashMap::new();
916 for f in fields {
917 if let Ok(val) = self.eval_config_expr(&f.value) {
918 map.insert(f.key.clone(), val);
919 }
920 }
921 quantization = Some(map);
922 }
923 InferenceField::Generation(fields, _) => {
924 let mut map = HashMap::new();
925 for f in fields {
926 if let Ok(val) = self.eval_config_expr(&f.value) {
927 map.insert(f.key.clone(), val);
928 }
929 }
930 generation = Some(map);
931 }
932 _ => {}
933 }
934 }
935
936 Ok(InferenceConfig {
937 model_graph,
938 quantization,
939 generation,
940 })
941 }
942}
943
944fn lower_dtype(dt: &DTypeKind) -> DType {
948 match dt {
949 DTypeKind::F16 => DType::F16,
950 DTypeKind::F32 => DType::F32,
951 DTypeKind::F64 => DType::F64,
952 DTypeKind::Bf16 => DType::Bf16,
953 DTypeKind::I8 => DType::I8,
954 DTypeKind::I16 => DType::I16,
955 DTypeKind::I32 => DType::I32,
956 DTypeKind::I64 => DType::I64,
957 DTypeKind::U8 => DType::U8,
958 DTypeKind::U16 => DType::U16,
959 DTypeKind::U32 => DType::U32,
960 DTypeKind::U64 => DType::U64,
961 DTypeKind::Bool => DType::Bool,
962 DTypeKind::Complex64 => DType::Complex64,
963 DTypeKind::Complex128 => DType::Complex128,
964 }
965}