1use crate::token::Span;
10
11#[derive(Debug, Clone)]
15pub struct Program {
16 pub items: Vec<TopLevel>,
17}
18
19#[derive(Debug, Clone)]
20pub enum TopLevel {
21 Import(ImportStmt),
22 Metadata(MetadataBlock),
23 Config(ConfigBlock),
24 Types(TypesBlock),
25 Graph(GraphBlock),
26 CustomOp(CustomOpBlock),
27 Training(TrainingBlock),
28 Inference(InferenceBlock),
29 Metrics(MetricsBlock),
30 Logging(LoggingBlock),
31 Visualization(VisualizationBlock),
32}
33
34#[derive(Debug, Clone)]
38pub struct ImportStmt {
39 pub path: String,
40 pub alias: Option<String>,
41 pub span: Span,
42}
43
44#[derive(Debug, Clone)]
48pub struct MetadataBlock {
49 pub fields: Vec<Field>,
50 pub span: Span,
51}
52
53#[derive(Debug, Clone)]
57pub struct ConfigBlock {
58 pub fields: Vec<ExprField>,
59 pub span: Span,
60}
61
62#[derive(Debug, Clone)]
66pub struct TypesBlock {
67 pub defs: Vec<TypeDef>,
68 pub span: Span,
69}
70
71#[derive(Debug, Clone)]
72pub struct TypeDef {
73 pub name: String,
74 pub ty: TypeExpr,
75 pub span: Span,
76}
77
78#[derive(Debug, Clone)]
80pub enum TypeExpr {
81 Tensor {
83 dims: Vec<Dimension>,
84 dtype: DTypeKind,
85 span: Span,
86 },
87 Scalar(DTypeKind, Span),
89 Tuple(Vec<TypeExpr>, Span),
91 List(Box<TypeExpr>, Span),
93 Dict(Vec<(String, TypeExpr)>, Span),
95 Named(String, Span),
97 Dynamic(Span),
99 IntDim(i64, Span),
101 BinaryDim {
103 left: Box<TypeExpr>,
104 op: BinOp,
105 right: Box<TypeExpr>,
106 span: Span,
107 },
108}
109
110#[derive(Debug, Clone)]
112pub enum Dimension {
113 Named(String, Span),
115 Concrete(i64, Span),
117 Dynamic(Span),
119 Inferred(Span),
121 Computed(Box<Expr>, Span),
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum DTypeKind {
128 F16,
129 F32,
130 F64,
131 Bf16,
132 I8,
133 I16,
134 I32,
135 I64,
136 U8,
137 U16,
138 U32,
139 U64,
140 Bool,
141 Complex64,
142 Complex128,
143}
144
145#[derive(Debug, Clone)]
149pub struct GraphBlock {
150 pub name: String,
151 pub params: Vec<ParamDef>,
152 pub return_type: Option<TypeExpr>,
153 pub body: Vec<GraphStmt>,
154 pub span: Span,
155}
156
157#[derive(Debug, Clone)]
159pub struct ParamDef {
160 pub name: String,
161 pub ty: TypeExpr,
162 pub optional: bool,
163 pub span: Span,
164}
165
166#[derive(Debug, Clone)]
168pub enum GraphStmt {
169 Input(InputDecl),
170 Output(OutputDecl),
171 Param(ParamDecl),
172 Node(NodeDecl),
173 Assert(AssertStmt),
174 Check(CheckBlock),
175}
176
177#[derive(Debug, Clone)]
179pub struct InputDecl {
180 pub name: String,
181 pub ty: TypeExpr,
182 pub optional: bool,
183 pub span: Span,
184}
185
186#[derive(Debug, Clone)]
188pub struct OutputDecl {
189 pub name: Option<String>,
190 pub expr: Expr,
191 pub span: Span,
192}
193
194#[derive(Debug, Clone)]
196pub struct ParamDecl {
197 pub name: String,
198 pub ty: TypeExpr,
199 pub attrs: Vec<ParamAttr>,
200 pub span: Span,
201}
202
203#[derive(Debug, Clone)]
204pub struct ParamAttr {
205 pub key: String,
206 pub value: Expr,
207 pub span: Span,
208}
209
210#[derive(Debug, Clone)]
212pub struct NodeDecl {
213 pub name: String,
214 pub ty: Option<TypeExpr>,
215 pub stmts: Vec<NodeStmt>,
216 pub span: Span,
217}
218
219#[derive(Debug, Clone)]
220pub enum NodeStmt {
221 Op(Expr, Span),
222 InputRef(Expr, Span),
223 OutputType(TypeExpr, Span),
224 Hint(HintKind, Span),
225 Attr(String, Expr, Span),
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
229pub enum HintKind {
230 RecomputeInBackward,
231 MustPreserve,
232 InPlace,
233 NoGrad,
234 Custom(String),
235}
236
237#[derive(Debug, Clone)]
239pub struct AssertStmt {
240 pub condition: Expr,
241 pub message: Option<String>,
242 pub span: Span,
243}
244
245#[derive(Debug, Clone)]
247pub struct CheckBlock {
248 pub name: String,
249 pub conditions: Vec<AssertStmt>,
250 pub span: Span,
251}
252
253#[derive(Debug, Clone)]
256pub struct CustomOpBlock {
257 pub name: String,
258 pub stmts: Vec<CustomOpStmt>,
259 pub span: Span,
260}
261
262#[derive(Debug, Clone)]
263pub enum CustomOpStmt {
264 Signature {
265 params: Vec<ParamDef>,
266 return_type: TypeExpr,
267 span: Span,
268 },
269 Impl {
270 target: String,
271 attrs: Vec<ExprField>,
272 span: Span,
273 },
274 Gradient {
275 target: String,
276 body: Vec<CustomOpStmt>,
277 span: Span,
278 },
279}
280
281#[derive(Debug, Clone)]
284pub struct TrainingBlock {
285 pub fields: Vec<TrainingField>,
286 pub span: Span,
287}
288
289#[derive(Debug, Clone)]
290pub enum TrainingField {
291 Model(String, Span),
292 Loss(String, Span),
293 Optimizer(Vec<ExprField>, Span),
294 LrSchedule(Vec<ExprField>, Span),
295 GradClip(Vec<ExprField>, Span),
296 Generic(ExprField),
297}
298
299#[derive(Debug, Clone)]
302pub struct InferenceBlock {
303 pub fields: Vec<InferenceField>,
304 pub span: Span,
305}
306
307#[derive(Debug, Clone)]
308pub enum InferenceField {
309 Model(String, Span),
310 Optimizations(Vec<Expr>, Span),
311 Quantization(Vec<ExprField>, Span),
312 Generation(Vec<ExprField>, Span),
313 Generic(ExprField),
314}
315
316#[derive(Debug, Clone)]
319pub struct MetricsBlock {
320 pub name: String,
321 pub defs: Vec<MetricDef>,
322 pub span: Span,
323}
324
325#[derive(Debug, Clone)]
326pub struct MetricDef {
327 pub name: String,
328 pub attrs: Vec<ExprField>,
329 pub span: Span,
330}
331
332#[derive(Debug, Clone)]
333pub struct LoggingBlock {
334 pub fields: Vec<ExprField>,
335 pub span: Span,
336}
337
338#[derive(Debug, Clone)]
339pub struct VisualizationBlock {
340 pub plots: Vec<PlotDef>,
341 pub span: Span,
342}
343
344#[derive(Debug, Clone)]
345pub struct PlotDef {
346 pub name: String,
347 pub attrs: Vec<ExprField>,
348 pub span: Span,
349}
350
351#[derive(Debug, Clone)]
355pub struct Field {
356 pub key: String,
357 pub value: Literal,
358 pub span: Span,
359}
360
361#[derive(Debug, Clone)]
363pub struct ExprField {
364 pub key: String,
365 pub value: Expr,
366 pub span: Span,
367}
368
369#[derive(Debug, Clone)]
372pub enum Expr {
373 Int(i64, Span),
375 Float(f64, Span),
377 Str(String, Span),
379 Bool(bool, Span),
381 Null(Span),
383 Ident(String, Span),
385 Binary {
387 left: Box<Expr>,
388 op: BinOp,
389 right: Box<Expr>,
390 span: Span,
391 },
392 Unary {
394 op: UnaryOp,
395 operand: Box<Expr>,
396 span: Span,
397 },
398 Call {
400 func: String,
401 args: Vec<Arg>,
402 span: Span,
403 },
404 QualifiedCall {
406 path: Vec<String>,
407 args: Vec<Arg>,
408 span: Span,
409 },
410 Member {
412 object: Box<Expr>,
413 field: String,
414 span: Span,
415 },
416 Index {
418 object: Box<Expr>,
419 index: Box<Expr>,
420 end: Option<Box<Expr>>,
421 span: Span,
422 },
423 List(Vec<Expr>, Span),
425 Dict(Vec<(String, Expr)>, Span),
427 Paren(Box<Expr>, Span),
429 IfExpr {
431 cond: Box<Expr>,
432 then_branch: Box<Expr>,
433 else_branch: Option<Box<Expr>>,
434 span: Span,
435 },
436 RepeatExpr {
438 count: Box<Expr>,
439 body: Box<Expr>,
440 span: Span,
441 },
442 Closure {
444 params: Vec<String>,
445 body: Box<Expr>,
446 span: Span,
447 },
448}
449
450impl Expr {
451 pub fn span(&self) -> Span {
452 match self {
453 Expr::Int(_, s)
454 | Expr::Float(_, s)
455 | Expr::Str(_, s)
456 | Expr::Bool(_, s)
457 | Expr::Null(s)
458 | Expr::Ident(_, s) => *s,
459 Expr::Binary { span, .. }
460 | Expr::Unary { span, .. }
461 | Expr::Call { span, .. }
462 | Expr::QualifiedCall { span, .. }
463 | Expr::Member { span, .. }
464 | Expr::Index { span, .. }
465 | Expr::List(_, span)
466 | Expr::Dict(_, span)
467 | Expr::Paren(_, span)
468 | Expr::IfExpr { span, .. }
469 | Expr::RepeatExpr { span, .. }
470 | Expr::Closure { span, .. } => *span,
471 }
472 }
473}
474
475#[derive(Debug, Clone)]
477pub struct Arg {
478 pub name: Option<String>,
479 pub value: Expr,
480 pub span: Span,
481}
482
483#[derive(Debug, Clone, Copy, PartialEq, Eq)]
485pub enum BinOp {
486 Add,
487 Sub,
488 Mul,
489 Div,
490 Mod,
491 Pow,
492 Eq,
493 Ne,
494 Lt,
495 Gt,
496 Le,
497 Ge,
498 And,
499 Or,
500 NullCoalesce,
501 BitAnd,
502 BitOr,
503 BitXor,
504 Shl,
505 Shr,
506}
507
508#[derive(Debug, Clone, Copy, PartialEq, Eq)]
510pub enum UnaryOp {
511 Neg,
512 Not,
513 BitNot,
514}
515
516#[derive(Debug, Clone)]
519pub enum Literal {
520 Int(i64, Span),
521 Float(f64, Span),
522 Str(String, Span),
523 Bool(bool, Span),
524 Null(Span),
525 List(Vec<Literal>, Span),
526 Dict(Vec<(String, Literal)>, Span),
527}