pub struct DataParallel<M> {
pub module: M,
pub num_workers: usize,
}Expand description
Wraps a Module and splits each input batch across num_workers threads.
The forward pass:
- Split input along dimension 0 into
num_workerschunks - Run each chunk through the module in parallel (rayon)
- Concatenate the outputs
Because all workers share the same parameters (Tensor uses Arc), the
autograd graph correctly tracks all operations. After computing loss
on the concatenated output and calling .backward(), the gradients
are automatically accumulated across all chunks.
§Example
ⓘ
let model = Sequential::new(vec![...]);
let dp = DataParallel::new(model, 4); // 4 workers
let output = dp.forward(&big_batch)?; // splits into 4 chunksFields§
§module: MThe underlying module (shared across workers).
num_workers: usizeNumber of parallel workers.
Implementations§
Source§impl<M> DataParallel<M>
impl<M> DataParallel<M>
Sourcepub fn new(module: M, num_workers: usize) -> Self
pub fn new(module: M, num_workers: usize) -> Self
Create a DataParallel wrapper with the given number of workers.
num_workers controls how many chunks the batch is split into.
For CPU, this maps to rayon thread-pool parallelism.
Sourcepub fn into_inner(self) -> M
pub fn into_inner(self) -> M
Unwrap the DataParallel, returning the inner module.
Trait Implementations§
Source§impl<M: Clone> Clone for DataParallel<M>
impl<M: Clone> Clone for DataParallel<M>
Source§impl<M: Debug> Debug for DataParallel<M>
impl<M: Debug> Debug for DataParallel<M>
Source§impl<M, B> Module<B> for DataParallel<M>
impl<M, B> Module<B> for DataParallel<M>
Source§fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>>
fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>>
Compute the output tensor from the input tensor.
This defines the layer’s computation (forward pass).
Source§fn parameters(&self) -> Vec<Tensor<B>>
fn parameters(&self) -> Vec<Tensor<B>>
Return all trainable parameters of this module.
The optimizer uses these to update weights during training.
Source§fn named_parameters(&self) -> Vec<(String, Tensor<B>)>
fn named_parameters(&self) -> Vec<(String, Tensor<B>)>
Return all trainable parameters with human-readable names. Read more
Source§fn set_training(&self, training: bool)
fn set_training(&self, training: bool)
Set training or evaluation mode. Read more
Source§fn is_training(&self) -> bool
fn is_training(&self) -> bool
Whether the module is in training mode (default: true).
Source§fn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Total number of scalar parameters in this module.
Source§fn trainable_params_count(&self) -> usize
fn trainable_params_count(&self) -> usize
Number of trainable (variable) parameters.
Auto Trait Implementations§
impl<M> Freeze for DataParallel<M>where
M: Freeze,
impl<M> RefUnwindSafe for DataParallel<M>where
M: RefUnwindSafe,
impl<M> Send for DataParallel<M>where
M: Send,
impl<M> Sync for DataParallel<M>where
M: Sync,
impl<M> Unpin for DataParallel<M>where
M: Unpin,
impl<M> UnwindSafe for DataParallel<M>where
M: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more