shrew_nn/flatten.rs
1// Flatten — Flatten spatial dimensions into a single feature dimension
2//
3// Flattens a contiguous range of dimensions of the input tensor.
4// Commonly used between convolutional and fully connected layers.
5//
6// By default, flattens all dimensions except the batch dimension:
7// [N, C, H, W] → [N, C*H*W]
8//
9// The `start_dim` and `end_dim` parameters control which dimensions
10// to flatten (1-indexed, inclusive). Default: start_dim=1, end_dim=-1.
11
12use shrew_core::backend::Backend;
13use shrew_core::error::Result;
14use shrew_core::shape::Shape;
15use shrew_core::tensor::Tensor;
16
17use crate::module::Module;
18
19/// Flatten layer: collapses dimensions `[start_dim..=end_dim]` into one.
20///
21/// Default (start_dim=1): `[N, C, H, W]` → `[N, C*H*W]`.
22///
23/// # Examples
24/// ```ignore
25/// let flatten = Flatten::new(1); // flatten from dim 1 onward
26/// let x: [2, 8, 4, 4] tensor
27/// let y = flatten.forward(&x)?; // [2, 128]
28/// ```
29pub struct Flatten {
30 start_dim: usize,
31}
32
33impl Flatten {
34 /// Create a Flatten that collapses from `start_dim` through the last dim.
35 ///
36 /// - `start_dim = 1`: flatten everything except batch → `[N, ...]` → `[N, flat]`
37 /// - `start_dim = 0`: flatten everything → `[total]`
38 pub fn new(start_dim: usize) -> Self {
39 Flatten { start_dim }
40 }
41
42 /// Default flatten: from dim 1 onward (preserves batch dimension).
43 pub fn default_flat() -> Self {
44 Flatten { start_dim: 1 }
45 }
46}
47
48impl Default for Flatten {
49 fn default() -> Self {
50 Self::default_flat()
51 }
52}
53
54impl<B: Backend> Module<B> for Flatten {
55 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
56 let dims = x.dims();
57 if self.start_dim >= dims.len() {
58 return Ok(x.clone()); // nothing to flatten
59 }
60
61 // Keep dims before start_dim, multiply out the rest
62 let mut new_dims: Vec<usize> = dims[..self.start_dim].to_vec();
63 let flat: usize = dims[self.start_dim..].iter().product();
64 new_dims.push(flat);
65
66 x.reshape(Shape::new(new_dims))
67 }
68
69 fn parameters(&self) -> Vec<Tensor<B>> {
70 vec![] // No learnable parameters
71 }
72}