benchmark_forward

Function benchmark_forward 

Source
pub fn benchmark_forward<B, M, F>(
    model: &M,
    input_fn: F,
    batch_size: usize,
    warmup: usize,
    iterations: usize,
) -> Result<BenchmarkResult>
where B: Backend, M: Module<B>, F: Fn() -> Tensor<B>,
Expand description

Benchmark a model’s forward pass.

Runs warmup untimed iterations, then iterations timed iterations, calling input_fn on each iteration to produce the input tensor.

§Example

use shrew::prelude::*;
use shrew::profiler::benchmark_forward;

let model = Linear::<CpuBackend>::new(16, 8, true, DType::F32, &CpuDevice).unwrap();
let result = benchmark_forward(
    &model,
    || Tensor::<CpuBackend>::rand((4, 16), DType::F32, &CpuDevice).unwrap(),
    4, // batch_size (for throughput calc)
    3, // warmup
    10, // iterations
).unwrap();
println!("{}", result);