1use shrew_core::backend::Backend;
17use shrew_core::error::Result;
18use shrew_core::tensor::Tensor;
19
20use crate::module::Module;
21
22pub struct Sequential<B: Backend> {
27 layers: Vec<Box<dyn Module<B>>>,
28}
29
30impl<B: Backend> Sequential<B> {
31 pub fn new() -> Self {
33 Sequential { layers: Vec::new() }
34 }
35
36 #[allow(clippy::should_implement_trait)]
38 pub fn add<M: Module<B> + 'static>(mut self, module: M) -> Self {
39 self.layers.push(Box::new(module));
40 self
41 }
42
43 pub fn len(&self) -> usize {
45 self.layers.len()
46 }
47
48 pub fn is_empty(&self) -> bool {
50 self.layers.is_empty()
51 }
52}
53
54impl<B: Backend> Default for Sequential<B> {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl<B: Backend> Module<B> for Sequential<B> {
61 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
62 let mut out = x.clone();
63 for layer in &self.layers {
64 out = layer.forward(&out)?;
65 }
66 Ok(out)
67 }
68
69 fn parameters(&self) -> Vec<Tensor<B>> {
70 self.layers.iter().flat_map(|l| l.parameters()).collect()
71 }
72
73 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
74 let mut named = Vec::new();
75 for (i, layer) in self.layers.iter().enumerate() {
76 for (k, v) in layer.named_parameters() {
77 named.push((format!("layers.{i}.{k}"), v));
78 }
79 }
80 named
81 }
82
83 fn set_training(&self, training: bool) {
85 for layer in &self.layers {
86 layer.set_training(training);
87 }
88 }
89}