Introduction
Welcome to Shrew, a declarative tensor language and compiler designed for high-performance deep learning.
Shrew provides a clean, domain-specific language (DSL) for defining neural network architectures, training configurations, and inference pipelines. It decouples the model definition from the execution engine, allowing for powerful optimizations and multi-backend support (CPU, CUDA, TPU).
Key Features
- Declarative Syntax: Define what your model is, not how to execute it.
- Type Safety: Strong static typing for tensor shapes and data types.
- Graph Compilation: Automatic graph construction, optimization, and scheduling.
- Modularity: Composable blocks for models, configuration, and data pipelines.
Getting Started
Check out the DSL Guide to learn the language, or browse the Examples to see Shrew in action.
Introduction to Shrew
Shrew files usually have the extension .sw. A Shrew program consists of a series of directives that define metadata, configuration, types, and computation graphs.
File Structure
A typical .sw file structure looks like this:
// Metadata about the model
@model { ... }
// Training or Inference configuration
@config { ... }
// Type definitions (optional aliases)
@types { ... }
// Computations graphs (functions)
@graph MyGraph(...) { ... }
Comments
Shrew supports C-style comments:
- Single line:
// comment - Multi-line:
/* comment */
Tensors & Types
Shrew is a strongly-typed language designed for tensor operations. The core type is the Tensor.
Tensor Type
The syntax for a tensor is:
Tensor<[Dimensions], DataType>
Dimensions
Dimensions can be:
- Named Symbolic:
Batch,Channels(inferred at runtime) - Fixed Integer:
224,1024 - Inferred:
?(unknown rank or dimension)
Data Types (dtype)
Supported dtypes:
- Floating point:
f16,bf16,f32,f64 - Integer:
i8,i16,i32,i64,u8,u16,u32,u64 - Boolean:
bool - Complex:
complex64,complex128
Examples
// A 2D matrix of shape [32, 128] with float32 elements
Tensor<[32, 128], f32>
// A generic batch of images
Tensor<[Batch, 3, Height, Width], f32>
// A vector of 10 integers
Tensor<[10], i32>
Literals
You can define constant tensors using double brackets [[ ... ]]:
// 1D tensor
[[1, 2, 3]]
// 2D tensor
[[
[1, 0],
[0, 1]
]]
Graphs & Operations
The @graph directive defines a computation graph, similar to a function in other languages.
Anatomy of a Graph
@graph GraphName(input1: Type1) -> OutputType {
// 1. Parameter Declarations (Weights)
param w: Type2 { init: "start_value"; };
// 2. Nodes (Operations)
node x: Type3 {
op: input1 * w;
};
// 3. Output Statement
output x;
}
Nodes
Nodes represent intermediate computations. They must have an op field defining the operation.
node activation {
op: relu(x);
}
Standard Operations
Shrew supports standard element-wise arithmetic and matrix operations:
+,-,*,/matmulor@operatorpow(**)
And common neural network functions:
relu(),sigmoid(),tanh()softmax()conv2d(),maxpool2d()
Attributes
Nodes can have additional attributes passed in the block:
node conv_layer {
op: conv2d(input, weight);
padding: 1;
stride: 2;
}
Rust API Setup
To use Shrew in your Rust project, add the dependencies to your Cargo.toml.
Dependencies
[dependencies]
shrew = "0.1"
anyhow = "1.0" # Recommended for error handling
Feature Flags
cuda: Enable CUDA backend support.python: Enable Python bindings.
[dependencies]
shrew = { version = "0.1", features = ["cuda"] }
Running Inference
This guide shows how to load a model and run a forward pass using the Rust API.
Full Example
use shrew::prelude::*;
use anyhow::Result;
fn main() -> Result<()> {
// 1. Load the model from a .sw file
let model = Model::load("models/linear_regression.sw")?;
// 2. Prepare input tensor
// Shape: [Batch=1, Input=3]
let input_data = vec![1.0, 2.0, 3.0];
let input = Tensor::new(&[1, 3], input_data)
.to_device(Device::Cpu)?;
// 3. Run inference
// The Input ID "x" matches the `input x:` declaration in the .sw file
let outputs = model.forward(hashmap! {
"x" => input
})?;
// 4. Get output
let result = outputs.get("y").expect("Output 'y' not found");
println!("Result: {:?}", result);
Ok(())
}
Key Components
Model::load: Parses and compiles the.swfile.Tensor::new: Creates a tensor from a shape and a flat data vector.Device: specifices where the tensor lives (Cpu,Cuda(0), etc.).
Python API Setup
Shrew can be installed as a Python package.
Installation
From Source
Ensure you have Rust and Cargo installed, then run in the root of the repo:
pip install .
From PyPI (Coming Soon)
pip install shrew
Verification
Check that the installation was successful:
import shrew
print(shrew.__version__)
Python Usage
The Python API provides an interface similar to the Rust API but in a more Pythonic way.
Example
import shrew
import numpy as np
# 1. Load the model
model = shrew.load("models/linear_regression.sw")
# 2. Prepare input (using NumPy)
x = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)
# 3. Run inference
# Inputs are passed as a dictionary mapping naming to NumPy arrays
outputs = model.forward({"x": x})
# 4. Get result
y = outputs["y"]
print("Output shape:", y.shape)
print("Output data:", y)
The Shrew DSL
The Shrew DSL is the core interface for interacting with the Shrew compiler. It is designed to be expressive, concise, and readable.
This section covers:
Grammar Specification
The Shrew language grammar is defined below in Extended Backus-Naur Form (EBNF).
SHREW LANGUAGE GRAMMAR v0.1
TOP LEVEL
program ::= { directive | import_stmt }
directive ::= metadata_block
| config_block
| types_block
| graph_block
| custom_op_block
| training_block
| inference_block
| metrics_block
| logging_block
| visualization_block
IMPORTS
import_stmt ::= "@import" string_literal [ "as" identifier ] ";"
METADATA
metadata_block ::= "@model" "{" { metadata_field } "}"
metadata_field ::= identifier ":" literal ";"
CONFIGURATION
config_block ::= "@config" "{" { config_field } "}"
config_field ::= identifier ":" expr ";"
TYPE SYSTEM
types_block ::= "@types" "{" { type_def } "}"
type_def ::= "type" identifier "=" type_expr ";"
type_expr ::= tensor_type
| scalar_type
| tuple_type
| list_type
| dict_type
| identifier
| "?"
| integer_literal
| binary_expr
tensor_type ::= "Tensor" "<" "[" dimension_list "]" "," dtype ">"
dimension_list ::= dimension { "," dimension }
dimension ::= identifier | integer_literal | "?" | "_" | binary_expr
dtype ::= "f32" | "f64" | "f16" | "bf16"
| "i8" | "i16" | "i32" | "i64"
| "u8" | "u16" | "u32" | "u64"
| "bool" | "complex64" | "complex128"
scalar_type ::= dtype
tuple_type ::= "(" type_expr { "," type_expr } ")"
list_type ::= "[" type_expr "]"
dict_type ::= "{" identifier ":" type_expr { "," identifier ":" type_expr } "}"
GRAPH DEFINITION
graph_block ::= "@graph" identifier [ "(" param_list ")" [ "->" type_expr ] ] "{" graph_body "}"
param_list ::= param_def { "," param_def }
param_def ::= identifier ":" type_expr [ "?" ]
graph_body ::= { graph_stmt }
graph_stmt ::= input_decl
| output_decl
| param_decl
| node_decl
| assert_stmt
| check_stmt
input_decl ::= "input" identifier ":" type_expr [ "?" ] ";"
output_decl ::= "output" [ identifier ":" ] expr ";"
param_decl ::= "param" identifier ":" type_expr [ param_attrs ] ";"
param_attrs ::= "{" { param_attr } "}"
param_attr ::= "init" ":" init_expr
| "frozen" ":" bool_literal
| "device" ":" device_expr
| identifier ":" literal
init_expr ::= string_literal | expr
node_decl ::= "node" identifier [ ":" type_expr ] [ node_body ] ";"
node_body ::= "{" { node_stmt } "}"
node_stmt ::= "op" ":" operation ";"
| "input" ":" expr ";"
| "output" ":" type_expr ";"
| hint_directive
| identifier ":" expr ";"
operation ::= identifier "(" [ arg_list ] ")"
| block_operation
| call_operation
| binary_expr
| unary_expr
block_operation ::= "if" expr "{" operation "}" [ "else" "{" operation "}" ]
| "repeat" "(" expr ")" "{" operation "}"
call_operation ::= "call" qualified_identifier "(" [ arg_list ] ")"
qualified_identifier ::= identifier { "." identifier | "::" identifier }
arg_list ::= arg { "," arg }
arg ::= [ identifier ":" ] expr
assert_stmt ::= "@assert" expr [ "," string_literal ] ";"
check_stmt ::= "@check" identifier "{" { check_condition } "}"
check_condition ::= "assert" expr [ "," string_literal ] ";"
hint_directive ::= "@hint" hint_type ";"
hint_type ::= "recompute_in_backward" | "must_preserve" | "in_place" | "no_grad" | identifier
CUSTOM OPERATORS
custom_op_block ::= "@custom_op" identifier "{" { custom_op_stmt } "}"
custom_op_stmt ::= "signature" ":" signature ";"
| "impl" identifier "{" { impl_attr } "}"
| "gradient" identifier "{" gradient_body "}"
signature ::= "(" param_list ")" "->" type_expr
impl_attr ::= identifier ":" expr ";"
gradient_body ::= { gradient_stmt }
gradient_stmt ::= "impl" identifier "{" { impl_attr } "}"
| "call" operation ";"
TRAINING CONFIGURATION
training_block ::= "@training" "{" { training_stmt } "}"
training_stmt ::= "model" ":" identifier ";"
| "loss" ":" identifier ";"
| "optimizer" ":" optimizer_config ";"
| "lr_schedule" ":" schedule_config ";"
| "grad_clip" ":" clip_config ";"
| "precision" ":" string_literal ";"
| "accumulation_steps" ":" integer_literal ";"
| identifier ":" expr ";"
optimizer_config ::= "{" { optimizer_attr } "}"
optimizer_attr ::= identifier ":" expr ";"
schedule_config ::= "{" { schedule_attr } "}"
schedule_attr ::= identifier ":" expr ";"
clip_config ::= "{" { clip_attr } "}"
clip_attr ::= identifier ":" expr ";"
INFERENCE CONFIGURATION
inference_block ::= "@inference" "{" { inference_stmt } "}"
inference_stmt ::= "model" ":" identifier ";"
| "optimizations" ":" list_literal ";"
| "quantization" ":" quant_config ";"
| "generation" ":" generation_config ";"
| identifier ":" expr ";"
quant_config ::= "{" { quant_attr } "}"
quant_attr ::= identifier ":" expr ";"
generation_config ::= "{" { generation_attr } "}"
generation_attr ::= identifier ":" expr ";"
METRICS & LOGGING
metrics_block ::= "@metrics" identifier "{" { metric_def } "}"
metric_def ::= "track" identifier "{" { metric_attr } "}"
metric_attr ::= "source" ":" source_expr ";"
| "compute" ":" compute_expr ";"
| "aggregate" ":" string_literal ";"
| "type" ":" string_literal ";"
| "log_every" ":" integer_literal ";"
| identifier ":" expr ";"
source_expr ::= qualified_identifier | "[" qualified_identifier { "," qualified_identifier } "]"
compute_expr ::= expr | "{" iteration_expr "}"
iteration_expr ::= "for" identifier "in" expr "{" expr "}"
logging_block ::= "@logging" "{" { logging_stmt } "}"
logging_stmt ::= "backend" ":" string_literal ";"
| identifier ":" config_literal ";"
| "checkpoints" ":" checkpoint_config ";"
checkpoint_config ::= "{" { checkpoint_attr } "}"
checkpoint_attr ::= identifier ":" expr ";"
visualization_block ::= "@visualizations" "{" { visualization_def } "}"
visualization_def ::= "plot" identifier "{" { plot_attr } "}"
plot_attr ::= identifier ":" expr ";"
EXPRESSIONS
expr ::= binary_expr | unary_expr | primary_expr
binary_expr ::= expr binary_op expr
binary_op ::= "+" | "-" | "*" | "/" | "%" | "**"
| "==" | "!=" | "<" | ">" | "<=" | ">="
| "&&" | "||" | "??"
| "&" | "|" | "^" | "<<" | ">>"
unary_expr ::= unary_op expr
unary_op ::= "-" | "!" | "~"
primary_expr ::= identifier
| literal
| function_call
| member_access
| index_access
| range_expr
| list_expr
| dict_expr
| paren_expr
| tensor_literal
function_call ::= identifier "(" [ arg_list ] ")"
member_access ::= expr "." identifier
index_access ::= expr "[" expr [ ":" expr ] "]"
range_expr ::= "range" "(" expr [ "," expr [ "," expr ] ] ")"
list_expr ::= "[" [ expr { "," expr } ] "]"
dict_expr ::= "{" [ dict_entry { "," dict_entry } ] "}"
dict_entry ::= identifier ":" expr
paren_expr ::= "(" expr ")"
tensor_literal ::= "[[" expr { "," expr } "]]"
LITERALS
literal ::= integer_literal
| float_literal
| bool_literal
| string_literal
| list_literal
| config_literal
| "null"
integer_literal ::= digit { digit }
float_literal ::= digit { digit } "." digit { digit } [ exponent ]
| digit { digit } exponent
exponent ::= ("e" | "E") [ "+" | "-" ] digit { digit }
bool_literal ::= "true" | "false"
string_literal ::= '"' { string_char } '"'
list_literal ::= "[" [ literal { "," literal } ] "]"
config_literal ::= "{" { identifier ":" literal { "," identifier ":" literal } } "}"
DEVICE EXPRESSIONS
device_expr ::= "cpu"
| "gpu" [ ":" integer_literal ]
| "tpu" [ ":" integer_literal ]
| identifier
IDENTIFIERS
identifier ::= letter { letter | digit | "_" }
letter ::= "a" .. "z" | "A" .. "Z"
digit ::= "0" .. "9"
Examples
Simple Linear Regression
@model {
name: "LinearRegression";
version: "1.0";
}
@config {
batch_size: 32;
learning_rate: 0.01;
}
@graph Forward(x: Tensor<[Batch, In], f32>) -> Tensor<[Batch, Out], f32> {
param w: Tensor<[In, Out], f32> { init: "normal(0, 0.02)"; };
param b: Tensor<[Out], f32> { init: "zeros"; };
node y: Tensor<[Batch, Out], f32> {
op: x * w + b;
};
output y;
}
Matrix Multiplication
@graph MatMul(A: Tensor<[M, K], f32>, B: Tensor<[K, N], f32>) {
node C {
op: A @ B;
};
output C;
}
Architecture Overview
Shrew is built as a modular set of Rust crates. The architecture is designed to separate the frontend language (DSL) from the backend execution engines, with a middle layer for Intermediate Representation (IR) and optimization.
High-Level Flow
- Frontend: The
shrew-ircrate parses the.swsource code into an Abstract Syntax Tree (AST) and then lowers it into a High-Level IR (Graph). - Optimization: The
shrew-optimcrate applies graph transformations (e.g., constant folding, operator fusion) to the IR. - Execution: The
shrew-corecrate orchestrates execution, dispatching tensor operations to specific backends (shrew-cpu,shrew-cuda).
Diagram
graph TD
Source[.sw File] --> Parser[shrew-ir: Parser]
Parser --> AST[AST]
AST --> Lower[shrew-ir: Lowering]
Lower --> IR[Graph IR]
IR --> Optim[shrew-optim]
Optim --> Exec[shrew-core: Executor]
Exec --> CPU[shrew-cpu]
Exec --> CUDA[shrew-cuda]
Crate Structure
The project is organized as a Cargo workspace with the following members:
shrew: The top-level crate that re-exports functionality and provides the main entry point (CLI viashrew-cli).shrew-core: The runtime core. Handles tensor storage, device management, and graph execution APIs.shrew-ir: Contains the parser, AST definitions, and the Intermediate Representation (IR) logic.shrew-optim: Optimization passes for the IR (graph rewriting, fusion).shrew-nn: Implementation of neural network layers and common operators.shrew-data: Data loading and preprocessing utilities.shrew-cpu: CPU backend implementation (using Rayon and SIMD where available).shrew-cuda: CUDA backend implementation (interfaces with cuBLAS, cuDNN).shrew-cli: Command-line interface tool.shrew-python: Python bindings (PyO3) for using Shrew from Python.
IR & Lowering
The Intermediate Representation (IR) is a directed acyclic graph (DAG) where nodes represent operations and edges represent data dependencies (tensors).
Lowering Process
The lowering phase converts the AST (which mirrors the syntax tree) into the Graph IR. Key steps include:
- Symbol Resolution: Linking identifiers to their definitions.
- Type Checking: Verifying tensor shapes and data types.
- Graph Construction: Building the node connectivity.
The IR is defined in crates/shrew-ir/src/graph.rs.
Execution Engine
The execution engine in shrew-core takes an optimized graph and runs it.
Tensors & Storage
Tensors are backed by a storage enum that can hold data on different devices:
#![allow(unused)]
fn main() {
pub enum Storage {
Cpu(Vec<f32>),
Cuda(CudaSlice<f32>),
// ...
}
}
Backend Dispatch
Operations are dispatched dynamically based on the tensor’s device. The Backend trait defines the interface for all supported operations (matmul, add, relu, etc.).