shrew_nn/
module.rs

1// Module trait — The interface every neural network layer implements
2//
3// In PyTorch, `nn.Module` is the base class for all neural network layers.
4// In Shrew, `Module` is a trait that every layer implements.
5//
6// The key method is forward() — it takes input tensor(s) and returns output.
7// The parameters() method returns all trainable tensors (for optimizers).
8//
9// WHY A TRAIT?
10//
11// Unlike PyTorch's class hierarchy, Rust uses traits for polymorphism.
12// Each module (Linear, Embedding, etc.) is a plain struct that implements
13// the Module trait. This is idiomatic Rust and enables static dispatch.
14//
15// GENERIC OVER BACKEND:
16//
17// All modules are generic over B: Backend, so the same module definition
18// can run on CPU or GPU. The tensors inside the module live on whatever
19// backend B is.
20
21use shrew_core::backend::Backend;
22use shrew_core::error::Result;
23use shrew_core::tensor::Tensor;
24
25/// The fundamental trait for all neural network layers.
26///
27/// Every layer in Shrew implements this trait, providing:
28/// - `forward()`: compute output from input (the actual computation)
29/// - `parameters()`: list all trainable tensors (for optimizer updates)
30/// - `set_training()` / `is_training()`: toggle train/eval mode
31/// - `num_parameters()` / `trainable_params_count()`: count parameters
32///
33/// # Example
34/// ```ignore
35/// struct MyLayer<B: Backend> {
36///     linear: Linear<B>,
37/// }
38///
39/// impl<B: Backend> Module<B> for MyLayer<B> {
40///     fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
41///         self.linear.forward(x)?.relu()
42///     }
43///     fn parameters(&self) -> Vec<Tensor<B>> {
44///         self.linear.parameters()
45///     }
46/// }
47/// ```
48pub trait Module<B: Backend> {
49    /// Compute the output tensor from the input tensor.
50    /// This defines the layer's computation (forward pass).
51    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>>;
52
53    /// Return all trainable parameters of this module.
54    /// The optimizer uses these to update weights during training.
55    fn parameters(&self) -> Vec<Tensor<B>>;
56
57    /// Set training or evaluation mode.
58    ///
59    /// Override in modules that behave differently in train vs eval
60    /// (e.g., Dropout, BatchNorm). Default is a no-op.
61    ///
62    /// Uses interior mutability (`Cell<bool>`) so `&self` suffices.
63    fn set_training(&self, _training: bool) {
64        // Default: no-op. Override in Dropout, BatchNorm, etc.
65    }
66
67    /// Whether the module is in training mode (default: true).
68    fn is_training(&self) -> bool {
69        true
70    }
71
72    /// Convenience: set training mode.
73    fn train(&self) {
74        self.set_training(true);
75    }
76
77    /// Convenience: set evaluation mode.
78    fn eval(&self) {
79        self.set_training(false);
80    }
81
82    /// Total number of scalar parameters in this module.
83    fn num_parameters(&self) -> usize {
84        self.parameters().iter().map(|p| p.elem_count()).sum()
85    }
86
87    /// Number of trainable (variable) parameters.
88    fn trainable_params_count(&self) -> usize {
89        self.parameters()
90            .iter()
91            .filter(|p| p.is_variable())
92            .map(|p| p.elem_count())
93            .sum()
94    }
95
96    /// Freeze all parameters: returns new parameter tensors with
97    /// `is_variable = false`, preventing gradient accumulation.
98    ///
99    /// The caller must rebuild the module with the frozen tensors.
100    fn frozen_parameters(&self) -> Vec<Tensor<B>> {
101        self.parameters().into_iter().map(|p| p.freeze()).collect()
102    }
103
104    /// Return all trainable parameters with human-readable names.
105    ///
106    /// Leaf modules (Linear, Conv2d, etc.) override this to provide
107    /// meaningful names like `"weight"` / `"bias"`.  Composite modules
108    /// should concatenate sub-module names with a `"."` separator, e.g.
109    /// `"fc1.weight"`, `"attn.w_q.weight"`.
110    ///
111    /// The default uses positional indices (`param_0`, `param_1`, …).
112    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
113        self.parameters()
114            .into_iter()
115            .enumerate()
116            .map(|(i, p)| (format!("param_{i}"), p))
117            .collect()
118    }
119
120    /// Returns a `state_dict`-style map of parameter name → tensor.
121    ///
122    /// This is the idiomatic way to serialize a module.
123    fn state_dict(&self) -> Vec<(String, Tensor<B>)> {
124        self.named_parameters()
125    }
126}