shrew_nn/linear.rs
1// Linear — Fully-connected (dense) layer
2//
3// The most fundamental neural network layer: y = xW^T + b
4//
5// Linear(in_features, out_features) transforms an input of shape
6// [..., in_features] to [..., out_features] — a matrix multiplication
7// followed by an optional bias addition.
8//
9// WEIGHT INITIALIZATION:
10//
11// We use Kaiming (He) uniform initialization, which is designed for ReLU
12// networks. The weights are drawn from U(-k, k) where k = sqrt(1/in_features).
13// This prevents the signal from vanishing or exploding as it passes through
14// many layers — critical for training deep networks.
15//
16// PARAMETER SHAPES:
17//
18// weight: [out_features, in_features] — stored transposed for efficient matmul
19// bias: [1, out_features] — broadcast across batch dimension
20//
21// COMPUTATION:
22//
23// y = x @ weight^T + bias
24// Input: [batch, in_features]
25// Output: [batch, out_features]
26
27use shrew_core::backend::Backend;
28use shrew_core::dtype::DType;
29use shrew_core::error::Result;
30use shrew_core::tensor::Tensor;
31
32use crate::module::Module;
33
34/// A fully-connected (dense) layer: y = xW^T + b.
35///
36/// # Type Parameters
37/// - `B`: the compute backend
38///
39/// # Examples
40/// ```ignore
41/// let linear = Linear::<CpuBackend>::new(784, 128, true, DType::F32, &dev)?;
42/// let x = CpuTensor::rand((32, 784), DType::F32, &dev)?; // batch of 32
43/// let y = linear.forward(&x)?; // shape: [32, 128]
44/// ```
45pub struct Linear<B: Backend> {
46 /// Weight matrix: [out_features, in_features]
47 weight: Tensor<B>,
48 /// Optional bias vector: [1, out_features]
49 bias: Option<Tensor<B>>,
50 in_features: usize,
51 out_features: usize,
52}
53
54impl<B: Backend> Linear<B> {
55 /// Create a new Linear layer with Kaiming uniform initialization.
56 ///
57 /// # Arguments
58 /// - `in_features`: size of each input sample
59 /// - `out_features`: size of each output sample
60 /// - `use_bias`: whether to add a learnable bias
61 /// - `dtype`: data type for parameters
62 /// - `device`: device to create parameters on
63 pub fn new(
64 in_features: usize,
65 out_features: usize,
66 use_bias: bool,
67 dtype: DType,
68 device: &B::Device,
69 ) -> Result<Self> {
70 // Kaiming uniform: U(-k, k) where k = sqrt(1/in_features)
71 // This is the standard initialization for layers followed by ReLU.
72 let k = (1.0 / in_features as f64).sqrt();
73
74 // weight = rand_uniform * 2k - k → uniform in [-k, k]
75 let weight = Tensor::<B>::rand((out_features, in_features), dtype, device)?
76 .affine(2.0 * k, -k)?
77 .set_variable();
78
79 let bias = if use_bias {
80 // Bias initialized to uniform [-k, k] as well
81 let b = Tensor::<B>::rand((1, out_features), dtype, device)?
82 .affine(2.0 * k, -k)?
83 .set_variable();
84 Some(b)
85 } else {
86 None
87 };
88
89 Ok(Linear {
90 weight,
91 bias,
92 in_features,
93 out_features,
94 })
95 }
96
97 /// Create a Linear layer from existing weight and bias tensors.
98 /// Useful for loading pre-trained models.
99 pub fn from_tensors(weight: Tensor<B>, bias: Option<Tensor<B>>) -> Result<Self> {
100 let dims = weight.dims();
101 if dims.len() != 2 {
102 return Err(shrew_core::Error::msg(format!(
103 "Linear weight must be 2D, got shape {:?}",
104 dims
105 )));
106 }
107 let out_features = dims[0];
108 let in_features = dims[1];
109 Ok(Linear {
110 weight: weight.set_variable(),
111 bias: bias.map(|b| b.set_variable()),
112 in_features,
113 out_features,
114 })
115 }
116
117 /// The input feature dimension.
118 pub fn in_features(&self) -> usize {
119 self.in_features
120 }
121
122 /// The output feature dimension.
123 pub fn out_features(&self) -> usize {
124 self.out_features
125 }
126
127 /// Direct access to the weight tensor.
128 pub fn weight(&self) -> &Tensor<B> {
129 &self.weight
130 }
131
132 /// Direct access to the bias tensor (if any).
133 pub fn bias(&self) -> Option<&Tensor<B>> {
134 self.bias.as_ref()
135 }
136}
137
138impl<B: Backend> Module<B> for Linear<B> {
139 /// Forward pass: y = x @ W^T + b
140 ///
141 /// Input shape: [batch, in_features]
142 /// Output shape: [batch, out_features]
143 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
144 // x: [batch, in_features]
145 // weight: [out_features, in_features]
146 // weight^T: [in_features, out_features]
147 // x @ weight^T: [batch, out_features]
148 let wt = self.weight.t()?.contiguous()?;
149 let output = x.matmul(&wt)?;
150
151 match &self.bias {
152 Some(bias) => {
153 // bias shape: [1, out_features] — broadcasts over batch dim
154 output.add(bias)
155 }
156 None => Ok(output),
157 }
158 }
159
160 fn parameters(&self) -> Vec<Tensor<B>> {
161 let mut params = vec![self.weight.clone()];
162 if let Some(ref b) = self.bias {
163 params.push(b.clone());
164 }
165 params
166 }
167
168 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
169 let mut named = vec![("weight".to_string(), self.weight.clone())];
170 if let Some(ref b) = self.bias {
171 named.push(("bias".to_string(), b.clone()));
172 }
173 named
174 }
175}