1pub use shrew_core::{
37 backend::{Backend, BackendDevice, BackendStorage, BinaryOp, CmpOp, ReduceOp, UnaryOp},
38 op::{Op, TensorId},
39 DType, Error, GradStore, Layout, Result, Shape, Tensor, WithDType,
40};
41
42pub use shrew_cpu::{CpuBackend, CpuDevice, CpuStorage, CpuTensor};
44
45#[cfg(feature = "cuda")]
47pub use shrew_cuda::{CudaBackend, CudaDevice, CudaStorage, CudaTensor};
48
49pub mod nn {
51 pub use shrew_nn::*;
52}
53
54pub mod optim {
56 pub use shrew_optim::*;
57}
58
59pub mod ir {
61 pub use shrew_ir::*;
62}
63
64pub mod exec;
66
67pub mod checkpoint;
69
70pub mod safetensors;
72
73pub mod distributed;
75
76pub mod quantize;
78
79pub mod onnx;
81
82pub mod profiler;
84
85pub mod prelude {
87 pub use crate::checkpoint::TrainingCheckpoint;
88 pub use crate::distributed::{
89 reduce_gradients, AllReduceOp, DataParallel, LossScaleConfig, MixedPrecisionTrainer,
90 ParallelTrainer, PipelineParallel, PipelineStage,
91 };
92 pub use crate::exec::{CompileStats, JitExecutor, JitResult};
93 pub use crate::exec::{Executor, RuntimeConfig, Trainer};
94 pub use crate::nn::{
95 bce_loss, bce_with_logits_loss, cross_entropy_loss, l1_loss, mse_loss, nll_loss,
96 smooth_l1_loss,
97 };
98 pub use crate::nn::{
99 AdaptiveAvgPool2d, AvgPool2d, BatchNorm2d, Conv1d, Conv2d, Dropout, Flatten, GRUCell, GeLU,
100 GroupNorm, LSTMCell, LayerNorm, LeakyReLU, Linear, MaxPool2d, Mish, Module,
101 MultiHeadAttention, RMSNorm, RNNCell, ReLU, Sequential, SiLU, TransformerBlock, ELU, GRU,
102 LSTM, RNN,
103 };
104 pub use crate::onnx::{
105 export_tensors as export_onnx_tensors, export_weights as export_onnx, load_onnx_weights,
106 OnnxAttribute, OnnxModel, OnnxNode, OnnxTensor,
107 };
108 pub use crate::optim::EMA;
109 pub use crate::optim::{clip_grad_norm, clip_grad_value, grad_norm, GradAccumulator};
110 pub use crate::optim::{Adam, AdamW, Optimizer, OptimizerState, RAdam, RMSProp, Stateful, SGD};
111 pub use crate::optim::{
112 CosineAnnealingLR, CosineWarmupLR, ExponentialLR, LinearLR, LrScheduler, ReduceLROnPlateau,
113 StepLR,
114 };
115 pub use crate::profiler::{
116 benchmark_forward, benchmark_forward_backward, estimate_model_memory, format_bytes,
117 BenchmarkResult, MemoryTracker, ModelSummary, ProfileEntry, ProfileEvent, ProfileReport,
118 Profiler, ScopedTimer, Stopwatch,
119 };
120 pub use crate::quantize::{
121 dequantize_tensor, quantization_stats, quantize_named_parameters, quantize_tensor,
122 QuantBits, QuantConfig, QuantGranularity, QuantMode, QuantStats, QuantizedLinear,
123 QuantizedTensor,
124 };
125 pub use crate::{CpuBackend, CpuDevice, CpuTensor, DType, GradStore, Shape, Tensor};
126}