pub struct EMA<B>where
B: Backend,{ /* private fields */ }Expand description
Exponential Moving Average of model parameters.
Maintains a shadow copy that is a smoothed version of training parameters.
§Example
ⓘ
let mut ema = EMA::new(model.parameters(), 0.999);
// Training loop:
optimizer.step(&grads)?;
ema.update(&model.parameters())?;
// Evaluation:
ema.apply()?; // Write EMA weights into model
let output = model.forward(input)?;
ema.restore()?; // Restore training weightsImplementations§
Source§impl<B> EMA<B>where
B: Backend,
impl<B> EMA<B>where
B: Backend,
Sourcepub fn new(params: Vec<Tensor<B>>, decay: f64) -> Result<EMA<B>, Error>
pub fn new(params: Vec<Tensor<B>>, decay: f64) -> Result<EMA<B>, Error>
Create a new EMA tracker.
§Arguments
params: The model parameters to trackdecay: Decay rate (typical: 0.999 or 0.9999)
Sourcepub fn update(&mut self, current_params: &[Tensor<B>]) -> Result<(), Error>
pub fn update(&mut self, current_params: &[Tensor<B>]) -> Result<(), Error>
Update the EMA shadow parameters with current model parameters.
Call this after each optimizer step.
Sourcepub fn update_with_warmup(
&mut self,
current_params: &[Tensor<B>],
) -> Result<(), Error>
pub fn update_with_warmup( &mut self, current_params: &[Tensor<B>], ) -> Result<(), Error>
Update using an adjusted decay that ramps up during early training.
The effective decay is: min(decay, (1 + num_updates) / (10 + num_updates)) This prevents the EMA from being too biased toward initial values.
Sourcepub fn apply(&mut self) -> Result<(), Error>
pub fn apply(&mut self) -> Result<(), Error>
Apply EMA parameters to the model (for evaluation).
This saves the current training parameters so they can be restored
with restore().
Sourcepub fn num_updates(&self) -> u64
pub fn num_updates(&self) -> u64
Get the number of updates performed.
Sourcepub fn shadow_values(&self, index: usize) -> &[f64]
pub fn shadow_values(&self, index: usize) -> &[f64]
Get the shadow (EMA) values for a specific parameter index.
Auto Trait Implementations§
impl<B> Freeze for EMA<B>
impl<B> RefUnwindSafe for EMA<B>
impl<B> Send for EMA<B>
impl<B> Sync for EMA<B>
impl<B> Unpin for EMA<B>
impl<B> UnwindSafe for EMA<B>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more