1use std::env;
15use std::fs;
16use std::process;
17use std::time::Instant;
18
19fn main() {
20 let args: Vec<String> = env::args().collect();
21
22 if args.len() < 2 {
23 print_usage();
24 process::exit(1);
25 }
26
27 let command = args[1].as_str();
28
29 match command {
30 "dump" | "validate" | "bench" | "info" => {
31 if args.len() < 3 {
32 eprintln!("Error: missing .sw file path");
33 eprintln!("Usage: shrew {command} <file.sw> [options]");
34 process::exit(1);
35 }
36 let file_path = &args[2];
37 let opts = parse_options(&args[3..]);
38
39 match run_command(command, file_path, &opts) {
40 Ok(()) => {}
41 Err(e) => {
42 eprintln!("Error: {e}");
43 process::exit(1);
44 }
45 }
46 }
47 "--help" | "-h" | "help" => {
48 print_usage();
49 }
50 "--version" | "-V" | "version" => {
51 println!("shrew {}", env!("CARGO_PKG_VERSION"));
52 }
53 other => {
54 eprintln!("Unknown command: {other}");
55 print_usage();
56 process::exit(1);
57 }
58 }
59}
60
61struct CliOptions {
64 batch_size: usize,
65 dtype: String,
66 verbose: bool,
67}
68
69fn parse_options(args: &[String]) -> CliOptions {
70 let mut opts = CliOptions {
71 batch_size: 1,
72 dtype: "f32".to_string(),
73 verbose: false,
74 };
75
76 let mut i = 0;
77 while i < args.len() {
78 match args[i].as_str() {
79 "--batch" => {
80 i += 1;
81 if i < args.len() {
82 opts.batch_size = args[i].parse().unwrap_or(1);
83 }
84 }
85 "--dtype" => {
86 i += 1;
87 if i < args.len() {
88 opts.dtype = args[i].clone();
89 }
90 }
91 "--verbose" | "-v" => {
92 opts.verbose = true;
93 }
94 other => {
95 eprintln!("Warning: unknown option '{other}'");
96 }
97 }
98 i += 1;
99 }
100
101 opts
102}
103
104fn run_command(command: &str, file_path: &str, opts: &CliOptions) -> Result<(), String> {
107 let source =
108 fs::read_to_string(file_path).map_err(|e| format!("Cannot read '{file_path}': {e}"))?;
109
110 match command {
111 "dump" => cmd_dump(&source, file_path, opts),
112 "validate" => cmd_validate(&source, file_path),
113 "bench" => cmd_bench(&source, file_path, opts),
114 "info" => cmd_info(&source, file_path, opts),
115 _ => Err(format!("Unknown command: {command}")),
116 }
117}
118
119fn cmd_dump(source: &str, file_path: &str, opts: &CliOptions) -> Result<(), String> {
122 let ast = shrew_ir::parse(source).map_err(|e| format!("Parse error: {e}"))?;
123 let mut ir = shrew_ir::lower(&ast).map_err(|e| format!("Lowering error: {e}"))?;
124
125 if let Err(errors) = shrew_ir::validate(&ir) {
127 for e in &errors {
128 eprintln!("Warning: {e}");
129 }
130 }
131
132 shrew_ir::infer_shapes(&mut ir);
133
134 println!("=== IR Dump: {file_path} ===");
135 println!();
136
137 for graph in &ir.graphs {
138 println!("graph {} {{", graph.name);
139
140 for input_id in &graph.inputs {
142 let node = &graph.nodes[input_id.0];
143 println!(" input {:10} : {:?}", node.name, node.output_type);
144 }
145
146 for node in &graph.nodes {
148 let inputs_str: Vec<&str> = node
149 .inputs
150 .iter()
151 .map(|id| graph.nodes[id.0].name.as_str())
152 .collect();
153 println!(
154 " {:10} = {:?}({})",
155 node.name,
156 node.op,
157 inputs_str.join(", ")
158 );
159 }
160
161 for output in &graph.outputs {
163 let node = &graph.nodes[output.node_id.0];
164 println!(" output {:10} : {:?}", output.name, node.output_type);
165 }
166
167 println!("}}");
168 println!();
169 }
170
171 if opts.verbose {
173 let stats = shrew_ir::optimize::optimize_graph_with_stats(&mut ir.graphs[0]);
174 println!("Optimization stats:");
175 println!(" Dead code removed: {}", stats.dead_code_removed);
176 println!(" Identities removed: {}", stats.identities_removed);
177 println!(" Constants folded: {}", stats.constants_folded);
178 println!(" CSE eliminated: {}", stats.cse_eliminated);
179 println!(" Operators fused: {}", stats.ops_fused);
180 }
181
182 Ok(())
183}
184
185fn cmd_validate(source: &str, file_path: &str) -> Result<(), String> {
188 println!("=== Validating: {file_path} ===");
189
190 let ast = match shrew_ir::parse(source) {
192 Ok(a) => {
193 println!(" [OK] Parse: {} item(s) found", a.items.len());
194 a
195 }
196 Err(e) => {
197 println!(" [FAIL] Parse error: {e}");
198 return Err("Validation failed at parse stage".to_string());
199 }
200 };
201
202 let mut ir = match shrew_ir::lower(&ast) {
204 Ok(ir) => {
205 println!(
206 " [OK] Lower: {} graph(s), {} total nodes",
207 ir.graphs.len(),
208 ir.graphs.iter().map(|g| g.nodes.len()).sum::<usize>()
209 );
210 ir
211 }
212 Err(e) => {
213 println!(" [FAIL] Lowering error: {e}");
214 return Err("Validation failed at lowering stage".to_string());
215 }
216 };
217
218 match shrew_ir::validate(&ir) {
220 Ok(()) => println!(" [OK] Validate: no errors"),
221 Err(errors) => {
222 println!(" [WARN] {} validation error(s):", errors.len());
223 for e in &errors {
224 println!(" - {e}");
225 }
226 }
227 }
228
229 shrew_ir::infer_shapes(&mut ir);
231 let shaped = ir
232 .graphs
233 .iter()
234 .flat_map(|g| g.nodes.iter())
235 .filter(|n| !matches!(n.output_type, shrew_ir::graph::IrType::Unknown))
236 .count();
237 let total = ir.graphs.iter().map(|g| g.nodes.len()).sum::<usize>();
238 println!(" [OK] Shapes: {shaped}/{total} nodes have resolved types");
239
240 let removed = shrew_ir::optimize(&mut ir);
242 println!(" [OK] Optimize: {removed} redundant ops removed");
243
244 println!();
245 println!("Validation passed!");
246 Ok(())
247}
248
249fn cmd_bench(source: &str, file_path: &str, opts: &CliOptions) -> Result<(), String> {
252 use shrew::prelude::*;
253
254 let dtype = parse_dtype(&opts.dtype)?;
255 let config = RuntimeConfig::default()
256 .set_dim("batch", opts.batch_size)
257 .set_dim("Batch", opts.batch_size)
258 .with_dtype(dtype)
259 .with_training(false);
260
261 let exec = shrew::exec::load_program::<CpuBackend>(source, CpuDevice, config.clone())
262 .map_err(|e| format!("Load error: {e}"))?;
263
264 let graph_names: Vec<String> = exec
265 .program()
266 .graphs
267 .iter()
268 .map(|g| g.name.clone())
269 .collect();
270 if graph_names.is_empty() {
271 return Err("No graphs found in program".to_string());
272 }
273
274 println!("=== Benchmark: {file_path} ===");
275 println!("Batch: {}, DType: {:?}", opts.batch_size, dtype);
276 println!();
277
278 let warmup = 3;
279 let iterations = 10;
280
281 for gname in &graph_names {
282 let graph = exec
284 .program()
285 .graphs
286 .iter()
287 .find(|g| g.name == *gname)
288 .ok_or_else(|| format!("Graph '{gname}' not found"))?;
289
290 let mut inputs = std::collections::HashMap::new();
291 for &input_id in &graph.inputs {
292 let node = &graph.nodes[input_id.0];
293 if let shrew_ir::graph::IrType::Tensor {
294 shape,
295 dtype: ir_dt,
296 } = &node.output_type
297 {
298 let dims: Vec<usize> = shape
299 .iter()
300 .map(|d| match d {
301 shrew_ir::graph::Dim::Fixed(n) => *n as usize,
302 shrew_ir::graph::Dim::Symbolic(s) => config
303 .dims
304 .get(s.as_str())
305 .copied()
306 .unwrap_or(opts.batch_size),
307 shrew_ir::graph::Dim::Dynamic => opts.batch_size,
308 })
309 .collect();
310 let core_dt = match ir_dt {
311 shrew_ir::graph::DType::F32 => shrew_core::DType::F32,
312 shrew_ir::graph::DType::F64 => shrew_core::DType::F64,
313 _ => dtype,
314 };
315 let tensor = CpuTensor::rand(shrew_core::Shape::new(dims), core_dt, &CpuDevice)
316 .map_err(|e| format!("Failed to create input '{}': {e}", node.name))?;
317 inputs.insert(node.name.clone(), tensor);
318 }
319 }
320
321 for _ in 0..warmup {
323 let _ = exec.run(gname, &inputs);
324 }
325
326 let mut times = Vec::with_capacity(iterations);
328 for _ in 0..iterations {
329 let t0 = Instant::now();
330 let _ = exec.run(gname, &inputs);
331 times.push(t0.elapsed());
332 }
333
334 let total_ms: f64 = times.iter().map(|t| t.as_secs_f64() * 1000.0).sum();
335 let avg_ms = total_ms / iterations as f64;
336 let min_ms = times
337 .iter()
338 .map(|t| t.as_secs_f64() * 1000.0)
339 .fold(f64::INFINITY, f64::min);
340 let max_ms = times
341 .iter()
342 .map(|t| t.as_secs_f64() * 1000.0)
343 .fold(0.0f64, f64::max);
344
345 println!("Graph: {gname}");
346 println!(" Iterations: {iterations} (+ {warmup} warmup)");
347 println!(" Avg: {avg_ms:.3} ms");
348 println!(" Min: {min_ms:.3} ms");
349 println!(" Max: {max_ms:.3} ms");
350 println!(
351 " Throughput: {:.1} samples/sec",
352 opts.batch_size as f64 / (avg_ms / 1000.0)
353 );
354 println!();
355 }
356
357 Ok(())
358}
359
360fn cmd_info(source: &str, file_path: &str, _opts: &CliOptions) -> Result<(), String> {
363 let ast = shrew_ir::parse(source).map_err(|e| format!("Parse error: {e}"))?;
364 let mut ir = shrew_ir::lower(&ast).map_err(|e| format!("Lowering error: {e}"))?;
365
366 shrew_ir::infer_shapes(&mut ir);
367
368 println!("=== Model Info: {file_path} ===");
369 println!();
370
371 for graph in &ir.graphs {
372 println!("Graph: {}", graph.name);
373 println!(" Inputs: {}", graph.inputs.len());
374 println!(" Outputs: {}", graph.outputs.len());
375 println!(" Nodes: {}", graph.nodes.len());
376
377 let mut op_counts: std::collections::HashMap<String, usize> =
379 std::collections::HashMap::new();
380 for node in &graph.nodes {
381 *op_counts.entry(format!("{:?}", node.op)).or_insert(0) += 1;
382 }
383
384 println!(" Operations:");
385 let mut sorted: Vec<_> = op_counts.into_iter().collect();
386 sorted.sort_by(|a, b| b.1.cmp(&a.1));
387 for (op, count) in &sorted {
388 println!(" {op}: {count}");
389 }
390
391 let param_count = graph.params.len();
393 println!(" Parameters: {param_count}");
394
395 println!();
396 }
397
398 if let Some(ref t) = ir.training {
400 println!("Training config:");
401 println!(" Optimizer: {}", t.optimizer.kind);
402 println!(" LR: {}", t.optimizer.lr);
403 println!(" Loss: {}", t.loss);
404 println!(" Epochs: {}", t.epochs);
405 println!(" Batch: {}", t.batch_size);
406 } else {
407 println!("No @training block.");
408 }
409
410 Ok(())
411}
412
413fn parse_dtype(s: &str) -> Result<shrew_core::DType, String> {
416 match s.to_lowercase().as_str() {
417 "f32" | "float32" => Ok(shrew_core::DType::F32),
418 "f64" | "float64" => Ok(shrew_core::DType::F64),
419 "f16" | "float16" => Ok(shrew_core::DType::F16),
420 "bf16" | "bfloat16" => Ok(shrew_core::DType::BF16),
421 "u8" | "uint8" => Ok(shrew_core::DType::U8),
422 "u32" | "uint32" => Ok(shrew_core::DType::U32),
423 "i64" | "int64" => Ok(shrew_core::DType::I64),
424 _ => Err(format!("Unknown dtype: {s}")),
425 }
426}
427
428fn print_usage() {
429 println!("Shrew — Deep Learning CLI");
430 println!();
431 println!("USAGE:");
432 println!(" shrew <command> <file.sw> [options]");
433 println!();
434 println!("COMMANDS:");
435 println!(" dump Print the lowered IR graph");
436 println!(" validate Check a .sw program for errors");
437 println!(" bench Benchmark forward pass performance");
438 println!(" info Show model summary (params, ops, shapes)");
439 println!(" version Print version");
440 println!(" help Show this help");
441 println!();
442 println!("OPTIONS:");
443 println!(" --batch N Set batch dimension (default: 1)");
444 println!(" --dtype <type> Set default dtype: f32, f64, f16 (default: f32)");
445 println!(" --verbose, -v Print detailed output");
446}