shrew_nn/
sequential.rs

1// Sequential — A chain of modules applied one after another
2//
3// Sequential is the simplest way to build a neural network: a list of layers
4// applied in order. It's equivalent to PyTorch's nn.Sequential.
5//
6// Example:
7//   let model = Sequential::new()
8//       .add(linear1)
9//       .add(ReLU)
10//       .add(linear2);
11//
12//   let output = model.forward(&input)?;
13//
14// The output of each layer becomes the input to the next.
15
16use shrew_core::backend::Backend;
17use shrew_core::error::Result;
18use shrew_core::tensor::Tensor;
19
20use crate::module::Module;
21
22/// A container that chains modules sequentially.
23///
24/// Each module's output becomes the next module's input.
25/// Sequential itself implements Module, so it can be nested.
26pub struct Sequential<B: Backend> {
27    layers: Vec<Box<dyn Module<B>>>,
28}
29
30impl<B: Backend> Sequential<B> {
31    /// Create an empty Sequential.
32    pub fn new() -> Self {
33        Sequential { layers: Vec::new() }
34    }
35
36    /// Add a layer to the end of the sequence. Returns self for chaining.
37    #[allow(clippy::should_implement_trait)]
38    pub fn add<M: Module<B> + 'static>(mut self, module: M) -> Self {
39        self.layers.push(Box::new(module));
40        self
41    }
42
43    /// Number of layers.
44    pub fn len(&self) -> usize {
45        self.layers.len()
46    }
47
48    /// Whether the sequential is empty.
49    pub fn is_empty(&self) -> bool {
50        self.layers.is_empty()
51    }
52}
53
54impl<B: Backend> Default for Sequential<B> {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl<B: Backend> Module<B> for Sequential<B> {
61    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
62        let mut out = x.clone();
63        for layer in &self.layers {
64            out = layer.forward(&out)?;
65        }
66        Ok(out)
67    }
68
69    fn parameters(&self) -> Vec<Tensor<B>> {
70        self.layers.iter().flat_map(|l| l.parameters()).collect()
71    }
72
73    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
74        let mut named = Vec::new();
75        for (i, layer) in self.layers.iter().enumerate() {
76            for (k, v) in layer.named_parameters() {
77                named.push((format!("layers.{i}.{k}"), v));
78            }
79        }
80        named
81    }
82
83    /// Propagate training mode to all child layers.
84    fn set_training(&self, training: bool) {
85        for layer in &self.layers {
86            layer.set_training(training);
87        }
88    }
89}