pub trait Module<B>where
B: Backend,{
// Required methods
fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>, Error>;
fn parameters(&self) -> Vec<Tensor<B>>;
// Provided methods
fn set_training(&self, _training: bool) { ... }
fn is_training(&self) -> bool { ... }
fn train(&self) { ... }
fn eval(&self) { ... }
fn num_parameters(&self) -> usize { ... }
fn trainable_params_count(&self) -> usize { ... }
fn frozen_parameters(&self) -> Vec<Tensor<B>> { ... }
fn named_parameters(&self) -> Vec<(String, Tensor<B>)> { ... }
fn state_dict(&self) -> Vec<(String, Tensor<B>)> { ... }
}Expand description
The fundamental trait for all neural network layers.
Every layer in Shrew implements this trait, providing:
forward(): compute output from input (the actual computation)parameters(): list all trainable tensors (for optimizer updates)set_training()/is_training(): toggle train/eval modenum_parameters()/trainable_params_count(): count parameters
§Example
struct MyLayer<B: Backend> {
linear: Linear<B>,
}
impl<B: Backend> Module<B> for MyLayer<B> {
fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
self.linear.forward(x)?.relu()
}
fn parameters(&self) -> Vec<Tensor<B>> {
self.linear.parameters()
}
}Required Methods§
Sourcefn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>, Error>
fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>, Error>
Compute the output tensor from the input tensor. This defines the layer’s computation (forward pass).
Sourcefn parameters(&self) -> Vec<Tensor<B>>
fn parameters(&self) -> Vec<Tensor<B>>
Return all trainable parameters of this module. The optimizer uses these to update weights during training.
Provided Methods§
Sourcefn set_training(&self, _training: bool)
fn set_training(&self, _training: bool)
Set training or evaluation mode.
Override in modules that behave differently in train vs eval (e.g., Dropout, BatchNorm). Default is a no-op.
Uses interior mutability (Cell<bool>) so &self suffices.
Sourcefn is_training(&self) -> bool
fn is_training(&self) -> bool
Whether the module is in training mode (default: true).
Sourcefn num_parameters(&self) -> usize
fn num_parameters(&self) -> usize
Total number of scalar parameters in this module.
Sourcefn trainable_params_count(&self) -> usize
fn trainable_params_count(&self) -> usize
Number of trainable (variable) parameters.
Sourcefn frozen_parameters(&self) -> Vec<Tensor<B>>
fn frozen_parameters(&self) -> Vec<Tensor<B>>
Freeze all parameters: returns new parameter tensors with
is_variable = false, preventing gradient accumulation.
The caller must rebuild the module with the frozen tensors.
Sourcefn named_parameters(&self) -> Vec<(String, Tensor<B>)>
fn named_parameters(&self) -> Vec<(String, Tensor<B>)>
Return all trainable parameters with human-readable names.
Leaf modules (Linear, Conv2d, etc.) override this to provide
meaningful names like "weight" / "bias". Composite modules
should concatenate sub-module names with a "." separator, e.g.
"fc1.weight", "attn.w_q.weight".
The default uses positional indices (param_0, param_1, …).
Sourcefn state_dict(&self) -> Vec<(String, Tensor<B>)>
fn state_dict(&self) -> Vec<(String, Tensor<B>)>
Returns a state_dict-style map of parameter name → tensor.
This is the idiomatic way to serialize a module.