shrew_ir/
ast.rs

1// AST — Abstract Syntax Tree for the .sw language
2//
3// Every node in the AST corresponds to a production in the grammar.
4// The AST is purely syntactic — no type checking, no name resolution,
5// no graph construction. Those happen in later IR lowering passes.
6//
7// DESIGN: Every node stores a Span for error reporting back to the user.
8
9use crate::token::Span;
10
11// Top-level program
12
13/// A complete .sw program is a sequence of directives and imports.
14#[derive(Debug, Clone)]
15pub struct Program {
16    pub items: Vec<TopLevel>,
17}
18
19#[derive(Debug, Clone)]
20pub enum TopLevel {
21    Import(ImportStmt),
22    Metadata(MetadataBlock),
23    Config(ConfigBlock),
24    Types(TypesBlock),
25    Graph(GraphBlock),
26    CustomOp(CustomOpBlock),
27    Training(TrainingBlock),
28    Inference(InferenceBlock),
29    Metrics(MetricsBlock),
30    Logging(LoggingBlock),
31    Visualization(VisualizationBlock),
32}
33
34// Import
35
36/// `@import "path/to/file.sw" as alias;`
37#[derive(Debug, Clone)]
38pub struct ImportStmt {
39    pub path: String,
40    pub alias: Option<String>,
41    pub span: Span,
42}
43
44// Metadata (@model)
45
46/// `@model { name: "GPT-2"; version: "1.0"; }`
47#[derive(Debug, Clone)]
48pub struct MetadataBlock {
49    pub fields: Vec<Field>,
50    pub span: Span,
51}
52
53// Config (@config)
54
55/// `@config { d_model: 768; n_heads: 12; }`
56#[derive(Debug, Clone)]
57pub struct ConfigBlock {
58    pub fields: Vec<ExprField>,
59    pub span: Span,
60}
61
62// Types (@types)
63
64/// `@types { type Hidden = Tensor<[Batch, 768], f32>; }`
65#[derive(Debug, Clone)]
66pub struct TypesBlock {
67    pub defs: Vec<TypeDef>,
68    pub span: Span,
69}
70
71#[derive(Debug, Clone)]
72pub struct TypeDef {
73    pub name: String,
74    pub ty: TypeExpr,
75    pub span: Span,
76}
77
78/// A type expression (Tensor<dims, dtype>, scalar, tuple, etc.).
79#[derive(Debug, Clone)]
80pub enum TypeExpr {
81    /// `Tensor<[dim1, dim2, ...], dtype>`
82    Tensor {
83        dims: Vec<Dimension>,
84        dtype: DTypeKind,
85        span: Span,
86    },
87    /// A bare scalar dtype: `f32`, `i64`, etc.
88    Scalar(DTypeKind, Span),
89    /// `(TypeA, TypeB, ...)`
90    Tuple(Vec<TypeExpr>, Span),
91    /// `[TypeExpr]` — list of
92    List(Box<TypeExpr>, Span),
93    /// `{ field: Type, ... }` — dict/struct
94    Dict(Vec<(String, TypeExpr)>, Span),
95    /// A named type alias reference
96    Named(String, Span),
97    /// `?` — dynamic/unknown
98    Dynamic(Span),
99    /// An integer dimension used as a type: concrete dimension value
100    IntDim(i64, Span),
101    /// Arithmetic on dimensions: e.g. `D / 2`
102    BinaryDim {
103        left: Box<TypeExpr>,
104        op: BinOp,
105        right: Box<TypeExpr>,
106        span: Span,
107    },
108}
109
110/// Dimension in a Tensor type.
111#[derive(Debug, Clone)]
112pub enum Dimension {
113    /// Named/symbolic: `Batch`, `SeqLen`
114    Named(String, Span),
115    /// Concrete: `768`
116    Concrete(i64, Span),
117    /// Dynamic: `?`
118    Dynamic(Span),
119    /// Inferred: `_`
120    Inferred(Span),
121    /// Computed: `D / 2`
122    Computed(Box<Expr>, Span),
123}
124
125/// Data types supported by the language.
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum DTypeKind {
128    F16,
129    F32,
130    F64,
131    Bf16,
132    I8,
133    I16,
134    I32,
135    I64,
136    U8,
137    U16,
138    U32,
139    U64,
140    Bool,
141    Complex64,
142    Complex128,
143}
144
145// Graph (@graph)
146
147/// `@graph Forward(x: Tensor<[B,S,D], f32>) -> Tensor<[B,S,D], f32> { ... }`
148#[derive(Debug, Clone)]
149pub struct GraphBlock {
150    pub name: String,
151    pub params: Vec<ParamDef>,
152    pub return_type: Option<TypeExpr>,
153    pub body: Vec<GraphStmt>,
154    pub span: Span,
155}
156
157/// A parameter definition: `name: Type [?]`
158#[derive(Debug, Clone)]
159pub struct ParamDef {
160    pub name: String,
161    pub ty: TypeExpr,
162    pub optional: bool,
163    pub span: Span,
164}
165
166/// Statements inside a graph body.
167#[derive(Debug, Clone)]
168pub enum GraphStmt {
169    Input(InputDecl),
170    Output(OutputDecl),
171    Param(ParamDecl),
172    Node(NodeDecl),
173    Assert(AssertStmt),
174    Check(CheckBlock),
175}
176
177/// `input x: Tensor<[B,S], f32>;`
178#[derive(Debug, Clone)]
179pub struct InputDecl {
180    pub name: String,
181    pub ty: TypeExpr,
182    pub optional: bool,
183    pub span: Span,
184}
185
186/// `output logits: softmax(h);` or `output expr;`
187#[derive(Debug, Clone)]
188pub struct OutputDecl {
189    pub name: Option<String>,
190    pub expr: Expr,
191    pub span: Span,
192}
193
194/// `param W: Tensor<[D,D], f32> { init: "normal(0,0.02)"; frozen: false; };`
195#[derive(Debug, Clone)]
196pub struct ParamDecl {
197    pub name: String,
198    pub ty: TypeExpr,
199    pub attrs: Vec<ParamAttr>,
200    pub span: Span,
201}
202
203#[derive(Debug, Clone)]
204pub struct ParamAttr {
205    pub key: String,
206    pub value: Expr,
207    pub span: Span,
208}
209
210/// `node h { op: matmul(x, W); }`  or  `node h: Type { ... };`
211#[derive(Debug, Clone)]
212pub struct NodeDecl {
213    pub name: String,
214    pub ty: Option<TypeExpr>,
215    pub stmts: Vec<NodeStmt>,
216    pub span: Span,
217}
218
219#[derive(Debug, Clone)]
220pub enum NodeStmt {
221    Op(Expr, Span),
222    InputRef(Expr, Span),
223    OutputType(TypeExpr, Span),
224    Hint(HintKind, Span),
225    Attr(String, Expr, Span),
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
229pub enum HintKind {
230    RecomputeInBackward,
231    MustPreserve,
232    InPlace,
233    NoGrad,
234    Custom(String),
235}
236
237/// `@assert shape(x) == [B, S, D], "shape mismatch";`
238#[derive(Debug, Clone)]
239pub struct AssertStmt {
240    pub condition: Expr,
241    pub message: Option<String>,
242    pub span: Span,
243}
244
245/// `@check name { assert ...; assert ...; }`
246#[derive(Debug, Clone)]
247pub struct CheckBlock {
248    pub name: String,
249    pub conditions: Vec<AssertStmt>,
250    pub span: Span,
251}
252
253// Custom operators (@custom_op)
254
255#[derive(Debug, Clone)]
256pub struct CustomOpBlock {
257    pub name: String,
258    pub stmts: Vec<CustomOpStmt>,
259    pub span: Span,
260}
261
262#[derive(Debug, Clone)]
263pub enum CustomOpStmt {
264    Signature {
265        params: Vec<ParamDef>,
266        return_type: TypeExpr,
267        span: Span,
268    },
269    Impl {
270        target: String,
271        attrs: Vec<ExprField>,
272        span: Span,
273    },
274    Gradient {
275        target: String,
276        body: Vec<CustomOpStmt>,
277        span: Span,
278    },
279}
280
281// Training (@training)
282
283#[derive(Debug, Clone)]
284pub struct TrainingBlock {
285    pub fields: Vec<TrainingField>,
286    pub span: Span,
287}
288
289#[derive(Debug, Clone)]
290pub enum TrainingField {
291    Model(String, Span),
292    Loss(String, Span),
293    Optimizer(Vec<ExprField>, Span),
294    LrSchedule(Vec<ExprField>, Span),
295    GradClip(Vec<ExprField>, Span),
296    Generic(ExprField),
297}
298
299// Inference (@inference)
300
301#[derive(Debug, Clone)]
302pub struct InferenceBlock {
303    pub fields: Vec<InferenceField>,
304    pub span: Span,
305}
306
307#[derive(Debug, Clone)]
308pub enum InferenceField {
309    Model(String, Span),
310    Optimizations(Vec<Expr>, Span),
311    Quantization(Vec<ExprField>, Span),
312    Generation(Vec<ExprField>, Span),
313    Generic(ExprField),
314}
315
316// Metrics & Logging
317
318#[derive(Debug, Clone)]
319pub struct MetricsBlock {
320    pub name: String,
321    pub defs: Vec<MetricDef>,
322    pub span: Span,
323}
324
325#[derive(Debug, Clone)]
326pub struct MetricDef {
327    pub name: String,
328    pub attrs: Vec<ExprField>,
329    pub span: Span,
330}
331
332#[derive(Debug, Clone)]
333pub struct LoggingBlock {
334    pub fields: Vec<ExprField>,
335    pub span: Span,
336}
337
338#[derive(Debug, Clone)]
339pub struct VisualizationBlock {
340    pub plots: Vec<PlotDef>,
341    pub span: Span,
342}
343
344#[derive(Debug, Clone)]
345pub struct PlotDef {
346    pub name: String,
347    pub attrs: Vec<ExprField>,
348    pub span: Span,
349}
350
351// Shared field types
352
353/// A field with a literal value: `name: "GPT-2";`
354#[derive(Debug, Clone)]
355pub struct Field {
356    pub key: String,
357    pub value: Literal,
358    pub span: Span,
359}
360
361/// A field with an expression value: `d_model: 768;`
362#[derive(Debug, Clone)]
363pub struct ExprField {
364    pub key: String,
365    pub value: Expr,
366    pub span: Span,
367}
368
369// Expressions
370
371#[derive(Debug, Clone)]
372pub enum Expr {
373    /// Integer literal: `42`
374    Int(i64, Span),
375    /// Float literal: `3.14`
376    Float(f64, Span),
377    /// String literal: `"hello"`
378    Str(String, Span),
379    /// Boolean: `true`, `false`
380    Bool(bool, Span),
381    /// Null: `null`
382    Null(Span),
383    /// Identifier: `x`, `Batch`
384    Ident(String, Span),
385    /// Binary expression: `a + b`
386    Binary {
387        left: Box<Expr>,
388        op: BinOp,
389        right: Box<Expr>,
390        span: Span,
391    },
392    /// Unary expression: `-x`, `!cond`
393    Unary {
394        op: UnaryOp,
395        operand: Box<Expr>,
396        span: Span,
397    },
398    /// Function call: `matmul(x, w)`
399    Call {
400        func: String,
401        args: Vec<Arg>,
402        span: Span,
403    },
404    /// Qualified call: `module.func(args)` or `mod::func(args)`
405    QualifiedCall {
406        path: Vec<String>,
407        args: Vec<Arg>,
408        span: Span,
409    },
410    /// Member access: `x.shape`
411    Member {
412        object: Box<Expr>,
413        field: String,
414        span: Span,
415    },
416    /// Index access: `x[0]` or `x[1:3]`
417    Index {
418        object: Box<Expr>,
419        index: Box<Expr>,
420        end: Option<Box<Expr>>,
421        span: Span,
422    },
423    /// List expression: `[1, 2, 3]`
424    List(Vec<Expr>, Span),
425    /// Dict expression: `{ key: value, ... }`
426    Dict(Vec<(String, Expr)>, Span),
427    /// Parenthesized: `(expr)`
428    Paren(Box<Expr>, Span),
429    /// Block operation — if
430    IfExpr {
431        cond: Box<Expr>,
432        then_branch: Box<Expr>,
433        else_branch: Option<Box<Expr>>,
434        span: Span,
435    },
436    /// Block operation — repeat
437    RepeatExpr {
438        count: Box<Expr>,
439        body: Box<Expr>,
440        span: Span,
441    },
442    /// Closure: `|a, b| { expr }`
443    Closure {
444        params: Vec<String>,
445        body: Box<Expr>,
446        span: Span,
447    },
448}
449
450impl Expr {
451    pub fn span(&self) -> Span {
452        match self {
453            Expr::Int(_, s)
454            | Expr::Float(_, s)
455            | Expr::Str(_, s)
456            | Expr::Bool(_, s)
457            | Expr::Null(s)
458            | Expr::Ident(_, s) => *s,
459            Expr::Binary { span, .. }
460            | Expr::Unary { span, .. }
461            | Expr::Call { span, .. }
462            | Expr::QualifiedCall { span, .. }
463            | Expr::Member { span, .. }
464            | Expr::Index { span, .. }
465            | Expr::List(_, span)
466            | Expr::Dict(_, span)
467            | Expr::Paren(_, span)
468            | Expr::IfExpr { span, .. }
469            | Expr::RepeatExpr { span, .. }
470            | Expr::Closure { span, .. } => *span,
471        }
472    }
473}
474
475/// Named or positional argument.
476#[derive(Debug, Clone)]
477pub struct Arg {
478    pub name: Option<String>,
479    pub value: Expr,
480    pub span: Span,
481}
482
483/// Binary operators.
484#[derive(Debug, Clone, Copy, PartialEq, Eq)]
485pub enum BinOp {
486    Add,
487    Sub,
488    Mul,
489    Div,
490    Mod,
491    Pow,
492    Eq,
493    Ne,
494    Lt,
495    Gt,
496    Le,
497    Ge,
498    And,
499    Or,
500    NullCoalesce,
501    BitAnd,
502    BitOr,
503    BitXor,
504    Shl,
505    Shr,
506}
507
508/// Unary operators.
509#[derive(Debug, Clone, Copy, PartialEq, Eq)]
510pub enum UnaryOp {
511    Neg,
512    Not,
513    BitNot,
514}
515
516// Literals
517
518#[derive(Debug, Clone)]
519pub enum Literal {
520    Int(i64, Span),
521    Float(f64, Span),
522    Str(String, Span),
523    Bool(bool, Span),
524    Null(Span),
525    List(Vec<Literal>, Span),
526    Dict(Vec<(String, Literal)>, Span),
527}