reduce_gradients

Function reduce_gradients 

Source
pub fn reduce_gradients<B: Backend>(
    grad_stores: &[GradStore<B>],
    params: &[Tensor<B>],
    strategy: AllReduceOp,
) -> Result<GradStore<B>>
Expand description

Average (or sum) multiple GradStores into a single GradStore.

This is the core AllReduce primitive. Each worker produces a GradStore from its backward pass; this function merges them.

ยงArguments

  • grad_stores: one GradStore per replica/worker
  • params: the shared parameter tensors (used to enumerate keys)
  • strategy: Sum or Average