Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Introduction

Welcome to Shrew, a declarative tensor language and compiler designed for high-performance deep learning.

Shrew provides a clean, domain-specific language (DSL) for defining neural network architectures, training configurations, and inference pipelines. It decouples the model definition from the execution engine, allowing for powerful optimizations and multi-backend support (CPU, CUDA, TPU).

Key Features

  • Declarative Syntax: Define what your model is, not how to execute it.
  • Type Safety: Strong static typing for tensor shapes and data types.
  • Graph Compilation: Automatic graph construction, optimization, and scheduling.
  • Modularity: Composable blocks for models, configuration, and data pipelines.

Getting Started

Check out the DSL Guide to learn the language, or browse the Examples to see Shrew in action.

Introduction to Shrew

Shrew files usually have the extension .sw. A Shrew program consists of a series of directives that define metadata, configuration, types, and computation graphs.

File Structure

A typical .sw file structure looks like this:

// Metadata about the model
@model { ... }

// Training or Inference configuration
@config { ... }

// Type definitions (optional aliases)
@types { ... }

// Computations graphs (functions)
@graph MyGraph(...) { ... }

Comments

Shrew supports C-style comments:

  • Single line: // comment
  • Multi-line: /* comment */

Tensors & Types

Shrew is a strongly-typed language designed for tensor operations. The core type is the Tensor.

Tensor Type

The syntax for a tensor is: Tensor<[Dimensions], DataType>

Dimensions

Dimensions can be:

  • Named Symbolic: Batch, Channels (inferred at runtime)
  • Fixed Integer: 224, 1024
  • Inferred: ? (unknown rank or dimension)

Data Types (dtype)

Supported dtypes:

  • Floating point: f16, bf16, f32, f64
  • Integer: i8, i16, i32, i64, u8, u16, u32, u64
  • Boolean: bool
  • Complex: complex64, complex128

Examples

// A 2D matrix of shape [32, 128] with float32 elements
Tensor<[32, 128], f32>

// A generic batch of images
Tensor<[Batch, 3, Height, Width], f32>

// A vector of 10 integers
Tensor<[10], i32>

Literals

You can define constant tensors using double brackets [[ ... ]]:

// 1D tensor
[[1, 2, 3]]

// 2D tensor
[[
    [1, 0],
    [0, 1]
]]

Graphs & Operations

The @graph directive defines a computation graph, similar to a function in other languages.

Anatomy of a Graph

@graph GraphName(input1: Type1) -> OutputType {
    // 1. Parameter Declarations (Weights)
    param w: Type2 { init: "start_value"; };

    // 2. Nodes (Operations)
    node x: Type3 {
        op: input1 * w;
    };

    // 3. Output Statement
    output x;
}

Nodes

Nodes represent intermediate computations. They must have an op field defining the operation.

node activation {
    op: relu(x);
}

Standard Operations

Shrew supports standard element-wise arithmetic and matrix operations:

  • +, -, *, /
  • matmul or @ operator
  • pow (**)

And common neural network functions:

  • relu(), sigmoid(), tanh()
  • softmax()
  • conv2d(), maxpool2d()

Attributes

Nodes can have additional attributes passed in the block:

node conv_layer {
    op: conv2d(input, weight);
    padding: 1;
    stride: 2;
}

Rust API Setup

To use Shrew in your Rust project, add the dependencies to your Cargo.toml.

Dependencies

[dependencies]
shrew = "0.1"
anyhow = "1.0" # Recommended for error handling

Feature Flags

  • cuda: Enable CUDA backend support.
  • python: Enable Python bindings.
[dependencies]
shrew = { version = "0.1", features = ["cuda"] }

Running Inference

This guide shows how to load a model and run a forward pass using the Rust API.

Full Example

use shrew::prelude::*;
use anyhow::Result;

fn main() -> Result<()> {
    // 1. Load the model from a .sw file
    let model = Model::load("models/linear_regression.sw")?;

    // 2. Prepare input tensor
    // Shape: [Batch=1, Input=3]
    let input_data = vec![1.0, 2.0, 3.0];
    let input = Tensor::new(&[1, 3], input_data)
        .to_device(Device::Cpu)?;

    // 3. Run inference
    // The Input ID "x" matches the `input x:` declaration in the .sw file
    let outputs = model.forward(hashmap! {
        "x" => input
    })?;

    // 4. Get output
    let result = outputs.get("y").expect("Output 'y' not found");
    println!("Result: {:?}", result);

    Ok(())
}

Key Components

  • Model::load: Parses and compiles the .sw file.
  • Tensor::new: Creates a tensor from a shape and a flat data vector.
  • Device: specifices where the tensor lives (Cpu, Cuda(0), etc.).

Python API Setup

Shrew can be installed as a Python package.

Installation

From Source

Ensure you have Rust and Cargo installed, then run in the root of the repo:

pip install .

From PyPI (Coming Soon)

pip install shrew

Verification

Check that the installation was successful:

import shrew
print(shrew.__version__)

Python Usage

The Python API provides an interface similar to the Rust API but in a more Pythonic way.

Example

import shrew
import numpy as np

# 1. Load the model
model = shrew.load("models/linear_regression.sw")

# 2. Prepare input (using NumPy)
x = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)

# 3. Run inference
# Inputs are passed as a dictionary mapping naming to NumPy arrays
outputs = model.forward({"x": x})

# 4. Get result
y = outputs["y"]
print("Output shape:", y.shape)
print("Output data:", y)

The Shrew DSL

The Shrew DSL is the core interface for interacting with the Shrew compiler. It is designed to be expressive, concise, and readable.

This section covers:

Grammar Specification

The Shrew language grammar is defined below in Extended Backus-Naur Form (EBNF).

SHREW LANGUAGE GRAMMAR v0.1

TOP LEVEL
program ::= { directive | import_stmt }

directive ::= metadata_block
            | config_block
            | types_block
            | graph_block
            | custom_op_block
            | training_block
            | inference_block
            | metrics_block
            | logging_block
            | visualization_block

IMPORTS
import_stmt ::= "@import" string_literal [ "as" identifier ] ";"

METADATA
metadata_block ::= "@model" "{" { metadata_field } "}"
metadata_field ::= identifier ":" literal ";"

CONFIGURATION
config_block ::= "@config" "{" { config_field } "}"
config_field ::= identifier ":" expr ";"

TYPE SYSTEM
types_block ::= "@types" "{" { type_def } "}"
type_def    ::= "type" identifier "=" type_expr ";"

type_expr ::= tensor_type
            | scalar_type
            | tuple_type
            | list_type
            | dict_type
            | identifier
            | "?"
            | integer_literal
            | binary_expr

tensor_type    ::= "Tensor" "<" "[" dimension_list "]" "," dtype ">"
dimension_list ::= dimension { "," dimension }
dimension      ::= identifier | integer_literal | "?" | "_" | binary_expr

dtype ::= "f32" | "f64" | "f16" | "bf16"
        | "i8" | "i16" | "i32" | "i64"
        | "u8" | "u16" | "u32" | "u64"
        | "bool" | "complex64" | "complex128"

scalar_type ::= dtype
tuple_type  ::= "(" type_expr { "," type_expr } ")"
list_type   ::= "[" type_expr "]"
dict_type   ::= "{" identifier ":" type_expr { "," identifier ":" type_expr } "}"

GRAPH DEFINITION
graph_block ::= "@graph" identifier [ "(" param_list ")" [ "->" type_expr ] ] "{" graph_body "}"
param_list  ::= param_def { "," param_def }
param_def   ::= identifier ":" type_expr [ "?" ]

graph_body ::= { graph_stmt }
graph_stmt ::= input_decl
             | output_decl
             | param_decl
             | node_decl
             | assert_stmt
             | check_stmt

input_decl  ::= "input" identifier ":" type_expr [ "?" ] ";"
output_decl ::= "output" [ identifier ":" ] expr ";"

param_decl  ::= "param" identifier ":" type_expr [ param_attrs ] ";"
param_attrs ::= "{" { param_attr } "}"
param_attr  ::= "init" ":" init_expr
              | "frozen" ":" bool_literal
              | "device" ":" device_expr
              | identifier ":" literal

init_expr ::= string_literal | expr

node_decl ::= "node" identifier [ ":" type_expr ] [ node_body ] ";"
node_body ::= "{" { node_stmt } "}"
node_stmt ::= "op" ":" operation ";"
            | "input" ":" expr ";"
            | "output" ":" type_expr ";"
            | hint_directive
            | identifier ":" expr ";"

operation ::= identifier "(" [ arg_list ] ")"
            | block_operation
            | call_operation
            | binary_expr
            | unary_expr

block_operation ::= "if" expr "{" operation "}" [ "else" "{" operation "}" ]
                  | "repeat" "(" expr ")" "{" operation "}"

call_operation       ::= "call" qualified_identifier "(" [ arg_list ] ")"
qualified_identifier ::= identifier { "." identifier | "::" identifier }
arg_list             ::= arg { "," arg }
arg                  ::= [ identifier ":" ] expr

assert_stmt     ::= "@assert" expr [ "," string_literal ] ";"
check_stmt      ::= "@check" identifier "{" { check_condition } "}"
check_condition ::= "assert" expr [ "," string_literal ] ";"

hint_directive ::= "@hint" hint_type ";"
hint_type      ::= "recompute_in_backward" | "must_preserve" | "in_place" | "no_grad" | identifier

CUSTOM OPERATORS
custom_op_block ::= "@custom_op" identifier "{" { custom_op_stmt } "}"
custom_op_stmt  ::= "signature" ":" signature ";"
                  | "impl" identifier "{" { impl_attr } "}"
                  | "gradient" identifier "{" gradient_body "}"

signature     ::= "(" param_list ")" "->" type_expr
impl_attr     ::= identifier ":" expr ";"
gradient_body ::= { gradient_stmt }
gradient_stmt ::= "impl" identifier "{" { impl_attr } "}"
                | "call" operation ";"

TRAINING CONFIGURATION
training_block ::= "@training" "{" { training_stmt } "}"
training_stmt  ::= "model" ":" identifier ";"
                 | "loss" ":" identifier ";"
                 | "optimizer" ":" optimizer_config ";"
                 | "lr_schedule" ":" schedule_config ";"
                 | "grad_clip" ":" clip_config ";"
                 | "precision" ":" string_literal ";"
                 | "accumulation_steps" ":" integer_literal ";"
                 | identifier ":" expr ";"

optimizer_config ::= "{" { optimizer_attr } "}"
optimizer_attr   ::= identifier ":" expr ";"
schedule_config  ::= "{" { schedule_attr } "}"
schedule_attr    ::= identifier ":" expr ";"
clip_config      ::= "{" { clip_attr } "}"
clip_attr        ::= identifier ":" expr ";"

INFERENCE CONFIGURATION
inference_block ::= "@inference" "{" { inference_stmt } "}"
inference_stmt  ::= "model" ":" identifier ";"
                  | "optimizations" ":" list_literal ";"
                  | "quantization" ":" quant_config ";"
                  | "generation" ":" generation_config ";"
                  | identifier ":" expr ";"

quant_config      ::= "{" { quant_attr } "}"
quant_attr        ::= identifier ":" expr ";"
generation_config ::= "{" { generation_attr } "}"
generation_attr   ::= identifier ":" expr ";"

METRICS & LOGGING
metrics_block ::= "@metrics" identifier "{" { metric_def } "}"
metric_def    ::= "track" identifier "{" { metric_attr } "}"
metric_attr   ::= "source" ":" source_expr ";"
                | "compute" ":" compute_expr ";"
                | "aggregate" ":" string_literal ";"
                | "type" ":" string_literal ";"
                | "log_every" ":" integer_literal ";"
                | identifier ":" expr ";"

source_expr    ::= qualified_identifier | "[" qualified_identifier { "," qualified_identifier } "]"
compute_expr   ::= expr | "{" iteration_expr "}"
iteration_expr ::= "for" identifier "in" expr "{" expr "}"

logging_block     ::= "@logging" "{" { logging_stmt } "}"
logging_stmt      ::= "backend" ":" string_literal ";"
                    | identifier ":" config_literal ";"
                    | "checkpoints" ":" checkpoint_config ";"

checkpoint_config ::= "{" { checkpoint_attr } "}"
checkpoint_attr   ::= identifier ":" expr ";"

visualization_block ::= "@visualizations" "{" { visualization_def } "}"
visualization_def   ::= "plot" identifier "{" { plot_attr } "}"
plot_attr           ::= identifier ":" expr ";"

EXPRESSIONS
expr ::= binary_expr | unary_expr | primary_expr

binary_expr ::= expr binary_op expr
binary_op   ::= "+" | "-" | "*" | "/" | "%" | "**"
              | "==" | "!=" | "<" | ">" | "<=" | ">="
              | "&&" | "||" | "??"
              | "&" | "|" | "^" | "<<" | ">>"

unary_expr ::= unary_op expr
unary_op   ::= "-" | "!" | "~"

primary_expr ::= identifier
               | literal
               | function_call
               | member_access
               | index_access
               | range_expr
               | list_expr
               | dict_expr
               | paren_expr
               | tensor_literal

function_call  ::= identifier "(" [ arg_list ] ")"
member_access  ::= expr "." identifier
index_access   ::= expr "[" expr [ ":" expr ] "]"
range_expr     ::= "range" "(" expr [ "," expr [ "," expr ] ] ")"
list_expr      ::= "[" [ expr { "," expr } ] "]"
dict_expr      ::= "{" [ dict_entry { "," dict_entry } ] "}"
dict_entry     ::= identifier ":" expr
paren_expr     ::= "(" expr ")"
tensor_literal ::= "[[" expr { "," expr } "]]"

LITERALS
literal ::= integer_literal
          | float_literal
          | bool_literal
          | string_literal
          | list_literal
          | config_literal
          | "null"

integer_literal ::= digit { digit }
float_literal   ::= digit { digit } "." digit { digit } [ exponent ]
                  | digit { digit } exponent
exponent        ::= ("e" | "E") [ "+" | "-" ] digit { digit }
bool_literal    ::= "true" | "false"
string_literal  ::= '"' { string_char } '"'
list_literal    ::= "[" [ literal { "," literal } ] "]"
config_literal  ::= "{" { identifier ":" literal { "," identifier ":" literal } } "}"

DEVICE EXPRESSIONS
device_expr ::= "cpu"
              | "gpu" [ ":" integer_literal ]
              | "tpu" [ ":" integer_literal ]
              | identifier

IDENTIFIERS
identifier ::= letter { letter | digit | "_" }
letter     ::= "a" .. "z" | "A" .. "Z"
digit      ::= "0" .. "9"

Examples

Simple Linear Regression

@model {
    name: "LinearRegression";
    version: "1.0";
}

@config {
    batch_size: 32;
    learning_rate: 0.01;
}

@graph Forward(x: Tensor<[Batch, In], f32>) -> Tensor<[Batch, Out], f32> {
    param w: Tensor<[In, Out], f32> { init: "normal(0, 0.02)"; };
    param b: Tensor<[Out], f32> { init: "zeros"; };
    
    node y: Tensor<[Batch, Out], f32> {
        op: x * w + b;
    };
    
    output y;
}

Matrix Multiplication

@graph MatMul(A: Tensor<[M, K], f32>, B: Tensor<[K, N], f32>) {
    node C {
        op: A @ B; 
    };
    output C;
}

Architecture Overview

Shrew is built as a modular set of Rust crates. The architecture is designed to separate the frontend language (DSL) from the backend execution engines, with a middle layer for Intermediate Representation (IR) and optimization.

High-Level Flow

  1. Frontend: The shrew-ir crate parses the .sw source code into an Abstract Syntax Tree (AST) and then lowers it into a High-Level IR (Graph).
  2. Optimization: The shrew-optim crate applies graph transformations (e.g., constant folding, operator fusion) to the IR.
  3. Execution: The shrew-core crate orchestrates execution, dispatching tensor operations to specific backends (shrew-cpu, shrew-cuda).

Diagram

graph TD
    Source[.sw File] --> Parser[shrew-ir: Parser]
    Parser --> AST[AST]
    AST --> Lower[shrew-ir: Lowering]
    Lower --> IR[Graph IR]
    IR --> Optim[shrew-optim]
    Optim --> Exec[shrew-core: Executor]
    Exec --> CPU[shrew-cpu]
    Exec --> CUDA[shrew-cuda]

Crate Structure

The project is organized as a Cargo workspace with the following members:

  • shrew: The top-level crate that re-exports functionality and provides the main entry point (CLI via shrew-cli).
  • shrew-core: The runtime core. Handles tensor storage, device management, and graph execution APIs.
  • shrew-ir: Contains the parser, AST definitions, and the Intermediate Representation (IR) logic.
  • shrew-optim: Optimization passes for the IR (graph rewriting, fusion).
  • shrew-nn: Implementation of neural network layers and common operators.
  • shrew-data: Data loading and preprocessing utilities.
  • shrew-cpu: CPU backend implementation (using Rayon and SIMD where available).
  • shrew-cuda: CUDA backend implementation (interfaces with cuBLAS, cuDNN).
  • shrew-cli: Command-line interface tool.
  • shrew-python: Python bindings (PyO3) for using Shrew from Python.

IR & Lowering

The Intermediate Representation (IR) is a directed acyclic graph (DAG) where nodes represent operations and edges represent data dependencies (tensors).

Lowering Process

The lowering phase converts the AST (which mirrors the syntax tree) into the Graph IR. Key steps include:

  1. Symbol Resolution: Linking identifiers to their definitions.
  2. Type Checking: Verifying tensor shapes and data types.
  3. Graph Construction: Building the node connectivity.

The IR is defined in crates/shrew-ir/src/graph.rs.

Execution Engine

The execution engine in shrew-core takes an optimized graph and runs it.

Tensors & Storage

Tensors are backed by a storage enum that can hold data on different devices:

#![allow(unused)]
fn main() {
pub enum Storage {
    Cpu(Vec<f32>),
    Cuda(CudaSlice<f32>),
    // ...
}
}

Backend Dispatch

Operations are dispatched dynamically based on the tensor’s device. The Backend trait defines the interface for all supported operations (matmul, add, relu, etc.).