shrew_optim/optimizer.rs
1// Optimizer trait — The interface all optimizers implement
2//
3// Every optimizer takes the current parameters + their gradients and
4// produces updated parameter values. The trait is simple:
5//
6// fn step(&mut self, grads: &GradStore<B>) → updated parameters
7//
8// DESIGN DECISION: Immutable parameter update
9//
10// Since our tensors are immutable (Arc-wrapped), optimizers can't modify
11// parameters in-place. Instead, step() returns new tensors with updated values.
12// The training loop is responsible for replacing the old parameters.
13//
14// This is actually cleaner than PyTorch's in-place mutation approach:
15// old_params → optimizer.step(grads) → new_params (functional update)
16//
17// The trade-off is slightly more allocation, but it avoids complex mutation
18// semantics and plays well with Rust's ownership model.
19
20use std::collections::HashMap;
21
22use shrew_core::backend::Backend;
23use shrew_core::backprop::GradStore;
24use shrew_core::error::Result;
25use shrew_core::tensor::Tensor;
26
27/// Trait that all optimizers implement.
28///
29/// Optimizers update model parameters given their gradients.
30///
31/// # Type Parameters
32/// - `B`: the compute backend
33pub trait Optimizer<B: Backend> {
34 /// Perform one optimization step.
35 ///
36 /// Given the current parameters and their gradients (from backward()),
37 /// compute and return the updated parameter values.
38 ///
39 /// Returns a vector of updated parameters in the same order as
40 /// the parameters passed to the optimizer's constructor.
41 fn step(&mut self, grads: &GradStore<B>) -> Result<Vec<Tensor<B>>>;
42
43 /// Return the current learning rate.
44 fn learning_rate(&self) -> f64;
45
46 /// Set a new learning rate (for learning rate scheduling).
47 fn set_learning_rate(&mut self, lr: f64);
48}
49
50// OptimizerState — Serializable state dictionary for checkpoint save/load
51
52/// A serializable snapshot of an optimizer's internal state.
53///
54/// Contains named scalar values (e.g., step count, learning rate)
55/// and named f64 buffers (e.g., momentum vectors, second moment estimates).
56///
57/// This follows the PyTorch `state_dict()` / `load_state_dict()` pattern
58/// but is Rust-native and format-agnostic.
59#[derive(Debug, Clone)]
60pub struct OptimizerState {
61 /// Optimizer type name (e.g., "Adam", "SGD") for validation on load.
62 pub optimizer_type: String,
63 /// Named scalar values (step count, hyperparameters, etc.)
64 pub scalars: HashMap<String, f64>,
65 /// Named f64 buffers (momentum vectors, second moment estimates, etc.)
66 /// Each buffer is flattened to a single Vec<f64>; the key encodes
67 /// the parameter index: e.g., "m.0", "m.1", "v.0", "v.1".
68 pub buffers: HashMap<String, Vec<f64>>,
69}
70
71impl OptimizerState {
72 /// Create an empty state dict for the given optimizer type.
73 pub fn new(optimizer_type: impl Into<String>) -> Self {
74 OptimizerState {
75 optimizer_type: optimizer_type.into(),
76 scalars: HashMap::new(),
77 buffers: HashMap::new(),
78 }
79 }
80
81 /// Insert a scalar value.
82 pub fn set_scalar(&mut self, key: impl Into<String>, value: f64) {
83 self.scalars.insert(key.into(), value);
84 }
85
86 /// Insert a buffer.
87 pub fn set_buffer(&mut self, key: impl Into<String>, data: Vec<f64>) {
88 self.buffers.insert(key.into(), data);
89 }
90
91 /// Get a scalar value.
92 pub fn get_scalar(&self, key: &str) -> Option<f64> {
93 self.scalars.get(key).copied()
94 }
95
96 /// Get a buffer.
97 pub fn get_buffer(&self, key: &str) -> Option<&Vec<f64>> {
98 self.buffers.get(key)
99 }
100}
101
102/// Trait for optimizers that can save and restore their internal state.
103///
104/// This enables training checkpoint save/load — not just model weights,
105/// but also the optimizer's momentum buffers, step counters, etc.,
106/// allowing training to resume exactly where it left off.
107pub trait Stateful {
108 /// Export the optimizer's internal state as a serializable dictionary.
109 fn state_dict(&self) -> OptimizerState;
110
111 /// Restore the optimizer's internal state from a previously saved dictionary.
112 ///
113 /// Returns an error if the state dict is incompatible (wrong optimizer type,
114 /// missing keys, wrong buffer sizes).
115 fn load_state_dict(&mut self, state: &OptimizerState) -> Result<()>;
116}