shrew_optim/
lib.rs

1//! # shrew-optim
2//!
3//! Optimizers for gradient-based training.
4//!
5//! Optimizers update model parameters using gradients computed by `backward()`.
6//! The training loop is:
7//!
8//! 1. `output = model.forward(input)`
9//! 2. `loss = loss_fn(output, target)`
10//! 3. `grads = loss.backward()` — autograd computes gradients
11//! 4. `optimizer.step(&grads)` — optimizer updates parameters
12//!
13//! Implemented optimizers:
14//! - **SGD**: Stochastic Gradient Descent (with optional momentum)
15//! - **Adam**: Adaptive Moment Estimation
16//! - **AdamW**: Adam with decoupled weight decay
17//! - **RMSProp**: Root Mean Square Propagation
18//! - **RAdam**: Rectified Adam (no warmup needed)
19
20pub mod adam;
21pub mod clip;
22pub mod ema;
23pub mod optimizer;
24pub mod radam;
25pub mod rmsprop;
26pub mod scheduler;
27pub mod sgd;
28
29pub use adam::{Adam, AdamW};
30pub use clip::{clip_grad_norm, clip_grad_value, grad_norm, GradAccumulator};
31pub use ema::EMA;
32pub use optimizer::{Optimizer, OptimizerState, Stateful};
33pub use radam::RAdam;
34pub use rmsprop::RMSProp;
35pub use scheduler::{
36    CosineAnnealingLR, CosineWarmupLR, ExponentialLR, LinearLR, LrScheduler, ReduceLROnPlateau,
37    StepLR,
38};
39pub use sgd::SGD;