shrew_cpu/
lib.rs

1//! # shrew-cpu
2//!
3//! CPU backend implementation for Shrew.
4//!
5//! This crate implements the [`Backend`](shrew_core::Backend) trait for CPU execution.
6//! Uses [`gemm`] for SIMD-accelerated matrix multiplication (AVX2/AVX-512/FMA)
7//! and [`rayon`] for parallel batched matmul and large elementwise ops.
8// reference implementation: everything runs on the CPU using standard Rust
9// iterators and (eventually) SIMD optimizations.
10//
11// Architecture:
12//   CpuBackend — the Backend implementor
13//   CpuDevice  — trivial device (there's only one CPU)
14//   CpuStorage — enum over typed Vec<T> for each DType
15
16pub mod ops;
17
18use shrew_core::backend::{
19    Backend, BackendDevice, BackendStorage, BinaryOp, CmpOp, ReduceOp, UnaryOp,
20};
21use shrew_core::dtype::DType;
22use shrew_core::error::{Error, Result};
23use shrew_core::layout::Layout;
24use shrew_core::shape::Shape;
25
26// CpuDevice — The CPU "device" (trivial: there's only one CPU)
27
28/// The CPU device. Since every machine has exactly one CPU (from our
29/// perspective), this is a zero-sized type.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct CpuDevice;
32
33impl BackendDevice for CpuDevice {
34    fn name(&self) -> String {
35        "cpu".to_string()
36    }
37}
38
39// CpuStorage — Typed storage for tensor data in CPU memory
40//
41// We use an enum over Vec<T> rather than a type-erased Vec<u8> because:
42// 1. We can iterate with the correct type without unsafe casting
43// 2. Pattern matching makes the code explicit and safe
44// 3. The compiler can optimize typed operations better
45
46/// Storage of tensor data in CPU memory.
47///
48/// Each variant holds a Vec of the corresponding Rust type.
49/// Operations pattern-match on this enum to dispatch to typed code.
50#[derive(Debug, Clone)]
51pub enum CpuStorage {
52    F16(Vec<half::f16>),
53    BF16(Vec<half::bf16>),
54    F32(Vec<f32>),
55    F64(Vec<f64>),
56    U8(Vec<u8>),
57    U32(Vec<u32>),
58    I64(Vec<i64>),
59}
60
61impl BackendStorage for CpuStorage {
62    fn dtype(&self) -> DType {
63        match self {
64            CpuStorage::F16(_) => DType::F16,
65            CpuStorage::BF16(_) => DType::BF16,
66            CpuStorage::F32(_) => DType::F32,
67            CpuStorage::F64(_) => DType::F64,
68            CpuStorage::U8(_) => DType::U8,
69            CpuStorage::U32(_) => DType::U32,
70            CpuStorage::I64(_) => DType::I64,
71        }
72    }
73
74    fn len(&self) -> usize {
75        match self {
76            CpuStorage::F16(v) => v.len(),
77            CpuStorage::BF16(v) => v.len(),
78            CpuStorage::F32(v) => v.len(),
79            CpuStorage::F64(v) => v.len(),
80            CpuStorage::U8(v) => v.len(),
81            CpuStorage::U32(v) => v.len(),
82            CpuStorage::I64(v) => v.len(),
83        }
84    }
85}
86
87// CpuBackend — The main CPU backend struct
88
89/// CPU backend. Implements Backend by running operations on CPU via iterators.
90#[derive(Debug, Clone)]
91pub struct CpuBackend;
92
93// Half-precision helpers: promote to F32, compute, demote back
94
95/// If storage is F16 or BF16, return the target half dtype.
96fn half_dtype(s: &CpuStorage) -> Option<DType> {
97    match s {
98        CpuStorage::F16(_) => Some(DType::F16),
99        CpuStorage::BF16(_) => Some(DType::BF16),
100        _ => None,
101    }
102}
103
104/// Promote F16/BF16 storage to contiguous F32. Returns unchanged for other types.
105fn promote_f32(s: &CpuStorage, layout: &Layout) -> (CpuStorage, Layout) {
106    match s {
107        CpuStorage::F16(data) => {
108            let f32_data: Vec<f32> = layout
109                .strided_indices()
110                .map(|idx| data[idx].to_f32())
111                .collect();
112            (
113                CpuStorage::F32(f32_data),
114                Layout::contiguous(layout.shape().clone()),
115            )
116        }
117        CpuStorage::BF16(data) => {
118            let f32_data: Vec<f32> = layout
119                .strided_indices()
120                .map(|idx| data[idx].to_f32())
121                .collect();
122            (
123                CpuStorage::F32(f32_data),
124                Layout::contiguous(layout.shape().clone()),
125            )
126        }
127        _ => (s.clone(), layout.clone()),
128    }
129}
130
131/// Demote F32 result back to the target half dtype.
132fn demote_f32(s: CpuStorage, target: DType) -> Result<CpuStorage> {
133    match (&s, target) {
134        (CpuStorage::F32(data), DType::F16) => Ok(CpuStorage::F16(
135            data.iter().map(|&v| half::f16::from_f32(v)).collect(),
136        )),
137        (CpuStorage::F32(data), DType::BF16) => Ok(CpuStorage::BF16(
138            data.iter().map(|&v| half::bf16::from_f32(v)).collect(),
139        )),
140        _ => Ok(s),
141    }
142}
143
144impl Backend for CpuBackend {
145    type Device = CpuDevice;
146    type Storage = CpuStorage;
147
148    fn zeros(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
149        let n = shape.elem_count();
150        Ok(match dtype {
151            DType::F16 => CpuStorage::F16(vec![half::f16::ZERO; n]),
152            DType::BF16 => CpuStorage::BF16(vec![half::bf16::ZERO; n]),
153            DType::F32 => CpuStorage::F32(vec![0.0f32; n]),
154            DType::F64 => CpuStorage::F64(vec![0.0f64; n]),
155            DType::U8 => CpuStorage::U8(vec![0u8; n]),
156            DType::U32 => CpuStorage::U32(vec![0u32; n]),
157            DType::I64 => CpuStorage::I64(vec![0i64; n]),
158        })
159    }
160
161    fn ones(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
162        let n = shape.elem_count();
163        Ok(match dtype {
164            DType::F16 => CpuStorage::F16(vec![half::f16::ONE; n]),
165            DType::BF16 => CpuStorage::BF16(vec![half::bf16::ONE; n]),
166            DType::F32 => CpuStorage::F32(vec![1.0f32; n]),
167            DType::F64 => CpuStorage::F64(vec![1.0f64; n]),
168            DType::U8 => CpuStorage::U8(vec![1u8; n]),
169            DType::U32 => CpuStorage::U32(vec![1u32; n]),
170            DType::I64 => CpuStorage::I64(vec![1i64; n]),
171        })
172    }
173
174    fn full(shape: &Shape, val: f64, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
175        let n = shape.elem_count();
176        Ok(match dtype {
177            DType::F16 => CpuStorage::F16(vec![half::f16::from_f64(val); n]),
178            DType::BF16 => CpuStorage::BF16(vec![half::bf16::from_f64(val); n]),
179            DType::F32 => CpuStorage::F32(vec![val as f32; n]),
180            DType::F64 => CpuStorage::F64(vec![val; n]),
181            DType::U8 => CpuStorage::U8(vec![val as u8; n]),
182            DType::U32 => CpuStorage::U32(vec![val as u32; n]),
183            DType::I64 => CpuStorage::I64(vec![val as i64; n]),
184        })
185    }
186
187    fn from_f64_slice(data: &[f64], dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
188        Ok(match dtype {
189            DType::F16 => CpuStorage::F16(data.iter().map(|&v| half::f16::from_f64(v)).collect()),
190            DType::BF16 => {
191                CpuStorage::BF16(data.iter().map(|&v| half::bf16::from_f64(v)).collect())
192            }
193            DType::F32 => CpuStorage::F32(data.iter().map(|&v| v as f32).collect()),
194            DType::F64 => CpuStorage::F64(data.to_vec()),
195            DType::U8 => CpuStorage::U8(data.iter().map(|&v| v as u8).collect()),
196            DType::U32 => CpuStorage::U32(data.iter().map(|&v| v as u32).collect()),
197            DType::I64 => CpuStorage::I64(data.iter().map(|&v| v as i64).collect()),
198        })
199    }
200
201    fn rand_uniform(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
202        use rand::Rng;
203        let n = shape.elem_count();
204        let mut rng = rand::thread_rng();
205        Ok(match dtype {
206            DType::F16 => CpuStorage::F16(
207                (0..n)
208                    .map(|_| half::f16::from_f32(rng.gen::<f32>()))
209                    .collect(),
210            ),
211            DType::BF16 => CpuStorage::BF16(
212                (0..n)
213                    .map(|_| half::bf16::from_f32(rng.gen::<f32>()))
214                    .collect(),
215            ),
216            DType::F32 => CpuStorage::F32((0..n).map(|_| rng.gen::<f32>()).collect()),
217            DType::F64 => CpuStorage::F64((0..n).map(|_| rng.gen::<f64>()).collect()),
218            _ => {
219                return Err(Error::msg(format!(
220                    "rand_uniform not supported for {:?}",
221                    dtype
222                )))
223            }
224        })
225    }
226
227    fn rand_normal(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
228        use rand::Rng;
229        use rand_distr::StandardNormal;
230        let n = shape.elem_count();
231        let mut rng = rand::thread_rng();
232        Ok(match dtype {
233            DType::F16 => CpuStorage::F16(
234                (0..n)
235                    .map(|_| half::f16::from_f32(rng.sample::<f32, _>(StandardNormal)))
236                    .collect(),
237            ),
238            DType::BF16 => CpuStorage::BF16(
239                (0..n)
240                    .map(|_| half::bf16::from_f32(rng.sample::<f32, _>(StandardNormal)))
241                    .collect(),
242            ),
243            DType::F32 => CpuStorage::F32(
244                (0..n)
245                    .map(|_| rng.sample::<f32, _>(StandardNormal))
246                    .collect(),
247            ),
248            DType::F64 => CpuStorage::F64(
249                (0..n)
250                    .map(|_| rng.sample::<f64, _>(StandardNormal))
251                    .collect(),
252            ),
253            _ => {
254                return Err(Error::msg(format!(
255                    "rand_normal not supported for {:?}",
256                    dtype
257                )))
258            }
259        })
260    }
261
262    fn binary_op(
263        op: BinaryOp,
264        lhs: &CpuStorage,
265        lhs_layout: &Layout,
266        rhs: &CpuStorage,
267        rhs_layout: &Layout,
268    ) -> Result<CpuStorage> {
269        let target = half_dtype(lhs).or(half_dtype(rhs));
270        if let Some(dt) = target {
271            let (l, ll) = promote_f32(lhs, lhs_layout);
272            let (r, rl) = promote_f32(rhs, rhs_layout);
273            let result = ops::binary_op(op, &l, &ll, &r, &rl)?;
274            return demote_f32(result, dt);
275        }
276        ops::binary_op(op, lhs, lhs_layout, rhs, rhs_layout)
277    }
278
279    fn unary_op(op: UnaryOp, input: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
280        if let Some(dt) = half_dtype(input) {
281            let (s, l) = promote_f32(input, layout);
282            let result = ops::unary_op(op, &s, &l)?;
283            return demote_f32(result, dt);
284        }
285        ops::unary_op(op, input, layout)
286    }
287
288    fn reduce_op(
289        op: ReduceOp,
290        input: &CpuStorage,
291        layout: &Layout,
292        dims: &[usize],
293        keep_dim: bool,
294    ) -> Result<CpuStorage> {
295        if let Some(dt) = half_dtype(input) {
296            let (s, l) = promote_f32(input, layout);
297            let result = ops::reduce_op(op, &s, &l, dims, keep_dim)?;
298            // ArgMax/ArgMin return I64, don't demote those
299            if matches!(op, ReduceOp::ArgMax | ReduceOp::ArgMin) {
300                return Ok(result);
301            }
302            return demote_f32(result, dt);
303        }
304        ops::reduce_op(op, input, layout, dims, keep_dim)
305    }
306
307    fn matmul(
308        lhs: &CpuStorage,
309        lhs_layout: &Layout,
310        rhs: &CpuStorage,
311        rhs_layout: &Layout,
312    ) -> Result<CpuStorage> {
313        let target = half_dtype(lhs).or(half_dtype(rhs));
314        if let Some(dt) = target {
315            let (l, ll) = promote_f32(lhs, lhs_layout);
316            let (r, rl) = promote_f32(rhs, rhs_layout);
317            let result = ops::matmul(&l, &ll, &r, &rl)?;
318            return demote_f32(result, dt);
319        }
320        ops::matmul(lhs, lhs_layout, rhs, rhs_layout)
321    }
322
323    fn to_contiguous(input: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
324        match input {
325            CpuStorage::F16(data) => {
326                let out: Vec<half::f16> = layout.strided_indices().map(|i| data[i]).collect();
327                Ok(CpuStorage::F16(out))
328            }
329            CpuStorage::BF16(data) => {
330                let out: Vec<half::bf16> = layout.strided_indices().map(|i| data[i]).collect();
331                Ok(CpuStorage::BF16(out))
332            }
333            _ => ops::to_contiguous(input, layout),
334        }
335    }
336
337    fn to_f64_vec(input: &CpuStorage, layout: &Layout) -> Result<Vec<f64>> {
338        match input {
339            CpuStorage::F16(data) => {
340                Ok(layout.strided_indices().map(|i| data[i].to_f64()).collect())
341            }
342            CpuStorage::BF16(data) => {
343                Ok(layout.strided_indices().map(|i| data[i].to_f64()).collect())
344            }
345            _ => ops::to_f64_vec(input, layout),
346        }
347    }
348
349    fn cmp_op(
350        op: CmpOp,
351        lhs: &CpuStorage,
352        lhs_layout: &Layout,
353        rhs: &CpuStorage,
354        rhs_layout: &Layout,
355    ) -> Result<CpuStorage> {
356        let target = half_dtype(lhs).or(half_dtype(rhs));
357        if target.is_some() {
358            let (l, ll) = promote_f32(lhs, lhs_layout);
359            let (r, rl) = promote_f32(rhs, rhs_layout);
360            // cmp_op returns U8, no demotion needed
361            return ops::cmp_op(op, &l, &ll, &r, &rl);
362        }
363        ops::cmp_op(op, lhs, lhs_layout, rhs, rhs_layout)
364    }
365
366    fn affine(input: &CpuStorage, layout: &Layout, mul: f64, add: f64) -> Result<CpuStorage> {
367        if let Some(dt) = half_dtype(input) {
368            let (s, l) = promote_f32(input, layout);
369            let result = ops::affine(&s, &l, mul, add)?;
370            return demote_f32(result, dt);
371        }
372        ops::affine(input, layout, mul, add)
373    }
374
375    fn index_select(
376        input: &CpuStorage,
377        input_layout: &Layout,
378        indices: &CpuStorage,
379        indices_layout: &Layout,
380        dim: usize,
381    ) -> Result<CpuStorage> {
382        if let Some(dt) = half_dtype(input) {
383            let (s, l) = promote_f32(input, input_layout);
384            let result = ops::index_select(&s, &l, indices, indices_layout, dim)?;
385            return demote_f32(result, dt);
386        }
387        ops::index_select(input, input_layout, indices, indices_layout, dim)
388    }
389
390    fn powf(input: &CpuStorage, layout: &Layout, exponent: f64) -> Result<CpuStorage> {
391        if let Some(dt) = half_dtype(input) {
392            let (s, l) = promote_f32(input, layout);
393            let result = ops::powf(&s, &l, exponent)?;
394            return demote_f32(result, dt);
395        }
396        ops::powf(input, layout, exponent)
397    }
398
399    fn clamp(input: &CpuStorage, layout: &Layout, min: f64, max: f64) -> Result<CpuStorage> {
400        if let Some(dt) = half_dtype(input) {
401            let (s, l) = promote_f32(input, layout);
402            let result = ops::clamp(&s, &l, min, max)?;
403            return demote_f32(result, dt);
404        }
405        ops::clamp(input, layout, min, max)
406    }
407
408    fn where_cond(
409        mask: &CpuStorage,
410        mask_layout: &Layout,
411        on_true: &CpuStorage,
412        on_true_layout: &Layout,
413        on_false: &CpuStorage,
414        on_false_layout: &Layout,
415    ) -> Result<CpuStorage> {
416        let target = half_dtype(on_true).or(half_dtype(on_false));
417        if let Some(dt) = target {
418            let (t, tl) = promote_f32(on_true, on_true_layout);
419            let (f, fl) = promote_f32(on_false, on_false_layout);
420            let result = ops::where_cond(mask, mask_layout, &t, &tl, &f, &fl)?;
421            return demote_f32(result, dt);
422        }
423        ops::where_cond(
424            mask,
425            mask_layout,
426            on_true,
427            on_true_layout,
428            on_false,
429            on_false_layout,
430        )
431    }
432
433    fn gather(
434        input: &CpuStorage,
435        input_layout: &Layout,
436        index: &CpuStorage,
437        index_layout: &Layout,
438        dim: usize,
439    ) -> Result<CpuStorage> {
440        if let Some(dt) = half_dtype(input) {
441            let (s, l) = promote_f32(input, input_layout);
442            let result = ops::gather(&s, &l, index, index_layout, dim)?;
443            return demote_f32(result, dt);
444        }
445        ops::gather(input, input_layout, index, index_layout, dim)
446    }
447
448    fn cat(inputs: &[(&CpuStorage, &Layout)], out_shape: &Shape, dim: usize) -> Result<CpuStorage> {
449        // Check if any input is half
450        let target = inputs.iter().find_map(|(s, _)| half_dtype(s));
451        if let Some(dt) = target {
452            let promoted: Vec<(CpuStorage, Layout)> =
453                inputs.iter().map(|(s, l)| promote_f32(s, l)).collect();
454            let refs: Vec<(&CpuStorage, &Layout)> = promoted.iter().map(|(s, l)| (s, l)).collect();
455            let result = ops::cat(&refs, out_shape, dim)?;
456            return demote_f32(result, dt);
457        }
458        ops::cat(inputs, out_shape, dim)
459    }
460}
461
462/// Convenience type alias for CPU tensors.
463pub type CpuTensor = shrew_core::Tensor<CpuBackend>;