shrew_nn/
linear.rs

1// Linear — Fully-connected (dense) layer
2//
3// The most fundamental neural network layer: y = xW^T + b
4//
5// Linear(in_features, out_features) transforms an input of shape
6// [..., in_features] to [..., out_features] — a matrix multiplication
7// followed by an optional bias addition.
8//
9// WEIGHT INITIALIZATION:
10//
11// We use Kaiming (He) uniform initialization, which is designed for ReLU
12// networks. The weights are drawn from U(-k, k) where k = sqrt(1/in_features).
13// This prevents the signal from vanishing or exploding as it passes through
14// many layers — critical for training deep networks.
15//
16// PARAMETER SHAPES:
17//
18//   weight: [out_features, in_features]  — stored transposed for efficient matmul
19//   bias:   [1, out_features]            — broadcast across batch dimension
20//
21// COMPUTATION:
22//
23//   y = x @ weight^T + bias
24//   Input:  [batch, in_features]
25//   Output: [batch, out_features]
26
27use shrew_core::backend::Backend;
28use shrew_core::dtype::DType;
29use shrew_core::error::Result;
30use shrew_core::tensor::Tensor;
31
32use crate::module::Module;
33
34/// A fully-connected (dense) layer: y = xW^T + b.
35///
36/// # Type Parameters
37/// - `B`: the compute backend
38///
39/// # Examples
40/// ```ignore
41/// let linear = Linear::<CpuBackend>::new(784, 128, true, DType::F32, &dev)?;
42/// let x = CpuTensor::rand((32, 784), DType::F32, &dev)?; // batch of 32
43/// let y = linear.forward(&x)?; // shape: [32, 128]
44/// ```
45pub struct Linear<B: Backend> {
46    /// Weight matrix: [out_features, in_features]
47    weight: Tensor<B>,
48    /// Optional bias vector: [1, out_features]
49    bias: Option<Tensor<B>>,
50    in_features: usize,
51    out_features: usize,
52}
53
54impl<B: Backend> Linear<B> {
55    /// Create a new Linear layer with Kaiming uniform initialization.
56    ///
57    /// # Arguments
58    /// - `in_features`: size of each input sample
59    /// - `out_features`: size of each output sample
60    /// - `use_bias`: whether to add a learnable bias
61    /// - `dtype`: data type for parameters
62    /// - `device`: device to create parameters on
63    pub fn new(
64        in_features: usize,
65        out_features: usize,
66        use_bias: bool,
67        dtype: DType,
68        device: &B::Device,
69    ) -> Result<Self> {
70        // Kaiming uniform: U(-k, k) where k = sqrt(1/in_features)
71        // This is the standard initialization for layers followed by ReLU.
72        let k = (1.0 / in_features as f64).sqrt();
73
74        // weight = rand_uniform * 2k - k  →  uniform in [-k, k]
75        let weight = Tensor::<B>::rand((out_features, in_features), dtype, device)?
76            .affine(2.0 * k, -k)?
77            .set_variable();
78
79        let bias = if use_bias {
80            // Bias initialized to uniform [-k, k] as well
81            let b = Tensor::<B>::rand((1, out_features), dtype, device)?
82                .affine(2.0 * k, -k)?
83                .set_variable();
84            Some(b)
85        } else {
86            None
87        };
88
89        Ok(Linear {
90            weight,
91            bias,
92            in_features,
93            out_features,
94        })
95    }
96
97    /// Create a Linear layer from existing weight and bias tensors.
98    /// Useful for loading pre-trained models.
99    pub fn from_tensors(weight: Tensor<B>, bias: Option<Tensor<B>>) -> Result<Self> {
100        let dims = weight.dims();
101        if dims.len() != 2 {
102            return Err(shrew_core::Error::msg(format!(
103                "Linear weight must be 2D, got shape {:?}",
104                dims
105            )));
106        }
107        let out_features = dims[0];
108        let in_features = dims[1];
109        Ok(Linear {
110            weight: weight.set_variable(),
111            bias: bias.map(|b| b.set_variable()),
112            in_features,
113            out_features,
114        })
115    }
116
117    /// The input feature dimension.
118    pub fn in_features(&self) -> usize {
119        self.in_features
120    }
121
122    /// The output feature dimension.
123    pub fn out_features(&self) -> usize {
124        self.out_features
125    }
126
127    /// Direct access to the weight tensor.
128    pub fn weight(&self) -> &Tensor<B> {
129        &self.weight
130    }
131
132    /// Direct access to the bias tensor (if any).
133    pub fn bias(&self) -> Option<&Tensor<B>> {
134        self.bias.as_ref()
135    }
136}
137
138impl<B: Backend> Module<B> for Linear<B> {
139    /// Forward pass: y = x @ W^T + b
140    ///
141    /// Input shape:  [batch, in_features]
142    /// Output shape: [batch, out_features]
143    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
144        // x: [batch, in_features]
145        // weight: [out_features, in_features]
146        // weight^T: [in_features, out_features]
147        // x @ weight^T: [batch, out_features]
148        let wt = self.weight.t()?.contiguous()?;
149        let output = x.matmul(&wt)?;
150
151        match &self.bias {
152            Some(bias) => {
153                // bias shape: [1, out_features] — broadcasts over batch dim
154                output.add(bias)
155            }
156            None => Ok(output),
157        }
158    }
159
160    fn parameters(&self) -> Vec<Tensor<B>> {
161        let mut params = vec![self.weight.clone()];
162        if let Some(ref b) = self.bias {
163            params.push(b.clone());
164        }
165        params
166    }
167
168    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
169        let mut named = vec![("weight".to_string(), self.weight.clone())];
170        if let Some(ref b) = self.bias {
171            named.push(("bias".to_string(), b.clone()));
172        }
173        named
174    }
175}