shrew_ir/
graph.rs

1// Graph IR — Validated intermediate representation for .sw programs
2//
3// The Graph IR is the layer between the raw AST and execution. It represents
4// the computation as a directed acyclic graph (DAG) where:
5//
6//   - Each node is a well-typed operation (matmul, add, relu, etc.)
7//   - Edges represent data flow between operations
8//   - Parameters and inputs are explicitly tracked
9//   - The graph is validated: no dangling references, type-checked dims
10//
11// The IR is designed for:
12//   1. Validation — catch errors before execution
13//   2. Optimization — constant folding, fusion, dead code elimination
14//   3. Scheduling — determine execution order
15//   4. Code generation — emit backend-specific kernels
16//
17// ARCHITECTURE:
18//   AST (from parser) ► Lowering ► GraphIR (this module)
19//                                         │
20//                                         ├ validate()
21//                                         ├ optimize()
22//                                         └ schedule() → execution plan
23
24use std::collections::HashMap;
25use std::fmt;
26
27// Node identifiers
28
29/// Unique identifier for a node in the graph.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub struct NodeId(pub usize);
32
33impl fmt::Display for NodeId {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        write!(f, "n{}", self.0)
36    }
37}
38
39// Type information
40
41/// A resolved tensor shape — dimensions are either concrete or symbolic.
42#[derive(Debug, Clone, PartialEq)]
43pub enum Dim {
44    /// Known at compile time: 768, 50257
45    Fixed(i64),
46    /// Symbolic, resolved at runtime: Batch, SeqLen
47    Symbolic(String),
48    /// Unknown / dynamic
49    Dynamic,
50}
51
52impl fmt::Display for Dim {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            Dim::Fixed(n) => write!(f, "{n}"),
56            Dim::Symbolic(s) => write!(f, "{s}"),
57            Dim::Dynamic => write!(f, "?"),
58        }
59    }
60}
61
62/// Resolved data type.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum DType {
65    F16,
66    F32,
67    F64,
68    Bf16,
69    I8,
70    I16,
71    I32,
72    I64,
73    U8,
74    U16,
75    U32,
76    U64,
77    Bool,
78    Complex64,
79    Complex128,
80}
81
82impl fmt::Display for DType {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            DType::F16 => write!(f, "f16"),
86            DType::F32 => write!(f, "f32"),
87            DType::F64 => write!(f, "f64"),
88            DType::Bf16 => write!(f, "bf16"),
89            DType::I8 => write!(f, "i8"),
90            DType::I16 => write!(f, "i16"),
91            DType::I32 => write!(f, "i32"),
92            DType::I64 => write!(f, "i64"),
93            DType::U8 => write!(f, "u8"),
94            DType::U16 => write!(f, "u16"),
95            DType::U32 => write!(f, "u32"),
96            DType::U64 => write!(f, "u64"),
97            DType::Bool => write!(f, "bool"),
98            DType::Complex64 => write!(f, "complex64"),
99            DType::Complex128 => write!(f, "complex128"),
100        }
101    }
102}
103
104/// The resolved type of a value in the graph.
105#[derive(Debug, Clone, PartialEq)]
106pub enum IrType {
107    /// A tensor with shape and dtype.
108    Tensor { shape: Vec<Dim>, dtype: DType },
109    /// A scalar value.
110    Scalar(DType),
111    /// Integer (used for things like dimension values).
112    Int,
113    /// String (used for attribute values).
114    Str,
115    /// Boolean.
116    Boolean,
117    /// Unknown / to be inferred.
118    Unknown,
119}
120
121impl fmt::Display for IrType {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        match self {
124            IrType::Tensor { shape, dtype } => {
125                let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
126                write!(f, "Tensor<[{}], {}>", dims.join(", "), dtype)
127            }
128            IrType::Scalar(dt) => write!(f, "{dt}"),
129            IrType::Int => write!(f, "int"),
130            IrType::Str => write!(f, "str"),
131            IrType::Boolean => write!(f, "bool"),
132            IrType::Unknown => write!(f, "?"),
133        }
134    }
135}
136
137// Operations
138
139/// An operation in the computation graph.
140#[derive(Debug, Clone)]
141pub enum OpKind {
142    // Tensor creation
143    /// Look up rows in an embedding table.
144    Embedding,
145    /// Generate a range of values.
146    Range,
147
148    // Unary ops
149    Neg,
150    Relu,
151    Gelu,
152    Silu,
153    Sigmoid,
154    Tanh,
155    Exp,
156    Log,
157    Sqrt,
158    Transpose,
159
160    // Binary ops
161    Add,
162    Sub,
163    Mul,
164    Div,
165    Mod,
166    Pow,
167    MatMul,
168
169    // Reduction ops
170    Sum {
171        dims: Vec<i64>,
172        keepdim: bool,
173    },
174    Mean {
175        dims: Vec<i64>,
176        keepdim: bool,
177    },
178    Max {
179        dim: i64,
180        keepdim: bool,
181    },
182    Min {
183        dim: i64,
184        keepdim: bool,
185    },
186    Variance {
187        dims: Vec<i64>,
188        keepdim: bool,
189    },
190
191    //  Normalization 
192    LayerNorm {
193        eps: f64,
194    },
195    BatchNorm {
196        eps: f64,
197    },
198
199    //  Attention 
200    MultiHeadAttention {
201        n_heads: i64,
202    },
203    TransformerBlock {
204        n_heads: i64,
205    },
206    Softmax {
207        dim: i64,
208    },
209
210    //  Shape ops 
211    Reshape {
212        target_shape: Vec<Dim>,
213    },
214    View {
215        target_shape: Vec<Dim>,
216    },
217    Permute {
218        dims: Vec<i64>,
219    },
220    Concat {
221        dim: i64,
222    },
223    Split {
224        dim: i64,
225        chunks: i64,
226    },
227    Expand {
228        target_shape: Vec<Dim>,
229    },
230
231    //  Dropout 
232    Dropout {
233        p: f64,
234    },
235
236    //  Linear 
237    Linear {
238        bias: bool,
239    },
240
241    //  Loss functions 
242    CrossEntropy,
243    MseLoss,
244
245    //  Comparison 
246    Equal,
247    NotEqual,
248    Less,
249    Greater,
250    LessEqual,
251    GreaterEqual,
252
253    //  Logical 
254    And,
255    Or,
256    Not,
257
258    //  Constants 
259    Constant(ConstantValue),
260
261    //  Control flow 
262    Repeat {
263        count: i64,
264        body_op: Box<OpKind>,
265    },
266
267    //  Custom / user-defined 
268    Custom {
269        name: String,
270        attrs: HashMap<String, AttrValue>,
271    },
272
273    //  Graph call (calling another @graph) 
274    Call {
275        graph_name: String,
276    },
277
278    //  Identity (pass-through, used for inputs) 
279    Identity,
280}
281
282impl fmt::Display for OpKind {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        match self {
285            OpKind::Add => write!(f, "add"),
286            OpKind::Sub => write!(f, "sub"),
287            OpKind::Mul => write!(f, "mul"),
288            OpKind::Div => write!(f, "div"),
289            OpKind::MatMul => write!(f, "matmul"),
290            OpKind::Embedding => write!(f, "embedding"),
291            OpKind::LayerNorm { eps } => write!(f, "layer_norm(eps={eps})"),
292            OpKind::Softmax { dim } => write!(f, "softmax(dim={dim})"),
293            OpKind::Relu => write!(f, "relu"),
294            OpKind::Gelu => write!(f, "gelu"),
295            OpKind::Transpose => write!(f, "transpose"),
296            OpKind::Constant(v) => write!(f, "const({v})"),
297            OpKind::Custom { name, .. } => write!(f, "custom({name})"),
298            OpKind::Call { graph_name } => write!(f, "call({graph_name})"),
299            OpKind::Identity => write!(f, "identity"),
300            other => write!(f, "{other:?}"),
301        }
302    }
303}
304
305/// A constant value embedded in the graph.
306#[derive(Debug, Clone)]
307pub enum ConstantValue {
308    Int(i64),
309    Float(f64),
310    Str(String),
311    Bool(bool),
312    Null,
313}
314
315impl fmt::Display for ConstantValue {
316    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        match self {
318            ConstantValue::Int(n) => write!(f, "{n}"),
319            ConstantValue::Float(v) => write!(f, "{v}"),
320            ConstantValue::Str(s) => write!(f, "\"{s}\""),
321            ConstantValue::Bool(b) => write!(f, "{b}"),
322            ConstantValue::Null => write!(f, "null"),
323        }
324    }
325}
326
327/// An attribute value (for custom ops and node attrs).
328#[derive(Debug, Clone)]
329pub enum AttrValue {
330    Int(i64),
331    Float(f64),
332    Str(String),
333    Bool(bool),
334    List(Vec<AttrValue>),
335}
336
337// Graph nodes
338
339/// A node in the computation graph.
340#[derive(Debug, Clone)]
341pub struct IrNode {
342    /// Unique identifier.
343    pub id: NodeId,
344    /// User-visible name (from the .sw source).
345    pub name: String,
346    /// The operation this node performs.
347    pub op: OpKind,
348    /// Input edges: which nodes feed into this one.
349    pub inputs: Vec<NodeId>,
350    /// The resolved type of this node's output.
351    pub output_type: IrType,
352    /// Optional attributes (key-value metadata).
353    pub attrs: HashMap<String, AttrValue>,
354    /// Execution hints from the source.
355    pub hints: Vec<IrHint>,
356}
357
358/// Execution hints attached to nodes.
359#[derive(Debug, Clone, PartialEq, Eq)]
360pub enum IrHint {
361    RecomputeInBackward,
362    MustPreserve,
363    InPlace,
364    NoGrad,
365    Custom(String),
366}
367
368// Parameters (learnable weights)
369
370/// A learnable parameter in the graph.
371#[derive(Debug, Clone)]
372pub struct IrParam {
373    /// Reference to the node that holds the param value.
374    pub node_id: NodeId,
375    /// Parameter name.
376    pub name: String,
377    /// Type (always a Tensor).
378    pub ty: IrType,
379    /// Initialization strategy.
380    pub init: InitStrategy,
381    /// Whether the parameter is frozen (no gradient).
382    pub frozen: bool,
383}
384
385/// How to initialize a parameter.
386#[derive(Debug, Clone)]
387pub enum InitStrategy {
388    Zeros,
389    Ones,
390    Normal { mean: f64, std: f64 },
391    Uniform { low: f64, high: f64 },
392    XavierUniform,
393    XavierNormal,
394    KaimingUniform,
395    KaimingNormal,
396    Custom(String),
397}
398
399// Assertions
400
401/// A compile-time or runtime assertion from @assert.
402#[derive(Debug, Clone)]
403pub struct IrAssert {
404    /// Human-readable description.
405    pub message: Option<String>,
406    /// The assertion expression as a string (for diagnostics).
407    pub expr_text: String,
408}
409
410// Graph definition
411
412/// An output of a graph, preserving the user-facing name across optimisations.
413#[derive(Debug, Clone)]
414pub struct IrOutput {
415    /// The user-visible output name (e.g. "out" from `output out;`).
416    pub name: String,
417    /// The node that produces this output (may be remapped by optimisations).
418    pub node_id: NodeId,
419}
420
421#[derive(Debug, Clone)]
422pub struct IrGraph {
423    /// Graph name (e.g., "Forward").
424    pub name: String,
425    /// All nodes, indexed by NodeId.
426    pub nodes: Vec<IrNode>,
427    /// Input nodes (by NodeId).
428    pub inputs: Vec<NodeId>,
429    /// Named outputs.
430    pub outputs: Vec<IrOutput>,
431    /// Learnable parameters.
432    pub params: Vec<IrParam>,
433    /// Assertions to verify.
434    pub asserts: Vec<IrAssert>,
435    /// Node lookup by name.
436    pub name_to_id: HashMap<String, NodeId>,
437}
438
439impl IrGraph {
440    /// Create a new empty graph.
441    pub fn new(name: impl Into<String>) -> Self {
442        Self {
443            name: name.into(),
444            nodes: Vec::new(),
445            inputs: Vec::new(),
446            outputs: Vec::new(),
447            params: Vec::new(),
448            asserts: Vec::new(),
449            name_to_id: HashMap::new(),
450        }
451    }
452
453    /// Add a node and return its NodeId.
454    pub fn add_node(
455        &mut self,
456        name: impl Into<String>,
457        op: OpKind,
458        inputs: Vec<NodeId>,
459        output_type: IrType,
460    ) -> NodeId {
461        let id = NodeId(self.nodes.len());
462        let name = name.into();
463        self.name_to_id.insert(name.clone(), id);
464        self.nodes.push(IrNode {
465            id,
466            name,
467            op,
468            inputs,
469            output_type,
470            attrs: HashMap::new(),
471            hints: Vec::new(),
472        });
473        id
474    }
475
476    /// Register a node as an output. Uses the node's name as the output name.
477    pub fn add_output(&mut self, node_id: NodeId) {
478        let name = self.nodes[node_id.0].name.clone();
479        self.outputs.push(IrOutput { name, node_id });
480    }
481
482    /// Register a node as an output with a custom name.
483    pub fn add_output_named(&mut self, name: impl Into<String>, node_id: NodeId) {
484        self.outputs.push(IrOutput {
485            name: name.into(),
486            node_id,
487        });
488    }
489
490    /// Look up a node by name.
491    pub fn get_node(&self, name: &str) -> Option<&IrNode> {
492        self.name_to_id.get(name).map(|id| &self.nodes[id.0])
493    }
494
495    /// Get a node by its ID.
496    pub fn node(&self, id: NodeId) -> &IrNode {
497        &self.nodes[id.0]
498    }
499
500    /// Get a mutable reference to a node by its ID.
501    pub fn node_mut(&mut self, id: NodeId) -> &mut IrNode {
502        &mut self.nodes[id.0]
503    }
504
505    /// Return the total number of nodes.
506    pub fn len(&self) -> usize {
507        self.nodes.len()
508    }
509
510    /// Check if the graph is empty.
511    pub fn is_empty(&self) -> bool {
512        self.nodes.is_empty()
513    }
514
515    /// Return a topological ordering of node IDs for execution.
516    pub fn topo_order(&self) -> Vec<NodeId> {
517        let n = self.nodes.len();
518        let mut in_degree = vec![0u32; n];
519        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
520
521        for node in &self.nodes {
522            for &inp in &node.inputs {
523                adj[inp.0].push(node.id.0);
524                in_degree[node.id.0] += 1;
525            }
526        }
527
528        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
529        let mut order = Vec::with_capacity(n);
530
531        while let Some(u) = queue.pop() {
532            order.push(NodeId(u));
533            for &v in &adj[u] {
534                in_degree[v] -= 1;
535                if in_degree[v] == 0 {
536                    queue.push(v);
537                }
538            }
539        }
540
541        order
542    }
543
544    /// Pretty-print the graph for debugging.
545    pub fn dump(&self) -> String {
546        let mut out = format!(
547            "=== IrGraph: {} ({} nodes) ===\n",
548            self.name,
549            self.nodes.len()
550        );
551
552        for node in &self.nodes {
553            let inputs: Vec<String> = node
554                .inputs
555                .iter()
556                .map(|id| format!("{}({})", self.nodes[id.0].name, id))
557                .collect();
558            out.push_str(&format!(
559                "  {} [{}]: {} <- [{}] :: {}\n",
560                node.id,
561                node.name,
562                node.op,
563                inputs.join(", "),
564                node.output_type,
565            ));
566            for hint in &node.hints {
567                out.push_str(&format!("    hint: {hint:?}\n"));
568            }
569        }
570
571        out.push_str(&format!("  inputs:  {:?}\n", self.inputs));
572        out.push_str(&format!(
573            "  outputs: {:?}\n",
574            self.outputs
575                .iter()
576                .map(|o| (&o.name, o.node_id))
577                .collect::<Vec<_>>()
578        ));
579        out.push_str(&format!("  params:  {} total\n", self.params.len()));
580        out
581    }
582}
583
584// Program-level IR (multiple graphs + config)
585
586/// The full program IR — lowered from the AST.
587#[derive(Debug, Clone)]
588pub struct IrProgram {
589    /// Model metadata.
590    pub metadata: HashMap<String, String>,
591    /// Configuration values.
592    pub config: HashMap<String, ConfigValue>,
593    /// Named type aliases.
594    pub type_aliases: HashMap<String, IrType>,
595    /// Computation graphs.
596    pub graphs: Vec<IrGraph>,
597    /// Training configuration.
598    pub training: Option<TrainingConfig>,
599    /// Inference configuration.
600    pub inference: Option<InferenceConfig>,
601}
602
603impl IrProgram {
604    pub fn new() -> Self {
605        Self {
606            metadata: HashMap::new(),
607            config: HashMap::new(),
608            type_aliases: HashMap::new(),
609            graphs: Vec::new(),
610            training: None,
611            inference: None,
612        }
613    }
614
615    /// Find a graph by name.
616    pub fn get_graph(&self, name: &str) -> Option<&IrGraph> {
617        self.graphs.iter().find(|g| g.name == name)
618    }
619}
620
621impl Default for IrProgram {
622    fn default() -> Self {
623        Self::new()
624    }
625}
626
627/// A configuration value.
628#[derive(Debug, Clone)]
629pub enum ConfigValue {
630    Int(i64),
631    Float(f64),
632    Str(String),
633    Bool(bool),
634    List(Vec<ConfigValue>),
635}
636
637impl fmt::Display for ConfigValue {
638    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
639        match self {
640            ConfigValue::Int(n) => write!(f, "{n}"),
641            ConfigValue::Float(v) => write!(f, "{v}"),
642            ConfigValue::Str(s) => write!(f, "\"{s}\""),
643            ConfigValue::Bool(b) => write!(f, "{b}"),
644            ConfigValue::List(items) => {
645                let s: Vec<String> = items.iter().map(|i| i.to_string()).collect();
646                write!(f, "[{}]", s.join(", "))
647            }
648        }
649    }
650}
651
652/// Training configuration (lowered from @training).
653#[derive(Debug, Clone)]
654pub struct TrainingConfig {
655    pub model_graph: String,
656    pub loss: String,
657    pub optimizer: OptimizerConfig,
658    pub lr_schedule: Option<LrScheduleConfig>,
659    pub grad_clip: Option<GradClipConfig>,
660    pub precision: String,
661    pub epochs: i64,
662    pub batch_size: i64,
663    pub accumulation_steps: i64,
664}
665
666#[derive(Debug, Clone)]
667pub struct OptimizerConfig {
668    pub kind: String,
669    pub lr: f64,
670    pub extra: HashMap<String, ConfigValue>,
671}
672
673#[derive(Debug, Clone)]
674pub struct LrScheduleConfig {
675    pub kind: String,
676    pub extra: HashMap<String, ConfigValue>,
677}
678
679#[derive(Debug, Clone)]
680pub struct GradClipConfig {
681    pub kind: String,
682    pub extra: HashMap<String, ConfigValue>,
683}
684
685/// Inference configuration (lowered from @inference).
686#[derive(Debug, Clone)]
687pub struct InferenceConfig {
688    pub model_graph: String,
689    pub quantization: Option<HashMap<String, ConfigValue>>,
690    pub generation: Option<HashMap<String, ConfigValue>>,
691}