shrew_optim/
optimizer.rs

1// Optimizer trait — The interface all optimizers implement
2//
3// Every optimizer takes the current parameters + their gradients and
4// produces updated parameter values. The trait is simple:
5//
6//   fn step(&mut self, grads: &GradStore<B>) → updated parameters
7//
8// DESIGN DECISION: Immutable parameter update
9//
10// Since our tensors are immutable (Arc-wrapped), optimizers can't modify
11// parameters in-place. Instead, step() returns new tensors with updated values.
12// The training loop is responsible for replacing the old parameters.
13//
14// This is actually cleaner than PyTorch's in-place mutation approach:
15//   old_params → optimizer.step(grads) → new_params (functional update)
16//
17// The trade-off is slightly more allocation, but it avoids complex mutation
18// semantics and plays well with Rust's ownership model.
19
20use std::collections::HashMap;
21
22use shrew_core::backend::Backend;
23use shrew_core::backprop::GradStore;
24use shrew_core::error::Result;
25use shrew_core::tensor::Tensor;
26
27/// Trait that all optimizers implement.
28///
29/// Optimizers update model parameters given their gradients.
30///
31/// # Type Parameters
32/// - `B`: the compute backend
33pub trait Optimizer<B: Backend> {
34    /// Perform one optimization step.
35    ///
36    /// Given the current parameters and their gradients (from backward()),
37    /// compute and return the updated parameter values.
38    ///
39    /// Returns a vector of updated parameters in the same order as
40    /// the parameters passed to the optimizer's constructor.
41    fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>>;
42
43    /// Return the current learning rate.
44    fn learning_rate(&self) -> f64;
45
46    /// Set a new learning rate (for learning rate scheduling).
47    fn set_learning_rate(&mut self, lr: f64);
48}
49
50// OptimizerState — Serializable state dictionary for checkpoint save/load
51
52/// A serializable snapshot of an optimizer's internal state.
53///
54/// Contains named scalar values (e.g., step count, learning rate)
55/// and named f64 buffers (e.g., momentum vectors, second moment estimates).
56///
57/// This follows the PyTorch `state_dict()` / `load_state_dict()` pattern
58/// but is Rust-native and format-agnostic.
59#[derive(Debug, Clone)]
60pub struct OptimizerState {
61    /// Optimizer type name (e.g., "Adam", "SGD") for validation on load.
62    pub optimizer_type: String,
63    /// Named scalar values (step count, hyperparameters, etc.)
64    pub scalars: HashMap<String, f64>,
65    /// Named f64 buffers (momentum vectors, second moment estimates, etc.)
66    /// Each buffer is flattened to a single Vec<f64>; the key encodes
67    /// the parameter index: e.g., "m.0", "m.1", "v.0", "v.1".
68    pub buffers: HashMap<String, Vec<f64>>,
69}
70
71impl OptimizerState {
72    /// Create an empty state dict for the given optimizer type.
73    pub fn new(optimizer_type: impl Into<String>) -> Self {
74        OptimizerState {
75            optimizer_type: optimizer_type.into(),
76            scalars: HashMap::new(),
77            buffers: HashMap::new(),
78        }
79    }
80
81    /// Insert a scalar value.
82    pub fn set_scalar(&mut self, key: impl Into<String>, value: f64) {
83        self.scalars.insert(key.into(), value);
84    }
85
86    /// Insert a buffer.
87    pub fn set_buffer(&mut self, key: impl Into<String>, data: Vec<f64>) {
88        self.buffers.insert(key.into(), data);
89    }
90
91    /// Get a scalar value.
92    pub fn get_scalar(&self, key: &str) -> Option<f64> {
93        self.scalars.get(key).copied()
94    }
95
96    /// Get a buffer.
97    pub fn get_buffer(&self, key: &str) -> Option<&Vec<f64>> {
98        self.buffers.get(key)
99    }
100}
101
102/// Trait for optimizers that can save and restore their internal state.
103///
104/// This enables training checkpoint save/load — not just model weights,
105/// but also the optimizer's momentum buffers, step counters, etc.,
106/// allowing training to resume exactly where it left off.
107pub trait Stateful {
108    /// Export the optimizer's internal state as a serializable dictionary.
109    fn state_dict(&self) -> OptimizerState;
110
111    /// Restore the optimizer's internal state from a previously saved dictionary.
112    ///
113    /// Returns an error if the state dict is incompatible (wrong optimizer type,
114    /// missing keys, wrong buffer sizes).
115    fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()>;
116}