1use shrew_core::backend::Backend;
23use shrew_core::dtype::DType;
24use shrew_core::error::Result;
25use shrew_core::shape::Shape;
26use shrew_core::tensor::Tensor;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum FanMode {
31 FanIn,
33 FanOut,
35}
36
37fn compute_fans(shape: &Shape) -> (f64, f64) {
45 let dims = shape.dims();
46 match dims.len() {
47 0 => (1.0, 1.0),
48 1 => (dims[0] as f64, dims[0] as f64),
49 2 => (dims[1] as f64, dims[0] as f64),
50 _ => {
51 let receptive_field: usize = dims[2..].iter().product();
52 let fan_in = dims[1] as f64 * receptive_field as f64;
53 let fan_out = dims[0] as f64 * receptive_field as f64;
54 (fan_in, fan_out)
55 }
56 }
57}
58
59pub fn uniform<B: Backend>(
61 shape: impl Into<Shape>,
62 low: f64,
63 high: f64,
64 dtype: DType,
65 device: &B::Device,
66) -> Result<Tensor<B>> {
67 let shape = shape.into();
68 let range = high - low;
69 let t = Tensor::<B>::rand(shape, dtype, device)?
70 .affine(range, low)?
71 .set_variable();
72 Ok(t)
73}
74
75pub fn normal<B: Backend>(
77 shape: impl Into<Shape>,
78 mean: f64,
79 std: f64,
80 dtype: DType,
81 device: &B::Device,
82) -> Result<Tensor<B>> {
83 let shape = shape.into();
84 let t = Tensor::<B>::randn(shape, dtype, device)?
85 .affine(std, mean)?
86 .set_variable();
87 Ok(t)
88}
89
90pub fn constant<B: Backend>(
92 shape: impl Into<Shape>,
93 val: f64,
94 dtype: DType,
95 device: &B::Device,
96) -> Result<Tensor<B>> {
97 let t = Tensor::<B>::full(shape, val, dtype, device)?.set_variable();
98 Ok(t)
99}
100
101pub fn zeros<B: Backend>(
103 shape: impl Into<Shape>,
104 dtype: DType,
105 device: &B::Device,
106) -> Result<Tensor<B>> {
107 let t = Tensor::<B>::zeros(shape, dtype, device)?.set_variable();
108 Ok(t)
109}
110
111pub fn ones<B: Backend>(
113 shape: impl Into<Shape>,
114 dtype: DType,
115 device: &B::Device,
116) -> Result<Tensor<B>> {
117 let t = Tensor::<B>::ones(shape, dtype, device)?.set_variable();
118 Ok(t)
119}
120
121pub fn xavier_uniform<B: Backend>(
129 shape: impl Into<Shape>,
130 gain: f64,
131 dtype: DType,
132 device: &B::Device,
133) -> Result<Tensor<B>> {
134 let shape = shape.into();
135 let (fan_in, fan_out) = compute_fans(&shape);
136 let a = gain * (6.0 / (fan_in + fan_out)).sqrt();
137 uniform::<B>(shape, -a, a, dtype, device)
138}
139
140pub fn xavier_normal<B: Backend>(
144 shape: impl Into<Shape>,
145 gain: f64,
146 dtype: DType,
147 device: &B::Device,
148) -> Result<Tensor<B>> {
149 let shape = shape.into();
150 let (fan_in, fan_out) = compute_fans(&shape);
151 let std = gain * (2.0 / (fan_in + fan_out)).sqrt();
152 normal::<B>(shape, 0.0, std, dtype, device)
153}
154
155pub fn kaiming_uniform<B: Backend>(
164 shape: impl Into<Shape>,
165 a: f64,
166 mode: FanMode,
167 dtype: DType,
168 device: &B::Device,
169) -> Result<Tensor<B>> {
170 let shape = shape.into();
171 let (fan_in, fan_out) = compute_fans(&shape);
172 let fan = match mode {
173 FanMode::FanIn => fan_in,
174 FanMode::FanOut => fan_out,
175 };
176 let gain_sq = 2.0 / (1.0 + a * a);
177 let bound = (3.0 * gain_sq / fan).sqrt();
178 uniform::<B>(shape, -bound, bound, dtype, device)
179}
180
181pub fn kaiming_normal<B: Backend>(
189 shape: impl Into<Shape>,
190 a: f64,
191 mode: FanMode,
192 dtype: DType,
193 device: &B::Device,
194) -> Result<Tensor<B>> {
195 let shape = shape.into();
196 let (fan_in, fan_out) = compute_fans(&shape);
197 let fan = match mode {
198 FanMode::FanIn => fan_in,
199 FanMode::FanOut => fan_out,
200 };
201 let gain_sq = 2.0 / (1.0 + a * a);
202 let std = (gain_sq / fan).sqrt();
203 normal::<B>(shape, 0.0, std, dtype, device)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use shrew_cpu::{CpuBackend, CpuDevice};
210
211 type T = Tensor<CpuBackend>;
212
213 #[test]
214 fn test_xavier_uniform_shape() {
215 let dev = CpuDevice;
216 let t = xavier_uniform::<CpuBackend>((128, 64), 1.0, DType::F32, &dev).unwrap();
217 assert_eq!(t.dims(), &[128, 64]);
218 assert_eq!(t.dtype(), DType::F32);
219 }
220
221 #[test]
222 fn test_xavier_normal_shape() {
223 let dev = CpuDevice;
224 let t = xavier_normal::<CpuBackend>((64, 32), 1.0, DType::F64, &dev).unwrap();
225 assert_eq!(t.dims(), &[64, 32]);
226 }
227
228 #[test]
229 fn test_kaiming_uniform_bounds() {
230 let dev = CpuDevice;
231 let t: T = kaiming_uniform((50, 100), 0.0, FanMode::FanIn, DType::F64, &dev).unwrap();
233 let v = t.to_f64_vec().unwrap();
234 let bound = (3.0 * 2.0 / 100.0_f64).sqrt(); for &x in &v {
236 assert!(
237 x >= -bound - 1e-6 && x <= bound + 1e-6,
238 "value {} out of bounds [-{}, {}]",
239 x,
240 bound,
241 bound
242 );
243 }
244 }
245
246 #[test]
247 fn test_kaiming_normal_shape() {
248 let dev = CpuDevice;
249 let t: T = kaiming_normal((32, 16), 0.0, FanMode::FanOut, DType::F32, &dev).unwrap();
250 assert_eq!(t.dims(), &[32, 16]);
251 }
252
253 #[test]
254 fn test_constant_values() {
255 let dev = CpuDevice;
256 let t: T = constant((3, 4), 7.0, DType::F64, &dev).unwrap();
257 let v = t.to_f64_vec().unwrap();
258 for &x in &v {
259 assert!((x - 7.0).abs() < 1e-10);
260 }
261 }
262
263 #[test]
264 fn test_zeros_values() {
265 let dev = CpuDevice;
266 let t: T = zeros(5, DType::F64, &dev).unwrap();
267 let v = t.to_f64_vec().unwrap();
268 for &x in &v {
269 assert!(x.abs() < 1e-10);
270 }
271 }
272
273 #[test]
274 fn test_ones_values() {
275 let dev = CpuDevice;
276 let t: T = ones((2, 3), DType::F64, &dev).unwrap();
277 let v = t.to_f64_vec().unwrap();
278 for &x in &v {
279 assert!((x - 1.0).abs() < 1e-10);
280 }
281 }
282
283 #[test]
284 fn test_uniform_range() {
285 let dev = CpuDevice;
286 let t: T = uniform((1000,), -2.0, 3.0, DType::F64, &dev).unwrap();
287 let v = t.to_f64_vec().unwrap();
288 for &x in &v {
289 assert!(x >= -2.0 - 1e-6 && x <= 3.0 + 1e-6);
290 }
291 }
292
293 #[test]
294 fn test_normal_stats() {
295 let dev = CpuDevice;
296 let t: T = normal((10000,), 5.0, 0.1, DType::F64, &dev).unwrap();
297 let v = t.to_f64_vec().unwrap();
298 let mean: f64 = v.iter().sum::<f64>() / v.len() as f64;
299 assert!((mean - 5.0).abs() < 0.05, "mean {} too far from 5.0", mean);
300 }
301
302 #[test]
303 fn test_compute_fans_conv() {
304 let shape = Shape::from((16, 3, 5, 5));
306 let (fan_in, fan_out) = compute_fans(&shape);
307 assert_eq!(fan_in, 3.0 * 25.0); assert_eq!(fan_out, 16.0 * 25.0); }
310}