Expand description
Prelude: import this for the most common types.
Re-exports§
pub use crate::checkpoint::TrainingCheckpoint;pub use crate::distributed::reduce_gradients;pub use crate::distributed::AllReduceOp;pub use crate::distributed::DataParallel;pub use crate::distributed::LossScaleConfig;pub use crate::distributed::MixedPrecisionTrainer;pub use crate::distributed::ParallelTrainer;pub use crate::distributed::PipelineParallel;pub use crate::distributed::PipelineStage;pub use crate::exec::CompileStats;pub use crate::exec::JitExecutor;pub use crate::exec::JitResult;pub use crate::exec::Executor;pub use crate::exec::RuntimeConfig;pub use crate::exec::Trainer;pub use crate::onnx::export_tensors as export_onnx_tensors;pub use crate::onnx::export_weights as export_onnx;pub use crate::onnx::load_onnx_weights;pub use crate::onnx::OnnxAttribute;pub use crate::onnx::OnnxModel;pub use crate::onnx::OnnxNode;pub use crate::onnx::OnnxTensor;pub use crate::profiler::benchmark_forward;pub use crate::profiler::benchmark_forward_backward;pub use crate::profiler::estimate_model_memory;pub use crate::profiler::format_bytes;pub use crate::profiler::BenchmarkResult;pub use crate::profiler::MemoryTracker;pub use crate::profiler::ModelSummary;pub use crate::profiler::ProfileEntry;pub use crate::profiler::ProfileEvent;pub use crate::profiler::ProfileReport;pub use crate::profiler::Profiler;pub use crate::profiler::ScopedTimer;pub use crate::profiler::Stopwatch;pub use crate::quantize::dequantize_tensor;pub use crate::quantize::quantization_stats;pub use crate::quantize::quantize_named_parameters;pub use crate::quantize::quantize_tensor;pub use crate::quantize::QuantBits;pub use crate::quantize::QuantConfig;pub use crate::quantize::QuantGranularity;pub use crate::quantize::QuantMode;pub use crate::quantize::QuantStats;pub use crate::quantize::QuantizedLinear;pub use crate::quantize::QuantizedTensor;
Structs§
- Adam
- Adam optimizer (Adaptive Moment Estimation).
- AdamW
- AdamW optimizer (Adam with decoupled weight decay).
- Adaptive
AvgPool2d - Adaptive 2D Average Pooling — pools to a fixed output size.
- AvgPool2d
- 2D average-pooling layer.
- Batch
Norm2d - 2D Batch Normalization layer for convolutional networks.
- Conv1d
- 1D convolution layer.
- Conv2d
- 2D convolutional layer.
- Cosine
AnnealingLR - Cosine annealing from
initial_lrtomin_lrovertotal_steps. - Cosine
WarmupLR - Linear warmup from 0 to
initial_lroverwarmup_steps, then cosine decay frominitial_lrtomin_lrover the remaining steps. - CpuBackend
- Re-export CPU backend. CPU backend. Implements Backend by running operations on CPU via iterators.
- CpuDevice
- Re-export CPU backend. The CPU device. Since every machine has exactly one CPU (from our perspective), this is a zero-sized type.
- Dropout
- Applies dropout regularization.
- ELU
- ELU activation: x if x > 0, alpha * (exp(x) - 1) otherwise
- EMA
- Exponential Moving Average of model parameters.
- ExponentialLR
- Multiply the learning rate by
gammaevery step. - Flatten
- Flatten layer: collapses dimensions
[start_dim..=end_dim]into one. - GRU
- A multi-step GRU that unrolls a GRUCell over the sequence dimension.
- GRUCell
- A single-step GRU cell.
- GeLU
- GELU activation (Gaussian Error Linear Unit) Used in Transformers (BERT, GPT, etc.)
- Grad
Accumulator - Gradient accumulation helper.
- Grad
Store - Re-export core types. Stores gradients for all tensors in a computation graph.
- Group
Norm - Group Normalization layer.
- LSTM
- A multi-step LSTM that unrolls an LSTMCell over the sequence dimension.
- LSTM
Cell - A single-step LSTM cell.
- Layer
Norm - Layer Normalization: normalizes over the last N dimensions.
- Leaky
ReLU - LeakyReLU activation: max(negative_slope * x, x)
- Linear
- A fully-connected (dense) layer: y = xW^T + b.
- LinearLR
- Linearly interpolate the learning rate from
start_factor * initial_lrtoend_factor * initial_lrovertotal_stepssteps. - MaxPool2d
- 2D max-pooling layer.
- Mish
- Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
- Multi
Head Attention - Multi-Head Self-Attention module.
- Optimizer
State - A serializable snapshot of an optimizer’s internal state.
- RAdam
- Rectified Adam (RAdam) optimizer.
- RMSNorm
- RMS Normalization layer (used in LLaMA, Mistral, etc.).
- RMSProp
- RMSProp optimizer.
- RNN
- A multi-step vanilla RNN that unrolls an RNNCell over the sequence dimension.
- RNNCell
- A single-step vanilla RNN cell.
- ReLU
- ReLU activation: max(0, x)
- ReduceLR
OnPlateau - Reduce the learning rate when a monitored metric plateaus.
- SGD
- Stochastic Gradient Descent optimizer with optional momentum.
- Sequential
- A container that chains modules sequentially.
- Shape
- Re-export core types. N-dimensional shape of a tensor.
- SiLU
- SiLU / Swish activation: x * σ(x) Used in modern architectures (EfficientNet, LLaMA, etc.)
- StepLR
- Multiply the learning rate by
gammaeverystep_sizesteps. - Tensor
- Re-export core types. An n-dimensional array of numbers on a specific backend.
- Transformer
Block - A single Transformer block (pre-norm style).
Enums§
- DType
- Re-export core types. Enum of all supported element data types.
Traits§
- LrScheduler
- Trait for learning rate schedulers.
- Module
- The fundamental trait for all neural network layers.
- Optimizer
- Trait that all optimizers implement.
- Stateful
- Trait for optimizers that can save and restore their internal state.
Functions§
- bce_
loss - Binary Cross-Entropy loss for probabilities in [0, 1].
- bce_
with_ logits_ loss - Binary Cross-Entropy with Logits (numerically stable).
- clip_
grad_ norm - Clip gradients by their global L2 norm.
- clip_
grad_ value - Clamp each gradient element to
[-max_value, max_value]. - cross_
entropy_ loss - Cross-entropy loss with log-softmax for numerical stability.
- grad_
norm - Compute the global L2 norm of all gradients without clipping.
- l1_loss
- L1 Loss (Mean Absolute Error): mean(|prediction - target|)
- mse_
loss - Mean Squared Error loss: mean((prediction - target)²)
- nll_
loss - Negative Log-Likelihood Loss with integer class indices.
- smooth_
l1_ loss - Smooth L1 Loss (Huber Loss):
Type Aliases§
- CpuTensor
- Re-export CPU backend. Convenience type alias for CPU tensors.