shrew_ir/
validate.rs

1// Graph Validation — Checks well-formedness of the IR before optimization
2//
3// Validation catches structural errors that the parser/lowering cannot detect:
4//
5//   1. Dangling inputs — node references a NodeId that doesn't exist
6//   2. Cycles — the graph must be a DAG
7//   3. Duplicate names — every node name must be unique within a graph
8//   4. Input/output validity — listed inputs/outputs exist in the graph
9//   5. Parameter validity — params reference real nodes with correct types
10//   6. Type consistency — binary ops require compatible types
11//   7. Program-level — training/inference reference existing graphs
12
13use crate::graph::*;
14use std::collections::{HashMap, HashSet};
15
16// Public API
17
18/// Validate an entire IrProgram. Returns all errors found (does not stop at first).
19pub fn validate(program: &IrProgram) -> std::result::Result<(), Vec<ValidationError>> {
20    let mut errors = Vec::new();
21
22    for graph in &program.graphs {
23        validate_graph(graph, &mut errors);
24    }
25
26    validate_program_refs(program, &mut errors);
27
28    if errors.is_empty() {
29        Ok(())
30    } else {
31        Err(errors)
32    }
33}
34
35/// Validate a single graph. Returns Ok(()) or the list of errors.
36pub fn validate_graph_standalone(graph: &IrGraph) -> std::result::Result<(), Vec<ValidationError>> {
37    let mut errors = Vec::new();
38    validate_graph(graph, &mut errors);
39    if errors.is_empty() {
40        Ok(())
41    } else {
42        Err(errors)
43    }
44}
45
46// Validation errors
47
48/// A validation error with context.
49#[derive(Debug, Clone)]
50pub struct ValidationError {
51    /// Which graph this error belongs to (empty for program-level).
52    pub graph: String,
53    /// Which node, if applicable.
54    pub node: Option<String>,
55    /// The error kind.
56    pub kind: ValidationErrorKind,
57}
58
59impl std::fmt::Display for ValidationError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        let loc = if let Some(node) = &self.node {
62            format!("{}::{}", self.graph, node)
63        } else if !self.graph.is_empty() {
64            self.graph.clone()
65        } else {
66            "program".to_string()
67        };
68        write!(f, "[{loc}] {}", self.kind)
69    }
70}
71
72/// Specific validation error kinds.
73#[derive(Debug, Clone)]
74pub enum ValidationErrorKind {
75    /// A node references a NodeId that doesn't exist.
76    DanglingInput { node_id: NodeId, input_id: NodeId },
77    /// The graph contains a cycle.
78    CycleDetected,
79    /// Two nodes share the same name.
80    DuplicateName { name: String },
81    /// An input listed in graph.inputs doesn't exist.
82    InvalidInput { node_id: NodeId },
83    /// An output listed in graph.outputs doesn't exist.
84    InvalidOutput { node_id: NodeId },
85    /// A parameter references a non-existent node.
86    InvalidParamNode { param_name: String, node_id: NodeId },
87    /// A parameter doesn't have a Tensor type.
88    ParamNotTensor { param_name: String },
89    /// Binary op has wrong number of inputs.
90    BinaryOpArity { expected: usize, got: usize },
91    /// Unary op has wrong number of inputs.
92    UnaryOpArity { expected: usize, got: usize },
93    /// Type mismatch on binary op inputs.
94    TypeMismatch { left: IrType, right: IrType },
95    /// Training config references a graph that doesn't exist.
96    TrainingGraphNotFound { name: String },
97    /// Inference config references a graph that doesn't exist.
98    InferenceGraphNotFound { name: String },
99    /// Graph has no outputs.
100    NoOutputs,
101    /// Reduction op with out-of-range dimension.
102    InvalidDim { dim: i64, rank: usize },
103}
104
105impl std::fmt::Display for ValidationErrorKind {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            Self::DanglingInput { node_id, input_id } => {
109                write!(f, "node {node_id} references non-existent input {input_id}")
110            }
111            Self::CycleDetected => write!(f, "cycle detected — graph is not a DAG"),
112            Self::DuplicateName { name } => write!(f, "duplicate node name \"{name}\""),
113            Self::InvalidInput { node_id } => write!(f, "graph input {node_id} does not exist"),
114            Self::InvalidOutput { node_id } => write!(f, "graph output {node_id} does not exist"),
115            Self::InvalidParamNode {
116                param_name,
117                node_id,
118            } => write!(
119                f,
120                "parameter \"{param_name}\" references non-existent node {node_id}"
121            ),
122            Self::ParamNotTensor { param_name } => {
123                write!(f, "parameter \"{param_name}\" must have Tensor type")
124            }
125            Self::BinaryOpArity { expected, got } => {
126                write!(f, "binary op expects {expected} inputs, got {got}")
127            }
128            Self::UnaryOpArity { expected, got } => {
129                write!(f, "unary op expects {expected} input, got {got}")
130            }
131            Self::TypeMismatch { left, right } => write!(f, "type mismatch: {left} vs {right}"),
132            Self::TrainingGraphNotFound { name } => {
133                write!(f, "@training references non-existent graph \"{name}\"")
134            }
135            Self::InferenceGraphNotFound { name } => {
136                write!(f, "@inference references non-existent graph \"{name}\"")
137            }
138            Self::NoOutputs => write!(f, "graph has no outputs"),
139            Self::InvalidDim { dim, rank } => {
140                write!(f, "dimension {dim} out of range for rank-{rank} tensor")
141            }
142        }
143    }
144}
145
146// Graph-level validation
147
148fn validate_graph(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
149    let gname = &graph.name;
150
151    // 1. No outputs check
152    if graph.outputs.is_empty() {
153        errors.push(ValidationError {
154            graph: gname.clone(),
155            node: None,
156            kind: ValidationErrorKind::NoOutputs,
157        });
158    }
159
160    // 2. Duplicate node names
161    check_duplicate_names(graph, errors);
162
163    // 3. Dangling input references
164    let has_dangling = check_dangling_inputs(graph, errors);
165
166    // 4. Graph inputs/outputs reference valid nodes
167    check_io_validity(graph, errors);
168
169    // 5. Parameter validity
170    check_params(graph, errors);
171
172    // 6. Op arity checks
173    check_op_arity(graph, errors);
174
175    // 7. Type consistency on binary ops (skip if dangling inputs)
176    if !has_dangling {
177        check_type_consistency(graph, errors);
178    }
179
180    // 8. Cycle detection (skip if dangling — topo_order would panic)
181    if !has_dangling {
182        check_acyclic(graph, errors);
183    }
184
185    // 9. Dimension bounds (skip if dangling)
186    if !has_dangling {
187        check_dim_bounds(graph, errors);
188    }
189}
190
191fn check_duplicate_names(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
192    let mut seen: HashMap<&str, usize> = HashMap::new();
193    for node in &graph.nodes {
194        let count = seen.entry(&node.name).or_insert(0);
195        *count += 1;
196        if *count == 2 {
197            // Report on second occurrence
198            errors.push(ValidationError {
199                graph: graph.name.clone(),
200                node: Some(node.name.clone()),
201                kind: ValidationErrorKind::DuplicateName {
202                    name: node.name.clone(),
203                },
204            });
205        }
206    }
207}
208
209fn check_dangling_inputs(graph: &IrGraph, errors: &mut Vec<ValidationError>) -> bool {
210    let max_id = graph.nodes.len();
211    let mut found = false;
212    for node in &graph.nodes {
213        for &inp in &node.inputs {
214            if inp.0 >= max_id {
215                found = true;
216                errors.push(ValidationError {
217                    graph: graph.name.clone(),
218                    node: Some(node.name.clone()),
219                    kind: ValidationErrorKind::DanglingInput {
220                        node_id: node.id,
221                        input_id: inp,
222                    },
223                });
224            }
225        }
226    }
227    found
228}
229
230fn check_io_validity(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
231    let max_id = graph.nodes.len();
232    for &id in &graph.inputs {
233        if id.0 >= max_id {
234            errors.push(ValidationError {
235                graph: graph.name.clone(),
236                node: None,
237                kind: ValidationErrorKind::InvalidInput { node_id: id },
238            });
239        }
240    }
241    for out in &graph.outputs {
242        if out.node_id.0 >= max_id {
243            errors.push(ValidationError {
244                graph: graph.name.clone(),
245                node: None,
246                kind: ValidationErrorKind::InvalidOutput {
247                    node_id: out.node_id,
248                },
249            });
250        }
251    }
252}
253
254fn check_params(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
255    let max_id = graph.nodes.len();
256    for param in &graph.params {
257        if param.node_id.0 >= max_id {
258            errors.push(ValidationError {
259                graph: graph.name.clone(),
260                node: None,
261                kind: ValidationErrorKind::InvalidParamNode {
262                    param_name: param.name.clone(),
263                    node_id: param.node_id,
264                },
265            });
266            continue;
267        }
268        // Params should be Tensor type (or Unknown before inference)
269        match &param.ty {
270            IrType::Tensor { .. } | IrType::Unknown => {}
271            _other => {
272                errors.push(ValidationError {
273                    graph: graph.name.clone(),
274                    node: Some(param.name.clone()),
275                    kind: ValidationErrorKind::ParamNotTensor {
276                        param_name: param.name.clone(),
277                    },
278                });
279            }
280        }
281    }
282}
283
284fn check_op_arity(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
285    for node in &graph.nodes {
286        let (min, max) = expected_arity(&node.op);
287        let got = node.inputs.len();
288        if got < min || got > max {
289            let kind = if is_binary_like(&node.op) {
290                ValidationErrorKind::BinaryOpArity { expected: min, got }
291            } else if is_unary_like(&node.op) {
292                ValidationErrorKind::UnaryOpArity { expected: min, got }
293            } else {
294                // For other ops, use binary arity error as generic
295                ValidationErrorKind::BinaryOpArity { expected: min, got }
296            };
297            errors.push(ValidationError {
298                graph: graph.name.clone(),
299                node: Some(node.name.clone()),
300                kind,
301            });
302        }
303    }
304}
305
306/// Return (min_inputs, max_inputs) for an op.
307fn expected_arity(op: &OpKind) -> (usize, usize) {
308    match op {
309        // Zero inputs
310        OpKind::Constant(_) | OpKind::Range => (0, 2),
311
312        // Exactly one input — unary ops
313        OpKind::Neg
314        | OpKind::Relu
315        | OpKind::Gelu
316        | OpKind::Silu
317        | OpKind::Sigmoid
318        | OpKind::Tanh
319        | OpKind::Exp
320        | OpKind::Log
321        | OpKind::Sqrt
322        | OpKind::Transpose
323        | OpKind::Not => (1, 1),
324
325        // One-input reductions / shape ops
326        OpKind::Sum { .. }
327        | OpKind::Mean { .. }
328        | OpKind::Max { .. }
329        | OpKind::Min { .. }
330        | OpKind::Variance { .. } => (1, 1),
331        OpKind::Reshape { .. }
332        | OpKind::View { .. }
333        | OpKind::Permute { .. }
334        | OpKind::Expand { .. } => (1, 1),
335        OpKind::Softmax { .. } => (1, 1),
336        OpKind::Dropout { .. } => (1, 1),
337
338        // Exactly two inputs — binary ops
339        OpKind::Add
340        | OpKind::Sub
341        | OpKind::Mul
342        | OpKind::Div
343        | OpKind::Mod
344        | OpKind::Pow
345        | OpKind::MatMul => (2, 2),
346        OpKind::Equal
347        | OpKind::NotEqual
348        | OpKind::Less
349        | OpKind::Greater
350        | OpKind::LessEqual
351        | OpKind::GreaterEqual => (2, 2),
352        OpKind::And | OpKind::Or => (2, 2),
353
354        // Normalization: input + weight + bias (2-3)
355        OpKind::LayerNorm { .. } | OpKind::BatchNorm { .. } => (1, 3),
356
357        // Embedding: table + indices (2), or just indices (1)
358        OpKind::Embedding => (1, 2),
359
360        // Linear: input [+ weight [+ bias]] (1-3)
361        OpKind::Linear { .. } => (1, 3),
362
363        // Concat/Split: variable
364        OpKind::Concat { .. } => (1, 64),
365        OpKind::Split { .. } => (1, 1),
366
367        // Loss functions: predictions + targets
368        OpKind::CrossEntropy | OpKind::MseLoss => (2, 2),
369
370        // Attention blocks: variable input
371        OpKind::MultiHeadAttention { .. } => (1, 6),
372        OpKind::TransformerBlock { .. } => (1, 6),
373
374        // Repeat: the body subgraph input
375        OpKind::Repeat { .. } => (1, 64),
376
377        // Identity: exactly one
378        OpKind::Identity => (0, 1),
379
380        // Custom/Call: any number
381        OpKind::Custom { .. } | OpKind::Call { .. } => (0, 64),
382    }
383}
384
385fn is_binary_like(op: &OpKind) -> bool {
386    matches!(
387        op,
388        OpKind::Add
389            | OpKind::Sub
390            | OpKind::Mul
391            | OpKind::Div
392            | OpKind::Mod
393            | OpKind::Pow
394            | OpKind::MatMul
395            | OpKind::Equal
396            | OpKind::NotEqual
397            | OpKind::Less
398            | OpKind::Greater
399            | OpKind::LessEqual
400            | OpKind::GreaterEqual
401            | OpKind::And
402            | OpKind::Or
403    )
404}
405
406fn is_unary_like(op: &OpKind) -> bool {
407    matches!(
408        op,
409        OpKind::Neg
410            | OpKind::Relu
411            | OpKind::Gelu
412            | OpKind::Silu
413            | OpKind::Sigmoid
414            | OpKind::Tanh
415            | OpKind::Exp
416            | OpKind::Log
417            | OpKind::Sqrt
418            | OpKind::Transpose
419            | OpKind::Not
420    )
421}
422
423fn check_type_consistency(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
424    for node in &graph.nodes {
425        if !is_binary_like(&node.op) || node.inputs.len() != 2 {
426            continue;
427        }
428        let left_ty = &graph.nodes[node.inputs[0].0].output_type;
429        let right_ty = &graph.nodes[node.inputs[1].0].output_type;
430
431        // Skip if either type is Unknown (not yet inferred)
432        if matches!(left_ty, IrType::Unknown) || matches!(right_ty, IrType::Unknown) {
433            continue;
434        }
435
436        // For tensor ops, check dtype compatibility
437        if let (IrType::Tensor { dtype: ld, .. }, IrType::Tensor { dtype: rd, .. }) =
438            (left_ty, right_ty)
439        {
440            if ld != rd {
441                errors.push(ValidationError {
442                    graph: graph.name.clone(),
443                    node: Some(node.name.clone()),
444                    kind: ValidationErrorKind::TypeMismatch {
445                        left: left_ty.clone(),
446                        right: right_ty.clone(),
447                    },
448                });
449            }
450        }
451    }
452}
453
454fn check_acyclic(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
455    let order = graph.topo_order();
456    // If topo_order returns fewer nodes than exist, there's a cycle
457    if order.len() < graph.nodes.len() {
458        errors.push(ValidationError {
459            graph: graph.name.clone(),
460            node: None,
461            kind: ValidationErrorKind::CycleDetected,
462        });
463    }
464}
465
466fn check_dim_bounds(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
467    for node in &graph.nodes {
468        // Check softmax dim
469        if let OpKind::Softmax { dim } = &node.op {
470            if let Some(rank) = output_rank(graph, &node.inputs) {
471                if !is_valid_dim(*dim, rank) {
472                    errors.push(ValidationError {
473                        graph: graph.name.clone(),
474                        node: Some(node.name.clone()),
475                        kind: ValidationErrorKind::InvalidDim { dim: *dim, rank },
476                    });
477                }
478            }
479        }
480        // Check reduction dims
481        match &node.op {
482            OpKind::Sum { dims, .. }
483            | OpKind::Mean { dims, .. }
484            | OpKind::Variance { dims, .. } => {
485                if node.inputs.len() == 1 {
486                    if let Some(rank) = node_rank(graph, node.inputs[0]) {
487                        for d in dims {
488                            if !is_valid_dim(*d, rank) {
489                                errors.push(ValidationError {
490                                    graph: graph.name.clone(),
491                                    node: Some(node.name.clone()),
492                                    kind: ValidationErrorKind::InvalidDim { dim: *d, rank },
493                                });
494                            }
495                        }
496                    }
497                }
498            }
499            OpKind::Max { dim, .. } | OpKind::Min { dim, .. } => {
500                if node.inputs.len() == 1 {
501                    if let Some(rank) = node_rank(graph, node.inputs[0]) {
502                        if !is_valid_dim(*dim, rank) {
503                            errors.push(ValidationError {
504                                graph: graph.name.clone(),
505                                node: Some(node.name.clone()),
506                                kind: ValidationErrorKind::InvalidDim { dim: *dim, rank },
507                            });
508                        }
509                    }
510                }
511            }
512            _ => {}
513        }
514    }
515}
516
517/// Get the rank of a node's output if it's a tensor.
518fn node_rank(graph: &IrGraph, id: NodeId) -> Option<usize> {
519    match &graph.nodes[id.0].output_type {
520        IrType::Tensor { shape, .. } => Some(shape.len()),
521        _ => None,
522    }
523}
524
525/// Get the rank of the first input node.
526fn output_rank(graph: &IrGraph, inputs: &[NodeId]) -> Option<usize> {
527    inputs.first().and_then(|id| node_rank(graph, *id))
528}
529
530/// Check if a dimension index is valid for the given rank.
531/// Supports negative indexing (e.g., -1 = last dim).
532fn is_valid_dim(dim: i64, rank: usize) -> bool {
533    let rank = rank as i64;
534    dim >= -rank && dim < rank
535}
536
537// Program-level validation
538
539fn validate_program_refs(program: &IrProgram, errors: &mut Vec<ValidationError>) {
540    let graph_names: HashSet<&str> = program.graphs.iter().map(|g| g.name.as_str()).collect();
541
542    if let Some(training) = &program.training {
543        if !graph_names.contains(training.model_graph.as_str()) {
544            errors.push(ValidationError {
545                graph: String::new(),
546                node: None,
547                kind: ValidationErrorKind::TrainingGraphNotFound {
548                    name: training.model_graph.clone(),
549                },
550            });
551        }
552    }
553
554    if let Some(inference) = &program.inference {
555        if !graph_names.contains(inference.model_graph.as_str()) {
556            errors.push(ValidationError {
557                graph: String::new(),
558                node: None,
559                kind: ValidationErrorKind::InferenceGraphNotFound {
560                    name: inference.model_graph.clone(),
561                },
562            });
563        }
564    }
565}