1use shrew_core::backend::Backend;
25use shrew_core::dtype::DType;
26use shrew_core::error::Result;
27use shrew_core::shape::Shape;
28use shrew_core::tensor::Tensor;
29
30use crate::module::Module;
31
32pub struct Conv2d<B: Backend> {
44 weight: Tensor<B>,
46 bias: Option<Tensor<B>>,
48 in_channels: usize,
49 out_channels: usize,
50 kernel_size: [usize; 2],
51 stride: [usize; 2],
52 padding: [usize; 2],
53}
54
55impl<B: Backend> Conv2d<B> {
56 #[allow(clippy::too_many_arguments)]
68 pub fn new(
69 in_channels: usize,
70 out_channels: usize,
71 kernel_size: [usize; 2],
72 stride: [usize; 2],
73 padding: [usize; 2],
74 use_bias: bool,
75 dtype: DType,
76 device: &B::Device,
77 ) -> Result<Self> {
78 let [kh, kw] = kernel_size;
79 let fan_in = in_channels * kh * kw;
80 let k = (1.0 / fan_in as f64).sqrt();
81
82 let weight = Tensor::<B>::rand(
84 Shape::new(vec![out_channels, in_channels, kh, kw]),
85 dtype,
86 device,
87 )?
88 .affine(2.0 * k, -k)?
89 .set_variable();
90
91 let bias = if use_bias {
92 let b = Tensor::<B>::rand(Shape::new(vec![out_channels]), dtype, device)?
93 .affine(2.0 * k, -k)?
94 .set_variable();
95 Some(b)
96 } else {
97 None
98 };
99
100 Ok(Conv2d {
101 weight,
102 bias,
103 in_channels,
104 out_channels,
105 kernel_size,
106 stride,
107 padding,
108 })
109 }
110
111 pub fn from_tensors(
113 weight: Tensor<B>,
114 bias: Option<Tensor<B>>,
115 stride: [usize; 2],
116 padding: [usize; 2],
117 ) -> Result<Self> {
118 let dims = weight.dims();
119 if dims.len() != 4 {
120 return Err(shrew_core::Error::msg(format!(
121 "Conv2d weight must be 4D [C_out,C_in,kH,kW], got {:?}",
122 dims
123 )));
124 }
125 let out_channels = dims[0];
126 let in_channels = dims[1];
127 let kernel_size = [dims[2], dims[3]];
128 Ok(Conv2d {
129 weight: weight.set_variable(),
130 bias: bias.map(|b| b.set_variable()),
131 in_channels,
132 out_channels,
133 kernel_size,
134 stride,
135 padding,
136 })
137 }
138
139 pub fn in_channels(&self) -> usize {
140 self.in_channels
141 }
142 pub fn out_channels(&self) -> usize {
143 self.out_channels
144 }
145 pub fn kernel_size(&self) -> [usize; 2] {
146 self.kernel_size
147 }
148 pub fn stride(&self) -> [usize; 2] {
149 self.stride
150 }
151 pub fn padding(&self) -> [usize; 2] {
152 self.padding
153 }
154 pub fn weight(&self) -> &Tensor<B> {
155 &self.weight
156 }
157 pub fn bias(&self) -> Option<&Tensor<B>> {
158 self.bias.as_ref()
159 }
160}
161
162impl<B: Backend> Module<B> for Conv2d<B> {
163 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
168 x.conv2d(&self.weight, self.bias.as_ref(), self.stride, self.padding)
169 }
170
171 fn parameters(&self) -> Vec<Tensor<B>> {
172 let mut params = vec![self.weight.clone()];
173 if let Some(ref b) = self.bias {
174 params.push(b.clone());
175 }
176 params
177 }
178
179 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
180 let mut named = vec![("weight".to_string(), self.weight.clone())];
181 if let Some(ref b) = self.bias {
182 named.push(("bias".to_string(), b.clone()));
183 }
184 named
185 }
186}
187
188pub struct MaxPool2d {
198 kernel_size: [usize; 2],
199 stride: [usize; 2],
200 padding: [usize; 2],
201}
202
203impl MaxPool2d {
204 pub fn new(kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2]) -> Self {
211 MaxPool2d {
212 kernel_size,
213 stride,
214 padding,
215 }
216 }
217}
218
219impl<B: Backend> Module<B> for MaxPool2d {
220 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
221 x.max_pool2d(self.kernel_size, self.stride, self.padding)
222 }
223
224 fn parameters(&self) -> Vec<Tensor<B>> {
225 vec![] }
227}
228
229pub struct AvgPool2d {
239 kernel_size: [usize; 2],
240 stride: [usize; 2],
241 padding: [usize; 2],
242}
243
244impl AvgPool2d {
245 pub fn new(kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2]) -> Self {
252 AvgPool2d {
253 kernel_size,
254 stride,
255 padding,
256 }
257 }
258}
259
260impl<B: Backend> Module<B> for AvgPool2d {
261 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
262 x.avg_pool2d(self.kernel_size, self.stride, self.padding)
263 }
264
265 fn parameters(&self) -> Vec<Tensor<B>> {
266 vec![] }
268}
269
270#[allow(dead_code)]
282pub struct Conv1d<B: Backend> {
283 weight: Tensor<B>,
284 bias: Option<Tensor<B>>,
285 in_channels: usize,
286 out_channels: usize,
287 kernel_size: usize,
288 stride: usize,
289 padding: usize,
290}
291
292impl<B: Backend> Conv1d<B> {
293 #[allow(clippy::too_many_arguments)]
295 pub fn new(
296 in_channels: usize,
297 out_channels: usize,
298 kernel_size: usize,
299 stride: usize,
300 padding: usize,
301 use_bias: bool,
302 dtype: DType,
303 device: &B::Device,
304 ) -> Result<Self> {
305 let k = 1.0 / (in_channels as f64 * kernel_size as f64).sqrt();
307 let w_shape = Shape::new(vec![out_channels, in_channels, kernel_size]);
308 let weight = Tensor::<B>::rand(w_shape, dtype, device)?
309 .affine(2.0 * k, -k)?
310 .set_variable();
311
312 let bias = if use_bias {
313 let b_shape = Shape::new(vec![out_channels]);
314 Some(
315 Tensor::<B>::rand(b_shape, dtype, device)?
316 .affine(2.0 * k, -k)?
317 .set_variable(),
318 )
319 } else {
320 None
321 };
322
323 Ok(Conv1d {
324 weight,
325 bias,
326 in_channels,
327 out_channels,
328 kernel_size,
329 stride,
330 padding,
331 })
332 }
333}
334
335impl<B: Backend> Module<B> for Conv1d<B> {
336 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
337 x.conv1d(&self.weight, self.bias.as_ref(), self.stride, self.padding)
338 }
339
340 fn parameters(&self) -> Vec<Tensor<B>> {
341 let mut params = vec![self.weight.clone()];
342 if let Some(ref b) = self.bias {
343 params.push(b.clone());
344 }
345 params
346 }
347
348 fn named_parameters(&self) -> Vec<(String, Tensor<B>)> {
349 let mut named = vec![("weight".to_string(), self.weight.clone())];
350 if let Some(ref b) = self.bias {
351 named.push(("bias".to_string(), b.clone()));
352 }
353 named
354 }
355}
356
357pub struct AdaptiveAvgPool2d {
369 output_size: [usize; 2],
370}
371
372impl AdaptiveAvgPool2d {
373 pub fn new(output_size: [usize; 2]) -> Self {
375 AdaptiveAvgPool2d { output_size }
376 }
377}
378
379impl<B: Backend> Module<B> for AdaptiveAvgPool2d {
380 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
381 let dims = x.dims();
382 if dims.len() != 4 {
383 return Err(shrew_core::Error::msg(format!(
384 "AdaptiveAvgPool2d: expected 4D [N,C,H,W], got {:?}",
385 dims
386 )));
387 }
388 let h_in = dims[2];
389 let w_in = dims[3];
390 let [h_out, w_out] = self.output_size;
391
392 if h_out == 0 || w_out == 0 {
393 return Err(shrew_core::Error::msg(
394 "AdaptiveAvgPool2d: output_size must be > 0",
395 ));
396 }
397
398 let stride_h = h_in / h_out;
402 let stride_w = w_in / w_out;
403 let kernel_h = h_in - (h_out - 1) * stride_h;
404 let kernel_w = w_in - (w_out - 1) * stride_w;
405
406 x.avg_pool2d([kernel_h, kernel_w], [stride_h, stride_w], [0, 0])
407 }
408
409 fn parameters(&self) -> Vec<Tensor<B>> {
410 vec![]
411 }
412}