load_training

Function load_training 

Source
pub fn load_training<B: Backend>(
    path: impl AsRef<Path>,
    device: &B::Device,
) -> Result<TrainingCheckpoint<B>>
Expand description

Load a complete training checkpoint from a file.

use shrew::checkpoint;
use shrew::prelude::*;

let ckpt = checkpoint::load_training::<CpuBackend>("training.shrew", &CpuDevice).unwrap();
println!("Resuming from epoch {}, step {}", ckpt.epoch, ckpt.global_step);
println!("Best loss so far: {}", ckpt.best_loss);

// Restore model params into executor...
// Restore optimizer state with optimizer.load_state_dict(ckpt.optimizer_state)...