shrew_nn/module.rs
1// Module trait — The interface every neural network layer implements
2//
3// In PyTorch, `nn.Module` is the base class for all neural network layers.
4// In Shrew, `Module` is a trait that every layer implements.
5//
6// The key method is forward() — it takes input tensor(s) and returns output.
7// The parameters() method returns all trainable tensors (for optimizers).
8//
9// WHY A TRAIT?
10//
11// Unlike PyTorch's class hierarchy, Rust uses traits for polymorphism.
12// Each module (Linear, Embedding, etc.) is a plain struct that implements
13// the Module trait. This is idiomatic Rust and enables static dispatch.
14//
15// GENERIC OVER BACKEND:
16//
17// All modules are generic over B: Backend, so the same module definition
18// can run on CPU or GPU. The tensors inside the module live on whatever
19// backend B is.
20
21use shrew_core::backend::Backend;
22use shrew_core::error::Result;
23use shrew_core::tensor::Tensor;
24
25/// The fundamental trait for all neural network layers.
26///
27/// Every layer in Shrew implements this trait, providing:
28/// - `forward()`: compute output from input (the actual computation)
29/// - `parameters()`: list all trainable tensors (for optimizer updates)
30/// - `set_training()` / `is_training()`: toggle train/eval mode
31/// - `num_parameters()` / `trainable_params_count()`: count parameters
32///
33/// # Example
34/// ```ignore
35/// struct MyLayer<B: Backend> {
36/// linear: Linear<B>,
37/// }
38///
39/// impl<B: Backend> Module<B> for MyLayer<B> {
40/// fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
41/// self.linear.forward(x)?.relu()
42/// }
43/// fn parameters(&self) -> Vec<Tensor<B>> {
44/// self.linear.parameters()
45/// }
46/// }
47/// ```
48pub trait Module<B: Backend> {
49 /// Compute the output tensor from the input tensor.
50 /// This defines the layer's computation (forward pass).
51 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>>;
52
53 /// Return all trainable parameters of this module.
54 /// The optimizer uses these to update weights during training.
55 fn parameters(&self) -> Vec<Tensor<B>>;
56
57 /// Set training or evaluation mode.
58 ///
59 /// Override in modules that behave differently in train vs eval
60 /// (e.g., Dropout, BatchNorm). Default is a no-op.
61 ///
62 /// Uses interior mutability (`Cell<bool>`) so `&self` suffices.
63 fn set_training(&self, _training: bool) {
64 // Default: no-op. Override in Dropout, BatchNorm, etc.
65 }
66
67 /// Whether the module is in training mode (default: true).
68 fn is_training(&self) -> bool {
69 true
70 }
71
72 /// Convenience: set training mode.
73 fn train(&self) {
74 self.set_training(true);
75 }
76
77 /// Convenience: set evaluation mode.
78 fn eval(&self) {
79 self.set_training(false);
80 }
81
82 /// Total number of scalar parameters in this module.
83 fn num_parameters(&self) -> usize {
84 self.parameters().iter().map(|p| p.elem_count()).sum()
85 }
86
87 /// Number of trainable (variable) parameters.
88 fn trainable_params_count(&self) -> usize {
89 self.parameters()
90 .iter()
91 .filter(|p| p.is_variable())
92 .map(|p| p.elem_count())
93 .sum()
94 }
95
96 /// Freeze all parameters: returns new parameter tensors with
97 /// `is_variable = false`, preventing gradient accumulation.
98 ///
99 /// The caller must rebuild the module with the frozen tensors.
100 fn frozen_parameters(&self) -> Vec<Tensor<B>> {
101 self.parameters().into_iter().map(|p| p.freeze()).collect()
102 }
103
104 /// Return all trainable parameters with human-readable names.
105 ///
106 /// Leaf modules (Linear, Conv2d, etc.) override this to provide
107 /// meaningful names like `"weight"` / `"bias"`. Composite modules
108 /// should concatenate sub-module names with a `"."` separator, e.g.
109 /// `"fc1.weight"`, `"attn.w_q.weight"`.
110 ///
111 /// The default uses positional indices (`param_0`, `param_1`, …).
112 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
113 self.parameters()
114 .into_iter()
115 .enumerate()
116 .map(|(i, p)| (format!("param_{i}"), p))
117 .collect()
118 }
119
120 /// Returns a `state_dict`-style map of parameter name → tensor.
121 ///
122 /// This is the idiomatic way to serialize a module.
123 fn state_dict(&self) -> Vec<(String, Tensor<B>)> {
124 self.named_parameters()
125 }
126}