shrew/
lib.rs

1//! # Shrew
2//!
3//! A deep learning library built from scratch in Rust.
4//!
5//! This is the top-level facade crate that re-exports everything you need.
6//!
7//! ## Usage
8//!
9//! ```rust
10//! use shrew::prelude::*;
11//! ```
12//!
13//! ## Architecture
14//!
15//! | Crate | Purpose |
16//! |-------|----------|
17//! | `shrew-core` | Tensor, Shape, DType, Layout, Backend trait, Autograd |
18//! | `shrew-cpu` | CPU backend with SIMD matmul and rayon parallelism |
19//! | `shrew-nn` | Neural network layers (Linear, Conv2d, RNN, LSTM, Transformer, etc.) |
20//! | `shrew-optim` | Optimizers (SGD, Adam, AdamW, RAdam, RMSProp), LR schedulers, EMA |
21//! | `shrew-data` | Dataset, DataLoader, MNIST, transforms |
22//! | `shrew-cuda` | CUDA GPU backend (feature-gated) |
23//! | `shrew-ir` | `.sw` IR format: lexer, parser, AST, Graph IR |
24//!
25//! ## Modules
26//!
27//! - [`distributed`] — DataParallel, MixedPrecisionTrainer, PipelineParallel
28//! - [`quantize`] — INT8/INT4 post-training quantization
29//! - [`onnx`] — ONNX import/export
30//! - [`profiler`] — Timing, memory tracking, model benchmarks
31//! - [`exec`] — Graph executor for `.sw` programs
32//! - [`checkpoint`] — Save/load model parameters
33//! - [`safetensors`] — HuggingFace-compatible serialization
34
35/// Re-export core types.
36pub use shrew_core::{
37    backend::{Backend, BackendDevice, BackendStorage, BinaryOp, CmpOp, ReduceOp, UnaryOp},
38    op::{Op, TensorId},
39    DType, Error, GradStore, Layout, Result, Shape, Tensor, WithDType,
40};
41
42/// Re-export CPU backend.
43pub use shrew_cpu::{CpuBackend, CpuDevice, CpuStorage, CpuTensor};
44
45/// Re-export CUDA backend (requires `cuda` feature + NVIDIA CUDA Toolkit).
46#[cfg(feature = "cuda")]
47pub use shrew_cuda::{CudaBackend, CudaDevice, CudaStorage, CudaTensor};
48
49/// Re-export neural network modules.
50pub mod nn {
51    pub use shrew_nn::*;
52}
53
54/// Re-export optimizers.
55pub mod optim {
56    pub use shrew_optim::*;
57}
58
59/// Re-export the .sw IR parser and AST.
60pub mod ir {
61    pub use shrew_ir::*;
62}
63
64/// Graph executor — runs .sw programs on the tensor runtime.
65pub mod exec;
66
67/// Checkpoint — save and load model parameters.
68pub mod checkpoint;
69
70/// Safetensors — interoperable tensor serialization (HuggingFace format).
71pub mod safetensors;
72
73/// Distributed training — DataParallel, MixedPrecision, Pipeline, gradient sync.
74pub mod distributed;
75
76/// Quantization — INT8/INT4 post-training quantization for inference.
77pub mod quantize;
78
79/// ONNX — Import/Export for interoperability with other frameworks.
80pub mod onnx;
81
82/// Profiling & Benchmarking — op-level timing, memory tracking, model summaries.
83pub mod profiler;
84
85/// Prelude: import this for the most common types.
86pub mod prelude {
87    pub use crate::checkpoint::TrainingCheckpoint;
88    pub use crate::distributed::{
89        reduce_gradients, AllReduceOp, DataParallel, LossScaleConfig, MixedPrecisionTrainer,
90        ParallelTrainer, PipelineParallel, PipelineStage,
91    };
92    pub use crate::exec::{CompileStats, JitExecutor, JitResult};
93    pub use crate::exec::{Executor, RuntimeConfig, Trainer};
94    pub use crate::nn::{
95        bce_loss, bce_with_logits_loss, cross_entropy_loss, l1_loss, mse_loss, nll_loss,
96        smooth_l1_loss,
97    };
98    pub use crate::nn::{
99        AdaptiveAvgPool2d, AvgPool2d, BatchNorm2d, Conv1d, Conv2d, Dropout, Flatten, GRUCell, GeLU,
100        GroupNorm, LSTMCell, LayerNorm, LeakyReLU, Linear, MaxPool2d, Mish, Module,
101        MultiHeadAttention, RMSNorm, RNNCell, ReLU, Sequential, SiLU, TransformerBlock, ELU, GRU,
102        LSTM, RNN,
103    };
104    pub use crate::onnx::{
105        export_tensors as export_onnx_tensors, export_weights as export_onnx, load_onnx_weights,
106        OnnxAttribute, OnnxModel, OnnxNode, OnnxTensor,
107    };
108    pub use crate::optim::EMA;
109    pub use crate::optim::{clip_grad_norm, clip_grad_value, grad_norm, GradAccumulator};
110    pub use crate::optim::{Adam, AdamW, Optimizer, OptimizerState, RAdam, RMSProp, Stateful, SGD};
111    pub use crate::optim::{
112        CosineAnnealingLR, CosineWarmupLR, ExponentialLR, LinearLR, LrScheduler, ReduceLROnPlateau,
113        StepLR,
114    };
115    pub use crate::profiler::{
116        benchmark_forward, benchmark_forward_backward, estimate_model_memory, format_bytes,
117        BenchmarkResult, MemoryTracker, ModelSummary, ProfileEntry, ProfileEvent, ProfileReport,
118        Profiler, ScopedTimer, Stopwatch,
119    };
120    pub use crate::quantize::{
121        dequantize_tensor, quantization_stats, quantize_named_parameters, quantize_tensor,
122        QuantBits, QuantConfig, QuantGranularity, QuantMode, QuantStats, QuantizedLinear,
123        QuantizedTensor,
124    };
125    pub use crate::{CpuBackend, CpuDevice, CpuTensor, DType, GradStore, Shape, Tensor};
126}