load_tensors

Function load_tensors 

Source
pub fn load_tensors<B: Backend>(
    path: impl AsRef<Path>,
    device: &B::Device,
) -> Result<Vec<(String, Tensor<B>)>>
Expand description

Load named tensors from a file.

use shrew::checkpoint;
use shrew::prelude::*;

let tensors = checkpoint::load_tensors::<CpuBackend>("weights.shrew", &CpuDevice).unwrap();
for (name, tensor) in &tensors {
    println!("{name}: {:?}", tensor.dims());
}