shrew_nn/
init.rs

1// nn::init — Parameter Initialization Utilities
2//
3// Standalone functions for creating initialized tensors, following PyTorch's
4// `torch.nn.init` module. These are useful when building custom layers or
5// when you need fine-grained control over initialization.
6//
7// AVAILABLE INITIALIZERS:
8//
9//   uniform(shape, low, high)       — U(low, high)
10//   normal(shape, mean, std)        — N(mean, std)
11//   constant(shape, val)            — all elements = val
12//   zeros(shape)                    — all zeros
13//   ones(shape)                     — all ones
14//   xavier_uniform(shape, gain)     — Glorot uniform
15//   xavier_normal(shape, gain)      — Glorot normal
16//   kaiming_uniform(shape, a, mode) — He uniform (for ReLU)
17//   kaiming_normal(shape, a, mode)  — He normal  (for ReLU)
18//
19// All functions return Tensor<B> with `set_variable()` already called,
20// making them ready for gradient tracking.
21
22use shrew_core::backend::Backend;
23use shrew_core::dtype::DType;
24use shrew_core::error::Result;
25use shrew_core::shape::Shape;
26use shrew_core::tensor::Tensor;
27
28/// Fan computation mode for Kaiming initialization.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum FanMode {
31    /// Use fan_in (input features). Default, preserves variance in forward pass.
32    FanIn,
33    /// Use fan_out (output features). Preserves variance in backward pass.
34    FanOut,
35}
36
37/// Compute (fan_in, fan_out) from a shape.
38///
39/// - For 1-D: fan_in = fan_out = dims[0]
40/// - For 2-D: fan_in = dims[1], fan_out = dims[0]
41/// - For 3-D+: fan_in = dims[1] * product(dims[2..]),
42///   fan_out = dims[0] * product(dims[2..])
43///   (convolution-style: dims[0]=out_channels, dims[1]=in_channels, rest=kernel)
44fn compute_fans(shape: &Shape) -> (f64, f64) {
45    let dims = shape.dims();
46    match dims.len() {
47        0 => (1.0, 1.0),
48        1 => (dims[0] as f64, dims[0] as f64),
49        2 => (dims[1] as f64, dims[0] as f64),
50        _ => {
51            let receptive_field: usize = dims[2..].iter().product();
52            let fan_in = dims[1] as f64 * receptive_field as f64;
53            let fan_out = dims[0] as f64 * receptive_field as f64;
54            (fan_in, fan_out)
55        }
56    }
57}
58
59/// Initialize a tensor from a uniform distribution U(low, high).
60pub fn uniform<B: Backend>(
61    shape: impl Into<Shape>,
62    low: f64,
63    high: f64,
64    dtype: DType,
65    device: &B::Device,
66) -> Result<Tensor<B>> {
67    let shape = shape.into();
68    let range = high - low;
69    let t = Tensor::<B>::rand(shape, dtype, device)?
70        .affine(range, low)?
71        .set_variable();
72    Ok(t)
73}
74
75/// Initialize a tensor from a normal distribution N(mean, std).
76pub fn normal<B: Backend>(
77    shape: impl Into<Shape>,
78    mean: f64,
79    std: f64,
80    dtype: DType,
81    device: &B::Device,
82) -> Result<Tensor<B>> {
83    let shape = shape.into();
84    let t = Tensor::<B>::randn(shape, dtype, device)?
85        .affine(std, mean)?
86        .set_variable();
87    Ok(t)
88}
89
90/// Initialize a tensor with a constant value.
91pub fn constant<B: Backend>(
92    shape: impl Into<Shape>,
93    val: f64,
94    dtype: DType,
95    device: &B::Device,
96) -> Result<Tensor<B>> {
97    let t = Tensor::<B>::full(shape, val, dtype, device)?.set_variable();
98    Ok(t)
99}
100
101/// Initialize a tensor with all zeros (as a variable).
102pub fn zeros<B: Backend>(
103    shape: impl Into<Shape>,
104    dtype: DType,
105    device: &B::Device,
106) -> Result<Tensor<B>> {
107    let t = Tensor::<B>::zeros(shape, dtype, device)?.set_variable();
108    Ok(t)
109}
110
111/// Initialize a tensor with all ones (as a variable).
112pub fn ones<B: Backend>(
113    shape: impl Into<Shape>,
114    dtype: DType,
115    device: &B::Device,
116) -> Result<Tensor<B>> {
117    let t = Tensor::<B>::ones(shape, dtype, device)?.set_variable();
118    Ok(t)
119}
120
121/// Xavier (Glorot) uniform initialization.
122///
123/// Draws from U(-a, a) where a = gain * sqrt(6 / (fan_in + fan_out)).
124/// Designed to keep variance constant across layers with linear activations.
125///
126/// # Arguments
127/// - `gain`: scaling factor (1.0 for linear/sigmoid, sqrt(2) for ReLU)
128pub fn xavier_uniform<B: Backend>(
129    shape: impl Into<Shape>,
130    gain: f64,
131    dtype: DType,
132    device: &B::Device,
133) -> Result<Tensor<B>> {
134    let shape = shape.into();
135    let (fan_in, fan_out) = compute_fans(&shape);
136    let a = gain * (6.0 / (fan_in + fan_out)).sqrt();
137    uniform::<B>(shape, -a, a, dtype, device)
138}
139
140/// Xavier (Glorot) normal initialization.
141///
142/// Draws from N(0, std) where std = gain * sqrt(2 / (fan_in + fan_out)).
143pub fn xavier_normal<B: Backend>(
144    shape: impl Into<Shape>,
145    gain: f64,
146    dtype: DType,
147    device: &B::Device,
148) -> Result<Tensor<B>> {
149    let shape = shape.into();
150    let (fan_in, fan_out) = compute_fans(&shape);
151    let std = gain * (2.0 / (fan_in + fan_out)).sqrt();
152    normal::<B>(shape, 0.0, std, dtype, device)
153}
154
155/// Kaiming (He) uniform initialization.
156///
157/// Draws from U(-bound, bound) where bound = sqrt(3 * gain² / fan).
158/// Designed for layers followed by ReLU (or variants).
159///
160/// # Arguments
161/// - `a`: negative slope of the rectifier (0 for ReLU, 0.01 for LeakyReLU)
162/// - `mode`: `FanIn` or `FanOut`
163pub fn kaiming_uniform<B: Backend>(
164    shape: impl Into<Shape>,
165    a: f64,
166    mode: FanMode,
167    dtype: DType,
168    device: &B::Device,
169) -> Result<Tensor<B>> {
170    let shape = shape.into();
171    let (fan_in, fan_out) = compute_fans(&shape);
172    let fan = match mode {
173        FanMode::FanIn => fan_in,
174        FanMode::FanOut => fan_out,
175    };
176    let gain_sq = 2.0 / (1.0 + a * a);
177    let bound = (3.0 * gain_sq / fan).sqrt();
178    uniform::<B>(shape, -bound, bound, dtype, device)
179}
180
181/// Kaiming (He) normal initialization.
182///
183/// Draws from N(0, std) where std = sqrt(gain² / fan).
184///
185/// # Arguments
186/// - `a`: negative slope of the rectifier (0 for ReLU)
187/// - `mode`: `FanIn` or `FanOut`
188pub fn kaiming_normal<B: Backend>(
189    shape: impl Into<Shape>,
190    a: f64,
191    mode: FanMode,
192    dtype: DType,
193    device: &B::Device,
194) -> Result<Tensor<B>> {
195    let shape = shape.into();
196    let (fan_in, fan_out) = compute_fans(&shape);
197    let fan = match mode {
198        FanMode::FanIn => fan_in,
199        FanMode::FanOut => fan_out,
200    };
201    let gain_sq = 2.0 / (1.0 + a * a);
202    let std = (gain_sq / fan).sqrt();
203    normal::<B>(shape, 0.0, std, dtype, device)
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use shrew_cpu::{CpuBackend, CpuDevice};
210
211    type T = Tensor<CpuBackend>;
212
213    #[test]
214    fn test_xavier_uniform_shape() {
215        let dev = CpuDevice;
216        let t = xavier_uniform::<CpuBackend>((128, 64), 1.0, DType::F32, &dev).unwrap();
217        assert_eq!(t.dims(), &[128, 64]);
218        assert_eq!(t.dtype(), DType::F32);
219    }
220
221    #[test]
222    fn test_xavier_normal_shape() {
223        let dev = CpuDevice;
224        let t = xavier_normal::<CpuBackend>((64, 32), 1.0, DType::F64, &dev).unwrap();
225        assert_eq!(t.dims(), &[64, 32]);
226    }
227
228    #[test]
229    fn test_kaiming_uniform_bounds() {
230        let dev = CpuDevice;
231        // fan_in = 100 for shape (50, 100), gain = sqrt(2) for ReLU (a=0)
232        let t: T = kaiming_uniform((50, 100), 0.0, FanMode::FanIn, DType::F64, &dev).unwrap();
233        let v = t.to_f64_vec().unwrap();
234        let bound = (3.0 * 2.0 / 100.0_f64).sqrt(); // sqrt(6/100)
235        for &x in &v {
236            assert!(
237                x >= -bound - 1e-6 && x <= bound + 1e-6,
238                "value {} out of bounds [-{}, {}]",
239                x,
240                bound,
241                bound
242            );
243        }
244    }
245
246    #[test]
247    fn test_kaiming_normal_shape() {
248        let dev = CpuDevice;
249        let t: T = kaiming_normal((32, 16), 0.0, FanMode::FanOut, DType::F32, &dev).unwrap();
250        assert_eq!(t.dims(), &[32, 16]);
251    }
252
253    #[test]
254    fn test_constant_values() {
255        let dev = CpuDevice;
256        let t: T = constant((3, 4), 7.0, DType::F64, &dev).unwrap();
257        let v = t.to_f64_vec().unwrap();
258        for &x in &v {
259            assert!((x - 7.0).abs() < 1e-10);
260        }
261    }
262
263    #[test]
264    fn test_zeros_values() {
265        let dev = CpuDevice;
266        let t: T = zeros(5, DType::F64, &dev).unwrap();
267        let v = t.to_f64_vec().unwrap();
268        for &x in &v {
269            assert!(x.abs() < 1e-10);
270        }
271    }
272
273    #[test]
274    fn test_ones_values() {
275        let dev = CpuDevice;
276        let t: T = ones((2, 3), DType::F64, &dev).unwrap();
277        let v = t.to_f64_vec().unwrap();
278        for &x in &v {
279            assert!((x - 1.0).abs() < 1e-10);
280        }
281    }
282
283    #[test]
284    fn test_uniform_range() {
285        let dev = CpuDevice;
286        let t: T = uniform((1000,), -2.0, 3.0, DType::F64, &dev).unwrap();
287        let v = t.to_f64_vec().unwrap();
288        for &x in &v {
289            assert!(x >= -2.0 - 1e-6 && x <= 3.0 + 1e-6);
290        }
291    }
292
293    #[test]
294    fn test_normal_stats() {
295        let dev = CpuDevice;
296        let t: T = normal((10000,), 5.0, 0.1, DType::F64, &dev).unwrap();
297        let v = t.to_f64_vec().unwrap();
298        let mean: f64 = v.iter().sum::<f64>() / v.len() as f64;
299        assert!((mean - 5.0).abs() < 0.05, "mean {} too far from 5.0", mean);
300    }
301
302    #[test]
303    fn test_compute_fans_conv() {
304        // Conv2d: [out_ch=16, in_ch=3, kh=5, kw=5]
305        let shape = Shape::from((16, 3, 5, 5));
306        let (fan_in, fan_out) = compute_fans(&shape);
307        assert_eq!(fan_in, 3.0 * 25.0); // 75
308        assert_eq!(fan_out, 16.0 * 25.0); // 400
309    }
310}