save_tensors

Function save_tensors 

Source
pub fn save_tensors<B: Backend>(
    path: impl AsRef<Path>,
    tensors: &[(String, Tensor<B>)],
) -> Result<()>
Expand description

Save a list of named tensors to a file.

use shrew::checkpoint;
use shrew::prelude::*;

let w1 = Tensor::<CpuBackend>::zeros((2, 3), DType::F32, &CpuDevice).unwrap();
let b1 = Tensor::<CpuBackend>::zeros((2,), DType::F32, &CpuDevice).unwrap();
let tensors = vec![
    ("w1".to_string(), w1),
    ("b1".to_string(), b1),
];
checkpoint::save_tensors("weights.shrew", &tensors).unwrap();