pub struct GradAccumulator<B>where
B: Backend,{ /* private fields */ }Expand description
Gradient accumulation helper.
Accumulates gradients over multiple mini-batches and provides the averaged gradient for the optimizer step.
§Example
ⓘ
let mut accum = GradAccumulator::<CpuBackend>::new(4); // 4 accumulation steps
for (i, batch) in batches.iter().enumerate() {
let loss = model.forward(batch)?;
let grads = loss.backward()?;
if let Some(avg_grads) = accum.step(&grads, ¶ms)? {
optimizer.step(&avg_grads)?;
}
}Implementations§
Source§impl<B> GradAccumulator<B>where
B: Backend,
impl<B> GradAccumulator<B>where
B: Backend,
Sourcepub fn new(accum_steps: u64) -> GradAccumulator<B>
pub fn new(accum_steps: u64) -> GradAccumulator<B>
Create a new gradient accumulator.
§Arguments
accum_steps: Number of mini-batches to accumulate before stepping
Sourcepub fn step(
&mut self,
grads: &GradStore<B>,
params: &[Tensor<B>],
) -> Result<Option<GradStore<B>>, Error>
pub fn step( &mut self, grads: &GradStore<B>, params: &[Tensor<B>], ) -> Result<Option<GradStore<B>>, Error>
Add gradients from one mini-batch.
Returns Some(averaged_grads) when accum_steps batches have been
accumulated, otherwise returns None.
The returned gradients are divided by accum_steps to produce
the average gradient.
Sourcepub fn accum_steps(&self) -> u64
pub fn accum_steps(&self) -> u64
Get the number of accumulation steps.
Sourcepub fn current_step(&self) -> u64
pub fn current_step(&self) -> u64
Get the current step within the accumulation window.
Auto Trait Implementations§
impl<B> Freeze for GradAccumulator<B>
impl<B> RefUnwindSafe for GradAccumulator<B>
impl<B> Send for GradAccumulator<B>
impl<B> Sync for GradAccumulator<B>
impl<B> Unpin for GradAccumulator<B>
impl<B> UnwindSafe for GradAccumulator<B>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more