1use std::collections::HashMap;
25use std::fmt;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub struct NodeId(pub usize);
32
33impl fmt::Display for NodeId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "n{}", self.0)
36 }
37}
38
39#[derive(Debug, Clone, PartialEq)]
43pub enum Dim {
44 Fixed(i64),
46 Symbolic(String),
48 Dynamic,
50}
51
52impl fmt::Display for Dim {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 match self {
55 Dim::Fixed(n) => write!(f, "{n}"),
56 Dim::Symbolic(s) => write!(f, "{s}"),
57 Dim::Dynamic => write!(f, "?"),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum DType {
65 F16,
66 F32,
67 F64,
68 Bf16,
69 I8,
70 I16,
71 I32,
72 I64,
73 U8,
74 U16,
75 U32,
76 U64,
77 Bool,
78 Complex64,
79 Complex128,
80}
81
82impl fmt::Display for DType {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 DType::F16 => write!(f, "f16"),
86 DType::F32 => write!(f, "f32"),
87 DType::F64 => write!(f, "f64"),
88 DType::Bf16 => write!(f, "bf16"),
89 DType::I8 => write!(f, "i8"),
90 DType::I16 => write!(f, "i16"),
91 DType::I32 => write!(f, "i32"),
92 DType::I64 => write!(f, "i64"),
93 DType::U8 => write!(f, "u8"),
94 DType::U16 => write!(f, "u16"),
95 DType::U32 => write!(f, "u32"),
96 DType::U64 => write!(f, "u64"),
97 DType::Bool => write!(f, "bool"),
98 DType::Complex64 => write!(f, "complex64"),
99 DType::Complex128 => write!(f, "complex128"),
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq)]
106pub enum IrType {
107 Tensor { shape: Vec<Dim>, dtype: DType },
109 Scalar(DType),
111 Int,
113 Str,
115 Boolean,
117 Unknown,
119}
120
121impl fmt::Display for IrType {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 match self {
124 IrType::Tensor { shape, dtype } => {
125 let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
126 write!(f, "Tensor<[{}], {}>", dims.join(", "), dtype)
127 }
128 IrType::Scalar(dt) => write!(f, "{dt}"),
129 IrType::Int => write!(f, "int"),
130 IrType::Str => write!(f, "str"),
131 IrType::Boolean => write!(f, "bool"),
132 IrType::Unknown => write!(f, "?"),
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
141pub enum OpKind {
142 Embedding,
145 Range,
147
148 Neg,
150 Relu,
151 Gelu,
152 Silu,
153 Sigmoid,
154 Tanh,
155 Exp,
156 Log,
157 Sqrt,
158 Transpose,
159
160 Add,
162 Sub,
163 Mul,
164 Div,
165 Mod,
166 Pow,
167 MatMul,
168
169 Sum {
171 dims: Vec<i64>,
172 keepdim: bool,
173 },
174 Mean {
175 dims: Vec<i64>,
176 keepdim: bool,
177 },
178 Max {
179 dim: i64,
180 keepdim: bool,
181 },
182 Min {
183 dim: i64,
184 keepdim: bool,
185 },
186 Variance {
187 dims: Vec<i64>,
188 keepdim: bool,
189 },
190
191 LayerNorm {
193 eps: f64,
194 },
195 BatchNorm {
196 eps: f64,
197 },
198
199 MultiHeadAttention {
201 n_heads: i64,
202 },
203 TransformerBlock {
204 n_heads: i64,
205 },
206 Softmax {
207 dim: i64,
208 },
209
210 Reshape {
212 target_shape: Vec<Dim>,
213 },
214 View {
215 target_shape: Vec<Dim>,
216 },
217 Permute {
218 dims: Vec<i64>,
219 },
220 Concat {
221 dim: i64,
222 },
223 Split {
224 dim: i64,
225 chunks: i64,
226 },
227 Expand {
228 target_shape: Vec<Dim>,
229 },
230
231 Dropout {
233 p: f64,
234 },
235
236 Linear {
238 bias: bool,
239 },
240
241 CrossEntropy,
243 MseLoss,
244
245 Equal,
247 NotEqual,
248 Less,
249 Greater,
250 LessEqual,
251 GreaterEqual,
252
253 And,
255 Or,
256 Not,
257
258 Constant(ConstantValue),
260
261 Repeat {
263 count: i64,
264 body_op: Box<OpKind>,
265 },
266
267 Custom {
269 name: String,
270 attrs: HashMap<String, AttrValue>,
271 },
272
273 Call {
275 graph_name: String,
276 },
277
278 Identity,
280}
281
282impl fmt::Display for OpKind {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 match self {
285 OpKind::Add => write!(f, "add"),
286 OpKind::Sub => write!(f, "sub"),
287 OpKind::Mul => write!(f, "mul"),
288 OpKind::Div => write!(f, "div"),
289 OpKind::MatMul => write!(f, "matmul"),
290 OpKind::Embedding => write!(f, "embedding"),
291 OpKind::LayerNorm { eps } => write!(f, "layer_norm(eps={eps})"),
292 OpKind::Softmax { dim } => write!(f, "softmax(dim={dim})"),
293 OpKind::Relu => write!(f, "relu"),
294 OpKind::Gelu => write!(f, "gelu"),
295 OpKind::Transpose => write!(f, "transpose"),
296 OpKind::Constant(v) => write!(f, "const({v})"),
297 OpKind::Custom { name, .. } => write!(f, "custom({name})"),
298 OpKind::Call { graph_name } => write!(f, "call({graph_name})"),
299 OpKind::Identity => write!(f, "identity"),
300 other => write!(f, "{other:?}"),
301 }
302 }
303}
304
305#[derive(Debug, Clone)]
307pub enum ConstantValue {
308 Int(i64),
309 Float(f64),
310 Str(String),
311 Bool(bool),
312 Null,
313}
314
315impl fmt::Display for ConstantValue {
316 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317 match self {
318 ConstantValue::Int(n) => write!(f, "{n}"),
319 ConstantValue::Float(v) => write!(f, "{v}"),
320 ConstantValue::Str(s) => write!(f, "\"{s}\""),
321 ConstantValue::Bool(b) => write!(f, "{b}"),
322 ConstantValue::Null => write!(f, "null"),
323 }
324 }
325}
326
327#[derive(Debug, Clone)]
329pub enum AttrValue {
330 Int(i64),
331 Float(f64),
332 Str(String),
333 Bool(bool),
334 List(Vec<AttrValue>),
335}
336
337#[derive(Debug, Clone)]
341pub struct IrNode {
342 pub id: NodeId,
344 pub name: String,
346 pub op: OpKind,
348 pub inputs: Vec<NodeId>,
350 pub output_type: IrType,
352 pub attrs: HashMap<String, AttrValue>,
354 pub hints: Vec<IrHint>,
356}
357
358#[derive(Debug, Clone, PartialEq, Eq)]
360pub enum IrHint {
361 RecomputeInBackward,
362 MustPreserve,
363 InPlace,
364 NoGrad,
365 Custom(String),
366}
367
368#[derive(Debug, Clone)]
372pub struct IrParam {
373 pub node_id: NodeId,
375 pub name: String,
377 pub ty: IrType,
379 pub init: InitStrategy,
381 pub frozen: bool,
383}
384
385#[derive(Debug, Clone)]
387pub enum InitStrategy {
388 Zeros,
389 Ones,
390 Normal { mean: f64, std: f64 },
391 Uniform { low: f64, high: f64 },
392 XavierUniform,
393 XavierNormal,
394 KaimingUniform,
395 KaimingNormal,
396 Custom(String),
397}
398
399#[derive(Debug, Clone)]
403pub struct IrAssert {
404 pub message: Option<String>,
406 pub expr_text: String,
408}
409
410#[derive(Debug, Clone)]
414pub struct IrOutput {
415 pub name: String,
417 pub node_id: NodeId,
419}
420
421#[derive(Debug, Clone)]
422pub struct IrGraph {
423 pub name: String,
425 pub nodes: Vec<IrNode>,
427 pub inputs: Vec<NodeId>,
429 pub outputs: Vec<IrOutput>,
431 pub params: Vec<IrParam>,
433 pub asserts: Vec<IrAssert>,
435 pub name_to_id: HashMap<String, NodeId>,
437}
438
439impl IrGraph {
440 pub fn new(name: impl Into<String>) -> Self {
442 Self {
443 name: name.into(),
444 nodes: Vec::new(),
445 inputs: Vec::new(),
446 outputs: Vec::new(),
447 params: Vec::new(),
448 asserts: Vec::new(),
449 name_to_id: HashMap::new(),
450 }
451 }
452
453 pub fn add_node(
455 &mut self,
456 name: impl Into<String>,
457 op: OpKind,
458 inputs: Vec<NodeId>,
459 output_type: IrType,
460 ) -> NodeId {
461 let id = NodeId(self.nodes.len());
462 let name = name.into();
463 self.name_to_id.insert(name.clone(), id);
464 self.nodes.push(IrNode {
465 id,
466 name,
467 op,
468 inputs,
469 output_type,
470 attrs: HashMap::new(),
471 hints: Vec::new(),
472 });
473 id
474 }
475
476 pub fn add_output(&mut self, node_id: NodeId) {
478 let name = self.nodes[node_id.0].name.clone();
479 self.outputs.push(IrOutput { name, node_id });
480 }
481
482 pub fn add_output_named(&mut self, name: impl Into<String>, node_id: NodeId) {
484 self.outputs.push(IrOutput {
485 name: name.into(),
486 node_id,
487 });
488 }
489
490 pub fn get_node(&self, name: &str) -> Option<&IrNode> {
492 self.name_to_id.get(name).map(|id| &self.nodes[id.0])
493 }
494
495 pub fn node(&self, id: NodeId) -> &IrNode {
497 &self.nodes[id.0]
498 }
499
500 pub fn node_mut(&mut self, id: NodeId) -> &mut IrNode {
502 &mut self.nodes[id.0]
503 }
504
505 pub fn len(&self) -> usize {
507 self.nodes.len()
508 }
509
510 pub fn is_empty(&self) -> bool {
512 self.nodes.is_empty()
513 }
514
515 pub fn topo_order(&self) -> Vec<NodeId> {
517 let n = self.nodes.len();
518 let mut in_degree = vec![0u32; n];
519 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
520
521 for node in &self.nodes {
522 for &inp in &node.inputs {
523 adj[inp.0].push(node.id.0);
524 in_degree[node.id.0] += 1;
525 }
526 }
527
528 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
529 let mut order = Vec::with_capacity(n);
530
531 while let Some(u) = queue.pop() {
532 order.push(NodeId(u));
533 for &v in &adj[u] {
534 in_degree[v] -= 1;
535 if in_degree[v] == 0 {
536 queue.push(v);
537 }
538 }
539 }
540
541 order
542 }
543
544 pub fn dump(&self) -> String {
546 let mut out = format!(
547 "=== IrGraph: {} ({} nodes) ===\n",
548 self.name,
549 self.nodes.len()
550 );
551
552 for node in &self.nodes {
553 let inputs: Vec<String> = node
554 .inputs
555 .iter()
556 .map(|id| format!("{}({})", self.nodes[id.0].name, id))
557 .collect();
558 out.push_str(&format!(
559 " {} [{}]: {} <- [{}] :: {}\n",
560 node.id,
561 node.name,
562 node.op,
563 inputs.join(", "),
564 node.output_type,
565 ));
566 for hint in &node.hints {
567 out.push_str(&format!(" hint: {hint:?}\n"));
568 }
569 }
570
571 out.push_str(&format!(" inputs: {:?}\n", self.inputs));
572 out.push_str(&format!(
573 " outputs: {:?}\n",
574 self.outputs
575 .iter()
576 .map(|o| (&o.name, o.node_id))
577 .collect::<Vec<_>>()
578 ));
579 out.push_str(&format!(" params: {} total\n", self.params.len()));
580 out
581 }
582}
583
584#[derive(Debug, Clone)]
588pub struct IrProgram {
589 pub metadata: HashMap<String, String>,
591 pub config: HashMap<String, ConfigValue>,
593 pub type_aliases: HashMap<String, IrType>,
595 pub graphs: Vec<IrGraph>,
597 pub training: Option<TrainingConfig>,
599 pub inference: Option<InferenceConfig>,
601}
602
603impl IrProgram {
604 pub fn new() -> Self {
605 Self {
606 metadata: HashMap::new(),
607 config: HashMap::new(),
608 type_aliases: HashMap::new(),
609 graphs: Vec::new(),
610 training: None,
611 inference: None,
612 }
613 }
614
615 pub fn get_graph(&self, name: &str) -> Option<&IrGraph> {
617 self.graphs.iter().find(|g| g.name == name)
618 }
619}
620
621impl Default for IrProgram {
622 fn default() -> Self {
623 Self::new()
624 }
625}
626
627#[derive(Debug, Clone)]
629pub enum ConfigValue {
630 Int(i64),
631 Float(f64),
632 Str(String),
633 Bool(bool),
634 List(Vec<ConfigValue>),
635}
636
637impl fmt::Display for ConfigValue {
638 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
639 match self {
640 ConfigValue::Int(n) => write!(f, "{n}"),
641 ConfigValue::Float(v) => write!(f, "{v}"),
642 ConfigValue::Str(s) => write!(f, "\"{s}\""),
643 ConfigValue::Bool(b) => write!(f, "{b}"),
644 ConfigValue::List(items) => {
645 let s: Vec<String> = items.iter().map(|i| i.to_string()).collect();
646 write!(f, "[{}]", s.join(", "))
647 }
648 }
649 }
650}
651
652#[derive(Debug, Clone)]
654pub struct TrainingConfig {
655 pub model_graph: String,
656 pub loss: String,
657 pub optimizer: OptimizerConfig,
658 pub lr_schedule: Option<LrScheduleConfig>,
659 pub grad_clip: Option<GradClipConfig>,
660 pub precision: String,
661 pub epochs: i64,
662 pub batch_size: i64,
663 pub accumulation_steps: i64,
664}
665
666#[derive(Debug, Clone)]
667pub struct OptimizerConfig {
668 pub kind: String,
669 pub lr: f64,
670 pub extra: HashMap<String, ConfigValue>,
671}
672
673#[derive(Debug, Clone)]
674pub struct LrScheduleConfig {
675 pub kind: String,
676 pub extra: HashMap<String, ConfigValue>,
677}
678
679#[derive(Debug, Clone)]
680pub struct GradClipConfig {
681 pub kind: String,
682 pub extra: HashMap<String, ConfigValue>,
683}
684
685#[derive(Debug, Clone)]
687pub struct InferenceConfig {
688 pub model_graph: String,
689 pub quantization: Option<HashMap<String, ConfigValue>>,
690 pub generation: Option<HashMap<String, ConfigValue>>,
691}