pub struct MixedPrecisionTrainer<M, O, B: Backend> { /* private fields */ }Expand description
Mixed-precision training: reduced-precision forward/backward with FP32 master weights.
Why mixed precision?
- FP16/BF16 is 2× faster on GPUs with tensor cores (V100, A100, H100)
- Half-precision uses half the memory for activations, enabling larger batches
- FP32 master weights prevent precision loss during gradient updates
How it works:
- Inputs and targets are cast to
compute_dtype(F16 or BF16) before forward - The forward pass runs with reduced-precision activations
- Dynamic loss scaling prevents gradient underflow in half-precision:
- Loss is multiplied by a scale factor before backward
- Gradients are divided by the same factor after
- If overflow (NaN/Inf) is detected, the step is skipped and scale reduces
- Gradients are cast back to FP32 and applied to FP32 master weights
Compute dtype options:
DType::F16: 16-bit IEEE float, range ±65504, good for most trainingDType::BF16: bfloat16, same range as F32 but less precision, preferred when availableDType::F32: Standard precision (disables casting, only does loss scaling)
§Example
ⓘ
let model = Linear::<CpuBackend>::new(784, 10, true, DType::F32, &CpuDevice)?;
let optimizer = Adam::new(model.parameters(), 1e-3);
let mut trainer = MixedPrecisionTrainer::new(
model, optimizer, DType::F16, Default::default(),
);
for (input, target) in data {
let metrics = trainer.train_step(&input, &target, mse_loss)?;
println!("loss={:.4}, scale={}", metrics.loss, metrics.loss_scale);
}Implementations§
Source§impl<M, O, B> MixedPrecisionTrainer<M, O, B>
impl<M, O, B> MixedPrecisionTrainer<M, O, B>
Sourcepub fn new(
model: M,
optimizer: O,
compute_dtype: DType,
config: LossScaleConfig,
) -> Self
pub fn new( model: M, optimizer: O, compute_dtype: DType, config: LossScaleConfig, ) -> Self
Create a new mixed-precision trainer.
The model and optimizer should use FP32 parameters.
compute_dtype sets the precision for forward/backward (F16, BF16, or F32).
Sourcepub fn loss_scale(&self) -> f64
pub fn loss_scale(&self) -> f64
Current loss scale.
Sourcepub fn compute_dtype(&self) -> DType
pub fn compute_dtype(&self) -> DType
The compute dtype (F16, BF16, or F32).
Sourcepub fn skipped_steps(&self) -> u64
pub fn skipped_steps(&self) -> u64
Total number of skipped steps.
Sourcepub fn train_step<F>(
&mut self,
input: &Tensor<B>,
target: &Tensor<B>,
loss_fn: F,
) -> Result<MixedPrecisionMetrics>
pub fn train_step<F>( &mut self, input: &Tensor<B>, target: &Tensor<B>, loss_fn: F, ) -> Result<MixedPrecisionMetrics>
Perform one mixed-precision training step.
The input and target are cast to compute_dtype for the forward pass.
Dynamic loss scaling is applied to prevent gradient underflow.
Gradients are cast back to FP32 and applied to FP32 master weights.
§Arguments
input: input tensor (any dtype, will be cast to compute_dtype)target: target tensor (any dtype, will be cast to compute_dtype)loss_fn: function computing scalar loss from (prediction, target)
§Returns
MixedPrecisionMetrics with loss value and scaling info.
Auto Trait Implementations§
impl<M, O, B> Freeze for MixedPrecisionTrainer<M, O, B>
impl<M, O, B> RefUnwindSafe for MixedPrecisionTrainer<M, O, B>
impl<M, O, B> Send for MixedPrecisionTrainer<M, O, B>
impl<M, O, B> Sync for MixedPrecisionTrainer<M, O, B>
impl<M, O, B> Unpin for MixedPrecisionTrainer<M, O, B>
impl<M, O, B> UnwindSafe for MixedPrecisionTrainer<M, O, B>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more