Module

Trait Module 

Source
pub trait Module<B>
where B: Backend,
{ // Required methods fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>, Error>; fn parameters(&self) -> Vec<Tensor<B>>; // Provided methods fn set_training(&self, _training: bool) { ... } fn is_training(&self) -> bool { ... } fn train(&self) { ... } fn eval(&self) { ... } fn num_parameters(&self) -> usize { ... } fn trainable_params_count(&self) -> usize { ... } fn frozen_parameters(&self) -> Vec<Tensor<B>> { ... } fn named_parameters(&self) -> Vec<(String, Tensor<B>)> { ... } fn state_dict(&self) -> Vec<(String, Tensor<B>)> { ... } }
Expand description

The fundamental trait for all neural network layers.

Every layer in Shrew implements this trait, providing:

  • forward(): compute output from input (the actual computation)
  • parameters(): list all trainable tensors (for optimizer updates)
  • set_training() / is_training(): toggle train/eval mode
  • num_parameters() / trainable_params_count(): count parameters

§Example

struct MyLayer<B: Backend> {
    linear: Linear<B>,
}

impl<B: Backend> Module<B> for MyLayer<B> {
    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
        self.linear.forward(x)?.relu()
    }
    fn parameters(&self) -> Vec<Tensor<B>> {
        self.linear.parameters()
    }
}

Required Methods§

Source

fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>, Error>

Compute the output tensor from the input tensor. This defines the layer’s computation (forward pass).

Source

fn parameters(&self) -> Vec<Tensor<B>>

Return all trainable parameters of this module. The optimizer uses these to update weights during training.

Provided Methods§

Source

fn set_training(&self, _training: bool)

Set training or evaluation mode.

Override in modules that behave differently in train vs eval (e.g., Dropout, BatchNorm). Default is a no-op.

Uses interior mutability (Cell<bool>) so &self suffices.

Source

fn is_training(&self) -> bool

Whether the module is in training mode (default: true).

Source

fn train(&self)

Convenience: set training mode.

Source

fn eval(&self)

Convenience: set evaluation mode.

Source

fn num_parameters(&self) -> usize

Total number of scalar parameters in this module.

Source

fn trainable_params_count(&self) -> usize

Number of trainable (variable) parameters.

Source

fn frozen_parameters(&self) -> Vec<Tensor<B>>

Freeze all parameters: returns new parameter tensors with is_variable = false, preventing gradient accumulation.

The caller must rebuild the module with the frozen tensors.

Source

fn named_parameters(&self) -> Vec<(String, Tensor<B>)>

Return all trainable parameters with human-readable names.

Leaf modules (Linear, Conv2d, etc.) override this to provide meaningful names like "weight" / "bias". Composite modules should concatenate sub-module names with a "." separator, e.g. "fc1.weight", "attn.w_q.weight".

The default uses positional indices (param_0, param_1, …).

Source

fn state_dict(&self) -> Vec<(String, Tensor<B>)>

Returns a state_dict-style map of parameter name → tensor.

This is the idiomatic way to serialize a module.

Implementors§

Source§

impl<B> Module<B> for AdaptiveAvgPool2d
where B: Backend,

Source§

impl<B> Module<B> for AvgPool2d
where B: Backend,

Source§

impl<B> Module<B> for BatchNorm2d<B>
where B: Backend,

Source§

impl<B> Module<B> for Conv1d<B>
where B: Backend,

Source§

impl<B> Module<B> for Conv2d<B>
where B: Backend,

Source§

impl<B> Module<B> for Dropout
where B: Backend,

Source§

impl<B> Module<B> for ELU
where B: Backend,

Source§

impl<B> Module<B> for Embedding<B>
where B: Backend,

Source§

impl<B> Module<B> for Flatten
where B: Backend,

Source§

impl<B> Module<B> for GeLU
where B: Backend,

Source§

impl<B> Module<B> for GroupNorm<B>
where B: Backend,

Source§

impl<B> Module<B> for LayerNorm<B>
where B: Backend,

Source§

impl<B> Module<B> for LeakyReLU
where B: Backend,

Source§

impl<B> Module<B> for Linear<B>
where B: Backend,

Source§

impl<B> Module<B> for MaxPool2d
where B: Backend,

Source§

impl<B> Module<B> for Mish
where B: Backend,

Source§

impl<B> Module<B> for MultiHeadAttention<B>
where B: Backend,

Source§

impl<B> Module<B> for RMSNorm<B>
where B: Backend,

Source§

impl<B> Module<B> for ReLU
where B: Backend,

Source§

impl<B> Module<B> for Sequential<B>
where B: Backend,

Source§

impl<B> Module<B> for SiLU
where B: Backend,

Source§

impl<B> Module<B> for Sigmoid
where B: Backend,

Source§

impl<B> Module<B> for Tanh
where B: Backend,

Source§

impl<B> Module<B> for TransformerBlock<B>
where B: Backend,

Source§

impl<B: Backend> Module<B> for QuantizedLinear<B>

Source§

impl<M, B> Module<B> for DataParallel<M>
where M: Module<B> + Send + Sync, B: Backend,