shrew_ir/
lexer.rs

1// Lexer — Converts .sw source text into a stream of Tokens
2//
3// The lexer is a hand-written scanner (no regex, no generator). It processes
4// the source one byte at a time, producing Token values.
5//
6// DESIGN DECISIONS:
7//
8//   1. We allow identifiers to contain underscores (e.g. `lr_schedule`),
9//      and these map to keyword tokens via `keyword_lookup`.
10//
11//   2. @ followed by an identifier is lexed as a single directive token.
12//
13//   3. Numbers: we support integers and floats (with optional exponent).
14//      We don't try to distinguish i64 vs u64 at lex time — that's the
15//      parser's job. All ints are stored as i64, floats as f64.
16//
17//   4. Strings use double quotes and support basic escapes: \n \t \\ \"
18//
19//   5. Comments: // line comments and /* block comments */ (no nesting).
20
21use crate::error::{Error, ErrorKind, Result};
22use crate::token::{keyword_lookup, Span, Token, TokenKind};
23
24/// Lexer state over a source string.
25pub struct Lexer<'src> {
26    src: &'src str,
27    bytes: &'src [u8],
28    pos: usize,
29    line: usize,
30    col: usize,
31}
32
33impl<'src> Lexer<'src> {
34    pub fn new(src: &'src str) -> Self {
35        Self {
36            src,
37            bytes: src.as_bytes(),
38            pos: 0,
39            line: 1,
40            col: 1,
41        }
42    }
43
44    /// Tokenize the entire source, returning a Vec of Tokens.
45    /// The last token is always Eof.
46    pub fn tokenize(&mut self) -> Result<Vec<Token>> {
47        let mut tokens = Vec::new();
48        loop {
49            let tok = self.next_token()?;
50            let is_eof = tok.kind == TokenKind::Eof;
51            tokens.push(tok);
52            if is_eof {
53                break;
54            }
55        }
56        Ok(tokens)
57    }
58
59    /// Read the next token.
60    fn next_token(&mut self) -> Result<Token> {
61        self.skip_whitespace_and_comments()?;
62
63        if self.pos >= self.bytes.len() {
64            return Ok(Token::new(TokenKind::Eof, self.span(0)));
65        }
66
67        let start_pos = self.pos;
68        let start_line = self.line;
69        let start_col = self.col;
70        let ch = self.bytes[self.pos] as char;
71
72        //  Directive (@keyword) 
73        if ch == '@' {
74            self.advance();
75            let ident_start = self.pos;
76            while self.pos < self.bytes.len()
77                && (self.bytes[self.pos].is_ascii_alphanumeric() || self.bytes[self.pos] == b'_')
78            {
79                self.advance();
80            }
81            let name = &self.src[ident_start..self.pos];
82            let kind = match name {
83                "model" => TokenKind::AtModel,
84                "config" => TokenKind::AtConfig,
85                "types" => TokenKind::AtTypes,
86                "graph" => TokenKind::AtGraph,
87                "custom_op" => TokenKind::AtCustomOp,
88                "training" => TokenKind::AtTraining,
89                "inference" => TokenKind::AtInference,
90                "metrics" => TokenKind::AtMetrics,
91                "logging" => TokenKind::AtLogging,
92                "visualizations" => TokenKind::AtVisualizations,
93                "import" => TokenKind::AtImport,
94                "assert" => TokenKind::AtAssert,
95                "check" => TokenKind::AtCheck,
96                "hint" => TokenKind::AtHint,
97                _ => {
98                    return Err(Error::new(
99                        ErrorKind::UnknownDirective(name.to_string()),
100                        Span::new(start_pos, self.pos - start_pos, start_line, start_col),
101                    ))
102                }
103            };
104            return Ok(Token::new(
105                kind,
106                Span::new(start_pos, self.pos - start_pos, start_line, start_col),
107            ));
108        }
109
110        //  String literal 
111        if ch == '"' {
112            return self.lex_string(start_pos, start_line, start_col);
113        }
114
115        //  Number literal 
116        if ch.is_ascii_digit() {
117            return self.lex_number(start_pos, start_line, start_col);
118        }
119
120        //  Identifier or keyword 
121        if ch.is_ascii_alphabetic() || ch == '_' {
122            return self.lex_ident(start_pos, start_line, start_col);
123        }
124
125        //  Multi-char operators & punctuation 
126        let kind = match ch {
127            '{' => {
128                self.advance();
129                TokenKind::LBrace
130            }
131            '}' => {
132                self.advance();
133                TokenKind::RBrace
134            }
135            '(' => {
136                self.advance();
137                TokenKind::LParen
138            }
139            ')' => {
140                self.advance();
141                TokenKind::RParen
142            }
143            '[' => {
144                self.advance();
145                TokenKind::LBracket
146            }
147            ']' => {
148                self.advance();
149                TokenKind::RBracket
150            }
151            ';' => {
152                self.advance();
153                TokenKind::Semi
154            }
155            ',' => {
156                self.advance();
157                TokenKind::Comma
158            }
159            '~' => {
160                self.advance();
161                TokenKind::Tilde
162            }
163            ':' => {
164                self.advance();
165                if self.peek() == Some(':') {
166                    self.advance();
167                    TokenKind::ColonColon
168                } else {
169                    TokenKind::Colon
170                }
171            }
172            '.' => {
173                self.advance();
174                TokenKind::Dot
175            }
176            '+' => {
177                self.advance();
178                TokenKind::Plus
179            }
180            '%' => {
181                self.advance();
182                TokenKind::Percent
183            }
184            '^' => {
185                self.advance();
186                TokenKind::Caret
187            }
188            '-' => {
189                self.advance();
190                if self.peek() == Some('>') {
191                    self.advance();
192                    TokenKind::Arrow
193                } else {
194                    TokenKind::Minus
195                }
196            }
197            '*' => {
198                self.advance();
199                if self.peek() == Some('*') {
200                    self.advance();
201                    TokenKind::StarStar
202                } else {
203                    TokenKind::Star
204                }
205            }
206            '/' => {
207                self.advance();
208                TokenKind::Slash
209            }
210            '=' => {
211                self.advance();
212                if self.peek() == Some('=') {
213                    self.advance();
214                    TokenKind::EqEq
215                } else {
216                    TokenKind::Eq
217                }
218            }
219            '!' => {
220                self.advance();
221                if self.peek() == Some('=') {
222                    self.advance();
223                    TokenKind::BangEq
224                } else {
225                    TokenKind::Bang
226                }
227            }
228            '<' => {
229                self.advance();
230                if self.peek() == Some('=') {
231                    self.advance();
232                    TokenKind::LtEq
233                } else if self.peek() == Some('<') {
234                    self.advance();
235                    TokenKind::LtLt
236                } else {
237                    TokenKind::Lt
238                }
239            }
240            '>' => {
241                self.advance();
242                if self.peek() == Some('=') {
243                    self.advance();
244                    TokenKind::GtEq
245                } else if self.peek() == Some('>') {
246                    self.advance();
247                    TokenKind::GtGt
248                } else {
249                    TokenKind::Gt
250                }
251            }
252            '&' => {
253                self.advance();
254                if self.peek() == Some('&') {
255                    self.advance();
256                    TokenKind::AmpAmp
257                } else {
258                    TokenKind::Amp
259                }
260            }
261            '|' => {
262                self.advance();
263                if self.peek() == Some('|') {
264                    self.advance();
265                    TokenKind::PipePipe
266                } else {
267                    TokenKind::Pipe
268                }
269            }
270            '?' => {
271                self.advance();
272                if self.peek() == Some('?') {
273                    self.advance();
274                    TokenKind::QuestionQuestion
275                } else {
276                    TokenKind::Question
277                }
278            }
279            _ => {
280                return Err(Error::new(
281                    ErrorKind::UnexpectedChar(ch),
282                    Span::new(start_pos, 1, start_line, start_col),
283                ));
284            }
285        };
286
287        Ok(Token::new(
288            kind,
289            Span::new(start_pos, self.pos - start_pos, start_line, start_col),
290        ))
291    }
292
293    // Helpers 
294
295    fn advance(&mut self) {
296        if self.pos < self.bytes.len() {
297            if self.bytes[self.pos] == b'\n' {
298                self.line += 1;
299                self.col = 1;
300            } else {
301                self.col += 1;
302            }
303            self.pos += 1;
304        }
305    }
306
307    fn peek(&self) -> Option<char> {
308        if self.pos < self.bytes.len() {
309            Some(self.bytes[self.pos] as char)
310        } else {
311            None
312        }
313    }
314
315    fn span(&self, len: usize) -> Span {
316        Span::new(self.pos, len, self.line, self.col)
317    }
318
319    /// Skip whitespace, single-line comments (//), and block comments (/* */).
320    fn skip_whitespace_and_comments(&mut self) -> Result<()> {
321        loop {
322            // Skip whitespace
323            while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_whitespace() {
324                self.advance();
325            }
326
327            if self.pos + 1 < self.bytes.len()
328                && self.bytes[self.pos] == b'/'
329                && self.bytes[self.pos + 1] == b'/'
330            {
331                // Single-line comment: skip to end of line
332                while self.pos < self.bytes.len() && self.bytes[self.pos] != b'\n' {
333                    self.advance();
334                }
335                continue;
336            }
337
338            if self.pos + 1 < self.bytes.len()
339                && self.bytes[self.pos] == b'/'
340                && self.bytes[self.pos + 1] == b'*'
341            {
342                // Block comment: skip to */
343                let start_pos = self.pos;
344                let start_line = self.line;
345                let start_col = self.col;
346                self.advance(); // /
347                self.advance(); // *
348                loop {
349                    if self.pos >= self.bytes.len() {
350                        return Err(Error::new(
351                            ErrorKind::UnterminatedComment,
352                            Span::new(start_pos, 2, start_line, start_col),
353                        ));
354                    }
355                    if self.bytes[self.pos] == b'*'
356                        && self.pos + 1 < self.bytes.len()
357                        && self.bytes[self.pos + 1] == b'/'
358                    {
359                        self.advance(); // *
360                        self.advance(); // /
361                        break;
362                    }
363                    self.advance();
364                }
365                continue;
366            }
367
368            break;
369        }
370        Ok(())
371    }
372
373    /// Lex a string literal (starting after the opening `"`).
374    fn lex_string(
375        &mut self,
376        start_pos: usize,
377        start_line: usize,
378        start_col: usize,
379    ) -> Result<Token> {
380        self.advance(); // skip opening "
381        let mut value = String::new();
382        loop {
383            if self.pos >= self.bytes.len() {
384                return Err(Error::new(
385                    ErrorKind::UnterminatedString,
386                    Span::new(start_pos, self.pos - start_pos, start_line, start_col),
387                ));
388            }
389            let ch = self.bytes[self.pos] as char;
390            if ch == '"' {
391                self.advance(); // skip closing "
392                break;
393            }
394            if ch == '\\' {
395                self.advance();
396                if self.pos >= self.bytes.len() {
397                    return Err(Error::new(
398                        ErrorKind::UnterminatedString,
399                        Span::new(start_pos, self.pos - start_pos, start_line, start_col),
400                    ));
401                }
402                let esc = self.bytes[self.pos] as char;
403                match esc {
404                    'n' => value.push('\n'),
405                    't' => value.push('\t'),
406                    '\\' => value.push('\\'),
407                    '"' => value.push('"'),
408                    _ => {
409                        value.push('\\');
410                        value.push(esc);
411                    }
412                }
413                self.advance();
414            } else {
415                value.push(ch);
416                self.advance();
417            }
418        }
419        Ok(Token::new(
420            TokenKind::StringLit(value),
421            Span::new(start_pos, self.pos - start_pos, start_line, start_col),
422        ))
423    }
424
425    /// Lex a number: integer or float.
426    fn lex_number(
427        &mut self,
428        start_pos: usize,
429        start_line: usize,
430        start_col: usize,
431    ) -> Result<Token> {
432        let num_start = self.pos;
433        while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
434            self.advance();
435        }
436
437        let mut is_float = false;
438
439        // Check for decimal point
440        if self.pos < self.bytes.len()
441            && self.bytes[self.pos] == b'.'
442            && self.pos + 1 < self.bytes.len()
443            && self.bytes[self.pos + 1].is_ascii_digit()
444        {
445            is_float = true;
446            self.advance(); // skip .
447            while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
448                self.advance();
449            }
450        }
451
452        // Check for exponent
453        if self.pos < self.bytes.len()
454            && (self.bytes[self.pos] == b'e' || self.bytes[self.pos] == b'E')
455        {
456            is_float = true;
457            self.advance(); // skip e/E
458            if self.pos < self.bytes.len()
459                && (self.bytes[self.pos] == b'+' || self.bytes[self.pos] == b'-')
460            {
461                self.advance();
462            }
463            if self.pos >= self.bytes.len() || !self.bytes[self.pos].is_ascii_digit() {
464                let raw = &self.src[num_start..self.pos];
465                return Err(Error::new(
466                    ErrorKind::InvalidNumber(raw.to_string()),
467                    Span::new(start_pos, self.pos - start_pos, start_line, start_col),
468                ));
469            }
470            while self.pos < self.bytes.len() && self.bytes[self.pos].is_ascii_digit() {
471                self.advance();
472            }
473        }
474
475        let raw = &self.src[num_start..self.pos];
476        let span = Span::new(start_pos, self.pos - start_pos, start_line, start_col);
477
478        if is_float {
479            let val: f64 = raw
480                .parse()
481                .map_err(|_| Error::new(ErrorKind::InvalidNumber(raw.to_string()), span))?;
482            Ok(Token::new(TokenKind::FloatLit(val), span))
483        } else {
484            let val: i64 = raw
485                .parse()
486                .map_err(|_| Error::new(ErrorKind::InvalidNumber(raw.to_string()), span))?;
487            Ok(Token::new(TokenKind::IntLit(val), span))
488        }
489    }
490
491    /// Lex an identifier or keyword.
492    fn lex_ident(
493        &mut self,
494        start_pos: usize,
495        start_line: usize,
496        start_col: usize,
497    ) -> Result<Token> {
498        let id_start = self.pos;
499        while self.pos < self.bytes.len()
500            && (self.bytes[self.pos].is_ascii_alphanumeric() || self.bytes[self.pos] == b'_')
501        {
502            self.advance();
503        }
504        let word = &self.src[id_start..self.pos];
505        let span = Span::new(start_pos, self.pos - start_pos, start_line, start_col);
506
507        // Check if it's _ alone (Underscore token for inferred dims)
508        if word == "_" {
509            return Ok(Token::new(TokenKind::Underscore, span));
510        }
511
512        let kind = keyword_lookup(word).unwrap_or_else(|| TokenKind::Ident(word.to_string()));
513        Ok(Token::new(kind, span))
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    fn lex(src: &str) -> Vec<TokenKind> {
522        Lexer::new(src)
523            .tokenize()
524            .unwrap()
525            .into_iter()
526            .map(|t| t.kind)
527            .collect()
528    }
529
530    #[test]
531    fn test_directive() {
532        let kinds = lex("@model { }");
533        assert_eq!(
534            kinds,
535            vec![
536                TokenKind::AtModel,
537                TokenKind::LBrace,
538                TokenKind::RBrace,
539                TokenKind::Eof,
540            ]
541        );
542    }
543
544    #[test]
545    fn test_ident_and_keyword() {
546        let kinds = lex("input foo");
547        assert_eq!(
548            kinds,
549            vec![
550                TokenKind::Input,
551                TokenKind::Ident("foo".into()),
552                TokenKind::Eof,
553            ]
554        );
555    }
556
557    #[test]
558    fn test_number_literals() {
559        let kinds = lex("42 3.14 1e-4");
560        assert_eq!(
561            kinds,
562            vec![
563                TokenKind::IntLit(42),
564                TokenKind::FloatLit(3.14),
565                TokenKind::FloatLit(1e-4),
566                TokenKind::Eof,
567            ]
568        );
569    }
570
571    #[test]
572    fn test_string_literal() {
573        let kinds = lex(r#""hello world""#);
574        assert_eq!(
575            kinds,
576            vec![TokenKind::StringLit("hello world".into()), TokenKind::Eof,]
577        );
578    }
579
580    #[test]
581    fn test_string_escape() {
582        let kinds = lex(r#""line\none""#);
583        assert_eq!(
584            kinds,
585            vec![TokenKind::StringLit("line\none".into()), TokenKind::Eof,]
586        );
587    }
588
589    #[test]
590    fn test_operators() {
591        let kinds = lex("+ - * / ** == != <= >= && || -> :: ?? <<");
592        assert_eq!(
593            kinds,
594            vec![
595                TokenKind::Plus,
596                TokenKind::Minus,
597                TokenKind::Star,
598                TokenKind::Slash,
599                TokenKind::StarStar,
600                TokenKind::EqEq,
601                TokenKind::BangEq,
602                TokenKind::LtEq,
603                TokenKind::GtEq,
604                TokenKind::AmpAmp,
605                TokenKind::PipePipe,
606                TokenKind::Arrow,
607                TokenKind::ColonColon,
608                TokenKind::QuestionQuestion,
609                TokenKind::LtLt,
610                TokenKind::Eof,
611            ]
612        );
613    }
614
615    #[test]
616    fn test_dtype_keywords() {
617        let kinds = lex("f32 f64 bf16 i64 u8 bool");
618        assert_eq!(
619            kinds,
620            vec![
621                TokenKind::F32,
622                TokenKind::F64,
623                TokenKind::Bf16,
624                TokenKind::I64,
625                TokenKind::U8,
626                TokenKind::Bool,
627                TokenKind::Eof,
628            ]
629        );
630    }
631
632    #[test]
633    fn test_comment_skipping() {
634        let kinds = lex("input // this is a comment\noutput");
635        assert_eq!(
636            kinds,
637            vec![TokenKind::Input, TokenKind::Output, TokenKind::Eof,]
638        );
639    }
640
641    #[test]
642    fn test_block_comment() {
643        let kinds = lex("input /* skip this */ output");
644        assert_eq!(
645            kinds,
646            vec![TokenKind::Input, TokenKind::Output, TokenKind::Eof,]
647        );
648    }
649
650    #[test]
651    fn test_tensor_type_tokens() {
652        let kinds = lex("Tensor<[Batch, 768], f32>");
653        assert_eq!(
654            kinds,
655            vec![
656                TokenKind::Tensor,
657                TokenKind::Lt,
658                TokenKind::LBracket,
659                TokenKind::Ident("Batch".into()),
660                TokenKind::Comma,
661                TokenKind::IntLit(768),
662                TokenKind::RBracket,
663                TokenKind::Comma,
664                TokenKind::F32,
665                TokenKind::Gt,
666                TokenKind::Eof,
667            ]
668        );
669    }
670
671    #[test]
672    fn test_full_model_block() {
673        let src = r#"
674            @model {
675                name: "GPT-2";
676                version: "1.0";
677            }
678        "#;
679        let kinds = lex(src);
680        assert_eq!(
681            kinds,
682            vec![
683                TokenKind::AtModel,
684                TokenKind::LBrace,
685                TokenKind::Ident("name".into()),
686                TokenKind::Colon,
687                TokenKind::StringLit("GPT-2".into()),
688                TokenKind::Semi,
689                TokenKind::Ident("version".into()),
690                TokenKind::Colon,
691                TokenKind::StringLit("1.0".into()),
692                TokenKind::Semi,
693                TokenKind::RBrace,
694                TokenKind::Eof,
695            ]
696        );
697    }
698
699    #[test]
700    fn test_negative_number_as_minus_int() {
701        // -3 is lexed as Minus + IntLit(3), NOT IntLit(-3)
702        let kinds = lex("-3");
703        assert_eq!(
704            kinds,
705            vec![TokenKind::Minus, TokenKind::IntLit(3), TokenKind::Eof,]
706        );
707    }
708
709    #[test]
710    fn test_underscore_and_question() {
711        let kinds = lex("_ ? ??");
712        assert_eq!(
713            kinds,
714            vec![
715                TokenKind::Underscore,
716                TokenKind::Question,
717                TokenKind::QuestionQuestion,
718                TokenKind::Eof,
719            ]
720        );
721    }
722
723    #[test]
724    fn test_span_tracking() {
725        let tokens = Lexer::new("ab cd").tokenize().unwrap();
726        assert_eq!(tokens[0].span.line, 1);
727        assert_eq!(tokens[0].span.col, 1);
728        assert_eq!(tokens[1].span.line, 1);
729        assert_eq!(tokens[1].span.col, 4);
730    }
731
732    #[test]
733    fn test_multiline_span() {
734        let tokens = Lexer::new("ab\ncd").tokenize().unwrap();
735        assert_eq!(tokens[0].span.line, 1);
736        assert_eq!(tokens[1].span.line, 2);
737        assert_eq!(tokens[1].span.col, 1);
738    }
739}