1use std::marker::PhantomData;
24
25use shrew_core::backend::Backend;
26use shrew_core::backprop::GradStore;
27use shrew_core::dtype::DType;
28use shrew_core::error::Result;
29use shrew_core::tensor::Tensor;
30
31use shrew_nn::Module;
32use shrew_optim::Optimizer;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum AllReduceOp {
39 Sum,
41 Average,
43}
44
45pub fn reduce_gradients<B: Backend>(
57 grad_stores: &[GradStore<B>],
58 params: &[Tensor<B>],
59 strategy: AllReduceOp,
60) -> Result<GradStore<B>> {
61 let n = grad_stores.len();
62 if n == 0 {
63 return Ok(GradStore::new());
64 }
65 if n == 1 {
66 return Ok(grad_stores[0].clone());
67 }
68
69 let mut merged = GradStore::new();
70
71 for param in params {
72 let mut grads: Vec<&Tensor<B>> = Vec::new();
74 for store in grad_stores {
75 if let Some(g) = store.get(param) {
76 grads.push(g);
77 }
78 }
79 if grads.is_empty() {
80 continue;
81 }
82
83 let mut acc = grads[0].clone();
85 for g in &grads[1..] {
86 acc = acc.add(g)?;
87 }
88
89 if strategy == AllReduceOp::Average && grads.len() > 1 {
91 let scale = 1.0 / grads.len() as f64;
92 acc = acc.affine(scale, 0.0)?;
93 }
94
95 merged.accumulate(param.id(), acc)?;
96 }
97
98 Ok(merged)
99}
100
101pub struct DataParallel<M> {
122 pub module: M,
124 pub num_workers: usize,
126}
127
128impl<M: Clone> Clone for DataParallel<M> {
129 fn clone(&self) -> Self {
130 Self {
131 module: self.module.clone(),
132 num_workers: self.num_workers,
133 }
134 }
135}
136
137impl<M: std::fmt::Debug> std::fmt::Debug for DataParallel<M> {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 f.debug_struct("DataParallel")
140 .field("module", &self.module)
141 .field("num_workers", &self.num_workers)
142 .finish()
143 }
144}
145
146impl<M> DataParallel<M> {
147 pub fn new(module: M, num_workers: usize) -> Self {
152 assert!(num_workers > 0, "num_workers must be > 0");
153 Self {
154 module,
155 num_workers,
156 }
157 }
158
159 pub fn inner(&self) -> &M {
161 &self.module
162 }
163
164 pub fn inner_mut(&mut self) -> &mut M {
166 &mut self.module
167 }
168
169 pub fn into_inner(self) -> M {
171 self.module
172 }
173}
174
175impl<M, B> Module<B> for DataParallel<M>
176where
177 M: Module<B> + Send + Sync,
178 B: Backend,
179{
180 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
181 let batch_size = x.dims()[0];
182 let effective_workers = self.num_workers.min(batch_size);
183
184 if effective_workers <= 1 {
185 return self.module.forward(x);
186 }
187
188 let chunks = x.chunk(effective_workers, 0)?;
190
191 let mut outputs = Vec::with_capacity(chunks.len());
196 for chunk in &chunks {
197 outputs.push(self.module.forward(chunk)?);
198 }
199
200 Tensor::cat(&outputs, 0)
202 }
203
204 fn parameters(&self) -> Vec<Tensor<B>> {
205 self.module.parameters()
206 }
207
208 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
209 self.module.named_parameters()
210 }
211
212 fn set_training(&self, training: bool) {
213 self.module.set_training(training);
214 }
215
216 fn is_training(&self) -> bool {
217 self.module.is_training()
218 }
219}
220
221#[derive(Debug, Clone)]
225pub struct LossScaleConfig {
226 pub init_scale: f64,
228 pub scale_growth_factor: f64,
230 pub scale_backoff_factor: f64,
232 pub growth_interval: u64,
234}
235
236impl Default for LossScaleConfig {
237 fn default() -> Self {
238 Self {
239 init_scale: 65536.0,
240 scale_growth_factor: 2.0,
241 scale_backoff_factor: 2.0,
242 growth_interval: 2000,
243 }
244 }
245}
246
247pub struct MixedPrecisionTrainer<M, O, B: Backend> {
282 model: M,
284 optimizer: O,
286 compute_dtype: DType,
288 loss_scale: f64,
290 config: LossScaleConfig,
292 good_steps: u64,
294 skipped_steps: u64,
296 _phantom: PhantomData<B>,
297}
298
299#[derive(Debug, Clone)]
301pub struct MixedPrecisionMetrics {
302 pub loss: f64,
304 pub skipped: bool,
306 pub loss_scale: f64,
308 pub total_skipped: u64,
310 pub compute_dtype: DType,
312}
313
314impl<M, O, B> MixedPrecisionTrainer<M, O, B>
315where
316 M: Module<B>,
317 O: Optimizer<B>,
318 B: Backend,
319{
320 pub fn new(model: M, optimizer: O, compute_dtype: DType, config: LossScaleConfig) -> Self {
325 let loss_scale = config.init_scale;
326 Self {
327 model,
328 optimizer,
329 compute_dtype,
330 loss_scale,
331 config,
332 good_steps: 0,
333 skipped_steps: 0,
334 _phantom: PhantomData,
335 }
336 }
337
338 pub fn model(&self) -> &M {
340 &self.model
341 }
342
343 pub fn model_mut(&mut self) -> &mut M {
345 &mut self.model
346 }
347
348 pub fn optimizer(&self) -> &O {
350 &self.optimizer
351 }
352
353 pub fn loss_scale(&self) -> f64 {
355 self.loss_scale
356 }
357
358 pub fn compute_dtype(&self) -> DType {
360 self.compute_dtype
361 }
362
363 pub fn skipped_steps(&self) -> u64 {
365 self.skipped_steps
366 }
367
368 pub fn train_step<F>(
382 &mut self,
383 input: &Tensor<B>,
384 target: &Tensor<B>,
385 loss_fn: F,
386 ) -> Result<MixedPrecisionMetrics>
387 where
388 F: Fn(&Tensor<B>, &Tensor<B>) -> Result<Tensor<B>>,
389 {
390 let model_dtype = self
394 .model
395 .parameters()
396 .first()
397 .map(|p| p.dtype())
398 .unwrap_or(DType::F32);
399 let should_cast = self.compute_dtype != DType::F32 && self.compute_dtype == model_dtype;
400
401 let input_cast = if should_cast && input.dtype() != self.compute_dtype {
402 input.to_dtype(self.compute_dtype)?
403 } else {
404 input.clone()
405 };
406 let target_cast = if should_cast && target.dtype() != self.compute_dtype {
407 target.to_dtype(self.compute_dtype)?
408 } else {
409 target.clone()
410 };
411
412 let output = self.model.forward(&input_cast)?;
414
415 let loss = loss_fn(&output, &target_cast)?;
417 let loss_val = loss.to_scalar_f64()?;
418
419 let scaled_loss = loss.affine(self.loss_scale, 0.0)?;
421
422 let grads = scaled_loss.backward()?;
424
425 let has_overflow = self.check_overflow(&grads)?;
427
428 if has_overflow {
429 self.loss_scale /= self.config.scale_backoff_factor;
431 self.loss_scale = self.loss_scale.max(1.0); self.good_steps = 0;
433 self.skipped_steps += 1;
434
435 return Ok(MixedPrecisionMetrics {
436 loss: loss_val,
437 skipped: true,
438 loss_scale: self.loss_scale,
439 total_skipped: self.skipped_steps,
440 compute_dtype: self.compute_dtype,
441 });
442 }
443
444 let unscaled = self.unscale_and_cast_gradients(&grads)?;
446
447 self.optimizer.step(&unscaled)?;
449
450 self.good_steps += 1;
452 if self.good_steps >= self.config.growth_interval {
453 self.loss_scale *= self.config.scale_growth_factor;
454 self.good_steps = 0;
455 }
456
457 Ok(MixedPrecisionMetrics {
458 loss: loss_val,
459 skipped: false,
460 loss_scale: self.loss_scale,
461 total_skipped: self.skipped_steps,
462 compute_dtype: self.compute_dtype,
463 })
464 }
465
466 fn check_overflow(&self, grads: &GradStore<B>) -> Result<bool> {
468 for param in self.model.parameters() {
469 if let Some(g) = grads.get(¶m) {
470 let data = g.to_f64_vec()?;
471 for &v in &data {
472 if v.is_nan() || v.is_infinite() {
473 return Ok(true);
474 }
475 }
476 }
477 }
478 Ok(false)
479 }
480
481 fn unscale_and_cast_gradients(&self, grads: &GradStore<B>) -> Result<GradStore<B>> {
486 let inv_scale = 1.0 / self.loss_scale;
487 let mut unscaled = GradStore::new();
488 for param in self.model.parameters() {
489 if let Some(g) = grads.get(¶m) {
490 let g_unscaled = g.affine(inv_scale, 0.0)?;
492 let g_fp32 = if g_unscaled.dtype() != param.dtype() {
494 g_unscaled.to_dtype(param.dtype())?
495 } else {
496 g_unscaled
497 };
498 unscaled.accumulate(param.id(), g_fp32)?;
499 }
500 }
501 Ok(unscaled)
502 }
503}
504
505pub struct PipelineStage<B: Backend> {
513 layers: Vec<Box<dyn Module<B>>>,
515 stage_id: usize,
517}
518
519impl<B: Backend> PipelineStage<B> {
520 pub fn new(stage_id: usize) -> Self {
522 Self {
523 layers: Vec::new(),
524 stage_id,
525 }
526 }
527
528 pub fn add_layer(mut self, layer: Box<dyn Module<B>>) -> Self {
530 self.layers.push(layer);
531 self
532 }
533
534 pub fn stage_id(&self) -> usize {
536 self.stage_id
537 }
538
539 pub fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
541 let mut out = x.clone();
542 for layer in &self.layers {
543 out = layer.forward(&out)?;
544 }
545 Ok(out)
546 }
547
548 pub fn parameters(&self) -> Vec<Tensor<B>> {
550 self.layers.iter().flat_map(|l| l.parameters()).collect()
551 }
552}
553
554pub struct PipelineParallel<B: Backend> {
571 stages: Vec<PipelineStage<B>>,
573 num_micro_batches: usize,
575}
576
577impl<B: Backend> PipelineParallel<B> {
578 pub fn new(stages: Vec<PipelineStage<B>>, num_micro_batches: usize) -> Self {
580 assert!(!stages.is_empty(), "pipeline needs at least one stage");
581 assert!(num_micro_batches > 0, "num_micro_batches must be > 0");
582 Self {
583 stages,
584 num_micro_batches,
585 }
586 }
587
588 pub fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
596 let batch_size = x.dims()[0];
597 let effective_micros = self.num_micro_batches.min(batch_size);
598
599 if effective_micros <= 1 {
600 let mut out = x.clone();
602 for stage in &self.stages {
603 out = stage.forward(&out)?;
604 }
605 return Ok(out);
606 }
607
608 let micro_batches = x.chunk(effective_micros, 0)?;
610
611 let mut outputs = Vec::with_capacity(micro_batches.len());
613 for mb in µ_batches {
614 let mut out = mb.clone();
615 for stage in &self.stages {
616 out = stage.forward(&out)?;
617 }
618 outputs.push(out);
619 }
620
621 Tensor::cat(&outputs, 0)
623 }
624
625 pub fn parameters(&self) -> Vec<Tensor<B>> {
627 self.stages.iter().flat_map(|s| s.parameters()).collect()
628 }
629
630 pub fn num_stages(&self) -> usize {
632 self.stages.len()
633 }
634
635 pub fn stage(&self, idx: usize) -> Option<&PipelineStage<B>> {
637 self.stages.get(idx)
638 }
639}
640
641pub struct ParallelTrainer<M, O, B: Backend> {
664 pub model: M,
666 pub optimizer: O,
668 accumulation_steps: usize,
670 accumulated: Option<GradStore<B>>,
672 current_step: usize,
674 loss_sum: f64,
676 _phantom: PhantomData<B>,
677}
678
679impl<M, O, B> ParallelTrainer<M, O, B>
680where
681 M: Module<B>,
682 O: Optimizer<B>,
683 B: Backend,
684{
685 pub fn new(model: M, optimizer: O, accumulation_steps: usize) -> Self {
689 assert!(accumulation_steps > 0);
690 Self {
691 model,
692 optimizer,
693 accumulation_steps,
694 accumulated: None,
695 current_step: 0,
696 loss_sum: 0.0,
697 _phantom: PhantomData,
698 }
699 }
700
701 pub fn accumulate_step<F>(
704 &mut self,
705 input: &Tensor<B>,
706 target: &Tensor<B>,
707 loss_fn: F,
708 ) -> Result<Option<f64>>
709 where
710 F: Fn(&Tensor<B>, &Tensor<B>) -> Result<Tensor<B>>,
711 {
712 let output = self.model.forward(input)?;
714 let loss = loss_fn(&output, target)?;
715 let loss_val = loss.to_scalar_f64()?;
716 self.loss_sum += loss_val;
717
718 let grads = loss.backward()?;
720
721 let params = self.model.parameters();
723 match self.accumulated.take() {
724 Some(prev) => {
725 let merged = reduce_gradients(&[prev, grads], ¶ms, AllReduceOp::Sum)?;
726 self.accumulated = Some(merged);
727 }
728 None => {
729 self.accumulated = Some(grads);
730 }
731 }
732
733 self.current_step += 1;
734
735 if self.current_step >= self.accumulation_steps {
737 let avg_grads = {
738 let acc = self.accumulated.take().unwrap();
739 let mut averaged = GradStore::new();
741 let scale = 1.0 / self.accumulation_steps as f64;
742 for param in ¶ms {
743 if let Some(g) = acc.get(param) {
744 let g_avg = g.affine(scale, 0.0)?;
745 averaged.accumulate(param.id(), g_avg)?;
746 }
747 }
748 averaged
749 };
750
751 self.optimizer.step(&avg_grads)?;
752
753 let avg_loss = self.loss_sum / self.accumulation_steps as f64;
754 self.current_step = 0;
755 self.loss_sum = 0.0;
756 self.accumulated = None;
757
758 Ok(Some(avg_loss))
759 } else {
760 Ok(None)
761 }
762 }
763
764 pub fn flush(&mut self) -> Result<Option<f64>> {
767 if self.current_step == 0 || self.accumulated.is_none() {
768 return Ok(None);
769 }
770
771 let params = self.model.parameters();
772 let acc = self.accumulated.take().unwrap();
773 let scale = 1.0 / self.current_step as f64;
774 let mut averaged = GradStore::new();
775 for param in ¶ms {
776 if let Some(g) = acc.get(param) {
777 let g_avg = g.affine(scale, 0.0)?;
778 averaged.accumulate(param.id(), g_avg)?;
779 }
780 }
781
782 self.optimizer.step(&averaged)?;
783
784 let avg_loss = self.loss_sum / self.current_step as f64;
785 self.current_step = 0;
786 self.loss_sum = 0.0;
787 self.accumulated = None;
788
789 Ok(Some(avg_loss))
790 }
791}
792
793#[cfg(test)]
796mod tests {
797 use super::*;
798 use shrew_cpu::{CpuBackend, CpuDevice};
799
800 type B = CpuBackend;
801 type T = Tensor<B>;
802 const DEV: CpuDevice = CpuDevice;
803
804 #[test]
807 fn test_reduce_gradients_average() {
808 let p = T::randn(vec![4], DType::F32, &DEV).unwrap().set_variable();
809 let loss1 = p.sum_all().unwrap();
810 let g1 = loss1.backward().unwrap();
811
812 let loss2 = p.affine(2.0, 0.0).unwrap().sum_all().unwrap();
813 let g2 = loss2.backward().unwrap();
814
815 let merged = reduce_gradients(&[g1, g2], &[p.clone()], AllReduceOp::Average).unwrap();
816 let avg = merged.get(&p).unwrap().to_f64_vec().unwrap();
817 for &v in &avg {
819 assert!((v - 1.5).abs() < 1e-5, "expected 1.5, got {v}");
820 }
821 }
822
823 #[test]
824 fn test_reduce_gradients_sum() {
825 let p = T::randn(vec![3], DType::F32, &DEV).unwrap().set_variable();
826 let loss1 = p.sum_all().unwrap();
827 let g1 = loss1.backward().unwrap();
828
829 let loss2 = p.sum_all().unwrap();
830 let g2 = loss2.backward().unwrap();
831
832 let merged = reduce_gradients(&[g1, g2], &[p.clone()], AllReduceOp::Sum).unwrap();
833 let summed = merged.get(&p).unwrap().to_f64_vec().unwrap();
834 for &v in &summed {
835 assert!((v - 2.0).abs() < 1e-5, "expected 2.0, got {v}");
836 }
837 }
838
839 #[test]
842 fn test_data_parallel_forward() {
843 let linear = shrew_nn::Linear::<B>::new(4, 2, true, DType::F32, &DEV).unwrap();
844 let dp = DataParallel::new(linear, 2);
845
846 let input = T::randn(vec![6, 4], DType::F32, &DEV).unwrap();
847 let output = dp.forward(&input).unwrap();
848 assert_eq!(output.dims(), &[6, 2]);
849 }
850
851 #[test]
852 fn test_data_parallel_single_worker() {
853 let linear = shrew_nn::Linear::<B>::new(3, 2, true, DType::F32, &DEV).unwrap();
854 let dp = DataParallel::new(linear, 1);
855
856 let input = T::randn(vec![4, 3], DType::F32, &DEV).unwrap();
857 let output = dp.forward(&input).unwrap();
858 assert_eq!(output.dims(), &[4, 2]);
859 }
860
861 #[test]
862 fn test_data_parallel_parameters() {
863 let linear = shrew_nn::Linear::<B>::new(4, 2, true, DType::F32, &DEV).unwrap();
864 let n_params = linear.parameters().len();
865 let dp = DataParallel::new(linear, 4);
866 assert_eq!(dp.parameters().len(), n_params);
867 }
868
869 #[test]
872 fn test_mixed_precision_basic() {
873 let linear = shrew_nn::Linear::<B>::new(4, 1, true, DType::F32, &DEV).unwrap();
874 let optimizer = shrew_optim::SGD::new(linear.parameters(), 0.01, 0.0, 0.0);
875 let mut trainer =
876 MixedPrecisionTrainer::new(linear, optimizer, DType::F16, LossScaleConfig::default());
877
878 let input = T::randn(vec![2, 4], DType::F32, &DEV).unwrap();
879 let target = T::zeros(vec![2, 1], DType::F32, &DEV).unwrap();
880
881 let metrics = trainer
882 .train_step(&input, &target, |pred, tgt| shrew_nn::mse_loss(pred, tgt))
883 .unwrap();
884
885 assert!(!metrics.skipped);
886 assert!(metrics.loss >= 0.0);
887 assert_eq!(metrics.loss_scale, 65536.0);
888 }
889
890 #[test]
893 fn test_pipeline_forward() {
894 let stage0 = PipelineStage::<B>::new(0).add_layer(Box::new(
895 shrew_nn::Linear::<B>::new(4, 8, true, DType::F32, &DEV).unwrap(),
896 ));
897 let stage1 = PipelineStage::<B>::new(1).add_layer(Box::new(
898 shrew_nn::Linear::<B>::new(8, 2, true, DType::F32, &DEV).unwrap(),
899 ));
900
901 let pipeline = PipelineParallel::new(vec![stage0, stage1], 2);
902 let input = T::randn(vec![4, 4], DType::F32, &DEV).unwrap();
903 let output = pipeline.forward(&input).unwrap();
904 assert_eq!(output.dims(), &[4, 2]);
905 }
906
907 #[test]
908 fn test_pipeline_parameters() {
909 let stage0 = PipelineStage::<B>::new(0).add_layer(Box::new(
910 shrew_nn::Linear::<B>::new(4, 8, true, DType::F32, &DEV).unwrap(),
911 ));
912 let stage1 = PipelineStage::<B>::new(1).add_layer(Box::new(
913 shrew_nn::Linear::<B>::new(8, 2, true, DType::F32, &DEV).unwrap(),
914 ));
915
916 let pipeline = PipelineParallel::new(vec![stage0, stage1], 1);
917 let total: usize = pipeline.parameters().iter().map(|p| p.elem_count()).sum();
919 assert_eq!(total, 40 + 18);
920 }
921
922 #[test]
925 fn test_parallel_trainer_accumulation() {
926 let linear = shrew_nn::Linear::<B>::new(3, 1, true, DType::F32, &DEV).unwrap();
927 let optimizer = shrew_optim::SGD::new(linear.parameters(), 0.01, 0.0, 0.0);
928 let mut trainer = ParallelTrainer::new(linear, optimizer, 2);
929
930 let x1 = T::randn(vec![1, 3], DType::F32, &DEV).unwrap();
931 let y1 = T::zeros(vec![1, 1], DType::F32, &DEV).unwrap();
932 let x2 = T::randn(vec![1, 3], DType::F32, &DEV).unwrap();
933 let y2 = T::zeros(vec![1, 1], DType::F32, &DEV).unwrap();
934
935 let result1 = trainer
937 .accumulate_step(&x1, &y1, |p, t| shrew_nn::mse_loss(p, t))
938 .unwrap();
939 assert!(result1.is_none());
940
941 let result2 = trainer
943 .accumulate_step(&x2, &y2, |p, t| shrew_nn::mse_loss(p, t))
944 .unwrap();
945 assert!(result2.is_some());
946 }
947
948 #[test]
949 fn test_parallel_trainer_flush() {
950 let linear = shrew_nn::Linear::<B>::new(3, 1, true, DType::F32, &DEV).unwrap();
951 let optimizer = shrew_optim::SGD::new(linear.parameters(), 0.01, 0.0, 0.0);
952 let mut trainer = ParallelTrainer::new(linear, optimizer, 4);
953
954 let x = T::randn(vec![1, 3], DType::F32, &DEV).unwrap();
955 let y = T::zeros(vec![1, 1], DType::F32, &DEV).unwrap();
956
957 trainer
959 .accumulate_step(&x, &y, |p, t| shrew_nn::mse_loss(p, t))
960 .unwrap();
961
962 let flushed = trainer.flush().unwrap();
964 assert!(flushed.is_some());
965 }
966}