Expand description
Re-export neural network modules.
Modules§
- activation
- attention
- batchnorm
- conv
- dropout
- embedding
- flatten
- groupnorm
- init
- layernorm
- linear
- loss
- metrics
- module
- rmsnorm
- rnn
- sequential
- transformer
Structs§
- Adaptive
AvgPool2d - Adaptive 2D Average Pooling — pools to a fixed output size.
- AvgPool2d
- 2D average-pooling layer.
- Batch
Norm2d - 2D Batch Normalization layer for convolutional networks.
- Class
Metrics - Per-class metric report entry.
- Confusion
Matrix - NxN confusion matrix. Entry [i][j] = count of samples with true class i predicted as class j.
- Conv1d
- 1D convolution layer.
- Conv2d
- 2D convolutional layer.
- Dropout
- Applies dropout regularization.
- ELU
- ELU activation: x if x > 0, alpha * (exp(x) - 1) otherwise
- Embedding
- A learnable lookup table mapping integer indices to dense vectors.
- Flatten
- Flatten layer: collapses dimensions
[start_dim..=end_dim]into one. - GRU
- A multi-step GRU that unrolls a GRUCell over the sequence dimension.
- GRUCell
- A single-step GRU cell.
- GeLU
- GELU activation (Gaussian Error Linear Unit) Used in Transformers (BERT, GPT, etc.)
- Group
Norm - Group Normalization layer.
- LSTM
- A multi-step LSTM that unrolls an LSTMCell over the sequence dimension.
- LSTM
Cell - A single-step LSTM cell.
- Layer
Norm - Layer Normalization: normalizes over the last N dimensions.
- Leaky
ReLU - LeakyReLU activation: max(negative_slope * x, x)
- Linear
- A fully-connected (dense) layer: y = xW^T + b.
- MaxPool2d
- 2D max-pooling layer.
- Mish
- Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
- Multi
Head Attention - Multi-Head Self-Attention module.
- RMSNorm
- RMS Normalization layer (used in LLaMA, Mistral, etc.).
- RNN
- A multi-step vanilla RNN that unrolls an RNNCell over the sequence dimension.
- RNNCell
- A single-step vanilla RNN cell.
- ReLU
- ReLU activation: max(0, x)
- Sequential
- A container that chains modules sequentially.
- SiLU
- SiLU / Swish activation: x * σ(x) Used in modern architectures (EfficientNet, LLaMA, etc.)
- Sigmoid
- Sigmoid activation: 1 / (1 + e^(-x))
- Tanh
- Tanh activation
- Transformer
Block - A single Transformer block (pre-norm style).
Enums§
- Average
- How to average per-class metrics for multi-class problems.
- Reduction
- Reduction mode for loss functions.
Traits§
- Module
- The fundamental trait for all neural network layers.
Functions§
- accuracy
- Classification accuracy: fraction of correct predictions.
- argmax_
classes - Compute argmax along the last axis, returning class indices.
- bce_
loss - Binary Cross-Entropy loss for probabilities in [0, 1].
- bce_
with_ logits_ loss - Binary Cross-Entropy with Logits (numerically stable).
- classification_
report - Per-class precision, recall, F1, and support — like sklearn’s classification_report.
- cross_
entropy_ loss - Cross-entropy loss with log-softmax for numerical stability.
- f1_
score - F1 Score — harmonic mean of precision and recall.
- l1_loss
- L1 Loss (Mean Absolute Error): mean(|prediction - target|)
- l1_
loss_ with_ reduction - L1 Loss with configurable reduction.
- mae
- Mean Absolute Error: mean(|y_true - y_pred|).
- mape
- Mean Absolute Percentage Error: mean(|y_true - y_pred| / |y_true|) * 100.
- mse_
loss - Mean Squared Error loss: mean((prediction - target)²)
- mse_
loss_ with_ reduction - MSE Loss with configurable reduction.
- nll_
loss - Negative Log-Likelihood Loss with integer class indices.
- perplexity
- Perplexity from cross-entropy loss: exp(loss).
- perplexity_
from_ log_ probs - Perplexity from a flat array of per-token log-probabilities.
- precision
- Precision for multi-class classification.
- r2_
score - R² (coefficient of determination).
- recall
- Recall for multi-class classification.
- rmse
- Root Mean Squared Error: sqrt(mean((y_true - y_pred)²)).
- smooth_
l1_ loss - Smooth L1 Loss (Huber Loss):
- tensor_
accuracy - Compute accuracy directly from logit tensors and one-hot/class-index targets.
- top_
k_ accuracy - Top-K accuracy: fraction of samples where the true class is in the top-K predictions.