shrew_optim/ema.rs
1// EMA — Exponential Moving Average of Model Parameters
2//
3// EMA maintains a shadow copy of model parameters that is an exponential
4// moving average of the training parameters. This smoothed version of the
5// model often generalizes better than the final training weights.
6//
7// Update rule (after each optimizer step):
8// shadow_θ = decay * shadow_θ + (1 - decay) * θ
9//
10// Typical decay: 0.999 (close to 1 means slower update → more smoothing)
11//
12// USAGE:
13// - During training: update EMA after each optimizer step
14// - During evaluation: use EMA parameters instead of training parameters
15//
16// This technique is used in:
17// - Image generation (DDPM, StyleGAN)
18// - Semi-supervised learning (Mean Teacher)
19// - Large language models (some fine-tuning recipes)
20//
21// DESIGN: The EMA stores copies of parameter data (as Vec<f64>) so it
22// doesn't interfere with training. Use `apply()` to write EMA weights
23// into the model parameters, and `restore()` to put training weights back.
24
25use shrew_core::backend::Backend;
26use shrew_core::error::Result;
27use shrew_core::tensor::Tensor;
28
29/// Exponential Moving Average of model parameters.
30///
31/// Maintains a shadow copy that is a smoothed version of training parameters.
32///
33/// # Example
34/// ```ignore
35/// let mut ema = EMA::new(model.parameters(), 0.999);
36///
37/// // Training loop:
38/// optimizer.step(&grads)?;
39/// ema.update(&model.parameters())?;
40///
41/// // Evaluation:
42/// ema.apply()?; // Write EMA weights into model
43/// let output = model.forward(input)?;
44/// ema.restore()?; // Restore training weights
45/// ```
46pub struct EMA<B: Backend> {
47 /// References to the model parameters (used for apply/restore)
48 params: Vec<Tensor<B>>,
49 /// Shadow parameters (EMA values)
50 shadow: Vec<Vec<f64>>,
51 /// Saved training parameters (for restore after apply)
52 backup: Vec<Vec<f64>>,
53 /// Decay rate (e.g., 0.999)
54 decay: f64,
55 /// Number of updates performed
56 num_updates: u64,
57}
58
59impl<B: Backend> EMA<B> {
60 /// Create a new EMA tracker.
61 ///
62 /// # Arguments
63 /// - `params`: The model parameters to track
64 /// - `decay`: Decay rate (typical: 0.999 or 0.9999)
65 pub fn new(params: Vec<Tensor<B>>, decay: f64) -> Result<Self> {
66 let shadow: Result<Vec<Vec<f64>>> = params.iter().map(|p| p.to_f64_vec()).collect();
67 let shadow = shadow?;
68
69 Ok(EMA {
70 params,
71 shadow,
72 backup: Vec::new(),
73 decay,
74 num_updates: 0,
75 })
76 }
77
78 /// Update the EMA shadow parameters with current model parameters.
79 ///
80 /// Call this after each optimizer step.
81 pub fn update(&mut self, current_params: &[Tensor<B>]) -> Result<()> {
82 self.num_updates += 1;
83
84 for (i, param) in current_params.iter().enumerate() {
85 let data = param.to_f64_vec()?;
86 for (s, &d) in self.shadow[i].iter_mut().zip(data.iter()) {
87 *s = self.decay * *s + (1.0 - self.decay) * d;
88 }
89 }
90
91 Ok(())
92 }
93
94 /// Update using an adjusted decay that ramps up during early training.
95 ///
96 /// The effective decay is: min(decay, (1 + num_updates) / (10 + num_updates))
97 /// This prevents the EMA from being too biased toward initial values.
98 pub fn update_with_warmup(&mut self, current_params: &[Tensor<B>]) -> Result<()> {
99 self.num_updates += 1;
100
101 let effective_decay = self
102 .decay
103 .min((1.0 + self.num_updates as f64) / (10.0 + self.num_updates as f64));
104
105 for (i, param) in current_params.iter().enumerate() {
106 let data = param.to_f64_vec()?;
107 for (s, &d) in self.shadow[i].iter_mut().zip(data.iter()) {
108 *s = effective_decay * *s + (1.0 - effective_decay) * d;
109 }
110 }
111
112 Ok(())
113 }
114
115 /// Apply EMA parameters to the model (for evaluation).
116 ///
117 /// This saves the current training parameters so they can be restored
118 /// with `restore()`.
119 pub fn apply(&mut self) -> Result<()> {
120 // Save current training weights
121 self.backup = Vec::with_capacity(self.params.len());
122 for param in &self.params {
123 self.backup.push(param.to_f64_vec()?);
124 }
125
126 // Write EMA weights into model parameters
127 for (param, shadow) in self.params.iter().zip(self.shadow.iter()) {
128 param.update_data_inplace(shadow)?;
129 }
130
131 Ok(())
132 }
133
134 /// Restore training parameters after `apply()`.
135 pub fn restore(&mut self) -> Result<()> {
136 if self.backup.is_empty() {
137 return Ok(());
138 }
139
140 for (param, backup) in self.params.iter().zip(self.backup.iter()) {
141 param.update_data_inplace(backup)?;
142 }
143
144 self.backup.clear();
145 Ok(())
146 }
147
148 /// Get the decay rate.
149 pub fn decay(&self) -> f64 {
150 self.decay
151 }
152
153 /// Get the number of updates performed.
154 pub fn num_updates(&self) -> u64 {
155 self.num_updates
156 }
157
158 /// Get the shadow (EMA) values for a specific parameter index.
159 pub fn shadow_values(&self, index: usize) -> &[f64] {
160 &self.shadow[index]
161 }
162}