shrew_nn/
lib.rs

1//! # shrew-nn
2//!
3//! Neural network layers, activations, and loss functions for Shrew.
4//!
5//! Provides reusable building blocks following the [`Module`] trait pattern
6//! (similar to PyTorch's `nn.Module`):
7//!
8//! 1. **Module trait** — every layer implements `forward()`
9//! 2. **Linear** — fully connected: `y = xW^T + b`
10//! 3. **Embedding** — lookup table for discrete tokens
11//! 4. **Dropout** — regularization via random zeroing
12//! 5. **Activations** — ReLU, GELU, SiLU, etc. as modules
13//! 6. **Loss functions** — MSELoss, CrossEntropyLoss
14//!
15//! Modules are generic over `Backend` (like `Tensor<B>`), so the same
16//! network definition works on CPU, CUDA, or any future backend.
17
18pub mod activation;
19pub mod attention;
20pub mod batchnorm;
21pub mod conv;
22pub mod dropout;
23pub mod embedding;
24pub mod flatten;
25pub mod groupnorm;
26pub mod init;
27pub mod layernorm;
28pub mod linear;
29pub mod loss;
30pub mod metrics;
31pub mod module;
32pub mod rmsnorm;
33pub mod rnn;
34pub mod sequential;
35pub mod transformer;
36
37pub use activation::{GeLU, LeakyReLU, Mish, ReLU, SiLU, Sigmoid, Tanh, ELU};
38pub use attention::MultiHeadAttention;
39pub use batchnorm::BatchNorm2d;
40pub use conv::{AdaptiveAvgPool2d, AvgPool2d, Conv1d, Conv2d, MaxPool2d};
41pub use dropout::Dropout;
42pub use embedding::Embedding;
43pub use flatten::Flatten;
44pub use groupnorm::GroupNorm;
45pub use layernorm::LayerNorm;
46pub use linear::Linear;
47pub use loss::{
48    bce_loss, bce_with_logits_loss, cross_entropy_loss, l1_loss, l1_loss_with_reduction, mse_loss,
49    mse_loss_with_reduction, nll_loss, smooth_l1_loss, Reduction,
50};
51pub use metrics::{
52    accuracy, argmax_classes, classification_report, f1_score, mae, mape, perplexity,
53    perplexity_from_log_probs, precision, r2_score, recall, rmse, tensor_accuracy, top_k_accuracy,
54    Average, ClassMetrics, ConfusionMatrix,
55};
56pub use module::Module;
57pub use rmsnorm::RMSNorm;
58pub use rnn::{GRUCell, LSTMCell, RNNCell, GRU, LSTM, RNN};
59pub use sequential::Sequential;
60pub use transformer::TransformerBlock;