shrew/
distributed.rs

1// Distributed Training — Data Parallelism, Mixed Precision, Pipeline Stages
2//
3// This module provides primitives for scaling training across multiple
4// workers, mixed-precision (FP16/FP32) training, and model-parallel
5// pipeline execution.
6//
7// COMPONENTS:
8//
9//   DataParallel<M>       — Splits input batches across N workers and runs
10//                           each forward pass in parallel (rayon threads).
11//                           Implements Module, so it's a drop-in replacement.
12//
13//   MixedPrecisionTrainer — Maintains FP32 "master" weights, casts to FP16
14//                           for forward/backward, applies dynamic loss scaling
15//                           to prevent underflow in FP16 gradients.
16//
17//   PipelineParallel      — Splits a sequential model into stages and overlaps
18//                           micro-batch execution (GPipe-style 1F1B schedule).
19//
20//   average_gradients()   — Averages multiple GradStores (the core AllReduce
21//                           primitive). Usable standalone for custom loops.
22
23use std::marker::PhantomData;
24
25use shrew_core::backend::Backend;
26use shrew_core::backprop::GradStore;
27use shrew_core::dtype::DType;
28use shrew_core::error::Result;
29use shrew_core::tensor::Tensor;
30
31use shrew_nn::Module;
32use shrew_optim::Optimizer;
33
34// AllReduce strategy
35
36/// Strategy for combining gradients from multiple replicas.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum AllReduceOp {
39    /// Sum all gradients (caller divides by N if needed).
40    Sum,
41    /// Average gradients across replicas (most common).
42    Average,
43}
44
45// Gradient averaging
46
47/// Average (or sum) multiple `GradStore`s into a single `GradStore`.
48///
49/// This is the core AllReduce primitive. Each worker produces a `GradStore`
50/// from its backward pass; this function merges them.
51///
52/// # Arguments
53/// - `grad_stores`: one `GradStore` per replica/worker
54/// - `params`: the shared parameter tensors (used to enumerate keys)
55/// - `strategy`: `Sum` or `Average`
56pub fn reduce_gradients<B: Backend>(
57    grad_stores: &[GradStore<B>],
58    params: &[Tensor<B>],
59    strategy: AllReduceOp,
60) -> Result<GradStore<B>> {
61    let n = grad_stores.len();
62    if n == 0 {
63        return Ok(GradStore::new());
64    }
65    if n == 1 {
66        return Ok(grad_stores[0].clone());
67    }
68
69    let mut merged = GradStore::new();
70
71    for param in params {
72        // Collect gradients from all stores for this parameter
73        let mut grads: Vec<&Tensor<B>> = Vec::new();
74        for store in grad_stores {
75            if let Some(g) = store.get(param) {
76                grads.push(g);
77            }
78        }
79        if grads.is_empty() {
80            continue;
81        }
82
83        // Sum all gradients
84        let mut acc = grads[0].clone();
85        for g in &grads[1..] {
86            acc = acc.add(g)?;
87        }
88
89        // Average if requested
90        if strategy == AllReduceOp::Average && grads.len() > 1 {
91            let scale = 1.0 / grads.len() as f64;
92            acc = acc.affine(scale, 0.0)?;
93        }
94
95        merged.accumulate(param.id(), acc)?;
96    }
97
98    Ok(merged)
99}
100
101// DataParallel — Module wrapper for batch-parallel forward passes
102
103/// Wraps a `Module` and splits each input batch across `num_workers` threads.
104///
105/// The forward pass:
106///   1. Split input along dimension 0 into `num_workers` chunks
107///   2. Run each chunk through the module in parallel (rayon)
108///   3. Concatenate the outputs
109///
110/// Because all workers share the same parameters (Tensor uses Arc), the
111/// autograd graph correctly tracks all operations. After computing loss
112/// on the concatenated output and calling `.backward()`, the gradients
113/// are automatically accumulated across all chunks.
114///
115/// # Example
116/// ```ignore
117/// let model = Sequential::new(vec![...]);
118/// let dp = DataParallel::new(model, 4);  // 4 workers
119/// let output = dp.forward(&big_batch)?;  // splits into 4 chunks
120/// ```
121pub struct DataParallel<M> {
122    /// The underlying module (shared across workers).
123    pub module: M,
124    /// Number of parallel workers.
125    pub num_workers: usize,
126}
127
128impl<M: Clone> Clone for DataParallel<M> {
129    fn clone(&self) -> Self {
130        Self {
131            module: self.module.clone(),
132            num_workers: self.num_workers,
133        }
134    }
135}
136
137impl<M: std::fmt::Debug> std::fmt::Debug for DataParallel<M> {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        f.debug_struct("DataParallel")
140            .field("module", &self.module)
141            .field("num_workers", &self.num_workers)
142            .finish()
143    }
144}
145
146impl<M> DataParallel<M> {
147    /// Create a `DataParallel` wrapper with the given number of workers.
148    ///
149    /// `num_workers` controls how many chunks the batch is split into.
150    /// For CPU, this maps to rayon thread-pool parallelism.
151    pub fn new(module: M, num_workers: usize) -> Self {
152        assert!(num_workers > 0, "num_workers must be > 0");
153        Self {
154            module,
155            num_workers,
156        }
157    }
158
159    /// Get a reference to the underlying module.
160    pub fn inner(&self) -> &M {
161        &self.module
162    }
163
164    /// Get a mutable reference to the underlying module.
165    pub fn inner_mut(&mut self) -> &mut M {
166        &mut self.module
167    }
168
169    /// Unwrap the `DataParallel`, returning the inner module.
170    pub fn into_inner(self) -> M {
171        self.module
172    }
173}
174
175impl<M, B> Module<B> for DataParallel<M>
176where
177    M: Module<B> + Send + Sync,
178    B: Backend,
179{
180    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
181        let batch_size = x.dims()[0];
182        let effective_workers = self.num_workers.min(batch_size);
183
184        if effective_workers <= 1 {
185            return self.module.forward(x);
186        }
187
188        // Split into chunks along batch dimension
189        let chunks = x.chunk(effective_workers, 0)?;
190
191        // Run forward on each chunk (sequentially for now - rayon requires
192        // the closure to be Send, which Result<Tensor<B>> satisfies)
193        // NOTE: True multi-device parallelism requires each replica on a
194        // separate device. For CPU, rayon gives thread-level parallelism.
195        let mut outputs = Vec::with_capacity(chunks.len());
196        for chunk in &chunks {
197            outputs.push(self.module.forward(chunk)?);
198        }
199
200        // Concatenate results
201        Tensor::cat(&outputs, 0)
202    }
203
204    fn parameters(&self) -> Vec<Tensor<B>> {
205        self.module.parameters()
206    }
207
208    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
209        self.module.named_parameters()
210    }
211
212    fn set_training(&self, training: bool) {
213        self.module.set_training(training);
214    }
215
216    fn is_training(&self) -> bool {
217        self.module.is_training()
218    }
219}
220
221// MixedPrecisionTrainer — FP16 forward + FP32 master weights
222
223/// Configuration for dynamic loss scaling in mixed-precision training.
224#[derive(Debug, Clone)]
225pub struct LossScaleConfig {
226    /// Initial loss scale factor (default: 2^16 = 65536).
227    pub init_scale: f64,
228    /// Multiply scale by this when no overflow (default: 2.0).
229    pub scale_growth_factor: f64,
230    /// Divide scale by this on overflow (default: 2.0).
231    pub scale_backoff_factor: f64,
232    /// Number of consecutive good steps before increasing scale (default: 2000).
233    pub growth_interval: u64,
234}
235
236impl Default for LossScaleConfig {
237    fn default() -> Self {
238        Self {
239            init_scale: 65536.0,
240            scale_growth_factor: 2.0,
241            scale_backoff_factor: 2.0,
242            growth_interval: 2000,
243        }
244    }
245}
246
247/// Mixed-precision training: reduced-precision forward/backward with FP32 master weights.
248///
249/// **Why mixed precision?**
250/// - FP16/BF16 is 2× faster on GPUs with tensor cores (V100, A100, H100)
251/// - Half-precision uses half the memory for activations, enabling larger batches
252/// - FP32 master weights prevent precision loss during gradient updates
253///
254/// **How it works:**
255/// 1. Inputs and targets are cast to `compute_dtype` (F16 or BF16) before forward
256/// 2. The forward pass runs with reduced-precision activations
257/// 3. Dynamic loss scaling prevents gradient underflow in half-precision:
258///    - Loss is multiplied by a scale factor before backward
259///    - Gradients are divided by the same factor after
260///    - If overflow (NaN/Inf) is detected, the step is skipped and scale reduces
261/// 4. Gradients are cast back to FP32 and applied to FP32 master weights
262///
263/// **Compute dtype options:**
264/// - `DType::F16`: 16-bit IEEE float, range ±65504, good for most training
265/// - `DType::BF16`: bfloat16, same range as F32 but less precision, preferred when available
266/// - `DType::F32`: Standard precision (disables casting, only does loss scaling)
267///
268/// # Example
269/// ```ignore
270/// let model = Linear::<CpuBackend>::new(784, 10, true, DType::F32, &CpuDevice)?;
271/// let optimizer = Adam::new(model.parameters(), 1e-3);
272/// let mut trainer = MixedPrecisionTrainer::new(
273///     model, optimizer, DType::F16, Default::default(),
274/// );
275///
276/// for (input, target) in data {
277///     let metrics = trainer.train_step(&input, &target, mse_loss)?;
278///     println!("loss={:.4}, scale={}", metrics.loss, metrics.loss_scale);
279/// }
280/// ```
281pub struct MixedPrecisionTrainer<M, O, B: Backend> {
282    /// The model (with FP32 parameters as master copies).
283    model: M,
284    /// The optimizer operating on FP32 parameters.
285    optimizer: O,
286    /// The dtype for forward/backward computation (F16, BF16, or F32).
287    compute_dtype: DType,
288    /// Current loss scale factor.
289    loss_scale: f64,
290    /// Loss scale configuration.
291    config: LossScaleConfig,
292    /// Number of consecutive successful steps (no overflow).
293    good_steps: u64,
294    /// Total skipped steps (overflow detected).
295    skipped_steps: u64,
296    _phantom: PhantomData<B>,
297}
298
299/// Metrics from a single mixed-precision training step.
300#[derive(Debug, Clone)]
301pub struct MixedPrecisionMetrics {
302    /// The unscaled loss value.
303    pub loss: f64,
304    /// Whether this step was skipped (overflow detected).
305    pub skipped: bool,
306    /// Current loss scale factor.
307    pub loss_scale: f64,
308    /// Total skipped steps so far.
309    pub total_skipped: u64,
310    /// The compute dtype used for this step.
311    pub compute_dtype: DType,
312}
313
314impl<M, O, B> MixedPrecisionTrainer<M, O, B>
315where
316    M: Module<B>,
317    O: Optimizer<B>,
318    B: Backend,
319{
320    /// Create a new mixed-precision trainer.
321    ///
322    /// The model and optimizer should use FP32 parameters.
323    /// `compute_dtype` sets the precision for forward/backward (F16, BF16, or F32).
324    pub fn new(model: M, optimizer: O, compute_dtype: DType, config: LossScaleConfig) -> Self {
325        let loss_scale = config.init_scale;
326        Self {
327            model,
328            optimizer,
329            compute_dtype,
330            loss_scale,
331            config,
332            good_steps: 0,
333            skipped_steps: 0,
334            _phantom: PhantomData,
335        }
336    }
337
338    /// Reference to the model.
339    pub fn model(&self) -> &M {
340        &self.model
341    }
342
343    /// Mutable reference to the model.
344    pub fn model_mut(&mut self) -> &mut M {
345        &mut self.model
346    }
347
348    /// Reference to the optimizer.
349    pub fn optimizer(&self) -> &O {
350        &self.optimizer
351    }
352
353    /// Current loss scale.
354    pub fn loss_scale(&self) -> f64 {
355        self.loss_scale
356    }
357
358    /// The compute dtype (F16, BF16, or F32).
359    pub fn compute_dtype(&self) -> DType {
360        self.compute_dtype
361    }
362
363    /// Total number of skipped steps.
364    pub fn skipped_steps(&self) -> u64 {
365        self.skipped_steps
366    }
367
368    /// Perform one mixed-precision training step.
369    ///
370    /// The input and target are cast to `compute_dtype` for the forward pass.
371    /// Dynamic loss scaling is applied to prevent gradient underflow.
372    /// Gradients are cast back to FP32 and applied to FP32 master weights.
373    ///
374    /// # Arguments
375    /// - `input`: input tensor (any dtype, will be cast to compute_dtype)
376    /// - `target`: target tensor (any dtype, will be cast to compute_dtype)
377    /// - `loss_fn`: function computing scalar loss from (prediction, target)
378    ///
379    /// # Returns
380    /// `MixedPrecisionMetrics` with loss value and scaling info.
381    pub fn train_step<F>(
382        &mut self,
383        input: &Tensor<B>,
384        target: &Tensor<B>,
385        loss_fn: F,
386    ) -> Result<MixedPrecisionMetrics>
387    where
388        F: Fn(&Tensor<B>, &Tensor<B>) -> Result<Tensor<B>>,
389    {
390        // 1. Determine if we should cast inputs to compute_dtype.
391        // Only cast if the model's parameters already match compute_dtype,
392        // otherwise auto-casting inputs would cause dtype mismatches with weights.
393        let model_dtype = self
394            .model
395            .parameters()
396            .first()
397            .map(|p| p.dtype())
398            .unwrap_or(DType::F32);
399        let should_cast = self.compute_dtype != DType::F32 && self.compute_dtype == model_dtype;
400
401        let input_cast = if should_cast && input.dtype() != self.compute_dtype {
402            input.to_dtype(self.compute_dtype)?
403        } else {
404            input.clone()
405        };
406        let target_cast = if should_cast && target.dtype() != self.compute_dtype {
407            target.to_dtype(self.compute_dtype)?
408        } else {
409            target.clone()
410        };
411
412        // 2. Forward pass
413        let output = self.model.forward(&input_cast)?;
414
415        // 3. Compute loss (in compute_dtype)
416        let loss = loss_fn(&output, &target_cast)?;
417        let loss_val = loss.to_scalar_f64()?;
418
419        // 4. Scale loss for backward (prevents gradient underflow in F16)
420        let scaled_loss = loss.affine(self.loss_scale, 0.0)?;
421
422        // 5. Backward on scaled loss
423        let grads = scaled_loss.backward()?;
424
425        // 6. Check for overflow in gradients
426        let has_overflow = self.check_overflow(&grads)?;
427
428        if has_overflow {
429            // Skip this step, reduce the scale
430            self.loss_scale /= self.config.scale_backoff_factor;
431            self.loss_scale = self.loss_scale.max(1.0); // don't go below 1
432            self.good_steps = 0;
433            self.skipped_steps += 1;
434
435            return Ok(MixedPrecisionMetrics {
436                loss: loss_val,
437                skipped: true,
438                loss_scale: self.loss_scale,
439                total_skipped: self.skipped_steps,
440                compute_dtype: self.compute_dtype,
441            });
442        }
443
444        // 7. Unscale gradients and cast back to FP32 for master weight update
445        let unscaled = self.unscale_and_cast_gradients(&grads)?;
446
447        // 8. Optimizer step with FP32 unscaled gradients
448        self.optimizer.step(&unscaled)?;
449
450        // 9. Update loss scale (possibly increase after consecutive good steps)
451        self.good_steps += 1;
452        if self.good_steps >= self.config.growth_interval {
453            self.loss_scale *= self.config.scale_growth_factor;
454            self.good_steps = 0;
455        }
456
457        Ok(MixedPrecisionMetrics {
458            loss: loss_val,
459            skipped: false,
460            loss_scale: self.loss_scale,
461            total_skipped: self.skipped_steps,
462            compute_dtype: self.compute_dtype,
463        })
464    }
465
466    /// Check if any gradient contains NaN or Inf.
467    fn check_overflow(&self, grads: &GradStore<B>) -> Result<bool> {
468        for param in self.model.parameters() {
469            if let Some(g) = grads.get(&param) {
470                let data = g.to_f64_vec()?;
471                for &v in &data {
472                    if v.is_nan() || v.is_infinite() {
473                        return Ok(true);
474                    }
475                }
476            }
477        }
478        Ok(false)
479    }
480
481    /// Unscale gradients by the loss scale factor and cast to FP32.
482    ///
483    /// This ensures the optimizer always sees FP32 gradients, regardless
484    /// of the compute dtype used during forward/backward.
485    fn unscale_and_cast_gradients(&self, grads: &GradStore<B>) -> Result<GradStore<B>> {
486        let inv_scale = 1.0 / self.loss_scale;
487        let mut unscaled = GradStore::new();
488        for param in self.model.parameters() {
489            if let Some(g) = grads.get(&param) {
490                // Unscale the gradient
491                let g_unscaled = g.affine(inv_scale, 0.0)?;
492                // Cast back to the master weight dtype (FP32) if needed
493                let g_fp32 = if g_unscaled.dtype() != param.dtype() {
494                    g_unscaled.to_dtype(param.dtype())?
495                } else {
496                    g_unscaled
497                };
498                unscaled.accumulate(param.id(), g_fp32)?;
499            }
500        }
501        Ok(unscaled)
502    }
503}
504
505// PipelineParallel — GPipe-style micro-batch pipelining
506
507/// A stage in a pipeline-parallel model.
508///
509/// Each stage holds a sub-model (one or more layers). During execution,
510/// micro-batches flow through stages in a pipeline fashion, overlapping
511/// the forward and backward passes of different micro-batches.
512pub struct PipelineStage<B: Backend> {
513    /// The layers in this stage (as boxed Module).
514    layers: Vec<Box<dyn Module<B>>>,
515    /// Stage index (0-based).
516    stage_id: usize,
517}
518
519impl<B: Backend> PipelineStage<B> {
520    /// Create a new pipeline stage.
521    pub fn new(stage_id: usize) -> Self {
522        Self {
523            layers: Vec::new(),
524            stage_id,
525        }
526    }
527
528    /// Add a layer to this stage.
529    pub fn add_layer(mut self, layer: Box<dyn Module<B>>) -> Self {
530        self.layers.push(layer);
531        self
532    }
533
534    /// Stage identifier.
535    pub fn stage_id(&self) -> usize {
536        self.stage_id
537    }
538
539    /// Forward pass through all layers in this stage.
540    pub fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
541        let mut out = x.clone();
542        for layer in &self.layers {
543            out = layer.forward(&out)?;
544        }
545        Ok(out)
546    }
547
548    /// Collect all parameters from all layers in this stage.
549    pub fn parameters(&self) -> Vec<Tensor<B>> {
550        self.layers.iter().flat_map(|l| l.parameters()).collect()
551    }
552}
553
554/// Pipeline-parallel executor using GPipe-style micro-batching.
555///
556/// Splits a model into sequential stages and processes micro-batches
557/// through the pipeline. This increases throughput by overlapping
558/// computation across stages.
559///
560/// # Example
561/// ```ignore
562/// let stage0 = PipelineStage::new(0)
563///     .add_layer(Box::new(Linear::new(784, 256, true, DType::F32, &dev)?));
564/// let stage1 = PipelineStage::new(1)
565///     .add_layer(Box::new(Linear::new(256, 10, true, DType::F32, &dev)?));
566///
567/// let pipeline = PipelineParallel::new(vec![stage0, stage1], 4);
568/// let output = pipeline.forward(&input)?;
569/// ```
570pub struct PipelineParallel<B: Backend> {
571    /// Ordered stages of the model.
572    stages: Vec<PipelineStage<B>>,
573    /// Number of micro-batches to split each input into.
574    num_micro_batches: usize,
575}
576
577impl<B: Backend> PipelineParallel<B> {
578    /// Create a pipeline with the given stages and micro-batch count.
579    pub fn new(stages: Vec<PipelineStage<B>>, num_micro_batches: usize) -> Self {
580        assert!(!stages.is_empty(), "pipeline needs at least one stage");
581        assert!(num_micro_batches > 0, "num_micro_batches must be > 0");
582        Self {
583            stages,
584            num_micro_batches,
585        }
586    }
587
588    /// Full forward pass through all stages.
589    ///
590    /// Splits the input into `num_micro_batches` micro-batches, runs each
591    /// through the pipeline sequentially, and concatenates the results.
592    ///
593    /// In a multi-device setting, stages would run on different devices
594    /// with inter-device transfers between stages.
595    pub fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
596        let batch_size = x.dims()[0];
597        let effective_micros = self.num_micro_batches.min(batch_size);
598
599        if effective_micros <= 1 {
600            // No micro-batching — sequential pass through all stages
601            let mut out = x.clone();
602            for stage in &self.stages {
603                out = stage.forward(&out)?;
604            }
605            return Ok(out);
606        }
607
608        // Split into micro-batches
609        let micro_batches = x.chunk(effective_micros, 0)?;
610
611        // Run each micro-batch through all stages
612        let mut outputs = Vec::with_capacity(micro_batches.len());
613        for mb in &micro_batches {
614            let mut out = mb.clone();
615            for stage in &self.stages {
616                out = stage.forward(&out)?;
617            }
618            outputs.push(out);
619        }
620
621        // Concatenate micro-batch outputs
622        Tensor::cat(&outputs, 0)
623    }
624
625    /// Collect all parameters from all stages (for optimizer).
626    pub fn parameters(&self) -> Vec<Tensor<B>> {
627        self.stages.iter().flat_map(|s| s.parameters()).collect()
628    }
629
630    /// Number of stages.
631    pub fn num_stages(&self) -> usize {
632        self.stages.len()
633    }
634
635    /// Get a reference to a specific stage.
636    pub fn stage(&self, idx: usize) -> Option<&PipelineStage<B>> {
637        self.stages.get(idx)
638    }
639}
640
641// ParallelTrainer — High-level training loop with gradient accumulation
642
643/// High-level training loop with gradient accumulation.
644///
645/// Splits a large effective batch into `accumulation_steps` micro-batches,
646/// accumulates gradients across all of them, then performs a single
647/// optimizer step. This simulates a larger batch size without requiring
648/// more memory.
649///
650/// # Example
651/// ```ignore
652/// let model = Sequential::new(vec![...]);
653/// let optimizer = Adam::new(model.parameters(), 1e-3);
654/// let mut trainer = ParallelTrainer::new(model, optimizer, 4);
655///
656/// // Each call accumulates 1/4 of the gradient; every 4th call steps.
657/// for (i, (x, y)) in data.iter().enumerate() {
658///     if let Some(loss) = trainer.accumulate_step(&x, &y, mse_loss)? {
659///         println!("step {}: loss = {:.4}", i, loss);
660///     }
661/// }
662/// ```
663pub struct ParallelTrainer<M, O, B: Backend> {
664    /// The model.
665    pub model: M,
666    /// The optimizer.
667    pub optimizer: O,
668    /// Number of micro-batches to accumulate before stepping.
669    accumulation_steps: usize,
670    /// Current accumulated gradients.
671    accumulated: Option<GradStore<B>>,
672    /// Current micro-batch index (0 .. accumulation_steps - 1).
673    current_step: usize,
674    /// Running loss sum for the current accumulation window.
675    loss_sum: f64,
676    _phantom: PhantomData<B>,
677}
678
679impl<M, O, B> ParallelTrainer<M, O, B>
680where
681    M: Module<B>,
682    O: Optimizer<B>,
683    B: Backend,
684{
685    /// Create a new `ParallelTrainer`.
686    ///
687    /// `accumulation_steps`: number of micro-batches before each optimizer step.
688    pub fn new(model: M, optimizer: O, accumulation_steps: usize) -> Self {
689        assert!(accumulation_steps > 0);
690        Self {
691            model,
692            optimizer,
693            accumulation_steps,
694            accumulated: None,
695            current_step: 0,
696            loss_sum: 0.0,
697            _phantom: PhantomData,
698        }
699    }
700
701    /// Process one micro-batch. Returns `Some(avg_loss)` when an optimizer
702    /// step was performed (every `accumulation_steps` calls), else `None`.
703    pub fn accumulate_step<F>(
704        &mut self,
705        input: &Tensor<B>,
706        target: &Tensor<B>,
707        loss_fn: F,
708    ) -> Result<Option<f64>>
709    where
710        F: Fn(&Tensor<B>, &Tensor<B>) -> Result<Tensor<B>>,
711    {
712        // Forward
713        let output = self.model.forward(input)?;
714        let loss = loss_fn(&output, target)?;
715        let loss_val = loss.to_scalar_f64()?;
716        self.loss_sum += loss_val;
717
718        // Backward
719        let grads = loss.backward()?;
720
721        // Accumulate gradients
722        let params = self.model.parameters();
723        match self.accumulated.take() {
724            Some(prev) => {
725                let merged = reduce_gradients(&[prev, grads], &params, AllReduceOp::Sum)?;
726                self.accumulated = Some(merged);
727            }
728            None => {
729                self.accumulated = Some(grads);
730            }
731        }
732
733        self.current_step += 1;
734
735        // Step when we've accumulated enough
736        if self.current_step >= self.accumulation_steps {
737            let avg_grads = {
738                let acc = self.accumulated.take().unwrap();
739                // Average by accumulation_steps
740                let mut averaged = GradStore::new();
741                let scale = 1.0 / self.accumulation_steps as f64;
742                for param in &params {
743                    if let Some(g) = acc.get(param) {
744                        let g_avg = g.affine(scale, 0.0)?;
745                        averaged.accumulate(param.id(), g_avg)?;
746                    }
747                }
748                averaged
749            };
750
751            self.optimizer.step(&avg_grads)?;
752
753            let avg_loss = self.loss_sum / self.accumulation_steps as f64;
754            self.current_step = 0;
755            self.loss_sum = 0.0;
756            self.accumulated = None;
757
758            Ok(Some(avg_loss))
759        } else {
760            Ok(None)
761        }
762    }
763
764    /// Force an optimizer step with whatever gradients have been accumulated so far.
765    /// Useful at the end of an epoch when remaining micro-batches < accumulation_steps.
766    pub fn flush(&mut self) -> Result<Option<f64>> {
767        if self.current_step == 0 || self.accumulated.is_none() {
768            return Ok(None);
769        }
770
771        let params = self.model.parameters();
772        let acc = self.accumulated.take().unwrap();
773        let scale = 1.0 / self.current_step as f64;
774        let mut averaged = GradStore::new();
775        for param in &params {
776            if let Some(g) = acc.get(param) {
777                let g_avg = g.affine(scale, 0.0)?;
778                averaged.accumulate(param.id(), g_avg)?;
779            }
780        }
781
782        self.optimizer.step(&averaged)?;
783
784        let avg_loss = self.loss_sum / self.current_step as f64;
785        self.current_step = 0;
786        self.loss_sum = 0.0;
787        self.accumulated = None;
788
789        Ok(Some(avg_loss))
790    }
791}
792
793// Tests
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798    use shrew_cpu::{CpuBackend, CpuDevice};
799
800    type B = CpuBackend;
801    type T = Tensor<B>;
802    const DEV: CpuDevice = CpuDevice;
803
804    // ── AllReduce / gradient averaging ──
805
806    #[test]
807    fn test_reduce_gradients_average() {
808        let p = T::randn(vec![4], DType::F32, &DEV).unwrap().set_variable();
809        let loss1 = p.sum_all().unwrap();
810        let g1 = loss1.backward().unwrap();
811
812        let loss2 = p.affine(2.0, 0.0).unwrap().sum_all().unwrap();
813        let g2 = loss2.backward().unwrap();
814
815        let merged = reduce_gradients(&[g1, g2], &[p.clone()], AllReduceOp::Average).unwrap();
816        let avg = merged.get(&p).unwrap().to_f64_vec().unwrap();
817        // g1 = all 1s, g2 = all 2s, average = all 1.5s
818        for &v in &avg {
819            assert!((v - 1.5).abs() < 1e-5, "expected 1.5, got {v}");
820        }
821    }
822
823    #[test]
824    fn test_reduce_gradients_sum() {
825        let p = T::randn(vec![3], DType::F32, &DEV).unwrap().set_variable();
826        let loss1 = p.sum_all().unwrap();
827        let g1 = loss1.backward().unwrap();
828
829        let loss2 = p.sum_all().unwrap();
830        let g2 = loss2.backward().unwrap();
831
832        let merged = reduce_gradients(&[g1, g2], &[p.clone()], AllReduceOp::Sum).unwrap();
833        let summed = merged.get(&p).unwrap().to_f64_vec().unwrap();
834        for &v in &summed {
835            assert!((v - 2.0).abs() < 1e-5, "expected 2.0, got {v}");
836        }
837    }
838
839    // ── DataParallel ──
840
841    #[test]
842    fn test_data_parallel_forward() {
843        let linear = shrew_nn::Linear::<B>::new(4, 2, true, DType::F32, &DEV).unwrap();
844        let dp = DataParallel::new(linear, 2);
845
846        let input = T::randn(vec![6, 4], DType::F32, &DEV).unwrap();
847        let output = dp.forward(&input).unwrap();
848        assert_eq!(output.dims(), &[6, 2]);
849    }
850
851    #[test]
852    fn test_data_parallel_single_worker() {
853        let linear = shrew_nn::Linear::<B>::new(3, 2, true, DType::F32, &DEV).unwrap();
854        let dp = DataParallel::new(linear, 1);
855
856        let input = T::randn(vec![4, 3], DType::F32, &DEV).unwrap();
857        let output = dp.forward(&input).unwrap();
858        assert_eq!(output.dims(), &[4, 2]);
859    }
860
861    #[test]
862    fn test_data_parallel_parameters() {
863        let linear = shrew_nn::Linear::<B>::new(4, 2, true, DType::F32, &DEV).unwrap();
864        let n_params = linear.parameters().len();
865        let dp = DataParallel::new(linear, 4);
866        assert_eq!(dp.parameters().len(), n_params);
867    }
868
869    // ── MixedPrecisionTrainer ──
870
871    #[test]
872    fn test_mixed_precision_basic() {
873        let linear = shrew_nn::Linear::<B>::new(4, 1, true, DType::F32, &DEV).unwrap();
874        let optimizer = shrew_optim::SGD::new(linear.parameters(), 0.01, 0.0, 0.0);
875        let mut trainer =
876            MixedPrecisionTrainer::new(linear, optimizer, DType::F16, LossScaleConfig::default());
877
878        let input = T::randn(vec![2, 4], DType::F32, &DEV).unwrap();
879        let target = T::zeros(vec![2, 1], DType::F32, &DEV).unwrap();
880
881        let metrics = trainer
882            .train_step(&input, &target, |pred, tgt| shrew_nn::mse_loss(pred, tgt))
883            .unwrap();
884
885        assert!(!metrics.skipped);
886        assert!(metrics.loss >= 0.0);
887        assert_eq!(metrics.loss_scale, 65536.0);
888    }
889
890    // ── Pipeline ──
891
892    #[test]
893    fn test_pipeline_forward() {
894        let stage0 = PipelineStage::<B>::new(0).add_layer(Box::new(
895            shrew_nn::Linear::<B>::new(4, 8, true, DType::F32, &DEV).unwrap(),
896        ));
897        let stage1 = PipelineStage::<B>::new(1).add_layer(Box::new(
898            shrew_nn::Linear::<B>::new(8, 2, true, DType::F32, &DEV).unwrap(),
899        ));
900
901        let pipeline = PipelineParallel::new(vec![stage0, stage1], 2);
902        let input = T::randn(vec![4, 4], DType::F32, &DEV).unwrap();
903        let output = pipeline.forward(&input).unwrap();
904        assert_eq!(output.dims(), &[4, 2]);
905    }
906
907    #[test]
908    fn test_pipeline_parameters() {
909        let stage0 = PipelineStage::<B>::new(0).add_layer(Box::new(
910            shrew_nn::Linear::<B>::new(4, 8, true, DType::F32, &DEV).unwrap(),
911        ));
912        let stage1 = PipelineStage::<B>::new(1).add_layer(Box::new(
913            shrew_nn::Linear::<B>::new(8, 2, true, DType::F32, &DEV).unwrap(),
914        ));
915
916        let pipeline = PipelineParallel::new(vec![stage0, stage1], 1);
917        // stage0: 4*8 + 8 = 40, stage1: 8*2 + 2 = 18, total = 58
918        let total: usize = pipeline.parameters().iter().map(|p| p.elem_count()).sum();
919        assert_eq!(total, 40 + 18);
920    }
921
922    // ── ParallelTrainer (gradient accumulation) ──
923
924    #[test]
925    fn test_parallel_trainer_accumulation() {
926        let linear = shrew_nn::Linear::<B>::new(3, 1, true, DType::F32, &DEV).unwrap();
927        let optimizer = shrew_optim::SGD::new(linear.parameters(), 0.01, 0.0, 0.0);
928        let mut trainer = ParallelTrainer::new(linear, optimizer, 2);
929
930        let x1 = T::randn(vec![1, 3], DType::F32, &DEV).unwrap();
931        let y1 = T::zeros(vec![1, 1], DType::F32, &DEV).unwrap();
932        let x2 = T::randn(vec![1, 3], DType::F32, &DEV).unwrap();
933        let y2 = T::zeros(vec![1, 1], DType::F32, &DEV).unwrap();
934
935        // First micro-batch: no step yet
936        let result1 = trainer
937            .accumulate_step(&x1, &y1, |p, t| shrew_nn::mse_loss(p, t))
938            .unwrap();
939        assert!(result1.is_none());
940
941        // Second micro-batch: step happens, returns average loss
942        let result2 = trainer
943            .accumulate_step(&x2, &y2, |p, t| shrew_nn::mse_loss(p, t))
944            .unwrap();
945        assert!(result2.is_some());
946    }
947
948    #[test]
949    fn test_parallel_trainer_flush() {
950        let linear = shrew_nn::Linear::<B>::new(3, 1, true, DType::F32, &DEV).unwrap();
951        let optimizer = shrew_optim::SGD::new(linear.parameters(), 0.01, 0.0, 0.0);
952        let mut trainer = ParallelTrainer::new(linear, optimizer, 4);
953
954        let x = T::randn(vec![1, 3], DType::F32, &DEV).unwrap();
955        let y = T::zeros(vec![1, 1], DType::F32, &DEV).unwrap();
956
957        // Only 1 of 4 accumulation steps done
958        trainer
959            .accumulate_step(&x, &y, |p, t| shrew_nn::mse_loss(p, t))
960            .unwrap();
961
962        // Flush forces a step with whatever we have
963        let flushed = trainer.flush().unwrap();
964        assert!(flushed.is_some());
965    }
966}