1use std::fmt;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct Span {
21 pub offset: usize,
23 pub len: usize,
25 pub line: usize,
27 pub col: usize,
29}
30
31impl Span {
32 pub fn new(offset: usize, len: usize, line: usize, col: usize) -> Self {
33 Self {
34 offset,
35 len,
36 line,
37 col,
38 }
39 }
40}
41
42impl fmt::Display for Span {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 write!(f, "{}:{}", self.line, self.col)
45 }
46}
47
48#[derive(Debug, Clone, PartialEq)]
50pub struct Token {
51 pub kind: TokenKind,
52 pub span: Span,
53}
54
55impl Token {
56 pub fn new(kind: TokenKind, span: Span) -> Self {
57 Self { kind, span }
58 }
59}
60
61#[derive(Debug, Clone, PartialEq)]
63pub enum TokenKind {
64 AtModel, AtConfig, AtTypes, AtGraph, AtCustomOp, AtTraining, AtInference, AtMetrics, AtLogging, AtVisualizations, AtImport, AtAssert, AtCheck, AtHint, Input,
82 Output,
83 Param,
84 Node,
85 Op,
86 Call,
87 If,
88 Else,
89 Repeat,
90 Scan,
91 For,
92 In,
93 Type,
94 Init,
95 Gradient,
96 Impl,
97 Signature,
98 As,
99 Track,
100 Plot,
101 Range,
102 Model,
103 Loss,
104 Optimizer,
105 LrSchedule, GradClip, Precision,
108 AccumulationSteps, Optimizations,
110 Quantization,
111 Generation,
112 Backend,
113 Checkpoints,
114 Frozen,
115 Device,
116 Source,
117 Compute,
118 Aggregate,
119 LogEvery, RecomputeInBackward, MustPreserve, InPlace, NoGrad, F16,
129 F32,
130 F64,
131 Bf16,
132 I8,
133 I16,
134 I32,
135 I64,
136 U8,
137 U16,
138 U32,
139 U64,
140 Bool,
141 Complex64,
142 Complex128,
143
144 Cpu,
146 Gpu,
147 Tpu,
148
149 Tensor,
151
152 IntLit(i64),
154 FloatLit(f64),
155 StringLit(String),
156 True,
157 False,
158 Null,
159
160 Plus, Minus, Star, Slash, Percent, StarStar, EqEq, BangEq, Lt, Gt, LtEq, GtEq, AmpAmp, PipePipe, QuestionQuestion, Amp, Pipe, Caret, LtLt, GtGt, Bang, Tilde, Eq, LBrace, RBrace, LParen, RParen, LBracket, RBracket, Colon, ColonColon, Semi, Comma, Dot, Arrow, Question, Underscore, Ident(String),
203
204 Eof,
206}
207
208impl fmt::Display for TokenKind {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 match self {
211 TokenKind::AtModel => write!(f, "@model"),
212 TokenKind::AtConfig => write!(f, "@config"),
213 TokenKind::AtTypes => write!(f, "@types"),
214 TokenKind::AtGraph => write!(f, "@graph"),
215 TokenKind::AtCustomOp => write!(f, "@custom_op"),
216 TokenKind::AtTraining => write!(f, "@training"),
217 TokenKind::AtInference => write!(f, "@inference"),
218 TokenKind::AtMetrics => write!(f, "@metrics"),
219 TokenKind::AtLogging => write!(f, "@logging"),
220 TokenKind::AtVisualizations => write!(f, "@visualizations"),
221 TokenKind::AtImport => write!(f, "@import"),
222 TokenKind::AtAssert => write!(f, "@assert"),
223 TokenKind::AtCheck => write!(f, "@check"),
224 TokenKind::AtHint => write!(f, "@hint"),
225 TokenKind::Ident(s) => write!(f, "{s}"),
226 TokenKind::IntLit(n) => write!(f, "{n}"),
227 TokenKind::FloatLit(n) => write!(f, "{n}"),
228 TokenKind::StringLit(s) => write!(f, "\"{s}\""),
229 TokenKind::True => write!(f, "true"),
230 TokenKind::False => write!(f, "false"),
231 TokenKind::Null => write!(f, "null"),
232 TokenKind::Eof => write!(f, "<eof>"),
233 other => write!(f, "{other:?}"),
234 }
235 }
236}
237
238impl TokenKind {
239 pub fn keyword_str(&self) -> Option<&'static str> {
242 match self {
243 TokenKind::Input => Some("input"),
244 TokenKind::Output => Some("output"),
245 TokenKind::Param => Some("param"),
246 TokenKind::Node => Some("node"),
247 TokenKind::Op => Some("op"),
248 TokenKind::Call => Some("call"),
249 TokenKind::If => Some("if"),
250 TokenKind::Else => Some("else"),
251 TokenKind::Repeat => Some("repeat"),
252 TokenKind::Scan => Some("scan"),
253 TokenKind::For => Some("for"),
254 TokenKind::In => Some("in"),
255 TokenKind::Type => Some("type"),
256 TokenKind::Init => Some("init"),
257 TokenKind::Gradient => Some("gradient"),
258 TokenKind::Impl => Some("impl"),
259 TokenKind::Signature => Some("signature"),
260 TokenKind::As => Some("as"),
261 TokenKind::Track => Some("track"),
262 TokenKind::Plot => Some("plot"),
263 TokenKind::Range => Some("range"),
264 TokenKind::Model => Some("model"),
265 TokenKind::Loss => Some("loss"),
266 TokenKind::Optimizer => Some("optimizer"),
267 TokenKind::LrSchedule => Some("lr_schedule"),
268 TokenKind::GradClip => Some("grad_clip"),
269 TokenKind::Precision => Some("precision"),
270 TokenKind::AccumulationSteps => Some("accumulation_steps"),
271 TokenKind::Optimizations => Some("optimizations"),
272 TokenKind::Quantization => Some("quantization"),
273 TokenKind::Generation => Some("generation"),
274 TokenKind::Backend => Some("backend"),
275 TokenKind::Checkpoints => Some("checkpoints"),
276 TokenKind::Frozen => Some("frozen"),
277 TokenKind::Device => Some("device"),
278 TokenKind::Source => Some("source"),
279 TokenKind::Compute => Some("compute"),
280 TokenKind::Aggregate => Some("aggregate"),
281 TokenKind::LogEvery => Some("log_every"),
282 TokenKind::RecomputeInBackward => Some("recompute_in_backward"),
283 TokenKind::MustPreserve => Some("must_preserve"),
284 TokenKind::InPlace => Some("in_place"),
285 TokenKind::NoGrad => Some("no_grad"),
286 TokenKind::Cpu => Some("cpu"),
287 TokenKind::Gpu => Some("gpu"),
288 TokenKind::Tpu => Some("tpu"),
289 _ => None,
290 }
291 }
292}
293
294pub fn keyword_lookup(s: &str) -> Option<TokenKind> {
297 match s {
298 "input" => Some(TokenKind::Input),
300 "output" => Some(TokenKind::Output),
301 "param" => Some(TokenKind::Param),
302 "node" => Some(TokenKind::Node),
303 "op" => Some(TokenKind::Op),
304 "call" => Some(TokenKind::Call),
305 "if" => Some(TokenKind::If),
306 "else" => Some(TokenKind::Else),
307 "repeat" => Some(TokenKind::Repeat),
308 "scan" => Some(TokenKind::Scan),
309 "for" => Some(TokenKind::For),
310 "in" => Some(TokenKind::In),
311 "type" => Some(TokenKind::Type),
312 "init" => Some(TokenKind::Init),
313 "gradient" => Some(TokenKind::Gradient),
314 "impl" => Some(TokenKind::Impl),
315 "signature" => Some(TokenKind::Signature),
316 "as" => Some(TokenKind::As),
317 "track" => Some(TokenKind::Track),
318 "plot" => Some(TokenKind::Plot),
319 "range" => Some(TokenKind::Range),
320 "model" => Some(TokenKind::Model),
321 "loss" => Some(TokenKind::Loss),
322 "optimizer" => Some(TokenKind::Optimizer),
323 "lr_schedule" => Some(TokenKind::LrSchedule),
324 "grad_clip" => Some(TokenKind::GradClip),
325 "precision" => Some(TokenKind::Precision),
326 "accumulation_steps" => Some(TokenKind::AccumulationSteps),
327 "optimizations" => Some(TokenKind::Optimizations),
328 "quantization" => Some(TokenKind::Quantization),
329 "generation" => Some(TokenKind::Generation),
330 "backend" => Some(TokenKind::Backend),
331 "checkpoints" => Some(TokenKind::Checkpoints),
332 "frozen" => Some(TokenKind::Frozen),
333 "device" => Some(TokenKind::Device),
334 "source" => Some(TokenKind::Source),
335 "compute" => Some(TokenKind::Compute),
336 "aggregate" => Some(TokenKind::Aggregate),
337 "log_every" => Some(TokenKind::LogEvery),
338
339 "recompute_in_backward" => Some(TokenKind::RecomputeInBackward),
341 "must_preserve" => Some(TokenKind::MustPreserve),
342 "in_place" => Some(TokenKind::InPlace),
343 "no_grad" => Some(TokenKind::NoGrad),
344
345 "f16" => Some(TokenKind::F16),
347 "f32" => Some(TokenKind::F32),
348 "f64" => Some(TokenKind::F64),
349 "bf16" => Some(TokenKind::Bf16),
350 "i8" => Some(TokenKind::I8),
351 "i16" => Some(TokenKind::I16),
352 "i32" => Some(TokenKind::I32),
353 "i64" => Some(TokenKind::I64),
354 "u8" => Some(TokenKind::U8),
355 "u16" => Some(TokenKind::U16),
356 "u32" => Some(TokenKind::U32),
357 "u64" => Some(TokenKind::U64),
358 "bool" => Some(TokenKind::Bool),
359 "complex64" => Some(TokenKind::Complex64),
360 "complex128" => Some(TokenKind::Complex128),
361
362 "cpu" => Some(TokenKind::Cpu),
364 "gpu" => Some(TokenKind::Gpu),
365 "tpu" => Some(TokenKind::Tpu),
366
367 "Tensor" => Some(TokenKind::Tensor),
369
370 "true" => Some(TokenKind::True),
372 "false" => Some(TokenKind::False),
373 "null" => Some(TokenKind::Null),
374
375 _ => None,
376 }
377}