shrew/
main.rs

1// shrew CLI — Command-line runner for .sw deep learning programs
2//
3// USAGE:
4//   shrew dump model.sw         # Print the lowered IR graph
5//   shrew validate model.sw     # Validate a .sw program
6//   shrew bench model.sw        # Benchmark forward pass
7//   shrew info model.sw         # Show model summary (params, ops, shapes)
8//
9// OPTIONS:
10//   --batch N      Set batch dimension (default: 1)
11//   --dtype f32|f64|f16  Set default dtype (default: f32)
12//   --verbose      Print detailed execution info
13
14use std::env;
15use std::fs;
16use std::process;
17use std::time::Instant;
18
19fn main() {
20    let args: Vec<String> = env::args().collect();
21
22    if args.len() < 2 {
23        print_usage();
24        process::exit(1);
25    }
26
27    let command = args[1].as_str();
28
29    match command {
30        "dump" | "validate" | "bench" | "info" => {
31            if args.len() < 3 {
32                eprintln!("Error: missing .sw file path");
33                eprintln!("Usage: shrew {command} <file.sw> [options]");
34                process::exit(1);
35            }
36            let file_path = &args[2];
37            let opts = parse_options(&args[3..]);
38
39            match run_command(command, file_path, &opts) {
40                Ok(()) => {}
41                Err(e) => {
42                    eprintln!("Error: {e}");
43                    process::exit(1);
44                }
45            }
46        }
47        "--help" | "-h" | "help" => {
48            print_usage();
49        }
50        "--version" | "-V" | "version" => {
51            println!("shrew {}", env!("CARGO_PKG_VERSION"));
52        }
53        other => {
54            eprintln!("Unknown command: {other}");
55            print_usage();
56            process::exit(1);
57        }
58    }
59}
60
61// Options
62
63struct CliOptions {
64    batch_size: usize,
65    dtype: String,
66    verbose: bool,
67}
68
69fn parse_options(args: &[String]) -> CliOptions {
70    let mut opts = CliOptions {
71        batch_size: 1,
72        dtype: "f32".to_string(),
73        verbose: false,
74    };
75
76    let mut i = 0;
77    while i < args.len() {
78        match args[i].as_str() {
79            "--batch" => {
80                i += 1;
81                if i < args.len() {
82                    opts.batch_size = args[i].parse().unwrap_or(1);
83                }
84            }
85            "--dtype" => {
86                i += 1;
87                if i < args.len() {
88                    opts.dtype = args[i].clone();
89                }
90            }
91            "--verbose" | "-v" => {
92                opts.verbose = true;
93            }
94            other => {
95                eprintln!("Warning: unknown option '{other}'");
96            }
97        }
98        i += 1;
99    }
100
101    opts
102}
103
104// Command dispatch
105
106fn run_command(command: &str, file_path: &str, opts: &CliOptions) -> Result<(), String> {
107    let source =
108        fs::read_to_string(file_path).map_err(|e| format!("Cannot read '{file_path}': {e}"))?;
109
110    match command {
111        "dump" => cmd_dump(&source, file_path, opts),
112        "validate" => cmd_validate(&source, file_path),
113        "bench" => cmd_bench(&source, file_path, opts),
114        "info" => cmd_info(&source, file_path, opts),
115        _ => Err(format!("Unknown command: {command}")),
116    }
117}
118
119// dump — Print the lowered IR graph
120
121fn cmd_dump(source: &str, file_path: &str, opts: &CliOptions) -> Result<(), String> {
122    let ast = shrew_ir::parse(source).map_err(|e| format!("Parse error: {e}"))?;
123    let mut ir = shrew_ir::lower(&ast).map_err(|e| format!("Lowering error: {e}"))?;
124
125    // Validate (but don't fail, just warn)
126    if let Err(errors) = shrew_ir::validate(&ir) {
127        for e in &errors {
128            eprintln!("Warning: {e}");
129        }
130    }
131
132    shrew_ir::infer_shapes(&mut ir);
133
134    println!("=== IR Dump: {file_path} ===");
135    println!();
136
137    for graph in &ir.graphs {
138        println!("graph {} {{", graph.name);
139
140        // Inputs
141        for input_id in &graph.inputs {
142            let node = &graph.nodes[input_id.0];
143            println!("  input {:10} : {:?}", node.name, node.output_type);
144        }
145
146        // Nodes
147        for node in &graph.nodes {
148            let inputs_str: Vec<&str> = node
149                .inputs
150                .iter()
151                .map(|id| graph.nodes[id.0].name.as_str())
152                .collect();
153            println!(
154                "  {:10} = {:?}({})",
155                node.name,
156                node.op,
157                inputs_str.join(", ")
158            );
159        }
160
161        // Outputs
162        for output in &graph.outputs {
163            let node = &graph.nodes[output.node_id.0];
164            println!("  output {:10} : {:?}", output.name, node.output_type);
165        }
166
167        println!("}}");
168        println!();
169    }
170
171    // Show optimization stats if verbose
172    if opts.verbose {
173        let stats = shrew_ir::optimize::optimize_graph_with_stats(&mut ir.graphs[0]);
174        println!("Optimization stats:");
175        println!("  Dead code removed:     {}", stats.dead_code_removed);
176        println!("  Identities removed:    {}", stats.identities_removed);
177        println!("  Constants folded:      {}", stats.constants_folded);
178        println!("  CSE eliminated:        {}", stats.cse_eliminated);
179        println!("  Operators fused:       {}", stats.ops_fused);
180    }
181
182    Ok(())
183}
184
185// validate — Check a .sw program for errors
186
187fn cmd_validate(source: &str, file_path: &str) -> Result<(), String> {
188    println!("=== Validating: {file_path} ===");
189
190    // Step 1: Parse
191    let ast = match shrew_ir::parse(source) {
192        Ok(a) => {
193            println!("  [OK] Parse: {} item(s) found", a.items.len());
194            a
195        }
196        Err(e) => {
197            println!("  [FAIL] Parse error: {e}");
198            return Err("Validation failed at parse stage".to_string());
199        }
200    };
201
202    // Step 2: Lower
203    let mut ir = match shrew_ir::lower(&ast) {
204        Ok(ir) => {
205            println!(
206                "  [OK] Lower: {} graph(s), {} total nodes",
207                ir.graphs.len(),
208                ir.graphs.iter().map(|g| g.nodes.len()).sum::<usize>()
209            );
210            ir
211        }
212        Err(e) => {
213            println!("  [FAIL] Lowering error: {e}");
214            return Err("Validation failed at lowering stage".to_string());
215        }
216    };
217
218    // Step 3: Validate
219    match shrew_ir::validate(&ir) {
220        Ok(()) => println!("  [OK] Validate: no errors"),
221        Err(errors) => {
222            println!("  [WARN] {} validation error(s):", errors.len());
223            for e in &errors {
224                println!("         - {e}");
225            }
226        }
227    }
228
229    // Step 4: Shape inference
230    shrew_ir::infer_shapes(&mut ir);
231    let shaped = ir
232        .graphs
233        .iter()
234        .flat_map(|g| g.nodes.iter())
235        .filter(|n| !matches!(n.output_type, shrew_ir::graph::IrType::Unknown))
236        .count();
237    let total = ir.graphs.iter().map(|g| g.nodes.len()).sum::<usize>();
238    println!("  [OK] Shapes: {shaped}/{total} nodes have resolved types");
239
240    // Step 5: Optimize
241    let removed = shrew_ir::optimize(&mut ir);
242    println!("  [OK] Optimize: {removed} redundant ops removed");
243
244    println!();
245    println!("Validation passed!");
246    Ok(())
247}
248
249// bench — Benchmark forward pass
250
251fn cmd_bench(source: &str, file_path: &str, opts: &CliOptions) -> Result<(), String> {
252    use shrew::prelude::*;
253
254    let dtype = parse_dtype(&opts.dtype)?;
255    let config = RuntimeConfig::default()
256        .set_dim("batch", opts.batch_size)
257        .set_dim("Batch", opts.batch_size)
258        .with_dtype(dtype)
259        .with_training(false);
260
261    let exec = shrew::exec::load_program::<CpuBackend>(source, CpuDevice, config.clone())
262        .map_err(|e| format!("Load error: {e}"))?;
263
264    let graph_names: Vec<String> = exec
265        .program()
266        .graphs
267        .iter()
268        .map(|g| g.name.clone())
269        .collect();
270    if graph_names.is_empty() {
271        return Err("No graphs found in program".to_string());
272    }
273
274    println!("=== Benchmark: {file_path} ===");
275    println!("Batch: {}, DType: {:?}", opts.batch_size, dtype);
276    println!();
277
278    let warmup = 3;
279    let iterations = 10;
280
281    for gname in &graph_names {
282        // Generate synthetic inputs
283        let graph = exec
284            .program()
285            .graphs
286            .iter()
287            .find(|g| g.name == *gname)
288            .ok_or_else(|| format!("Graph '{gname}' not found"))?;
289
290        let mut inputs = std::collections::HashMap::new();
291        for &input_id in &graph.inputs {
292            let node = &graph.nodes[input_id.0];
293            if let shrew_ir::graph::IrType::Tensor {
294                shape,
295                dtype: ir_dt,
296            } = &node.output_type
297            {
298                let dims: Vec<usize> = shape
299                    .iter()
300                    .map(|d| match d {
301                        shrew_ir::graph::Dim::Fixed(n) => *n as usize,
302                        shrew_ir::graph::Dim::Symbolic(s) => config
303                            .dims
304                            .get(s.as_str())
305                            .copied()
306                            .unwrap_or(opts.batch_size),
307                        shrew_ir::graph::Dim::Dynamic => opts.batch_size,
308                    })
309                    .collect();
310                let core_dt = match ir_dt {
311                    shrew_ir::graph::DType::F32 => shrew_core::DType::F32,
312                    shrew_ir::graph::DType::F64 => shrew_core::DType::F64,
313                    _ => dtype,
314                };
315                let tensor = CpuTensor::rand(shrew_core::Shape::new(dims), core_dt, &CpuDevice)
316                    .map_err(|e| format!("Failed to create input '{}': {e}", node.name))?;
317                inputs.insert(node.name.clone(), tensor);
318            }
319        }
320
321        // Warmup
322        for _ in 0..warmup {
323            let _ = exec.run(gname, &inputs);
324        }
325
326        // Timed runs
327        let mut times = Vec::with_capacity(iterations);
328        for _ in 0..iterations {
329            let t0 = Instant::now();
330            let _ = exec.run(gname, &inputs);
331            times.push(t0.elapsed());
332        }
333
334        let total_ms: f64 = times.iter().map(|t| t.as_secs_f64() * 1000.0).sum();
335        let avg_ms = total_ms / iterations as f64;
336        let min_ms = times
337            .iter()
338            .map(|t| t.as_secs_f64() * 1000.0)
339            .fold(f64::INFINITY, f64::min);
340        let max_ms = times
341            .iter()
342            .map(|t| t.as_secs_f64() * 1000.0)
343            .fold(0.0f64, f64::max);
344
345        println!("Graph: {gname}");
346        println!("  Iterations: {iterations} (+ {warmup} warmup)");
347        println!("  Avg:  {avg_ms:.3} ms");
348        println!("  Min:  {min_ms:.3} ms");
349        println!("  Max:  {max_ms:.3} ms");
350        println!(
351            "  Throughput: {:.1} samples/sec",
352            opts.batch_size as f64 / (avg_ms / 1000.0)
353        );
354        println!();
355    }
356
357    Ok(())
358}
359
360// info — Show model summary
361
362fn cmd_info(source: &str, file_path: &str, _opts: &CliOptions) -> Result<(), String> {
363    let ast = shrew_ir::parse(source).map_err(|e| format!("Parse error: {e}"))?;
364    let mut ir = shrew_ir::lower(&ast).map_err(|e| format!("Lowering error: {e}"))?;
365
366    shrew_ir::infer_shapes(&mut ir);
367
368    println!("=== Model Info: {file_path} ===");
369    println!();
370
371    for graph in &ir.graphs {
372        println!("Graph: {}", graph.name);
373        println!("  Inputs:  {}", graph.inputs.len());
374        println!("  Outputs: {}", graph.outputs.len());
375        println!("  Nodes:   {}", graph.nodes.len());
376
377        // Count op types
378        let mut op_counts: std::collections::HashMap<String, usize> =
379            std::collections::HashMap::new();
380        for node in &graph.nodes {
381            *op_counts.entry(format!("{:?}", node.op)).or_insert(0) += 1;
382        }
383
384        println!("  Operations:");
385        let mut sorted: Vec<_> = op_counts.into_iter().collect();
386        sorted.sort_by(|a, b| b.1.cmp(&a.1));
387        for (op, count) in &sorted {
388            println!("    {op}: {count}");
389        }
390
391        // Count parameters from graph.params
392        let param_count = graph.params.len();
393        println!("  Parameters: {param_count}");
394
395        println!();
396    }
397
398    // Training block info
399    if let Some(ref t) = ir.training {
400        println!("Training config:");
401        println!("  Optimizer: {}", t.optimizer.kind);
402        println!("  LR:        {}", t.optimizer.lr);
403        println!("  Loss:      {}", t.loss);
404        println!("  Epochs:    {}", t.epochs);
405        println!("  Batch:     {}", t.batch_size);
406    } else {
407        println!("No @training block.");
408    }
409
410    Ok(())
411}
412
413// Helpers
414
415fn parse_dtype(s: &str) -> Result<shrew_core::DType, String> {
416    match s.to_lowercase().as_str() {
417        "f32" | "float32" => Ok(shrew_core::DType::F32),
418        "f64" | "float64" => Ok(shrew_core::DType::F64),
419        "f16" | "float16" => Ok(shrew_core::DType::F16),
420        "bf16" | "bfloat16" => Ok(shrew_core::DType::BF16),
421        "u8" | "uint8" => Ok(shrew_core::DType::U8),
422        "u32" | "uint32" => Ok(shrew_core::DType::U32),
423        "i64" | "int64" => Ok(shrew_core::DType::I64),
424        _ => Err(format!("Unknown dtype: {s}")),
425    }
426}
427
428fn print_usage() {
429    println!("Shrew — Deep Learning CLI");
430    println!();
431    println!("USAGE:");
432    println!("  shrew <command> <file.sw> [options]");
433    println!();
434    println!("COMMANDS:");
435    println!("  dump       Print the lowered IR graph");
436    println!("  validate   Check a .sw program for errors");
437    println!("  bench      Benchmark forward pass performance");
438    println!("  info       Show model summary (params, ops, shapes)");
439    println!("  version    Print version");
440    println!("  help       Show this help");
441    println!();
442    println!("OPTIONS:");
443    println!("  --batch N        Set batch dimension (default: 1)");
444    println!("  --dtype <type>   Set default dtype: f32, f64, f16 (default: f32)");
445    println!("  --verbose, -v    Print detailed output");
446}