MixedPrecisionTrainer

Struct MixedPrecisionTrainer 

Source
pub struct MixedPrecisionTrainer<M, O, B: Backend> { /* private fields */ }
Expand description

Mixed-precision training: reduced-precision forward/backward with FP32 master weights.

Why mixed precision?

  • FP16/BF16 is 2× faster on GPUs with tensor cores (V100, A100, H100)
  • Half-precision uses half the memory for activations, enabling larger batches
  • FP32 master weights prevent precision loss during gradient updates

How it works:

  1. Inputs and targets are cast to compute_dtype (F16 or BF16) before forward
  2. The forward pass runs with reduced-precision activations
  3. Dynamic loss scaling prevents gradient underflow in half-precision:
    • Loss is multiplied by a scale factor before backward
    • Gradients are divided by the same factor after
    • If overflow (NaN/Inf) is detected, the step is skipped and scale reduces
  4. Gradients are cast back to FP32 and applied to FP32 master weights

Compute dtype options:

  • DType::F16: 16-bit IEEE float, range ±65504, good for most training
  • DType::BF16: bfloat16, same range as F32 but less precision, preferred when available
  • DType::F32: Standard precision (disables casting, only does loss scaling)

§Example

let model = Linear::<CpuBackend>::new(784, 10, true, DType::F32, &CpuDevice)?;
let optimizer = Adam::new(model.parameters(), 1e-3);
let mut trainer = MixedPrecisionTrainer::new(
    model, optimizer, DType::F16, Default::default(),
);

for (input, target) in data {
    let metrics = trainer.train_step(&input, &target, mse_loss)?;
    println!("loss={:.4}, scale={}", metrics.loss, metrics.loss_scale);
}

Implementations§

Source§

impl<M, O, B> MixedPrecisionTrainer<M, O, B>
where M: Module<B>, O: Optimizer<B>, B: Backend,

Source

pub fn new( model: M, optimizer: O, compute_dtype: DType, config: LossScaleConfig, ) -> Self

Create a new mixed-precision trainer.

The model and optimizer should use FP32 parameters. compute_dtype sets the precision for forward/backward (F16, BF16, or F32).

Source

pub fn model(&self) -> &M

Reference to the model.

Source

pub fn model_mut(&mut self) -> &mut M

Mutable reference to the model.

Source

pub fn optimizer(&self) -> &O

Reference to the optimizer.

Source

pub fn loss_scale(&self) -> f64

Current loss scale.

Source

pub fn compute_dtype(&self) -> DType

The compute dtype (F16, BF16, or F32).

Source

pub fn skipped_steps(&self) -> u64

Total number of skipped steps.

Source

pub fn train_step<F>( &mut self, input: &Tensor<B>, target: &Tensor<B>, loss_fn: F, ) -> Result<MixedPrecisionMetrics>
where F: Fn(&Tensor<B>, &Tensor<B>) -> Result<Tensor<B>>,

Perform one mixed-precision training step.

The input and target are cast to compute_dtype for the forward pass. Dynamic loss scaling is applied to prevent gradient underflow. Gradients are cast back to FP32 and applied to FP32 master weights.

§Arguments
  • input: input tensor (any dtype, will be cast to compute_dtype)
  • target: target tensor (any dtype, will be cast to compute_dtype)
  • loss_fn: function computing scalar loss from (prediction, target)
§Returns

MixedPrecisionMetrics with loss value and scaling info.

Auto Trait Implementations§

§

impl<M, O, B> Freeze for MixedPrecisionTrainer<M, O, B>
where M: Freeze, O: Freeze,

§

impl<M, O, B> RefUnwindSafe for MixedPrecisionTrainer<M, O, B>

§

impl<M, O, B> Send for MixedPrecisionTrainer<M, O, B>
where M: Send, O: Send,

§

impl<M, O, B> Sync for MixedPrecisionTrainer<M, O, B>
where M: Sync, O: Sync,

§

impl<M, O, B> Unpin for MixedPrecisionTrainer<M, O, B>
where M: Unpin, O: Unpin, B: Unpin,

§

impl<M, O, B> UnwindSafe for MixedPrecisionTrainer<M, O, B>
where M: UnwindSafe, O: UnwindSafe, B: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
§

impl<T> Pointable for T

§

const ALIGN: usize

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

§

fn vzip(self) -> V