1pub mod ops;
17
18use shrew_core::backend::{
19 Backend, BackendDevice, BackendStorage, BinaryOp, CmpOp, ReduceOp, UnaryOp,
20};
21use shrew_core::dtype::DType;
22use shrew_core::error::{Error, Result};
23use shrew_core::layout::Layout;
24use shrew_core::shape::Shape;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct CpuDevice;
32
33impl BackendDevice for CpuDevice {
34 fn name(&self) -> String {
35 "cpu".to_string()
36 }
37}
38
39#[derive(Debug, Clone)]
51pub enum CpuStorage {
52 F16(Vec<half::f16>),
53 BF16(Vec<half::bf16>),
54 F32(Vec<f32>),
55 F64(Vec<f64>),
56 U8(Vec<u8>),
57 U32(Vec<u32>),
58 I64(Vec<i64>),
59}
60
61impl BackendStorage for CpuStorage {
62 fn dtype(&self) -> DType {
63 match self {
64 CpuStorage::F16(_) => DType::F16,
65 CpuStorage::BF16(_) => DType::BF16,
66 CpuStorage::F32(_) => DType::F32,
67 CpuStorage::F64(_) => DType::F64,
68 CpuStorage::U8(_) => DType::U8,
69 CpuStorage::U32(_) => DType::U32,
70 CpuStorage::I64(_) => DType::I64,
71 }
72 }
73
74 fn len(&self) -> usize {
75 match self {
76 CpuStorage::F16(v) => v.len(),
77 CpuStorage::BF16(v) => v.len(),
78 CpuStorage::F32(v) => v.len(),
79 CpuStorage::F64(v) => v.len(),
80 CpuStorage::U8(v) => v.len(),
81 CpuStorage::U32(v) => v.len(),
82 CpuStorage::I64(v) => v.len(),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
91pub struct CpuBackend;
92
93fn half_dtype(s: &CpuStorage) -> Option<DType> {
97 match s {
98 CpuStorage::F16(_) => Some(DType::F16),
99 CpuStorage::BF16(_) => Some(DType::BF16),
100 _ => None,
101 }
102}
103
104fn promote_f32(s: &CpuStorage, layout: &Layout) -> (CpuStorage, Layout) {
106 match s {
107 CpuStorage::F16(data) => {
108 let f32_data: Vec<f32> = layout
109 .strided_indices()
110 .map(|idx| data[idx].to_f32())
111 .collect();
112 (
113 CpuStorage::F32(f32_data),
114 Layout::contiguous(layout.shape().clone()),
115 )
116 }
117 CpuStorage::BF16(data) => {
118 let f32_data: Vec<f32> = layout
119 .strided_indices()
120 .map(|idx| data[idx].to_f32())
121 .collect();
122 (
123 CpuStorage::F32(f32_data),
124 Layout::contiguous(layout.shape().clone()),
125 )
126 }
127 _ => (s.clone(), layout.clone()),
128 }
129}
130
131fn demote_f32(s: CpuStorage, target: DType) -> Result<CpuStorage> {
133 match (&s, target) {
134 (CpuStorage::F32(data), DType::F16) => Ok(CpuStorage::F16(
135 data.iter().map(|&v| half::f16::from_f32(v)).collect(),
136 )),
137 (CpuStorage::F32(data), DType::BF16) => Ok(CpuStorage::BF16(
138 data.iter().map(|&v| half::bf16::from_f32(v)).collect(),
139 )),
140 _ => Ok(s),
141 }
142}
143
144impl Backend for CpuBackend {
145 type Device = CpuDevice;
146 type Storage = CpuStorage;
147
148 fn zeros(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
149 let n = shape.elem_count();
150 Ok(match dtype {
151 DType::F16 => CpuStorage::F16(vec![half::f16::ZERO; n]),
152 DType::BF16 => CpuStorage::BF16(vec![half::bf16::ZERO; n]),
153 DType::F32 => CpuStorage::F32(vec![0.0f32; n]),
154 DType::F64 => CpuStorage::F64(vec![0.0f64; n]),
155 DType::U8 => CpuStorage::U8(vec![0u8; n]),
156 DType::U32 => CpuStorage::U32(vec![0u32; n]),
157 DType::I64 => CpuStorage::I64(vec![0i64; n]),
158 })
159 }
160
161 fn ones(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
162 let n = shape.elem_count();
163 Ok(match dtype {
164 DType::F16 => CpuStorage::F16(vec![half::f16::ONE; n]),
165 DType::BF16 => CpuStorage::BF16(vec![half::bf16::ONE; n]),
166 DType::F32 => CpuStorage::F32(vec![1.0f32; n]),
167 DType::F64 => CpuStorage::F64(vec![1.0f64; n]),
168 DType::U8 => CpuStorage::U8(vec![1u8; n]),
169 DType::U32 => CpuStorage::U32(vec![1u32; n]),
170 DType::I64 => CpuStorage::I64(vec![1i64; n]),
171 })
172 }
173
174 fn full(shape: &Shape, val: f64, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
175 let n = shape.elem_count();
176 Ok(match dtype {
177 DType::F16 => CpuStorage::F16(vec![half::f16::from_f64(val); n]),
178 DType::BF16 => CpuStorage::BF16(vec![half::bf16::from_f64(val); n]),
179 DType::F32 => CpuStorage::F32(vec![val as f32; n]),
180 DType::F64 => CpuStorage::F64(vec![val; n]),
181 DType::U8 => CpuStorage::U8(vec![val as u8; n]),
182 DType::U32 => CpuStorage::U32(vec![val as u32; n]),
183 DType::I64 => CpuStorage::I64(vec![val as i64; n]),
184 })
185 }
186
187 fn from_f64_slice(data: &[f64], dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
188 Ok(match dtype {
189 DType::F16 => CpuStorage::F16(data.iter().map(|&v| half::f16::from_f64(v)).collect()),
190 DType::BF16 => {
191 CpuStorage::BF16(data.iter().map(|&v| half::bf16::from_f64(v)).collect())
192 }
193 DType::F32 => CpuStorage::F32(data.iter().map(|&v| v as f32).collect()),
194 DType::F64 => CpuStorage::F64(data.to_vec()),
195 DType::U8 => CpuStorage::U8(data.iter().map(|&v| v as u8).collect()),
196 DType::U32 => CpuStorage::U32(data.iter().map(|&v| v as u32).collect()),
197 DType::I64 => CpuStorage::I64(data.iter().map(|&v| v as i64).collect()),
198 })
199 }
200
201 fn rand_uniform(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
202 use rand::Rng;
203 let n = shape.elem_count();
204 let mut rng = rand::thread_rng();
205 Ok(match dtype {
206 DType::F16 => CpuStorage::F16(
207 (0..n)
208 .map(|_| half::f16::from_f32(rng.gen::<f32>()))
209 .collect(),
210 ),
211 DType::BF16 => CpuStorage::BF16(
212 (0..n)
213 .map(|_| half::bf16::from_f32(rng.gen::<f32>()))
214 .collect(),
215 ),
216 DType::F32 => CpuStorage::F32((0..n).map(|_| rng.gen::<f32>()).collect()),
217 DType::F64 => CpuStorage::F64((0..n).map(|_| rng.gen::<f64>()).collect()),
218 _ => {
219 return Err(Error::msg(format!(
220 "rand_uniform not supported for {:?}",
221 dtype
222 )))
223 }
224 })
225 }
226
227 fn rand_normal(shape: &Shape, dtype: DType, _device: &CpuDevice) -> Result<CpuStorage> {
228 use rand::Rng;
229 use rand_distr::StandardNormal;
230 let n = shape.elem_count();
231 let mut rng = rand::thread_rng();
232 Ok(match dtype {
233 DType::F16 => CpuStorage::F16(
234 (0..n)
235 .map(|_| half::f16::from_f32(rng.sample::<f32, _>(StandardNormal)))
236 .collect(),
237 ),
238 DType::BF16 => CpuStorage::BF16(
239 (0..n)
240 .map(|_| half::bf16::from_f32(rng.sample::<f32, _>(StandardNormal)))
241 .collect(),
242 ),
243 DType::F32 => CpuStorage::F32(
244 (0..n)
245 .map(|_| rng.sample::<f32, _>(StandardNormal))
246 .collect(),
247 ),
248 DType::F64 => CpuStorage::F64(
249 (0..n)
250 .map(|_| rng.sample::<f64, _>(StandardNormal))
251 .collect(),
252 ),
253 _ => {
254 return Err(Error::msg(format!(
255 "rand_normal not supported for {:?}",
256 dtype
257 )))
258 }
259 })
260 }
261
262 fn binary_op(
263 op: BinaryOp,
264 lhs: &CpuStorage,
265 lhs_layout: &Layout,
266 rhs: &CpuStorage,
267 rhs_layout: &Layout,
268 ) -> Result<CpuStorage> {
269 let target = half_dtype(lhs).or(half_dtype(rhs));
270 if let Some(dt) = target {
271 let (l, ll) = promote_f32(lhs, lhs_layout);
272 let (r, rl) = promote_f32(rhs, rhs_layout);
273 let result = ops::binary_op(op, &l, &ll, &r, &rl)?;
274 return demote_f32(result, dt);
275 }
276 ops::binary_op(op, lhs, lhs_layout, rhs, rhs_layout)
277 }
278
279 fn unary_op(op: UnaryOp, input: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
280 if let Some(dt) = half_dtype(input) {
281 let (s, l) = promote_f32(input, layout);
282 let result = ops::unary_op(op, &s, &l)?;
283 return demote_f32(result, dt);
284 }
285 ops::unary_op(op, input, layout)
286 }
287
288 fn reduce_op(
289 op: ReduceOp,
290 input: &CpuStorage,
291 layout: &Layout,
292 dims: &[usize],
293 keep_dim: bool,
294 ) -> Result<CpuStorage> {
295 if let Some(dt) = half_dtype(input) {
296 let (s, l) = promote_f32(input, layout);
297 let result = ops::reduce_op(op, &s, &l, dims, keep_dim)?;
298 if matches!(op, ReduceOp::ArgMax | ReduceOp::ArgMin) {
300 return Ok(result);
301 }
302 return demote_f32(result, dt);
303 }
304 ops::reduce_op(op, input, layout, dims, keep_dim)
305 }
306
307 fn matmul(
308 lhs: &CpuStorage,
309 lhs_layout: &Layout,
310 rhs: &CpuStorage,
311 rhs_layout: &Layout,
312 ) -> Result<CpuStorage> {
313 let target = half_dtype(lhs).or(half_dtype(rhs));
314 if let Some(dt) = target {
315 let (l, ll) = promote_f32(lhs, lhs_layout);
316 let (r, rl) = promote_f32(rhs, rhs_layout);
317 let result = ops::matmul(&l, &ll, &r, &rl)?;
318 return demote_f32(result, dt);
319 }
320 ops::matmul(lhs, lhs_layout, rhs, rhs_layout)
321 }
322
323 fn to_contiguous(input: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
324 match input {
325 CpuStorage::F16(data) => {
326 let out: Vec<half::f16> = layout.strided_indices().map(|i| data[i]).collect();
327 Ok(CpuStorage::F16(out))
328 }
329 CpuStorage::BF16(data) => {
330 let out: Vec<half::bf16> = layout.strided_indices().map(|i| data[i]).collect();
331 Ok(CpuStorage::BF16(out))
332 }
333 _ => ops::to_contiguous(input, layout),
334 }
335 }
336
337 fn to_f64_vec(input: &CpuStorage, layout: &Layout) -> Result<Vec<f64>> {
338 match input {
339 CpuStorage::F16(data) => {
340 Ok(layout.strided_indices().map(|i| data[i].to_f64()).collect())
341 }
342 CpuStorage::BF16(data) => {
343 Ok(layout.strided_indices().map(|i| data[i].to_f64()).collect())
344 }
345 _ => ops::to_f64_vec(input, layout),
346 }
347 }
348
349 fn cmp_op(
350 op: CmpOp,
351 lhs: &CpuStorage,
352 lhs_layout: &Layout,
353 rhs: &CpuStorage,
354 rhs_layout: &Layout,
355 ) -> Result<CpuStorage> {
356 let target = half_dtype(lhs).or(half_dtype(rhs));
357 if target.is_some() {
358 let (l, ll) = promote_f32(lhs, lhs_layout);
359 let (r, rl) = promote_f32(rhs, rhs_layout);
360 return ops::cmp_op(op, &l, &ll, &r, &rl);
362 }
363 ops::cmp_op(op, lhs, lhs_layout, rhs, rhs_layout)
364 }
365
366 fn affine(input: &CpuStorage, layout: &Layout, mul: f64, add: f64) -> Result<CpuStorage> {
367 if let Some(dt) = half_dtype(input) {
368 let (s, l) = promote_f32(input, layout);
369 let result = ops::affine(&s, &l, mul, add)?;
370 return demote_f32(result, dt);
371 }
372 ops::affine(input, layout, mul, add)
373 }
374
375 fn index_select(
376 input: &CpuStorage,
377 input_layout: &Layout,
378 indices: &CpuStorage,
379 indices_layout: &Layout,
380 dim: usize,
381 ) -> Result<CpuStorage> {
382 if let Some(dt) = half_dtype(input) {
383 let (s, l) = promote_f32(input, input_layout);
384 let result = ops::index_select(&s, &l, indices, indices_layout, dim)?;
385 return demote_f32(result, dt);
386 }
387 ops::index_select(input, input_layout, indices, indices_layout, dim)
388 }
389
390 fn powf(input: &CpuStorage, layout: &Layout, exponent: f64) -> Result<CpuStorage> {
391 if let Some(dt) = half_dtype(input) {
392 let (s, l) = promote_f32(input, layout);
393 let result = ops::powf(&s, &l, exponent)?;
394 return demote_f32(result, dt);
395 }
396 ops::powf(input, layout, exponent)
397 }
398
399 fn clamp(input: &CpuStorage, layout: &Layout, min: f64, max: f64) -> Result<CpuStorage> {
400 if let Some(dt) = half_dtype(input) {
401 let (s, l) = promote_f32(input, layout);
402 let result = ops::clamp(&s, &l, min, max)?;
403 return demote_f32(result, dt);
404 }
405 ops::clamp(input, layout, min, max)
406 }
407
408 fn where_cond(
409 mask: &CpuStorage,
410 mask_layout: &Layout,
411 on_true: &CpuStorage,
412 on_true_layout: &Layout,
413 on_false: &CpuStorage,
414 on_false_layout: &Layout,
415 ) -> Result<CpuStorage> {
416 let target = half_dtype(on_true).or(half_dtype(on_false));
417 if let Some(dt) = target {
418 let (t, tl) = promote_f32(on_true, on_true_layout);
419 let (f, fl) = promote_f32(on_false, on_false_layout);
420 let result = ops::where_cond(mask, mask_layout, &t, &tl, &f, &fl)?;
421 return demote_f32(result, dt);
422 }
423 ops::where_cond(
424 mask,
425 mask_layout,
426 on_true,
427 on_true_layout,
428 on_false,
429 on_false_layout,
430 )
431 }
432
433 fn gather(
434 input: &CpuStorage,
435 input_layout: &Layout,
436 index: &CpuStorage,
437 index_layout: &Layout,
438 dim: usize,
439 ) -> Result<CpuStorage> {
440 if let Some(dt) = half_dtype(input) {
441 let (s, l) = promote_f32(input, input_layout);
442 let result = ops::gather(&s, &l, index, index_layout, dim)?;
443 return demote_f32(result, dt);
444 }
445 ops::gather(input, input_layout, index, index_layout, dim)
446 }
447
448 fn cat(inputs: &[(&CpuStorage, &Layout)], out_shape: &Shape, dim: usize) -> Result<CpuStorage> {
449 let target = inputs.iter().find_map(|(s, _)| half_dtype(s));
451 if let Some(dt) = target {
452 let promoted: Vec<(CpuStorage, Layout)> =
453 inputs.iter().map(|(s, l)| promote_f32(s, l)).collect();
454 let refs: Vec<(&CpuStorage, &Layout)> = promoted.iter().map(|(s, l)| (s, l)).collect();
455 let result = ops::cat(&refs, out_shape, dim)?;
456 return demote_f32(result, dt);
457 }
458 ops::cat(inputs, out_shape, dim)
459 }
460}
461
462pub type CpuTensor = shrew_core::Tensor<CpuBackend>;