shrew_optim/
ema.rs

1// EMA — Exponential Moving Average of Model Parameters
2//
3// EMA maintains a shadow copy of model parameters that is an exponential
4// moving average of the training parameters. This smoothed version of the
5// model often generalizes better than the final training weights.
6//
7// Update rule (after each optimizer step):
8//   shadow_θ = decay * shadow_θ + (1 - decay) * θ
9//
10// Typical decay: 0.999 (close to 1 means slower update → more smoothing)
11//
12// USAGE:
13//   - During training: update EMA after each optimizer step
14//   - During evaluation: use EMA parameters instead of training parameters
15//
16// This technique is used in:
17//   - Image generation (DDPM, StyleGAN)
18//   - Semi-supervised learning (Mean Teacher)
19//   - Large language models (some fine-tuning recipes)
20//
21// DESIGN: The EMA stores copies of parameter data (as Vec<f64>) so it
22// doesn't interfere with training. Use `apply()` to write EMA weights
23// into the model parameters, and `restore()` to put training weights back.
24
25use shrew_core::backend::Backend;
26use shrew_core::error::Result;
27use shrew_core::tensor::Tensor;
28
29/// Exponential Moving Average of model parameters.
30///
31/// Maintains a shadow copy that is a smoothed version of training parameters.
32///
33/// # Example
34/// ```ignore
35/// let mut ema = EMA::new(model.parameters(), 0.999);
36///
37/// // Training loop:
38/// optimizer.step(&grads)?;
39/// ema.update(&model.parameters())?;
40///
41/// // Evaluation:
42/// ema.apply()?;             // Write EMA weights into model
43/// let output = model.forward(input)?;
44/// ema.restore()?;           // Restore training weights
45/// ```
46pub struct EMA<B: Backend> {
47    /// References to the model parameters (used for apply/restore)
48    params: Vec<Tensor<B>>,
49    /// Shadow parameters (EMA values)
50    shadow: Vec<Vec<f64>>,
51    /// Saved training parameters (for restore after apply)
52    backup: Vec<Vec<f64>>,
53    /// Decay rate (e.g., 0.999)
54    decay: f64,
55    /// Number of updates performed
56    num_updates: u64,
57}
58
59impl<B: Backend> EMA<B> {
60    /// Create a new EMA tracker.
61    ///
62    /// # Arguments
63    /// - `params`: The model parameters to track
64    /// - `decay`: Decay rate (typical: 0.999 or 0.9999)
65    pub fn new(params: Vec<Tensor<B>>, decay: f64) -> Result<Self> {
66        let shadow: Result<Vec<Vec<f64>>> = params.iter().map(|p| p.to_f64_vec()).collect();
67        let shadow = shadow?;
68
69        Ok(EMA {
70            params,
71            shadow,
72            backup: Vec::new(),
73            decay,
74            num_updates: 0,
75        })
76    }
77
78    /// Update the EMA shadow parameters with current model parameters.
79    ///
80    /// Call this after each optimizer step.
81    pub fn update(&mut self, current_params: &[Tensor<B>]) -> Result<()> {
82        self.num_updates += 1;
83
84        for (i, param) in current_params.iter().enumerate() {
85            let data = param.to_f64_vec()?;
86            for (s, &d) in self.shadow[i].iter_mut().zip(data.iter()) {
87                *s = self.decay * *s + (1.0 - self.decay) * d;
88            }
89        }
90
91        Ok(())
92    }
93
94    /// Update using an adjusted decay that ramps up during early training.
95    ///
96    /// The effective decay is: min(decay, (1 + num_updates) / (10 + num_updates))
97    /// This prevents the EMA from being too biased toward initial values.
98    pub fn update_with_warmup(&mut self, current_params: &[Tensor<B>]) -> Result<()> {
99        self.num_updates += 1;
100
101        let effective_decay = self
102            .decay
103            .min((1.0 + self.num_updates as f64) / (10.0 + self.num_updates as f64));
104
105        for (i, param) in current_params.iter().enumerate() {
106            let data = param.to_f64_vec()?;
107            for (s, &d) in self.shadow[i].iter_mut().zip(data.iter()) {
108                *s = effective_decay * *s + (1.0 - effective_decay) * d;
109            }
110        }
111
112        Ok(())
113    }
114
115    /// Apply EMA parameters to the model (for evaluation).
116    ///
117    /// This saves the current training parameters so they can be restored
118    /// with `restore()`.
119    pub fn apply(&mut self) -> Result<()> {
120        // Save current training weights
121        self.backup = Vec::with_capacity(self.params.len());
122        for param in &self.params {
123            self.backup.push(param.to_f64_vec()?);
124        }
125
126        // Write EMA weights into model parameters
127        for (param, shadow) in self.params.iter().zip(self.shadow.iter()) {
128            param.update_data_inplace(shadow)?;
129        }
130
131        Ok(())
132    }
133
134    /// Restore training parameters after `apply()`.
135    pub fn restore(&mut self) -> Result<()> {
136        if self.backup.is_empty() {
137            return Ok(());
138        }
139
140        for (param, backup) in self.params.iter().zip(self.backup.iter()) {
141            param.update_data_inplace(backup)?;
142        }
143
144        self.backup.clear();
145        Ok(())
146    }
147
148    /// Get the decay rate.
149    pub fn decay(&self) -> f64 {
150        self.decay
151    }
152
153    /// Get the number of updates performed.
154    pub fn num_updates(&self) -> u64 {
155        self.num_updates
156    }
157
158    /// Get the shadow (EMA) values for a specific parameter index.
159    pub fn shadow_values(&self, index: usize) -> &[f64] {
160        &self.shadow[index]
161    }
162}