shrew_ir/
token.rs

1// Token — All lexical tokens of the .sw language
2//
3// The .sw format uses @ as the directive prefix (like Rust attributes).
4// Tokens fall into these categories:
5//
6//   1. Directives    — @model, @graph, @training, etc.
7//   2. Keywords      — input, output, param, node, op, type, etc.
8//   3. Operators     — +, -, *, /, **, ==, !=, &&, ||, etc.
9//   4. Punctuation   — { } ( ) [ ] : ; , . -> ? _ ::
10//   5. Literals      — integers, floats, strings, booleans, null
11//   6. DType names   — f32, f64, bf16, etc.
12//   7. Identifiers   — user-defined names
13//
14// Each token carries a Span (byte offset + length) for error reporting.
15
16use std::fmt;
17
18/// Byte-level location in source text.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct Span {
21    /// Byte offset from the start of the source.
22    pub offset: usize,
23    /// Length in bytes.
24    pub len: usize,
25    /// Line number (1-based).
26    pub line: usize,
27    /// Column number (1-based, in bytes).
28    pub col: usize,
29}
30
31impl Span {
32    pub fn new(offset: usize, len: usize, line: usize, col: usize) -> Self {
33        Self {
34            offset,
35            len,
36            line,
37            col,
38        }
39    }
40}
41
42impl fmt::Display for Span {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        write!(f, "{}:{}", self.line, self.col)
45    }
46}
47
48/// A token with its kind and source location.
49#[derive(Debug, Clone, PartialEq)]
50pub struct Token {
51    pub kind: TokenKind,
52    pub span: Span,
53}
54
55impl Token {
56    pub fn new(kind: TokenKind, span: Span) -> Self {
57        Self { kind, span }
58    }
59}
60
61/// Every possible token kind in the .sw language.
62#[derive(Debug, Clone, PartialEq)]
63pub enum TokenKind {
64    //  Directives (@-prefixed blocks) 
65    AtModel,          // @model
66    AtConfig,         // @config
67    AtTypes,          // @types
68    AtGraph,          // @graph
69    AtCustomOp,       // @custom_op
70    AtTraining,       // @training
71    AtInference,      // @inference
72    AtMetrics,        // @metrics
73    AtLogging,        // @logging
74    AtVisualizations, // @visualizations
75    AtImport,         // @import
76    AtAssert,         // @assert
77    AtCheck,          // @check
78    AtHint,           // @hint
79
80    //  Keywords 
81    Input,
82    Output,
83    Param,
84    Node,
85    Op,
86    Call,
87    If,
88    Else,
89    Repeat,
90    Scan,
91    For,
92    In,
93    Type,
94    Init,
95    Gradient,
96    Impl,
97    Signature,
98    As,
99    Track,
100    Plot,
101    Range,
102    Model,
103    Loss,
104    Optimizer,
105    LrSchedule, // lr_schedule
106    GradClip,   // grad_clip
107    Precision,
108    AccumulationSteps, // accumulation_steps
109    Optimizations,
110    Quantization,
111    Generation,
112    Backend,
113    Checkpoints,
114    Frozen,
115    Device,
116    Source,
117    Compute,
118    Aggregate,
119    LogEvery, // log_every
120
121    //  Hint values ─
122    RecomputeInBackward, // recompute_in_backward
123    MustPreserve,        // must_preserve
124    InPlace,             // in_place
125    NoGrad,              // no_grad
126
127    //  DType keywords 
128    F16,
129    F32,
130    F64,
131    Bf16,
132    I8,
133    I16,
134    I32,
135    I64,
136    U8,
137    U16,
138    U32,
139    U64,
140    Bool,
141    Complex64,
142    Complex128,
143
144    //  Device keywords ─
145    Cpu,
146    Gpu,
147    Tpu,
148
149    //  Tensor keyword 
150    Tensor,
151
152    //  Literals 
153    IntLit(i64),
154    FloatLit(f64),
155    StringLit(String),
156    True,
157    False,
158    Null,
159
160    //  Operators ─
161    Plus,             // +
162    Minus,            // -
163    Star,             // *
164    Slash,            // /
165    Percent,          // %
166    StarStar,         // **
167    EqEq,             // ==
168    BangEq,           // !=
169    Lt,               // <
170    Gt,               // >
171    LtEq,             // <=
172    GtEq,             // >=
173    AmpAmp,           // &&
174    PipePipe,         // ||
175    QuestionQuestion, // ??
176    Amp,              // &
177    Pipe,             // |
178    Caret,            // ^
179    LtLt,             // <<
180    GtGt,             // >>
181    Bang,             // !
182    Tilde,            // ~
183    Eq,               // =
184
185    //  Punctuation ─
186    LBrace,     // {
187    RBrace,     // }
188    LParen,     // (
189    RParen,     // )
190    LBracket,   // [
191    RBracket,   // ]
192    Colon,      // :
193    ColonColon, // ::
194    Semi,       // ;
195    Comma,      // ,
196    Dot,        // .
197    Arrow,      // ->
198    Question,   // ?
199    Underscore, // _
200
201    //  Identifiers ─
202    Ident(String),
203
204    //  Special ─
205    Eof,
206}
207
208impl fmt::Display for TokenKind {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        match self {
211            TokenKind::AtModel => write!(f, "@model"),
212            TokenKind::AtConfig => write!(f, "@config"),
213            TokenKind::AtTypes => write!(f, "@types"),
214            TokenKind::AtGraph => write!(f, "@graph"),
215            TokenKind::AtCustomOp => write!(f, "@custom_op"),
216            TokenKind::AtTraining => write!(f, "@training"),
217            TokenKind::AtInference => write!(f, "@inference"),
218            TokenKind::AtMetrics => write!(f, "@metrics"),
219            TokenKind::AtLogging => write!(f, "@logging"),
220            TokenKind::AtVisualizations => write!(f, "@visualizations"),
221            TokenKind::AtImport => write!(f, "@import"),
222            TokenKind::AtAssert => write!(f, "@assert"),
223            TokenKind::AtCheck => write!(f, "@check"),
224            TokenKind::AtHint => write!(f, "@hint"),
225            TokenKind::Ident(s) => write!(f, "{s}"),
226            TokenKind::IntLit(n) => write!(f, "{n}"),
227            TokenKind::FloatLit(n) => write!(f, "{n}"),
228            TokenKind::StringLit(s) => write!(f, "\"{s}\""),
229            TokenKind::True => write!(f, "true"),
230            TokenKind::False => write!(f, "false"),
231            TokenKind::Null => write!(f, "null"),
232            TokenKind::Eof => write!(f, "<eof>"),
233            other => write!(f, "{other:?}"),
234        }
235    }
236}
237
238impl TokenKind {
239    /// Return the string representation of a keyword token.
240    /// Returns `None` for non-keyword tokens (literals, operators, idents, etc.).
241    pub fn keyword_str(&self) -> Option<&'static str> {
242        match self {
243            TokenKind::Input => Some("input"),
244            TokenKind::Output => Some("output"),
245            TokenKind::Param => Some("param"),
246            TokenKind::Node => Some("node"),
247            TokenKind::Op => Some("op"),
248            TokenKind::Call => Some("call"),
249            TokenKind::If => Some("if"),
250            TokenKind::Else => Some("else"),
251            TokenKind::Repeat => Some("repeat"),
252            TokenKind::Scan => Some("scan"),
253            TokenKind::For => Some("for"),
254            TokenKind::In => Some("in"),
255            TokenKind::Type => Some("type"),
256            TokenKind::Init => Some("init"),
257            TokenKind::Gradient => Some("gradient"),
258            TokenKind::Impl => Some("impl"),
259            TokenKind::Signature => Some("signature"),
260            TokenKind::As => Some("as"),
261            TokenKind::Track => Some("track"),
262            TokenKind::Plot => Some("plot"),
263            TokenKind::Range => Some("range"),
264            TokenKind::Model => Some("model"),
265            TokenKind::Loss => Some("loss"),
266            TokenKind::Optimizer => Some("optimizer"),
267            TokenKind::LrSchedule => Some("lr_schedule"),
268            TokenKind::GradClip => Some("grad_clip"),
269            TokenKind::Precision => Some("precision"),
270            TokenKind::AccumulationSteps => Some("accumulation_steps"),
271            TokenKind::Optimizations => Some("optimizations"),
272            TokenKind::Quantization => Some("quantization"),
273            TokenKind::Generation => Some("generation"),
274            TokenKind::Backend => Some("backend"),
275            TokenKind::Checkpoints => Some("checkpoints"),
276            TokenKind::Frozen => Some("frozen"),
277            TokenKind::Device => Some("device"),
278            TokenKind::Source => Some("source"),
279            TokenKind::Compute => Some("compute"),
280            TokenKind::Aggregate => Some("aggregate"),
281            TokenKind::LogEvery => Some("log_every"),
282            TokenKind::RecomputeInBackward => Some("recompute_in_backward"),
283            TokenKind::MustPreserve => Some("must_preserve"),
284            TokenKind::InPlace => Some("in_place"),
285            TokenKind::NoGrad => Some("no_grad"),
286            TokenKind::Cpu => Some("cpu"),
287            TokenKind::Gpu => Some("gpu"),
288            TokenKind::Tpu => Some("tpu"),
289            _ => None,
290        }
291    }
292}
293
294/// Look up keyword/dtype/device from an identifier string.
295/// Returns None if the string is a plain identifier.
296pub fn keyword_lookup(s: &str) -> Option<TokenKind> {
297    match s {
298        // Keywords
299        "input" => Some(TokenKind::Input),
300        "output" => Some(TokenKind::Output),
301        "param" => Some(TokenKind::Param),
302        "node" => Some(TokenKind::Node),
303        "op" => Some(TokenKind::Op),
304        "call" => Some(TokenKind::Call),
305        "if" => Some(TokenKind::If),
306        "else" => Some(TokenKind::Else),
307        "repeat" => Some(TokenKind::Repeat),
308        "scan" => Some(TokenKind::Scan),
309        "for" => Some(TokenKind::For),
310        "in" => Some(TokenKind::In),
311        "type" => Some(TokenKind::Type),
312        "init" => Some(TokenKind::Init),
313        "gradient" => Some(TokenKind::Gradient),
314        "impl" => Some(TokenKind::Impl),
315        "signature" => Some(TokenKind::Signature),
316        "as" => Some(TokenKind::As),
317        "track" => Some(TokenKind::Track),
318        "plot" => Some(TokenKind::Plot),
319        "range" => Some(TokenKind::Range),
320        "model" => Some(TokenKind::Model),
321        "loss" => Some(TokenKind::Loss),
322        "optimizer" => Some(TokenKind::Optimizer),
323        "lr_schedule" => Some(TokenKind::LrSchedule),
324        "grad_clip" => Some(TokenKind::GradClip),
325        "precision" => Some(TokenKind::Precision),
326        "accumulation_steps" => Some(TokenKind::AccumulationSteps),
327        "optimizations" => Some(TokenKind::Optimizations),
328        "quantization" => Some(TokenKind::Quantization),
329        "generation" => Some(TokenKind::Generation),
330        "backend" => Some(TokenKind::Backend),
331        "checkpoints" => Some(TokenKind::Checkpoints),
332        "frozen" => Some(TokenKind::Frozen),
333        "device" => Some(TokenKind::Device),
334        "source" => Some(TokenKind::Source),
335        "compute" => Some(TokenKind::Compute),
336        "aggregate" => Some(TokenKind::Aggregate),
337        "log_every" => Some(TokenKind::LogEvery),
338
339        // Hint values
340        "recompute_in_backward" => Some(TokenKind::RecomputeInBackward),
341        "must_preserve" => Some(TokenKind::MustPreserve),
342        "in_place" => Some(TokenKind::InPlace),
343        "no_grad" => Some(TokenKind::NoGrad),
344
345        // DTypes
346        "f16" => Some(TokenKind::F16),
347        "f32" => Some(TokenKind::F32),
348        "f64" => Some(TokenKind::F64),
349        "bf16" => Some(TokenKind::Bf16),
350        "i8" => Some(TokenKind::I8),
351        "i16" => Some(TokenKind::I16),
352        "i32" => Some(TokenKind::I32),
353        "i64" => Some(TokenKind::I64),
354        "u8" => Some(TokenKind::U8),
355        "u16" => Some(TokenKind::U16),
356        "u32" => Some(TokenKind::U32),
357        "u64" => Some(TokenKind::U64),
358        "bool" => Some(TokenKind::Bool),
359        "complex64" => Some(TokenKind::Complex64),
360        "complex128" => Some(TokenKind::Complex128),
361
362        // Devices
363        "cpu" => Some(TokenKind::Cpu),
364        "gpu" => Some(TokenKind::Gpu),
365        "tpu" => Some(TokenKind::Tpu),
366
367        // Tensor keyword
368        "Tensor" => Some(TokenKind::Tensor),
369
370        // Boolean & null literals
371        "true" => Some(TokenKind::True),
372        "false" => Some(TokenKind::False),
373        "null" => Some(TokenKind::Null),
374
375        _ => None,
376    }
377}