1use crate::error::{Error, ErrorKind, Result};
22use crate::token::{keyword_lookup, Span, Token, TokenKind};
23
24pub struct Lexer<'src> {
26 src: &'src str,
27 bytes: &'src [u8],
28 pos: usize,
29 line: usize,
30 col: usize,
31}
32
33impl<'src> Lexer<'src> {
34 pub fn new(src: &'src str) -> Self {
35 Self {
36 src,
37 bytes: src.as_bytes(),
38 pos: 0,
39 line: 1,
40 col: 1,
41 }
42 }
43
44 pub fn tokenize(&mut self) -> Result<Vec<Token>> {
47 let mut tokens = Vec::new();
48 loop {
49 let tok = self.next_token()?;
50 let is_eof = tok.kind == TokenKind::Eof;
51 tokens.push(tok);
52 if is_eof {
53 break;
54 }
55 }
56 Ok(tokens)
57 }
58
59 fn next_token(&mut self) -> Result<Token> {
61 self.skip_whitespace_and_comments()?;
62
63 if self.pos >= self.bytes.len() {
64 return Ok(Token::new(TokenKind::Eof, self.span(0)));
65 }
66
67 let start_pos = self.pos;
68 let start_line = self.line;
69 let start_col = self.col;
70 let ch = self.bytes[self.pos] as char;
71
72 if ch == '@' {
74 self.advance();
75 let ident_start = self.pos;
76 while self.pos < self.bytes.len()
77 && (self.bytes[self.pos].is_ascii_alphanumeric() || self.bytes[self.pos] == b'_')
78 {
79 self.advance();
80 }
81 let name = &self.src[ident_start..self.pos];
82 let kind = match name {
83 "model" => TokenKind::AtModel,
84 "config" => TokenKind::AtConfig,
85 "types" => TokenKind::AtTypes,
86 "graph" => TokenKind::AtGraph,
87 "custom_op" => TokenKind::AtCustomOp,
88 "training" => TokenKind::AtTraining,
89 "inference" => TokenKind::AtInference,
90 "metrics" => TokenKind::AtMetrics,
91 "logging" => TokenKind::AtLogging,
92 "visualizations" => TokenKind::AtVisualizations,
93 "import" => TokenKind::AtImport,
94 "assert" => TokenKind::AtAssert,
95 "check" => TokenKind::AtCheck,
96 "hint" => TokenKind::AtHint,
97 _ => {
98 return Err(Error::new(
99 ErrorKind::UnknownDirective(name.to_string()),
100 Span::new(start_pos, self.pos - start_pos, start_line, start_col),
101 ))
102 }
103 };
104 return Ok(Token::new(
105 kind,
106 Span::new(start_pos, self.pos - start_pos, start_line, start_col),
107 ));
108 }
109
110 if ch == '"' {
112 return self.lex_string(start_pos, start_line, start_col);
113 }
114
115 if ch.is_ascii_digit() {
117 return self.lex_number(start_pos, start_line, start_col);
118 }
119
120 if ch.is_ascii_alphabetic() || ch == '_' {
122 return self.lex_ident(start_pos, start_line, start_col);
123 }
124
125 let kind = match ch {
127 '{' => {
128 self.advance();
129 TokenKind::LBrace
130 }
131 '}' => {
132 self.advance();
133 TokenKind::RBrace
134 }
135 '(' => {
136 self.advance();
137 TokenKind::LParen
138 }
139 ')' => {
140 self.advance();
141 TokenKind::RParen
142 }
143 '[' => {
144 self.advance();
145 TokenKind::LBracket
146 }
147 ']' => {
148 self.advance();
149 TokenKind::RBracket
150 }
151 ';' => {
152 self.advance();
153 TokenKind::Semi
154 }
155 ',' => {
156 self.advance();
157 TokenKind::Comma
158 }
159 '~' => {
160 self.advance();
161 TokenKind::Tilde
162 }
163 ':' => {
164 self.advance();
165 if self.peek() == Some(':') {
166 self.advance();
167 TokenKind::ColonColon
168 } else {
169 TokenKind::Colon
170 }
171 }
172 '.' => {
173 self.advance();
174 TokenKind::Dot
175 }
176 '+' => {
177 self.advance();
178 TokenKind::Plus
179 }
180 '%' => {
181 self.advance();
182 TokenKind::Percent
183 }
184 '^' => {
185 self.advance();
186 TokenKind::Caret
187 }
188 '-' => {
189 self.advance();
190 if self.peek() == Some('>') {
191 self.advance();
192 TokenKind::Arrow
193 } else {
194 TokenKind::Minus
195 }
196 }
197 '*' => {
198 self.advance();
199 if self.peek() == Some('*') {
200 self.advance();
201 TokenKind::StarStar
202 } else {
203 TokenKind::Star
204 }
205 }
206 '/' => {
207 self.advance();
208 TokenKind::Slash
209 }
210 '=' => {
211 self.advance();
212 if self.peek() == Some('=') {
213 self.advance();
214 TokenKind::EqEq
215 } else {
216 TokenKind::Eq
217 }
218 }
219 '!' => {
220 self.advance();
221 if self.peek() == Some('=') {
222 self.advance();
223 TokenKind::BangEq
224 } else {
225 TokenKind::Bang
226 }
227 }
228 '<' => {
229 self.advance();
230 if self.peek() == Some('=') {
231 self.advance();
232 TokenKind::LtEq
233 } else if self.peek() == Some('<') {
234 self.advance();
235 TokenKind::LtLt
236 } else {
237 TokenKind::Lt
238 }
239 }
240 '>' => {
241 self.advance();
242 if self.peek() == Some('=') {
243 self.advance();
244 TokenKind::GtEq
245 } else if self.peek() == Some('>') {
246 self.advance();
247 TokenKind::GtGt
248 } else {
249 TokenKind::Gt
250 }
251 }
252 '&' => {
253 self.advance();
254 if self.peek() == Some('&') {
255 self.advance();
256 TokenKind::AmpAmp
257 } else {
258 TokenKind::Amp
259 }
260 }
261 '|' => {
262 self.advance();
263 if self.peek() == Some('|') {
264 self.advance();
265 TokenKind::PipePipe
266 } else {
267 TokenKind::Pipe
268 }
269 }
270 '?' => {
271 self.advance();
272 if self.peek() == Some('?') {
273 self.advance();
274 TokenKind::QuestionQuestion
275 } else {
276 TokenKind::Question
277 }
278 }
279 _ => {
280 return Err(Error::new(
281 ErrorKind::UnexpectedChar(ch),
282 Span::new(start_pos, 1, start_line, start_col),
283 ));
284 }
285 };
286
287 Ok(Token::new(
288 kind,
289 Span::new(start_pos, self.pos - start_pos, start_line, start_col),
290 ))
291 }
292
293 fn advance(&mut self) {
296 if self.pos < self.bytes.len() {
297 if self.bytes[self.pos] == b'\n' {
298 self.line += 1;
299 self.col = 1;
300 } else {
301 self.col += 1;
302 }
303 self.pos += 1;
304 }
305 }
306
307 fn peek(&self) -> Option<char> {
308 if self.pos < self.bytes.len() {
309 Some(self.bytes[self.pos] as char)
310 } else {
311 None
312 }
313 }
314
315 fn span(&self, len: usize) -> Span {
316 Span::new(self.pos, len, self.line, self.col)
317 }
318
319 fn skip_whitespace_and_comments(&mut self) -> Result<()> {
321 loop {
322 while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_whitespace() {
324 self.advance();
325 }
326
327 if self.pos + 1 < self.bytes.len()
328 && self.bytes[self.pos] == b'/'
329 && self.bytes[self.pos + 1] == b'/'
330 {
331 while self.pos < self.bytes.len() && self.bytes[self.pos] != b'\n' {
333 self.advance();
334 }
335 continue;
336 }
337
338 if self.pos + 1 < self.bytes.len()
339 && self.bytes[self.pos] == b'/'
340 && self.bytes[self.pos + 1] == b'*'
341 {
342 let start_pos = self.pos;
344 let start_line = self.line;
345 let start_col = self.col;
346 self.advance(); self.advance(); loop {
349 if self.pos >= self.bytes.len() {
350 return Err(Error::new(
351 ErrorKind::UnterminatedComment,
352 Span::new(start_pos, 2, start_line, start_col),
353 ));
354 }
355 if self.bytes[self.pos] == b'*'
356 && self.pos + 1 < self.bytes.len()
357 && self.bytes[self.pos + 1] == b'/'
358 {
359 self.advance(); self.advance(); break;
362 }
363 self.advance();
364 }
365 continue;
366 }
367
368 break;
369 }
370 Ok(())
371 }
372
373 fn lex_string(
375 &mut self,
376 start_pos: usize,
377 start_line: usize,
378 start_col: usize,
379 ) -> Result<Token> {
380 self.advance(); let mut value = String::new();
382 loop {
383 if self.pos >= self.bytes.len() {
384 return Err(Error::new(
385 ErrorKind::UnterminatedString,
386 Span::new(start_pos, self.pos - start_pos, start_line, start_col),
387 ));
388 }
389 let ch = self.bytes[self.pos] as char;
390 if ch == '"' {
391 self.advance(); break;
393 }
394 if ch == '\\' {
395 self.advance();
396 if self.pos >= self.bytes.len() {
397 return Err(Error::new(
398 ErrorKind::UnterminatedString,
399 Span::new(start_pos, self.pos - start_pos, start_line, start_col),
400 ));
401 }
402 let esc = self.bytes[self.pos] as char;
403 match esc {
404 'n' => value.push('\n'),
405 't' => value.push('\t'),
406 '\\' => value.push('\\'),
407 '"' => value.push('"'),
408 _ => {
409 value.push('\\');
410 value.push(esc);
411 }
412 }
413 self.advance();
414 } else {
415 value.push(ch);
416 self.advance();
417 }
418 }
419 Ok(Token::new(
420 TokenKind::StringLit(value),
421 Span::new(start_pos, self.pos - start_pos, start_line, start_col),
422 ))
423 }
424
425 fn lex_number(
427 &mut self,
428 start_pos: usize,
429 start_line: usize,
430 start_col: usize,
431 ) -> Result<Token> {
432 let num_start = self.pos;
433 while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
434 self.advance();
435 }
436
437 let mut is_float = false;
438
439 if self.pos < self.bytes.len()
441 && self.bytes[self.pos] == b'.'
442 && self.pos + 1 < self.bytes.len()
443 && self.bytes[self.pos + 1].is_ascii_digit()
444 {
445 is_float = true;
446 self.advance(); while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
448 self.advance();
449 }
450 }
451
452 if self.pos < self.bytes.len()
454 && (self.bytes[self.pos] == b'e' || self.bytes[self.pos] == b'E')
455 {
456 is_float = true;
457 self.advance(); if self.pos < self.bytes.len()
459 && (self.bytes[self.pos] == b'+' || self.bytes[self.pos] == b'-')
460 {
461 self.advance();
462 }
463 if self.pos >= self.bytes.len() || !self.bytes[self.pos].is_ascii_digit() {
464 let raw = &self.src[num_start..self.pos];
465 return Err(Error::new(
466 ErrorKind::InvalidNumber(raw.to_string()),
467 Span::new(start_pos, self.pos - start_pos, start_line, start_col),
468 ));
469 }
470 while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
471 self.advance();
472 }
473 }
474
475 let raw = &self.src[num_start..self.pos];
476 let span = Span::new(start_pos, self.pos - start_pos, start_line, start_col);
477
478 if is_float {
479 let val: f64 = raw
480 .parse()
481 .map_err(|_| Error::new(ErrorKind::InvalidNumber(raw.to_string()), span))?;
482 Ok(Token::new(TokenKind::FloatLit(val), span))
483 } else {
484 let val: i64 = raw
485 .parse()
486 .map_err(|_| Error::new(ErrorKind::InvalidNumber(raw.to_string()), span))?;
487 Ok(Token::new(TokenKind::IntLit(val), span))
488 }
489 }
490
491 fn lex_ident(
493 &mut self,
494 start_pos: usize,
495 start_line: usize,
496 start_col: usize,
497 ) -> Result<Token> {
498 let id_start = self.pos;
499 while self.pos < self.bytes.len()
500 && (self.bytes[self.pos].is_ascii_alphanumeric() || self.bytes[self.pos] == b'_')
501 {
502 self.advance();
503 }
504 let word = &self.src[id_start..self.pos];
505 let span = Span::new(start_pos, self.pos - start_pos, start_line, start_col);
506
507 if word == "_" {
509 return Ok(Token::new(TokenKind::Underscore, span));
510 }
511
512 let kind = keyword_lookup(word).unwrap_or_else(|| TokenKind::Ident(word.to_string()));
513 Ok(Token::new(kind, span))
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 fn lex(src: &str) -> Vec<TokenKind> {
522 Lexer::new(src)
523 .tokenize()
524 .unwrap()
525 .into_iter()
526 .map(|t| t.kind)
527 .collect()
528 }
529
530 #[test]
531 fn test_directive() {
532 let kinds = lex("@model { }");
533 assert_eq!(
534 kinds,
535 vec![
536 TokenKind::AtModel,
537 TokenKind::LBrace,
538 TokenKind::RBrace,
539 TokenKind::Eof,
540 ]
541 );
542 }
543
544 #[test]
545 fn test_ident_and_keyword() {
546 let kinds = lex("input foo");
547 assert_eq!(
548 kinds,
549 vec![
550 TokenKind::Input,
551 TokenKind::Ident("foo".into()),
552 TokenKind::Eof,
553 ]
554 );
555 }
556
557 #[test]
558 fn test_number_literals() {
559 let kinds = lex("42 3.14 1e-4");
560 assert_eq!(
561 kinds,
562 vec![
563 TokenKind::IntLit(42),
564 TokenKind::FloatLit(3.14),
565 TokenKind::FloatLit(1e-4),
566 TokenKind::Eof,
567 ]
568 );
569 }
570
571 #[test]
572 fn test_string_literal() {
573 let kinds = lex(r#""hello world""#);
574 assert_eq!(
575 kinds,
576 vec![TokenKind::StringLit("hello world".into()), TokenKind::Eof,]
577 );
578 }
579
580 #[test]
581 fn test_string_escape() {
582 let kinds = lex(r#""line\none""#);
583 assert_eq!(
584 kinds,
585 vec![TokenKind::StringLit("line\none".into()), TokenKind::Eof,]
586 );
587 }
588
589 #[test]
590 fn test_operators() {
591 let kinds = lex("+ - * / ** == != <= >= && || -> :: ?? <<");
592 assert_eq!(
593 kinds,
594 vec![
595 TokenKind::Plus,
596 TokenKind::Minus,
597 TokenKind::Star,
598 TokenKind::Slash,
599 TokenKind::StarStar,
600 TokenKind::EqEq,
601 TokenKind::BangEq,
602 TokenKind::LtEq,
603 TokenKind::GtEq,
604 TokenKind::AmpAmp,
605 TokenKind::PipePipe,
606 TokenKind::Arrow,
607 TokenKind::ColonColon,
608 TokenKind::QuestionQuestion,
609 TokenKind::LtLt,
610 TokenKind::Eof,
611 ]
612 );
613 }
614
615 #[test]
616 fn test_dtype_keywords() {
617 let kinds = lex("f32 f64 bf16 i64 u8 bool");
618 assert_eq!(
619 kinds,
620 vec![
621 TokenKind::F32,
622 TokenKind::F64,
623 TokenKind::Bf16,
624 TokenKind::I64,
625 TokenKind::U8,
626 TokenKind::Bool,
627 TokenKind::Eof,
628 ]
629 );
630 }
631
632 #[test]
633 fn test_comment_skipping() {
634 let kinds = lex("input // this is a comment\noutput");
635 assert_eq!(
636 kinds,
637 vec![TokenKind::Input, TokenKind::Output, TokenKind::Eof,]
638 );
639 }
640
641 #[test]
642 fn test_block_comment() {
643 let kinds = lex("input /* skip this */ output");
644 assert_eq!(
645 kinds,
646 vec![TokenKind::Input, TokenKind::Output, TokenKind::Eof,]
647 );
648 }
649
650 #[test]
651 fn test_tensor_type_tokens() {
652 let kinds = lex("Tensor<[Batch, 768], f32>");
653 assert_eq!(
654 kinds,
655 vec![
656 TokenKind::Tensor,
657 TokenKind::Lt,
658 TokenKind::LBracket,
659 TokenKind::Ident("Batch".into()),
660 TokenKind::Comma,
661 TokenKind::IntLit(768),
662 TokenKind::RBracket,
663 TokenKind::Comma,
664 TokenKind::F32,
665 TokenKind::Gt,
666 TokenKind::Eof,
667 ]
668 );
669 }
670
671 #[test]
672 fn test_full_model_block() {
673 let src = r#"
674 @model {
675 name: "GPT-2";
676 version: "1.0";
677 }
678 "#;
679 let kinds = lex(src);
680 assert_eq!(
681 kinds,
682 vec![
683 TokenKind::AtModel,
684 TokenKind::LBrace,
685 TokenKind::Ident("name".into()),
686 TokenKind::Colon,
687 TokenKind::StringLit("GPT-2".into()),
688 TokenKind::Semi,
689 TokenKind::Ident("version".into()),
690 TokenKind::Colon,
691 TokenKind::StringLit("1.0".into()),
692 TokenKind::Semi,
693 TokenKind::RBrace,
694 TokenKind::Eof,
695 ]
696 );
697 }
698
699 #[test]
700 fn test_negative_number_as_minus_int() {
701 let kinds = lex("-3");
703 assert_eq!(
704 kinds,
705 vec![TokenKind::Minus, TokenKind::IntLit(3), TokenKind::Eof,]
706 );
707 }
708
709 #[test]
710 fn test_underscore_and_question() {
711 let kinds = lex("_ ? ??");
712 assert_eq!(
713 kinds,
714 vec![
715 TokenKind::Underscore,
716 TokenKind::Question,
717 TokenKind::QuestionQuestion,
718 TokenKind::Eof,
719 ]
720 );
721 }
722
723 #[test]
724 fn test_span_tracking() {
725 let tokens = Lexer::new("ab cd").tokenize().unwrap();
726 assert_eq!(tokens[0].span.line, 1);
727 assert_eq!(tokens[0].span.col, 1);
728 assert_eq!(tokens[1].span.line, 1);
729 assert_eq!(tokens[1].span.col, 4);
730 }
731
732 #[test]
733 fn test_multiline_span() {
734 let tokens = Lexer::new("ab\ncd").tokenize().unwrap();
735 assert_eq!(tokens[0].span.line, 1);
736 assert_eq!(tokens[1].span.line, 2);
737 assert_eq!(tokens[1].span.col, 1);
738 }
739}