shrew_nn/
conv.rs

1// Conv2d & MaxPool2d — 2D convolutional layers
2//
3// Conv2d applies a set of learnable 2D convolution filters to an input
4// tensor of shape [N, C_in, H, W], producing [N, C_out, H_out, W_out].
5//
6// MaxPool2d performs 2D max-pooling (spatial down-sampling) by taking the
7// maximum value in each sliding window.
8//
9// WEIGHT INITIALIZATION (Conv2d):
10//
11//   Kaiming (He) uniform: U(-k, k) where k = sqrt(1 / (C_in * kH * kW)).
12//   This is the standard for layers followed by ReLU.
13//
14// PARAMETER SHAPES (Conv2d):
15//
16//   weight: [C_out, C_in, kH, kW]
17//   bias:   [C_out]                 (optional)
18//
19// OUTPUT SIZE FORMULA:
20//
21//   H_out = floor((H + 2*padding_h - kernel_h) / stride_h) + 1
22//   W_out = floor((W + 2*padding_w - kernel_w) / stride_w) + 1
23
24use shrew_core::backend::Backend;
25use shrew_core::dtype::DType;
26use shrew_core::error::Result;
27use shrew_core::shape::Shape;
28use shrew_core::tensor::Tensor;
29
30use crate::module::Module;
31
32/// 2D convolutional layer.
33///
34/// Applies a set of learnable filters to a 4D input `[N, C_in, H, W]`,
35/// producing output of shape `[N, C_out, H_out, W_out]`.
36///
37/// # Examples
38/// ```ignore
39/// let conv = Conv2d::<CpuBackend>::new(1, 16, [3, 3], [1, 1], [1, 1], true, DType::F32, &dev)?;
40/// let x = CpuTensor::rand((4, 1, 28, 28), DType::F32, &dev)?;
41/// let y = conv.forward(&x)?; // [4, 16, 28, 28]
42/// ```
43pub struct Conv2d<B: Backend> {
44    /// Convolution filters: [C_out, C_in, kH, kW]
45    weight: Tensor<B>,
46    /// Optional bias: [C_out]
47    bias: Option<Tensor<B>>,
48    in_channels: usize,
49    out_channels: usize,
50    kernel_size: [usize; 2],
51    stride: [usize; 2],
52    padding: [usize; 2],
53}
54
55impl<B: Backend> Conv2d<B> {
56    /// Create a new Conv2d layer with Kaiming uniform initialization.
57    ///
58    /// # Arguments
59    /// - `in_channels`:  number of input channels (C_in)
60    /// - `out_channels`: number of output channels / filters (C_out)
61    /// - `kernel_size`:  `[kH, kW]` spatial size of each filter
62    /// - `stride`:       `[sH, sW]` stride of the convolution
63    /// - `padding`:      `[pH, pW]` zero-padding added to both sides
64    /// - `use_bias`:     whether to include an additive bias
65    /// - `dtype`:        data type for parameters
66    /// - `device`:       device to create parameters on
67    #[allow(clippy::too_many_arguments)]
68    pub fn new(
69        in_channels: usize,
70        out_channels: usize,
71        kernel_size: [usize; 2],
72        stride: [usize; 2],
73        padding: [usize; 2],
74        use_bias: bool,
75        dtype: DType,
76        device: &B::Device,
77    ) -> Result<Self> {
78        let [kh, kw] = kernel_size;
79        let fan_in = in_channels * kh * kw;
80        let k = (1.0 / fan_in as f64).sqrt();
81
82        // weight: uniform in [-k, k], shape [C_out, C_in, kH, kW]
83        let weight = Tensor::<B>::rand(
84            Shape::new(vec![out_channels, in_channels, kh, kw]),
85            dtype,
86            device,
87        )?
88        .affine(2.0 * k, -k)?
89        .set_variable();
90
91        let bias = if use_bias {
92            let b = Tensor::<B>::rand(Shape::new(vec![out_channels]), dtype, device)?
93                .affine(2.0 * k, -k)?
94                .set_variable();
95            Some(b)
96        } else {
97            None
98        };
99
100        Ok(Conv2d {
101            weight,
102            bias,
103            in_channels,
104            out_channels,
105            kernel_size,
106            stride,
107            padding,
108        })
109    }
110
111    /// Create a Conv2d from existing weight and bias tensors (e.g. for loading).
112    pub fn from_tensors(
113        weight: Tensor<B>,
114        bias: Option<Tensor<B>>,
115        stride: [usize; 2],
116        padding: [usize; 2],
117    ) -> Result<Self> {
118        let dims = weight.dims();
119        if dims.len() != 4 {
120            return Err(shrew_core::Error::msg(format!(
121                "Conv2d weight must be 4D [C_out,C_in,kH,kW], got {:?}",
122                dims
123            )));
124        }
125        let out_channels = dims[0];
126        let in_channels = dims[1];
127        let kernel_size = [dims[2], dims[3]];
128        Ok(Conv2d {
129            weight: weight.set_variable(),
130            bias: bias.map(|b| b.set_variable()),
131            in_channels,
132            out_channels,
133            kernel_size,
134            stride,
135            padding,
136        })
137    }
138
139    pub fn in_channels(&self) -> usize {
140        self.in_channels
141    }
142    pub fn out_channels(&self) -> usize {
143        self.out_channels
144    }
145    pub fn kernel_size(&self) -> [usize; 2] {
146        self.kernel_size
147    }
148    pub fn stride(&self) -> [usize; 2] {
149        self.stride
150    }
151    pub fn padding(&self) -> [usize; 2] {
152        self.padding
153    }
154    pub fn weight(&self) -> &Tensor<B> {
155        &self.weight
156    }
157    pub fn bias(&self) -> Option<&Tensor<B>> {
158        self.bias.as_ref()
159    }
160}
161
162impl<B: Backend> Module<B> for Conv2d<B> {
163    /// Forward pass: 2D convolution.
164    ///
165    /// Input:  `[N, C_in, H, W]`
166    /// Output: `[N, C_out, H_out, W_out]`
167    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
168        x.conv2d(&self.weight, self.bias.as_ref(), self.stride, self.padding)
169    }
170
171    fn parameters(&self) -> Vec<Tensor<B>> {
172        let mut params = vec![self.weight.clone()];
173        if let Some(ref b) = self.bias {
174            params.push(b.clone());
175        }
176        params
177    }
178
179    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
180        let mut named = vec![("weight".to_string(), self.weight.clone())];
181        if let Some(ref b) = self.bias {
182            named.push(("bias".to_string(), b.clone()));
183        }
184        named
185    }
186}
187
188// MaxPool2d
189
190/// 2D max-pooling layer.
191///
192/// Slides a window of `kernel_size` over the input's spatial dimensions
193/// and takes the max in each window.
194///
195/// Input:  `[N, C, H, W]`
196/// Output: `[N, C, H_out, W_out]`
197pub struct MaxPool2d {
198    kernel_size: [usize; 2],
199    stride: [usize; 2],
200    padding: [usize; 2],
201}
202
203impl MaxPool2d {
204    /// Create a new MaxPool2d layer.
205    ///
206    /// # Arguments
207    /// - `kernel_size`: `[kH, kW]`
208    /// - `stride`:      `[sH, sW]` — typically equal to kernel_size
209    /// - `padding`:     `[pH, pW]`
210    pub fn new(kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2]) -> Self {
211        MaxPool2d {
212            kernel_size,
213            stride,
214            padding,
215        }
216    }
217}
218
219impl<B: Backend> Module<B> for MaxPool2d {
220    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
221        x.max_pool2d(self.kernel_size, self.stride, self.padding)
222    }
223
224    fn parameters(&self) -> Vec<Tensor<B>> {
225        vec![] // No learnable parameters
226    }
227}
228
229// AvgPool2d
230
231/// 2D average-pooling layer.
232///
233/// Slides a window of `kernel_size` over the input's spatial dimensions
234/// and takes the mean in each window.
235///
236/// Input:  `[N, C, H, W]`
237/// Output: `[N, C, H_out, W_out]`
238pub struct AvgPool2d {
239    kernel_size: [usize; 2],
240    stride: [usize; 2],
241    padding: [usize; 2],
242}
243
244impl AvgPool2d {
245    /// Create a new AvgPool2d layer.
246    ///
247    /// # Arguments
248    /// - `kernel_size`: `[kH, kW]`
249    /// - `stride`:      `[sH, sW]` — typically equal to kernel_size
250    /// - `padding`:     `[pH, pW]`
251    pub fn new(kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2]) -> Self {
252        AvgPool2d {
253            kernel_size,
254            stride,
255            padding,
256        }
257    }
258}
259
260impl<B: Backend> Module<B> for AvgPool2d {
261    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
262        x.avg_pool2d(self.kernel_size, self.stride, self.padding)
263    }
264
265    fn parameters(&self) -> Vec<Tensor<B>> {
266        vec![] // No learnable parameters
267    }
268}
269
270// Conv1d
271
272/// 1D convolution layer.
273///
274/// Input:  `[N, C_in, L]`
275/// Output: `[N, C_out, L_out]`
276///
277/// where `L_out = (L + 2*padding - kernel_size) / stride + 1`.
278///
279/// Weight shape: `[C_out, C_in, K]`
280/// Bias shape:   `[C_out]` (optional)
281#[allow(dead_code)]
282pub struct Conv1d<B: Backend> {
283    weight: Tensor<B>,
284    bias: Option<Tensor<B>>,
285    in_channels: usize,
286    out_channels: usize,
287    kernel_size: usize,
288    stride: usize,
289    padding: usize,
290}
291
292impl<B: Backend> Conv1d<B> {
293    /// Create a new Conv1d layer with Kaiming-uniform initialization.
294    #[allow(clippy::too_many_arguments)]
295    pub fn new(
296        in_channels: usize,
297        out_channels: usize,
298        kernel_size: usize,
299        stride: usize,
300        padding: usize,
301        use_bias: bool,
302        dtype: DType,
303        device: &B::Device,
304    ) -> Result<Self> {
305        // Kaiming uniform initialization
306        let k = 1.0 / (in_channels as f64 * kernel_size as f64).sqrt();
307        let w_shape = Shape::new(vec![out_channels, in_channels, kernel_size]);
308        let weight = Tensor::<B>::rand(w_shape, dtype, device)?
309            .affine(2.0 * k, -k)?
310            .set_variable();
311
312        let bias = if use_bias {
313            let b_shape = Shape::new(vec![out_channels]);
314            Some(
315                Tensor::<B>::rand(b_shape, dtype, device)?
316                    .affine(2.0 * k, -k)?
317                    .set_variable(),
318            )
319        } else {
320            None
321        };
322
323        Ok(Conv1d {
324            weight,
325            bias,
326            in_channels,
327            out_channels,
328            kernel_size,
329            stride,
330            padding,
331        })
332    }
333}
334
335impl<B: Backend> Module<B> for Conv1d<B> {
336    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
337        x.conv1d(&self.weight, self.bias.as_ref(), self.stride, self.padding)
338    }
339
340    fn parameters(&self) -> Vec<Tensor<B>> {
341        let mut params = vec![self.weight.clone()];
342        if let Some(ref b) = self.bias {
343            params.push(b.clone());
344        }
345        params
346    }
347
348    fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
349        let mut named = vec![("weight".to_string(), self.weight.clone())];
350        if let Some(ref b) = self.bias {
351            named.push(("bias".to_string(), b.clone()));
352        }
353        named
354    }
355}
356
357// AdaptiveAvgPool2d
358
359/// Adaptive 2D Average Pooling — pools to a fixed output size.
360///
361/// Automatically computes kernel_size, stride, and padding to produce
362/// the desired output spatial dimensions, regardless of input size.
363///
364/// Input:  `[N, C, H_in, W_in]`
365/// Output: `[N, C, H_out, W_out]`
366///
367/// Common use: `AdaptiveAvgPool2d([1, 1])` — global average pooling.
368pub struct AdaptiveAvgPool2d {
369    output_size: [usize; 2],
370}
371
372impl AdaptiveAvgPool2d {
373    /// Create an AdaptiveAvgPool2d with the desired output spatial size.
374    pub fn new(output_size: [usize; 2]) -> Self {
375        AdaptiveAvgPool2d { output_size }
376    }
377}
378
379impl<B: Backend> Module<B> for AdaptiveAvgPool2d {
380    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
381        let dims = x.dims();
382        if dims.len() != 4 {
383            return Err(shrew_core::Error::msg(format!(
384                "AdaptiveAvgPool2d: expected 4D [N,C,H,W], got {:?}",
385                dims
386            )));
387        }
388        let h_in = dims[2];
389        let w_in = dims[3];
390        let [h_out, w_out] = self.output_size;
391
392        if h_out == 0 || w_out == 0 {
393            return Err(shrew_core::Error::msg(
394                "AdaptiveAvgPool2d: output_size must be > 0",
395            ));
396        }
397
398        // Compute kernel, stride, padding to achieve desired output size.
399        // Formula: output = floor((input + 2*pad - kernel) / stride) + 1
400        // Simplest: stride = input / output, kernel = input - (output-1)*stride
401        let stride_h = h_in / h_out;
402        let stride_w = w_in / w_out;
403        let kernel_h = h_in - (h_out - 1) * stride_h;
404        let kernel_w = w_in - (w_out - 1) * stride_w;
405
406        x.avg_pool2d([kernel_h, kernel_w], [stride_h, stride_w], [0, 0])
407    }
408
409    fn parameters(&self) -> Vec<Tensor<B>> {
410        vec![]
411    }
412}