1use crate::graph::*;
14use std::collections::{HashMap, HashSet};
15
16pub fn validate(program: &IrProgram) -> std::result::Result<(), Vec<ValidationError>> {
20 let mut errors = Vec::new();
21
22 for graph in &program.graphs {
23 validate_graph(graph, &mut errors);
24 }
25
26 validate_program_refs(program, &mut errors);
27
28 if errors.is_empty() {
29 Ok(())
30 } else {
31 Err(errors)
32 }
33}
34
35pub fn validate_graph_standalone(graph: &IrGraph) -> std::result::Result<(), Vec<ValidationError>> {
37 let mut errors = Vec::new();
38 validate_graph(graph, &mut errors);
39 if errors.is_empty() {
40 Ok(())
41 } else {
42 Err(errors)
43 }
44}
45
46#[derive(Debug, Clone)]
50pub struct ValidationError {
51 pub graph: String,
53 pub node: Option<String>,
55 pub kind: ValidationErrorKind,
57}
58
59impl std::fmt::Display for ValidationError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 let loc = if let Some(node) = &self.node {
62 format!("{}::{}", self.graph, node)
63 } else if !self.graph.is_empty() {
64 self.graph.clone()
65 } else {
66 "program".to_string()
67 };
68 write!(f, "[{loc}] {}", self.kind)
69 }
70}
71
72#[derive(Debug, Clone)]
74pub enum ValidationErrorKind {
75 DanglingInput { node_id: NodeId, input_id: NodeId },
77 CycleDetected,
79 DuplicateName { name: String },
81 InvalidInput { node_id: NodeId },
83 InvalidOutput { node_id: NodeId },
85 InvalidParamNode { param_name: String, node_id: NodeId },
87 ParamNotTensor { param_name: String },
89 BinaryOpArity { expected: usize, got: usize },
91 UnaryOpArity { expected: usize, got: usize },
93 TypeMismatch { left: IrType, right: IrType },
95 TrainingGraphNotFound { name: String },
97 InferenceGraphNotFound { name: String },
99 NoOutputs,
101 InvalidDim { dim: i64, rank: usize },
103}
104
105impl std::fmt::Display for ValidationErrorKind {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 match self {
108 Self::DanglingInput { node_id, input_id } => {
109 write!(f, "node {node_id} references non-existent input {input_id}")
110 }
111 Self::CycleDetected => write!(f, "cycle detected — graph is not a DAG"),
112 Self::DuplicateName { name } => write!(f, "duplicate node name \"{name}\""),
113 Self::InvalidInput { node_id } => write!(f, "graph input {node_id} does not exist"),
114 Self::InvalidOutput { node_id } => write!(f, "graph output {node_id} does not exist"),
115 Self::InvalidParamNode {
116 param_name,
117 node_id,
118 } => write!(
119 f,
120 "parameter \"{param_name}\" references non-existent node {node_id}"
121 ),
122 Self::ParamNotTensor { param_name } => {
123 write!(f, "parameter \"{param_name}\" must have Tensor type")
124 }
125 Self::BinaryOpArity { expected, got } => {
126 write!(f, "binary op expects {expected} inputs, got {got}")
127 }
128 Self::UnaryOpArity { expected, got } => {
129 write!(f, "unary op expects {expected} input, got {got}")
130 }
131 Self::TypeMismatch { left, right } => write!(f, "type mismatch: {left} vs {right}"),
132 Self::TrainingGraphNotFound { name } => {
133 write!(f, "@training references non-existent graph \"{name}\"")
134 }
135 Self::InferenceGraphNotFound { name } => {
136 write!(f, "@inference references non-existent graph \"{name}\"")
137 }
138 Self::NoOutputs => write!(f, "graph has no outputs"),
139 Self::InvalidDim { dim, rank } => {
140 write!(f, "dimension {dim} out of range for rank-{rank} tensor")
141 }
142 }
143 }
144}
145
146fn validate_graph(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
149 let gname = &graph.name;
150
151 if graph.outputs.is_empty() {
153 errors.push(ValidationError {
154 graph: gname.clone(),
155 node: None,
156 kind: ValidationErrorKind::NoOutputs,
157 });
158 }
159
160 check_duplicate_names(graph, errors);
162
163 let has_dangling = check_dangling_inputs(graph, errors);
165
166 check_io_validity(graph, errors);
168
169 check_params(graph, errors);
171
172 check_op_arity(graph, errors);
174
175 if !has_dangling {
177 check_type_consistency(graph, errors);
178 }
179
180 if !has_dangling {
182 check_acyclic(graph, errors);
183 }
184
185 if !has_dangling {
187 check_dim_bounds(graph, errors);
188 }
189}
190
191fn check_duplicate_names(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
192 let mut seen: HashMap<&str, usize> = HashMap::new();
193 for node in &graph.nodes {
194 let count = seen.entry(&node.name).or_insert(0);
195 *count += 1;
196 if *count == 2 {
197 errors.push(ValidationError {
199 graph: graph.name.clone(),
200 node: Some(node.name.clone()),
201 kind: ValidationErrorKind::DuplicateName {
202 name: node.name.clone(),
203 },
204 });
205 }
206 }
207}
208
209fn check_dangling_inputs(graph: &IrGraph, errors: &mut Vec<ValidationError>) -> bool {
210 let max_id = graph.nodes.len();
211 let mut found = false;
212 for node in &graph.nodes {
213 for &inp in &node.inputs {
214 if inp.0 >= max_id {
215 found = true;
216 errors.push(ValidationError {
217 graph: graph.name.clone(),
218 node: Some(node.name.clone()),
219 kind: ValidationErrorKind::DanglingInput {
220 node_id: node.id,
221 input_id: inp,
222 },
223 });
224 }
225 }
226 }
227 found
228}
229
230fn check_io_validity(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
231 let max_id = graph.nodes.len();
232 for &id in &graph.inputs {
233 if id.0 >= max_id {
234 errors.push(ValidationError {
235 graph: graph.name.clone(),
236 node: None,
237 kind: ValidationErrorKind::InvalidInput { node_id: id },
238 });
239 }
240 }
241 for out in &graph.outputs {
242 if out.node_id.0 >= max_id {
243 errors.push(ValidationError {
244 graph: graph.name.clone(),
245 node: None,
246 kind: ValidationErrorKind::InvalidOutput {
247 node_id: out.node_id,
248 },
249 });
250 }
251 }
252}
253
254fn check_params(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
255 let max_id = graph.nodes.len();
256 for param in &graph.params {
257 if param.node_id.0 >= max_id {
258 errors.push(ValidationError {
259 graph: graph.name.clone(),
260 node: None,
261 kind: ValidationErrorKind::InvalidParamNode {
262 param_name: param.name.clone(),
263 node_id: param.node_id,
264 },
265 });
266 continue;
267 }
268 match ¶m.ty {
270 IrType::Tensor { .. } | IrType::Unknown => {}
271 _other => {
272 errors.push(ValidationError {
273 graph: graph.name.clone(),
274 node: Some(param.name.clone()),
275 kind: ValidationErrorKind::ParamNotTensor {
276 param_name: param.name.clone(),
277 },
278 });
279 }
280 }
281 }
282}
283
284fn check_op_arity(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
285 for node in &graph.nodes {
286 let (min, max) = expected_arity(&node.op);
287 let got = node.inputs.len();
288 if got < min || got > max {
289 let kind = if is_binary_like(&node.op) {
290 ValidationErrorKind::BinaryOpArity { expected: min, got }
291 } else if is_unary_like(&node.op) {
292 ValidationErrorKind::UnaryOpArity { expected: min, got }
293 } else {
294 ValidationErrorKind::BinaryOpArity { expected: min, got }
296 };
297 errors.push(ValidationError {
298 graph: graph.name.clone(),
299 node: Some(node.name.clone()),
300 kind,
301 });
302 }
303 }
304}
305
306fn expected_arity(op: &OpKind) -> (usize, usize) {
308 match op {
309 OpKind::Constant(_) | OpKind::Range => (0, 2),
311
312 OpKind::Neg
314 | OpKind::Relu
315 | OpKind::Gelu
316 | OpKind::Silu
317 | OpKind::Sigmoid
318 | OpKind::Tanh
319 | OpKind::Exp
320 | OpKind::Log
321 | OpKind::Sqrt
322 | OpKind::Transpose
323 | OpKind::Not => (1, 1),
324
325 OpKind::Sum { .. }
327 | OpKind::Mean { .. }
328 | OpKind::Max { .. }
329 | OpKind::Min { .. }
330 | OpKind::Variance { .. } => (1, 1),
331 OpKind::Reshape { .. }
332 | OpKind::View { .. }
333 | OpKind::Permute { .. }
334 | OpKind::Expand { .. } => (1, 1),
335 OpKind::Softmax { .. } => (1, 1),
336 OpKind::Dropout { .. } => (1, 1),
337
338 OpKind::Add
340 | OpKind::Sub
341 | OpKind::Mul
342 | OpKind::Div
343 | OpKind::Mod
344 | OpKind::Pow
345 | OpKind::MatMul => (2, 2),
346 OpKind::Equal
347 | OpKind::NotEqual
348 | OpKind::Less
349 | OpKind::Greater
350 | OpKind::LessEqual
351 | OpKind::GreaterEqual => (2, 2),
352 OpKind::And | OpKind::Or => (2, 2),
353
354 OpKind::LayerNorm { .. } | OpKind::BatchNorm { .. } => (1, 3),
356
357 OpKind::Embedding => (1, 2),
359
360 OpKind::Linear { .. } => (1, 3),
362
363 OpKind::Concat { .. } => (1, 64),
365 OpKind::Split { .. } => (1, 1),
366
367 OpKind::CrossEntropy | OpKind::MseLoss => (2, 2),
369
370 OpKind::MultiHeadAttention { .. } => (1, 6),
372 OpKind::TransformerBlock { .. } => (1, 6),
373
374 OpKind::Repeat { .. } => (1, 64),
376
377 OpKind::Identity => (0, 1),
379
380 OpKind::Custom { .. } | OpKind::Call { .. } => (0, 64),
382 }
383}
384
385fn is_binary_like(op: &OpKind) -> bool {
386 matches!(
387 op,
388 OpKind::Add
389 | OpKind::Sub
390 | OpKind::Mul
391 | OpKind::Div
392 | OpKind::Mod
393 | OpKind::Pow
394 | OpKind::MatMul
395 | OpKind::Equal
396 | OpKind::NotEqual
397 | OpKind::Less
398 | OpKind::Greater
399 | OpKind::LessEqual
400 | OpKind::GreaterEqual
401 | OpKind::And
402 | OpKind::Or
403 )
404}
405
406fn is_unary_like(op: &OpKind) -> bool {
407 matches!(
408 op,
409 OpKind::Neg
410 | OpKind::Relu
411 | OpKind::Gelu
412 | OpKind::Silu
413 | OpKind::Sigmoid
414 | OpKind::Tanh
415 | OpKind::Exp
416 | OpKind::Log
417 | OpKind::Sqrt
418 | OpKind::Transpose
419 | OpKind::Not
420 )
421}
422
423fn check_type_consistency(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
424 for node in &graph.nodes {
425 if !is_binary_like(&node.op) || node.inputs.len() != 2 {
426 continue;
427 }
428 let left_ty = &graph.nodes[node.inputs[0].0].output_type;
429 let right_ty = &graph.nodes[node.inputs[1].0].output_type;
430
431 if matches!(left_ty, IrType::Unknown) || matches!(right_ty, IrType::Unknown) {
433 continue;
434 }
435
436 if let (IrType::Tensor { dtype: ld, .. }, IrType::Tensor { dtype: rd, .. }) =
438 (left_ty, right_ty)
439 {
440 if ld != rd {
441 errors.push(ValidationError {
442 graph: graph.name.clone(),
443 node: Some(node.name.clone()),
444 kind: ValidationErrorKind::TypeMismatch {
445 left: left_ty.clone(),
446 right: right_ty.clone(),
447 },
448 });
449 }
450 }
451 }
452}
453
454fn check_acyclic(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
455 let order = graph.topo_order();
456 if order.len() < graph.nodes.len() {
458 errors.push(ValidationError {
459 graph: graph.name.clone(),
460 node: None,
461 kind: ValidationErrorKind::CycleDetected,
462 });
463 }
464}
465
466fn check_dim_bounds(graph: &IrGraph, errors: &mut Vec<ValidationError>) {
467 for node in &graph.nodes {
468 if let OpKind::Softmax { dim } = &node.op {
470 if let Some(rank) = output_rank(graph, &node.inputs) {
471 if !is_valid_dim(*dim, rank) {
472 errors.push(ValidationError {
473 graph: graph.name.clone(),
474 node: Some(node.name.clone()),
475 kind: ValidationErrorKind::InvalidDim { dim: *dim, rank },
476 });
477 }
478 }
479 }
480 match &node.op {
482 OpKind::Sum { dims, .. }
483 | OpKind::Mean { dims, .. }
484 | OpKind::Variance { dims, .. } => {
485 if node.inputs.len() == 1 {
486 if let Some(rank) = node_rank(graph, node.inputs[0]) {
487 for d in dims {
488 if !is_valid_dim(*d, rank) {
489 errors.push(ValidationError {
490 graph: graph.name.clone(),
491 node: Some(node.name.clone()),
492 kind: ValidationErrorKind::InvalidDim { dim: *d, rank },
493 });
494 }
495 }
496 }
497 }
498 }
499 OpKind::Max { dim, .. } | OpKind::Min { dim, .. } => {
500 if node.inputs.len() == 1 {
501 if let Some(rank) = node_rank(graph, node.inputs[0]) {
502 if !is_valid_dim(*dim, rank) {
503 errors.push(ValidationError {
504 graph: graph.name.clone(),
505 node: Some(node.name.clone()),
506 kind: ValidationErrorKind::InvalidDim { dim: *dim, rank },
507 });
508 }
509 }
510 }
511 }
512 _ => {}
513 }
514 }
515}
516
517fn node_rank(graph: &IrGraph, id: NodeId) -> Option<usize> {
519 match &graph.nodes[id.0].output_type {
520 IrType::Tensor { shape, .. } => Some(shape.len()),
521 _ => None,
522 }
523}
524
525fn output_rank(graph: &IrGraph, inputs: &[NodeId]) -> Option<usize> {
527 inputs.first().and_then(|id| node_rank(graph, *id))
528}
529
530fn is_valid_dim(dim: i64, rank: usize) -> bool {
533 let rank = rank as i64;
534 dim >= -rank && dim < rank
535}
536
537fn validate_program_refs(program: &IrProgram, errors: &mut Vec<ValidationError>) {
540 let graph_names: HashSet<&str> = program.graphs.iter().map(|g| g.name.as_str()).collect();
541
542 if let Some(training) = &program.training {
543 if !graph_names.contains(training.model_graph.as_str()) {
544 errors.push(ValidationError {
545 graph: String::new(),
546 node: None,
547 kind: ValidationErrorKind::TrainingGraphNotFound {
548 name: training.model_graph.clone(),
549 },
550 });
551 }
552 }
553
554 if let Some(inference) = &program.inference {
555 if !graph_names.contains(inference.model_graph.as_str()) {
556 errors.push(ValidationError {
557 graph: String::new(),
558 node: None,
559 kind: ValidationErrorKind::InferenceGraphNotFound {
560 name: inference.model_graph.clone(),
561 },
562 });
563 }
564 }
565}