shrew_nn/
flatten.rs

1// Flatten — Flatten spatial dimensions into a single feature dimension
2//
3// Flattens a contiguous range of dimensions of the input tensor.
4// Commonly used between convolutional and fully connected layers.
5//
6// By default, flattens all dimensions except the batch dimension:
7//   [N, C, H, W] → [N, C*H*W]
8//
9// The `start_dim` and `end_dim` parameters control which dimensions
10// to flatten (1-indexed, inclusive). Default: start_dim=1, end_dim=-1.
11
12use shrew_core::backend::Backend;
13use shrew_core::error::Result;
14use shrew_core::shape::Shape;
15use shrew_core::tensor::Tensor;
16
17use crate::module::Module;
18
19/// Flatten layer: collapses dimensions `[start_dim..=end_dim]` into one.
20///
21/// Default (start_dim=1): `[N, C, H, W]` → `[N, C*H*W]`.
22///
23/// # Examples
24/// ```ignore
25/// let flatten = Flatten::new(1); // flatten from dim 1 onward
26/// let x: [2, 8, 4, 4] tensor
27/// let y = flatten.forward(&x)?; // [2, 128]
28/// ```
29pub struct Flatten {
30    start_dim: usize,
31}
32
33impl Flatten {
34    /// Create a Flatten that collapses from `start_dim` through the last dim.
35    ///
36    /// - `start_dim = 1`: flatten everything except batch → `[N, ...]` → `[N, flat]`
37    /// - `start_dim = 0`: flatten everything → `[total]`
38    pub fn new(start_dim: usize) -> Self {
39        Flatten { start_dim }
40    }
41
42    /// Default flatten: from dim 1 onward (preserves batch dimension).
43    pub fn default_flat() -> Self {
44        Flatten { start_dim: 1 }
45    }
46}
47
48impl Default for Flatten {
49    fn default() -> Self {
50        Self::default_flat()
51    }
52}
53
54impl<B: Backend> Module<B> for Flatten {
55    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
56        let dims = x.dims();
57        if self.start_dim >= dims.len() {
58            return Ok(x.clone()); // nothing to flatten
59        }
60
61        // Keep dims before start_dim, multiply out the rest
62        let mut new_dims: Vec<usize> = dims[..self.start_dim].to_vec();
63        let flat: usize = dims[self.start_dim..].iter().product();
64        new_dims.push(flat);
65
66        x.reshape(Shape::new(new_dims))
67    }
68
69    fn parameters(&self) -> Vec<Tensor<B>> {
70        vec![] // No learnable parameters
71    }
72}