1mod kernels;
21pub mod pool;
22
23use cudarc::cublas::CudaBlas;
24use cudarc::driver::{CudaSlice, DevicePtr, DeviceSlice, LaunchAsync, LaunchConfig};
25use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions};
26use half::{bf16, f16};
27use pool::CudaMemPool;
28use std::fmt;
29use std::sync::Arc;
30
31use shrew_core::backend::{
32 Backend, BackendDevice, BackendStorage, BinaryOp, CmpOp, ReduceOp, UnaryOp,
33};
34use shrew_core::dtype::DType;
35use shrew_core::error::{Error, Result};
36use shrew_core::layout::Layout;
37use shrew_core::shape::Shape;
38
39pub struct CudaDevice {
44 dev: Arc<cudarc::driver::CudaDevice>,
45 blas: Arc<CudaBlas>,
46 pool: Arc<CudaMemPool>,
47 ordinal: usize,
48}
49
50impl CudaDevice {
51 pub fn new(ordinal: usize) -> Result<Self> {
54 let dev = cudarc::driver::CudaDevice::new(ordinal)
55 .map_err(|e| Error::msg(format!("CUDA device creation failed: {e}")))?;
56
57 let blas = CudaBlas::new(dev.clone())
58 .map_err(|e| Error::msg(format!("cuBLAS init failed: {e}")))?;
59
60 let major = dev.attribute(cudarc::driver::sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR).unwrap_or(8);
65 let minor = dev.attribute(cudarc::driver::sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR).unwrap_or(9);
66 let arch_str: &'static str = Box::leak(format!("sm_{major}{minor}").into_boxed_str());
67 let opts = CompileOptions {
68 arch: Some(arch_str),
69 ..Default::default()
70 };
71 let ptx = compile_ptx_with_opts(kernels::KERNEL_SOURCE, opts)
72 .map_err(|e| Error::msg(format!("NVRTC compilation failed: {e}")))?;
73 dev.load_ptx(ptx, kernels::MODULE_NAME, kernels::KERNEL_NAMES)
74 .map_err(|e| Error::msg(format!("PTX load failed: {e}")))?;
75
76 Ok(CudaDevice {
77 dev,
78 blas: Arc::new(blas),
79 pool: Arc::new(CudaMemPool::new()),
80 ordinal,
81 })
82 }
83
84 pub fn device(&self) -> &Arc<cudarc::driver::CudaDevice> {
86 &self.dev
87 }
88
89 pub fn blas(&self) -> &CudaBlas {
91 &self.blas
92 }
93
94 fn get_func(&self, name: &str) -> Result<cudarc::driver::CudaFunction> {
96 self.dev
97 .get_func(kernels::MODULE_NAME, name)
98 .ok_or_else(|| Error::msg(format!("CUDA kernel '{name}' not found")))
99 }
100
101 pub fn pool(&self) -> &CudaMemPool {
105 &self.pool
106 }
107
108 pub fn empty_cache(&self) {
110 self.pool.empty_cache();
111 }
112
113 pub fn pool_stats(&self) -> pool::PoolStats {
115 self.pool.stats()
116 }
117
118 pub fn reclaim(&self, storage: CudaStorage) {
120 self.pool.reclaim_storage(storage);
121 }
122
123 pub fn pool_alloc_f32(&self, n: usize) -> Result<CudaSlice<f32>> {
127 self.pool
128 .alloc_f32(&self.dev, n)
129 .map_err(|e| Error::msg(format!("pool alloc f32: {e}")))
130 }
131 pub fn pool_alloc_f64(&self, n: usize) -> Result<CudaSlice<f64>> {
132 self.pool
133 .alloc_f64(&self.dev, n)
134 .map_err(|e| Error::msg(format!("pool alloc f64: {e}")))
135 }
136 pub fn pool_alloc_u16(&self, n: usize) -> Result<CudaSlice<u16>> {
137 self.pool
138 .alloc_u16(&self.dev, n)
139 .map_err(|e| Error::msg(format!("pool alloc u16: {e}")))
140 }
141 pub fn pool_alloc_u8(&self, n: usize) -> Result<CudaSlice<u8>> {
142 self.pool
143 .alloc_u8(&self.dev, n)
144 .map_err(|e| Error::msg(format!("pool alloc u8: {e}")))
145 }
146 pub fn pool_alloc_u32(&self, n: usize) -> Result<CudaSlice<u32>> {
147 self.pool
148 .alloc_u32(&self.dev, n)
149 .map_err(|e| Error::msg(format!("pool alloc u32: {e}")))
150 }
151 pub fn pool_alloc_i64(&self, n: usize) -> Result<CudaSlice<i64>> {
152 self.pool
153 .alloc_i64(&self.dev, n)
154 .map_err(|e| Error::msg(format!("pool alloc i64: {e}")))
155 }
156
157 pub fn pool_alloc_zeros_f32(&self, n: usize) -> Result<CudaSlice<f32>> {
159 self.pool
160 .alloc_zeros_f32(&self.dev, n)
161 .map_err(|e| Error::msg(format!("pool alloc zeros f32: {e}")))
162 }
163 pub fn pool_alloc_zeros_f64(&self, n: usize) -> Result<CudaSlice<f64>> {
164 self.pool
165 .alloc_zeros_f64(&self.dev, n)
166 .map_err(|e| Error::msg(format!("pool alloc zeros f64: {e}")))
167 }
168 pub fn pool_alloc_zeros_u16(&self, n: usize) -> Result<CudaSlice<u16>> {
169 self.pool
170 .alloc_zeros_u16(&self.dev, n)
171 .map_err(|e| Error::msg(format!("pool alloc zeros u16: {e}")))
172 }
173 pub fn pool_alloc_zeros_u8(&self, n: usize) -> Result<CudaSlice<u8>> {
174 self.pool
175 .alloc_zeros_u8(&self.dev, n)
176 .map_err(|e| Error::msg(format!("pool alloc zeros u8: {e}")))
177 }
178 pub fn pool_alloc_zeros_u32(&self, n: usize) -> Result<CudaSlice<u32>> {
179 self.pool
180 .alloc_zeros_u32(&self.dev, n)
181 .map_err(|e| Error::msg(format!("pool alloc zeros u32: {e}")))
182 }
183 pub fn pool_alloc_zeros_i64(&self, n: usize) -> Result<CudaSlice<i64>> {
184 self.pool
185 .alloc_zeros_i64(&self.dev, n)
186 .map_err(|e| Error::msg(format!("pool alloc zeros i64: {e}")))
187 }
188}
189
190impl Clone for CudaDevice {
191 fn clone(&self) -> Self {
192 CudaDevice {
193 dev: self.dev.clone(),
194 blas: self.blas.clone(),
195 pool: self.pool.clone(),
196 ordinal: self.ordinal,
197 }
198 }
199}
200
201impl fmt::Debug for CudaDevice {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 write!(f, "CudaDevice(cuda:{})", self.ordinal)
204 }
205}
206
207unsafe impl Send for CudaDevice {}
209unsafe impl Sync for CudaDevice {}
210
211impl BackendDevice for CudaDevice {
212 fn name(&self) -> String {
213 format!("cuda:{}", self.ordinal)
214 }
215}
216
217pub enum CudaStorage {
222 F16(CudaSlice<u16>),
223 BF16(CudaSlice<u16>),
224 F32(CudaSlice<f32>),
225 F64(CudaSlice<f64>),
226 U8(CudaSlice<u8>),
227 U32(CudaSlice<u32>),
228 I64(CudaSlice<i64>),
229}
230
231impl Clone for CudaStorage {
232 fn clone(&self) -> Self {
233 match self {
234 CudaStorage::F16(s) => CudaStorage::F16(s.clone()),
235 CudaStorage::BF16(s) => CudaStorage::BF16(s.clone()),
236 CudaStorage::F32(s) => CudaStorage::F32(s.clone()),
237 CudaStorage::F64(s) => CudaStorage::F64(s.clone()),
238 CudaStorage::U8(s) => CudaStorage::U8(s.clone()),
239 CudaStorage::U32(s) => CudaStorage::U32(s.clone()),
240 CudaStorage::I64(s) => CudaStorage::I64(s.clone()),
241 }
242 }
243}
244
245impl fmt::Debug for CudaStorage {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 match self {
248 CudaStorage::F16(s) => write!(f, "CudaStorage::F16(len={})", s.len()),
249 CudaStorage::BF16(s) => write!(f, "CudaStorage::BF16(len={})", s.len()),
250 CudaStorage::F32(s) => write!(f, "CudaStorage::F32(len={})", s.len()),
251 CudaStorage::F64(s) => write!(f, "CudaStorage::F64(len={})", s.len()),
252 CudaStorage::U8(s) => write!(f, "CudaStorage::U8(len={})", s.len()),
253 CudaStorage::U32(s) => write!(f, "CudaStorage::U32(len={})", s.len()),
254 CudaStorage::I64(s) => write!(f, "CudaStorage::I64(len={})", s.len()),
255 }
256 }
257}
258
259unsafe impl Send for CudaStorage {}
260unsafe impl Sync for CudaStorage {}
261
262impl BackendStorage for CudaStorage {
263 fn dtype(&self) -> DType {
264 match self {
265 CudaStorage::F16(_) => DType::F16,
266 CudaStorage::BF16(_) => DType::BF16,
267 CudaStorage::F32(_) => DType::F32,
268 CudaStorage::F64(_) => DType::F64,
269 CudaStorage::U8(_) => DType::U8,
270 CudaStorage::U32(_) => DType::U32,
271 CudaStorage::I64(_) => DType::I64,
272 }
273 }
274
275 fn len(&self) -> usize {
276 match self {
277 CudaStorage::F16(s) => s.len(),
278 CudaStorage::BF16(s) => s.len(),
279 CudaStorage::F32(s) => s.len(),
280 CudaStorage::F64(s) => s.len(),
281 CudaStorage::U8(s) => s.len(),
282 CudaStorage::U32(s) => s.len(),
283 CudaStorage::I64(s) => s.len(),
284 }
285 }
286}
287
288fn launch_cfg(n: usize) -> LaunchConfig {
292 const BLOCK: u32 = 256;
293 let grid = (n as u32).div_ceil(BLOCK);
294 LaunchConfig {
295 block_dim: (BLOCK, 1, 1),
296 grid_dim: (grid.max(1), 1, 1),
297 shared_mem_bytes: 0,
298 }
299}
300
301fn ensure_contiguous(
305 storage: &CudaStorage,
306 layout: &Layout,
307 device: &CudaDevice,
308) -> Result<CudaStorage> {
309 if layout.is_contiguous() && layout.offset() == 0 {
310 return Ok(storage.clone());
311 }
312
313 let n = layout.elem_count();
314 let cfg = launch_cfg(n);
315 let ndim = layout.rank() as i32;
316 let offset = layout.offset() as i32;
317
318 let shape_i32: Vec<i32> = layout.dims().iter().map(|&d| d as i32).collect();
320 let strides_i32: Vec<i32> = layout.strides().iter().map(|&s| s as i32).collect();
321 let shape_dev = device
322 .dev
323 .htod_copy(shape_i32)
324 .map_err(|e| Error::msg(format!("htod shape: {e}")))?;
325 let strides_dev = device
326 .dev
327 .htod_copy(strides_i32)
328 .map_err(|e| Error::msg(format!("htod strides: {e}")))?;
329
330 match storage {
331 CudaStorage::F16(src) | CudaStorage::BF16(src) => {
332 let mut dst: CudaSlice<u16> = device
333 .dev
334 .alloc_zeros(n)
335 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
336 let func = device.get_func("to_contiguous_u16")?;
337 unsafe {
338 func.launch(
339 cfg,
340 (
341 src,
342 &mut dst,
343 &shape_dev,
344 &strides_dev,
345 offset,
346 ndim,
347 n as u32,
348 ),
349 )
350 }
351 .map_err(|e| Error::msg(format!("launch to_contiguous_u16: {e}")))?;
352 match storage {
353 CudaStorage::F16(_) => Ok(CudaStorage::F16(dst)),
354 _ => Ok(CudaStorage::BF16(dst)),
355 }
356 }
357 CudaStorage::F32(src) => {
358 let mut dst: CudaSlice<f32> = device
359 .dev
360 .alloc_zeros(n)
361 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
362 let func = device.get_func("to_contiguous_f32")?;
363 unsafe {
364 func.launch(
365 cfg,
366 (
367 src,
368 &mut dst,
369 &shape_dev,
370 &strides_dev,
371 offset,
372 ndim,
373 n as u32,
374 ),
375 )
376 }
377 .map_err(|e| Error::msg(format!("launch to_contiguous_f32: {e}")))?;
378 Ok(CudaStorage::F32(dst))
379 }
380 CudaStorage::F64(src) => {
381 let mut dst: CudaSlice<f64> = device
382 .dev
383 .alloc_zeros(n)
384 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
385 let func = device.get_func("to_contiguous_f64")?;
386 unsafe {
387 func.launch(
388 cfg,
389 (
390 src,
391 &mut dst,
392 &shape_dev,
393 &strides_dev,
394 offset,
395 ndim,
396 n as u32,
397 ),
398 )
399 }
400 .map_err(|e| Error::msg(format!("launch to_contiguous_f64: {e}")))?;
401 Ok(CudaStorage::F64(dst))
402 }
403 CudaStorage::U8(src) => {
404 let mut dst: CudaSlice<u8> = device
405 .dev
406 .alloc_zeros(n)
407 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
408 let func = device.get_func("to_contiguous_u8")?;
409 unsafe {
410 func.launch(
411 cfg,
412 (
413 src,
414 &mut dst,
415 &shape_dev,
416 &strides_dev,
417 offset,
418 ndim,
419 n as u32,
420 ),
421 )
422 }
423 .map_err(|e| Error::msg(format!("launch to_contiguous_u8: {e}")))?;
424 Ok(CudaStorage::U8(dst))
425 }
426 _ => Err(Error::msg(
427 "to_contiguous not implemented for this dtype on CUDA",
428 )),
429 }
430}
431
432fn device_from_storage(storage: &CudaStorage) -> Arc<cudarc::driver::CudaDevice> {
434 match storage {
435 CudaStorage::F16(s) => s.device(),
436 CudaStorage::BF16(s) => s.device(),
437 CudaStorage::F32(s) => s.device(),
438 CudaStorage::F64(s) => s.device(),
439 CudaStorage::U8(s) => s.device(),
440 CudaStorage::U32(s) => s.device(),
441 CudaStorage::I64(s) => s.device(),
442 }
443}
444
445fn dev_from_storage(storage: &CudaStorage) -> Result<CudaDevice> {
448 let raw_dev = device_from_storage(storage);
449 let blas = CudaBlas::new(raw_dev.clone()).map_err(|e| Error::msg(format!("blas: {e}")))?;
450 Ok(CudaDevice {
451 dev: raw_dev,
452 blas: Arc::new(blas),
453 pool: Arc::new(CudaMemPool::new()),
454 ordinal: 0,
455 })
456}
457
458#[derive(Clone, Debug)]
462pub struct CudaBackend;
463
464impl Backend for CudaBackend {
465 type Device = CudaDevice;
466 type Storage = CudaStorage;
467
468 fn zeros(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
471 let n = shape.elem_count();
472 match dtype {
473 DType::F16 => {
474 let s: CudaSlice<u16> = device
475 .dev
476 .alloc_zeros(n)
477 .map_err(|e| Error::msg(format!("alloc zeros f16: {e}")))?;
478 Ok(CudaStorage::F16(s))
479 }
480 DType::BF16 => {
481 let s: CudaSlice<u16> = device
482 .dev
483 .alloc_zeros(n)
484 .map_err(|e| Error::msg(format!("alloc zeros bf16: {e}")))?;
485 Ok(CudaStorage::BF16(s))
486 }
487 DType::F32 => {
488 let s: CudaSlice<f32> = device
489 .dev
490 .alloc_zeros(n)
491 .map_err(|e| Error::msg(format!("alloc zeros f32: {e}")))?;
492 Ok(CudaStorage::F32(s))
493 }
494 DType::F64 => {
495 let s: CudaSlice<f64> = device
496 .dev
497 .alloc_zeros(n)
498 .map_err(|e| Error::msg(format!("alloc zeros f64: {e}")))?;
499 Ok(CudaStorage::F64(s))
500 }
501 DType::U8 => {
502 let s: CudaSlice<u8> = device
503 .dev
504 .alloc_zeros(n)
505 .map_err(|e| Error::msg(format!("alloc zeros u8: {e}")))?;
506 Ok(CudaStorage::U8(s))
507 }
508 DType::U32 => {
509 let s: CudaSlice<u32> = device
510 .dev
511 .alloc_zeros(n)
512 .map_err(|e| Error::msg(format!("alloc zeros u32: {e}")))?;
513 Ok(CudaStorage::U32(s))
514 }
515 DType::I64 => {
516 let s: CudaSlice<i64> = device
517 .dev
518 .alloc_zeros(n)
519 .map_err(|e| Error::msg(format!("alloc zeros i64: {e}")))?;
520 Ok(CudaStorage::I64(s))
521 }
522 }
523 }
524
525 fn ones(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
526 Self::full(shape, 1.0, dtype, device)
527 }
528
529 fn full(shape: &Shape, val: f64, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
530 let n = shape.elem_count();
531 let cfg = launch_cfg(n);
532 match dtype {
533 DType::F16 => {
534 let mut s: CudaSlice<u16> = device
535 .dev
536 .alloc_zeros(n)
537 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
538 let func = device.get_func("fill_f16")?;
539 unsafe { func.launch(cfg, (&mut s, val as f32, n as u32)) }
540 .map_err(|e| Error::msg(format!("fill_f16: {e}")))?;
541 Ok(CudaStorage::F16(s))
542 }
543 DType::BF16 => {
544 let mut s: CudaSlice<u16> = device
545 .dev
546 .alloc_zeros(n)
547 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
548 let func = device.get_func("fill_bf16")?;
549 unsafe { func.launch(cfg, (&mut s, val as f32, n as u32)) }
550 .map_err(|e| Error::msg(format!("fill_bf16: {e}")))?;
551 Ok(CudaStorage::BF16(s))
552 }
553 DType::F32 => {
554 let mut s: CudaSlice<f32> = device
555 .dev
556 .alloc_zeros(n)
557 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
558 let func = device.get_func("fill_f32")?;
559 unsafe { func.launch(cfg, (&mut s, val as f32, n as u32)) }
560 .map_err(|e| Error::msg(format!("fill_f32: {e}")))?;
561 Ok(CudaStorage::F32(s))
562 }
563 DType::F64 => {
564 let mut s: CudaSlice<f64> = device
565 .dev
566 .alloc_zeros(n)
567 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
568 let func = device.get_func("fill_f64")?;
569 unsafe { func.launch(cfg, (&mut s, val, n as u32)) }
570 .map_err(|e| Error::msg(format!("fill_f64: {e}")))?;
571 Ok(CudaStorage::F64(s))
572 }
573 DType::U8 => {
574 let mut s: CudaSlice<u8> = device
575 .dev
576 .alloc_zeros(n)
577 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
578 let func = device.get_func("fill_u8")?;
579 unsafe { func.launch(cfg, (&mut s, val as u8, n as u32)) }
580 .map_err(|e| Error::msg(format!("fill_u8: {e}")))?;
581 Ok(CudaStorage::U8(s))
582 }
583 DType::U32 => {
584 let mut s: CudaSlice<u32> = device
585 .dev
586 .alloc_zeros(n)
587 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
588 let func = device.get_func("fill_u32")?;
589 unsafe { func.launch(cfg, (&mut s, val as u32, n as u32)) }
590 .map_err(|e| Error::msg(format!("fill_u32: {e}")))?;
591 Ok(CudaStorage::U32(s))
592 }
593 DType::I64 => {
594 let mut s: CudaSlice<i64> = device
595 .dev
596 .alloc_zeros(n)
597 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
598 let func = device.get_func("fill_i64")?;
599 unsafe { func.launch(cfg, (&mut s, val as i64, n as u32)) }
600 .map_err(|e| Error::msg(format!("fill_i64: {e}")))?;
601 Ok(CudaStorage::I64(s))
602 }
603 }
604 }
605
606 fn from_f64_slice(data: &[f64], dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
607 match dtype {
608 DType::F16 => {
609 let host: Vec<u16> = data.iter().map(|&v| f16::from_f64(v).to_bits()).collect();
610 let s = device
611 .dev
612 .htod_copy(host)
613 .map_err(|e| Error::msg(format!("htod f16: {e}")))?;
614 Ok(CudaStorage::F16(s))
615 }
616 DType::BF16 => {
617 let host: Vec<u16> = data.iter().map(|&v| bf16::from_f64(v).to_bits()).collect();
618 let s = device
619 .dev
620 .htod_copy(host)
621 .map_err(|e| Error::msg(format!("htod bf16: {e}")))?;
622 Ok(CudaStorage::BF16(s))
623 }
624 DType::F32 => {
625 let host: Vec<f32> = data.iter().map(|&v| v as f32).collect();
626 let s = device
627 .dev
628 .htod_copy(host)
629 .map_err(|e| Error::msg(format!("htod f32: {e}")))?;
630 Ok(CudaStorage::F32(s))
631 }
632 DType::F64 => {
633 let s = device
634 .dev
635 .htod_copy(data.to_vec())
636 .map_err(|e| Error::msg(format!("htod f64: {e}")))?;
637 Ok(CudaStorage::F64(s))
638 }
639 DType::U8 => {
640 let host: Vec<u8> = data.iter().map(|&v| v as u8).collect();
641 let s = device
642 .dev
643 .htod_copy(host)
644 .map_err(|e| Error::msg(format!("htod u8: {e}")))?;
645 Ok(CudaStorage::U8(s))
646 }
647 DType::U32 => {
648 let host: Vec<u32> = data.iter().map(|&v| v as u32).collect();
649 let s = device
650 .dev
651 .htod_copy(host)
652 .map_err(|e| Error::msg(format!("htod u32: {e}")))?;
653 Ok(CudaStorage::U32(s))
654 }
655 DType::I64 => {
656 let host: Vec<i64> = data.iter().map(|&v| v as i64).collect();
657 let s = device
658 .dev
659 .htod_copy(host)
660 .map_err(|e| Error::msg(format!("htod i64: {e}")))?;
661 Ok(CudaStorage::I64(s))
662 }
663 }
664 }
665
666 fn rand_uniform(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
667 use rand::Rng;
669 let n = shape.elem_count();
670 let mut rng = rand::thread_rng();
671 match dtype {
672 DType::F16 => {
673 let host: Vec<u16> = (0..n)
674 .map(|_| f16::from_f32(rng.gen::<f32>()).to_bits())
675 .collect();
676 let s = device
677 .dev
678 .htod_copy(host)
679 .map_err(|e| Error::msg(format!("htod rand_uniform f16: {e}")))?;
680 Ok(CudaStorage::F16(s))
681 }
682 DType::BF16 => {
683 let host: Vec<u16> = (0..n)
684 .map(|_| bf16::from_f32(rng.gen::<f32>()).to_bits())
685 .collect();
686 let s = device
687 .dev
688 .htod_copy(host)
689 .map_err(|e| Error::msg(format!("htod rand_uniform bf16: {e}")))?;
690 Ok(CudaStorage::BF16(s))
691 }
692 DType::F32 => {
693 let host: Vec<f32> = (0..n).map(|_| rng.gen::<f32>()).collect();
694 let s = device
695 .dev
696 .htod_copy(host)
697 .map_err(|e| Error::msg(format!("htod rand_uniform f32: {e}")))?;
698 Ok(CudaStorage::F32(s))
699 }
700 DType::F64 => {
701 let host: Vec<f64> = (0..n).map(|_| rng.gen::<f64>()).collect();
702 let s = device
703 .dev
704 .htod_copy(host)
705 .map_err(|e| Error::msg(format!("htod rand_uniform f64: {e}")))?;
706 Ok(CudaStorage::F64(s))
707 }
708 _ => Err(Error::msg(format!(
709 "rand_uniform not supported for {:?}",
710 dtype
711 ))),
712 }
713 }
714
715 fn rand_normal(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
716 use rand::Rng;
717 use rand_distr::StandardNormal;
718 let n = shape.elem_count();
719 let mut rng = rand::thread_rng();
720 match dtype {
721 DType::F16 => {
722 let host: Vec<u16> = (0..n)
723 .map(|_| f16::from_f32(rng.sample::<f32, _>(StandardNormal)).to_bits())
724 .collect();
725 let s = device
726 .dev
727 .htod_copy(host)
728 .map_err(|e| Error::msg(format!("htod rand_normal f16: {e}")))?;
729 Ok(CudaStorage::F16(s))
730 }
731 DType::BF16 => {
732 let host: Vec<u16> = (0..n)
733 .map(|_| bf16::from_f32(rng.sample::<f32, _>(StandardNormal)).to_bits())
734 .collect();
735 let s = device
736 .dev
737 .htod_copy(host)
738 .map_err(|e| Error::msg(format!("htod rand_normal bf16: {e}")))?;
739 Ok(CudaStorage::BF16(s))
740 }
741 DType::F32 => {
742 let host: Vec<f32> = (0..n)
743 .map(|_| rng.sample::<f32, _>(StandardNormal))
744 .collect();
745 let s = device
746 .dev
747 .htod_copy(host)
748 .map_err(|e| Error::msg(format!("htod rand_normal f32: {e}")))?;
749 Ok(CudaStorage::F32(s))
750 }
751 DType::F64 => {
752 let host: Vec<f64> = (0..n)
753 .map(|_| rng.sample::<f64, _>(StandardNormal))
754 .collect();
755 let s = device
756 .dev
757 .htod_copy(host)
758 .map_err(|e| Error::msg(format!("htod rand_normal f64: {e}")))?;
759 Ok(CudaStorage::F64(s))
760 }
761 _ => Err(Error::msg(format!(
762 "rand_normal not supported for {:?}",
763 dtype
764 ))),
765 }
766 }
767
768 fn binary_op(
771 op: BinaryOp,
772 lhs: &CudaStorage,
773 lhs_layout: &Layout,
774 rhs: &CudaStorage,
775 rhs_layout: &Layout,
776 ) -> Result<CudaStorage> {
777 let dev = dev_from_storage(lhs)?;
778
779 let lhs_c = ensure_contiguous(lhs, lhs_layout, &dev)?;
781 let rhs_c = ensure_contiguous(rhs, rhs_layout, &dev)?;
782
783 let lhs_shape = lhs_layout.shape();
784 let rhs_shape = rhs_layout.shape();
785
786 let op_name = match op {
787 BinaryOp::Add => "add",
788 BinaryOp::Sub => "sub",
789 BinaryOp::Mul => "mul",
790 BinaryOp::Div => "div",
791 };
792
793 let needs_broadcast = lhs_shape.dims() != rhs_shape.dims();
795
796 if needs_broadcast {
797 let out_shape = Shape::broadcast_shape(lhs_shape, rhs_shape)?;
799 let a_strides = lhs_shape.broadcast_strides(&out_shape);
800 let b_strides = rhs_shape.broadcast_strides(&out_shape);
801 let n = out_shape.elem_count();
802 let cfg = launch_cfg(n);
803
804 let out_dims_u32: Vec<u32> = out_shape.dims().iter().map(|&d| d as u32).collect();
806 let a_strides_u32: Vec<u32> = a_strides.iter().map(|&s| s as u32).collect();
807 let b_strides_u32: Vec<u32> = b_strides.iter().map(|&s| s as u32).collect();
808 let rank = out_shape.dims().len() as u32;
809
810 let dims_gpu = dev
811 .dev
812 .htod_copy(out_dims_u32)
813 .map_err(|e| Error::msg(format!("alloc dims: {e}")))?;
814 let a_strides_gpu = dev
815 .dev
816 .htod_copy(a_strides_u32)
817 .map_err(|e| Error::msg(format!("alloc a_strides: {e}")))?;
818 let b_strides_gpu = dev
819 .dev
820 .htod_copy(b_strides_u32)
821 .map_err(|e| Error::msg(format!("alloc b_strides: {e}")))?;
822
823 match (&lhs_c, &rhs_c) {
824 (CudaStorage::F32(a), CudaStorage::F32(b)) => {
825 let mut out: CudaSlice<f32> = dev
826 .dev
827 .alloc_zeros(n)
828 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
829 let func = dev.get_func(&format!("bcast_binary_{op_name}_f32"))?;
830 unsafe {
831 func.launch(
832 cfg,
833 (
834 a,
835 b,
836 &mut out,
837 &dims_gpu,
838 &a_strides_gpu,
839 &b_strides_gpu,
840 rank,
841 n as u32,
842 ),
843 )
844 }
845 .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
846 Ok(CudaStorage::F32(out))
847 }
848 (CudaStorage::F64(a), CudaStorage::F64(b)) => {
849 let mut out: CudaSlice<f64> = dev
850 .dev
851 .alloc_zeros(n)
852 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
853 let func = dev.get_func(&format!("bcast_binary_{op_name}_f64"))?;
854 unsafe {
855 func.launch(
856 cfg,
857 (
858 a,
859 b,
860 &mut out,
861 &dims_gpu,
862 &a_strides_gpu,
863 &b_strides_gpu,
864 rank,
865 n as u32,
866 ),
867 )
868 }
869 .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
870 Ok(CudaStorage::F64(out))
871 }
872 (CudaStorage::F16(a), CudaStorage::F16(b)) => {
873 let mut out: CudaSlice<u16> = dev
874 .dev
875 .alloc_zeros(n)
876 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
877 let func = dev.get_func(&format!("bcast_binary_{op_name}_f16"))?;
878 unsafe {
879 func.launch(
880 cfg,
881 (
882 a,
883 b,
884 &mut out,
885 &dims_gpu,
886 &a_strides_gpu,
887 &b_strides_gpu,
888 rank,
889 n as u32,
890 ),
891 )
892 }
893 .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
894 Ok(CudaStorage::F16(out))
895 }
896 (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
897 let mut out: CudaSlice<u16> = dev
898 .dev
899 .alloc_zeros(n)
900 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
901 let func = dev.get_func(&format!("bcast_binary_{op_name}_bf16"))?;
902 unsafe {
903 func.launch(
904 cfg,
905 (
906 a,
907 b,
908 &mut out,
909 &dims_gpu,
910 &a_strides_gpu,
911 &b_strides_gpu,
912 rank,
913 n as u32,
914 ),
915 )
916 }
917 .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
918 Ok(CudaStorage::BF16(out))
919 }
920 _ => Err(Error::msg(
921 "bcast binary_op: dtype mismatch or unsupported dtype",
922 )),
923 }
924 } else {
925 let n = lhs_layout.elem_count();
927 let cfg = launch_cfg(n);
928
929 match (&lhs_c, &rhs_c) {
930 (CudaStorage::F16(a), CudaStorage::F16(b)) => {
931 let mut out: CudaSlice<u16> = dev
932 .dev
933 .alloc_zeros(n)
934 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
935 let func = dev.get_func(&format!("binary_{op_name}_f16"))?;
936 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
937 .map_err(|e| Error::msg(format!("binary op: {e}")))?;
938 Ok(CudaStorage::F16(out))
939 }
940 (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
941 let mut out: CudaSlice<u16> = dev
942 .dev
943 .alloc_zeros(n)
944 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
945 let func = dev.get_func(&format!("binary_{op_name}_bf16"))?;
946 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
947 .map_err(|e| Error::msg(format!("binary op: {e}")))?;
948 Ok(CudaStorage::BF16(out))
949 }
950 (CudaStorage::F32(a), CudaStorage::F32(b)) => {
951 let mut out: CudaSlice<f32> = dev
952 .dev
953 .alloc_zeros(n)
954 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
955 let func = dev.get_func(&format!("binary_{op_name}_f32"))?;
956 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
957 .map_err(|e| Error::msg(format!("binary op: {e}")))?;
958 Ok(CudaStorage::F32(out))
959 }
960 (CudaStorage::F64(a), CudaStorage::F64(b)) => {
961 let mut out: CudaSlice<f64> = dev
962 .dev
963 .alloc_zeros(n)
964 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
965 let func = dev.get_func(&format!("binary_{op_name}_f64"))?;
966 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
967 .map_err(|e| Error::msg(format!("binary op: {e}")))?;
968 Ok(CudaStorage::F64(out))
969 }
970 _ => Err(Error::msg("binary_op: dtype mismatch or unsupported dtype")),
971 }
972 }
973 }
974
975 fn unary_op(op: UnaryOp, input: &CudaStorage, layout: &Layout) -> Result<CudaStorage> {
978 let dev = dev_from_storage(input)?;
979
980 let input_c = ensure_contiguous(input, layout, &dev)?;
981 let n = layout.elem_count();
982 let cfg = launch_cfg(n);
983
984 let op_name = match op {
985 UnaryOp::Neg => "neg",
986 UnaryOp::Abs => "abs",
987 UnaryOp::Exp => "exp",
988 UnaryOp::Log => "log",
989 UnaryOp::Sqrt => "sqrt",
990 UnaryOp::Relu => "relu",
991 UnaryOp::Sigmoid => "sigmoid",
992 UnaryOp::Tanh => "tanh",
993 UnaryOp::Gelu => "gelu",
994 UnaryOp::Silu => "silu",
995 UnaryOp::Sin => "sin",
996 UnaryOp::Cos => "cos",
997 UnaryOp::Square => "square",
998 UnaryOp::Floor => "floor",
999 UnaryOp::Ceil => "ceil",
1000 UnaryOp::Round => "round",
1001 };
1002
1003 match &input_c {
1004 CudaStorage::F16(inp) => {
1005 let mut out: CudaSlice<u16> = dev
1006 .dev
1007 .alloc_zeros(n)
1008 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1009 let func = dev.get_func(&format!("unary_{op_name}_f16"))?;
1010 unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1011 .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1012 Ok(CudaStorage::F16(out))
1013 }
1014 CudaStorage::BF16(inp) => {
1015 let mut out: CudaSlice<u16> = dev
1016 .dev
1017 .alloc_zeros(n)
1018 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1019 let func = dev.get_func(&format!("unary_{op_name}_bf16"))?;
1020 unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1021 .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1022 Ok(CudaStorage::BF16(out))
1023 }
1024 CudaStorage::F32(inp) => {
1025 let mut out: CudaSlice<f32> = dev
1026 .dev
1027 .alloc_zeros(n)
1028 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1029 let func = dev.get_func(&format!("unary_{op_name}_f32"))?;
1030 unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1031 .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1032 Ok(CudaStorage::F32(out))
1033 }
1034 CudaStorage::F64(inp) => {
1035 let mut out: CudaSlice<f64> = dev
1036 .dev
1037 .alloc_zeros(n)
1038 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1039 let func = dev.get_func(&format!("unary_{op_name}_f64"))?;
1040 unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1041 .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1042 Ok(CudaStorage::F64(out))
1043 }
1044 _ => Err(Error::msg("unary_op: only float types supported")),
1045 }
1046 }
1047
1048 fn reduce_op(
1051 op: ReduceOp,
1052 input: &CudaStorage,
1053 layout: &Layout,
1054 dims: &[usize],
1055 _keep_dim: bool,
1056 ) -> Result<CudaStorage> {
1057 let dev = dev_from_storage(input)?;
1058
1059 let input_c = ensure_contiguous(input, layout, &dev)?;
1060 let shape_dims = layout.dims();
1061 let rank = shape_dims.len();
1062
1063 let reduce_dim = if dims.is_empty() {
1065 None
1066 } else if dims.len() == 1 {
1067 Some(dims[0])
1068 } else {
1069 return Err(Error::msg(
1070 "CUDA reduce_op: multi-dim reduction not yet supported",
1071 ));
1072 };
1073
1074 let (outer_size, reduce_size, inner_size) = if let Some(dim) = reduce_dim {
1075 if dim >= rank {
1076 return Err(Error::msg(format!(
1077 "dim {dim} out of range for rank {rank}"
1078 )));
1079 }
1080 let outer: usize = shape_dims[..dim].iter().product::<usize>().max(1);
1081 let red = shape_dims[dim];
1082 let inner: usize = shape_dims[dim + 1..].iter().product::<usize>().max(1);
1083 (outer, red, inner)
1084 } else {
1085 let total: usize = shape_dims.iter().product();
1086 (1usize, total, 1usize)
1087 };
1088
1089 let out_n = outer_size * inner_size;
1090 let cfg = launch_cfg(out_n);
1091
1092 let is_arg = matches!(op, ReduceOp::ArgMax | ReduceOp::ArgMin);
1093
1094 let op_name = match op {
1095 ReduceOp::Sum => "sum",
1096 ReduceOp::Mean => "mean",
1097 ReduceOp::Max => "max",
1098 ReduceOp::Min => "min",
1099 ReduceOp::ArgMax => "argmax",
1100 ReduceOp::ArgMin => "argmin",
1101 };
1102
1103 if is_arg {
1104 let mut out: CudaSlice<i64> = dev
1106 .dev
1107 .alloc_zeros(out_n)
1108 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1109 match &input_c {
1110 CudaStorage::F16(inp) => {
1111 let func = dev.get_func(&format!("reduce_{op_name}_f16"))?;
1112 unsafe {
1113 func.launch(
1114 cfg,
1115 (
1116 inp,
1117 &mut out,
1118 outer_size as u32,
1119 reduce_size as u32,
1120 inner_size as u32,
1121 ),
1122 )
1123 }
1124 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1125 }
1126 CudaStorage::BF16(inp) => {
1127 let func = dev.get_func(&format!("reduce_{op_name}_bf16"))?;
1128 unsafe {
1129 func.launch(
1130 cfg,
1131 (
1132 inp,
1133 &mut out,
1134 outer_size as u32,
1135 reduce_size as u32,
1136 inner_size as u32,
1137 ),
1138 )
1139 }
1140 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1141 }
1142 CudaStorage::F32(inp) => {
1143 let func = dev.get_func(&format!("reduce_{op_name}_f32"))?;
1144 unsafe {
1145 func.launch(
1146 cfg,
1147 (
1148 inp,
1149 &mut out,
1150 outer_size as u32,
1151 reduce_size as u32,
1152 inner_size as u32,
1153 ),
1154 )
1155 }
1156 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1157 }
1158 CudaStorage::F64(inp) => {
1159 let func = dev.get_func(&format!("reduce_{op_name}_f64"))?;
1160 unsafe {
1161 func.launch(
1162 cfg,
1163 (
1164 inp,
1165 &mut out,
1166 outer_size as u32,
1167 reduce_size as u32,
1168 inner_size as u32,
1169 ),
1170 )
1171 }
1172 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1173 }
1174 _ => return Err(Error::msg("reduce: only float types supported")),
1175 }
1176 Ok(CudaStorage::I64(out))
1177 } else {
1178 match &input_c {
1180 CudaStorage::F16(inp) => {
1181 let mut out: CudaSlice<u16> = dev
1182 .dev
1183 .alloc_zeros(out_n)
1184 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1185 let func = dev.get_func(&format!("reduce_{op_name}_f16"))?;
1186 unsafe {
1187 func.launch(
1188 cfg,
1189 (
1190 inp,
1191 &mut out,
1192 outer_size as u32,
1193 reduce_size as u32,
1194 inner_size as u32,
1195 ),
1196 )
1197 }
1198 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1199 Ok(CudaStorage::F16(out))
1200 }
1201 CudaStorage::BF16(inp) => {
1202 let mut out: CudaSlice<u16> = dev
1203 .dev
1204 .alloc_zeros(out_n)
1205 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1206 let func = dev.get_func(&format!("reduce_{op_name}_bf16"))?;
1207 unsafe {
1208 func.launch(
1209 cfg,
1210 (
1211 inp,
1212 &mut out,
1213 outer_size as u32,
1214 reduce_size as u32,
1215 inner_size as u32,
1216 ),
1217 )
1218 }
1219 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1220 Ok(CudaStorage::BF16(out))
1221 }
1222 CudaStorage::F32(inp) => {
1223 let mut out: CudaSlice<f32> = dev
1224 .dev
1225 .alloc_zeros(out_n)
1226 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1227 let func = dev.get_func(&format!("reduce_{op_name}_f32"))?;
1228 unsafe {
1229 func.launch(
1230 cfg,
1231 (
1232 inp,
1233 &mut out,
1234 outer_size as u32,
1235 reduce_size as u32,
1236 inner_size as u32,
1237 ),
1238 )
1239 }
1240 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1241 Ok(CudaStorage::F32(out))
1242 }
1243 CudaStorage::F64(inp) => {
1244 let mut out: CudaSlice<f64> = dev
1245 .dev
1246 .alloc_zeros(out_n)
1247 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1248 let func = dev.get_func(&format!("reduce_{op_name}_f64"))?;
1249 unsafe {
1250 func.launch(
1251 cfg,
1252 (
1253 inp,
1254 &mut out,
1255 outer_size as u32,
1256 reduce_size as u32,
1257 inner_size as u32,
1258 ),
1259 )
1260 }
1261 .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1262 Ok(CudaStorage::F64(out))
1263 }
1264 _ => Err(Error::msg("reduce: only float types supported")),
1265 }
1266 }
1267 }
1268
1269 fn matmul(
1273 lhs: &CudaStorage,
1274 lhs_layout: &Layout,
1275 rhs: &CudaStorage,
1276 rhs_layout: &Layout,
1277 ) -> Result<CudaStorage> {
1278 let dev = dev_from_storage(lhs)?;
1279
1280 let lhs_c = ensure_contiguous(lhs, lhs_layout, &dev)?;
1282 let rhs_c = ensure_contiguous(rhs, rhs_layout, &dev)?;
1283
1284 let lhs_dims = lhs_layout.dims();
1285 let rhs_dims = rhs_layout.dims();
1286 let rank = lhs_dims.len();
1287 let m = lhs_dims[rank - 2];
1288 let k = lhs_dims[rank - 1];
1289 let n = rhs_dims[rhs_dims.len() - 1];
1290 let batch_size: usize = lhs_dims[..rank - 2].iter().product::<usize>().max(1);
1291
1292 match (&lhs_c, &rhs_c) {
1293 (CudaStorage::F16(a), CudaStorage::F16(b)) => {
1294 let a_n = a.len();
1296 let b_n = b.len();
1297 let mn = m * n;
1298 let total = batch_size * mn;
1299
1300 let mut a_f32: CudaSlice<f32> = dev
1302 .dev
1303 .alloc_zeros(a_n)
1304 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1305 let cast_a = dev.get_func("cast_f16_to_f32")?;
1306 let cfg_a = launch_cfg(a_n);
1307 unsafe { cast_a.launch(cfg_a, (a, &mut a_f32, a_n as u32)) }
1308 .map_err(|e| Error::msg(format!("cast: {e}")))?;
1309
1310 let mut b_f32: CudaSlice<f32> = dev
1312 .dev
1313 .alloc_zeros(b_n)
1314 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1315 let cast_b = dev.get_func("cast_f16_to_f32")?;
1316 let cfg_b = launch_cfg(b_n);
1317 unsafe { cast_b.launch(cfg_b, (b, &mut b_f32, b_n as u32)) }
1318 .map_err(|e| Error::msg(format!("cast: {e}")))?;
1319
1320 let out_f32: CudaSlice<f32> = dev
1322 .dev
1323 .alloc_zeros(total)
1324 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1325
1326 use cudarc::cublas::sys::cublasOperation_t;
1327 for batch in 0..batch_size {
1328 let a_offset = batch * m * k;
1329 let b_offset = batch * k * n;
1330 let c_offset = batch * mn;
1331 let a_slice = a_f32.slice(a_offset..a_offset + m * k);
1332 let b_slice = b_f32.slice(b_offset..b_offset + k * n);
1333 let c_slice = out_f32.slice(c_offset..c_offset + mn);
1334 unsafe {
1335 cudarc::cublas::result::sgemm(
1336 *dev.blas.handle(),
1337 cublasOperation_t::CUBLAS_OP_N,
1338 cublasOperation_t::CUBLAS_OP_N,
1339 n as i32,
1340 m as i32,
1341 k as i32,
1342 (&1.0f32) as *const f32,
1343 *b_slice.device_ptr() as *const f32,
1344 n as i32,
1345 *a_slice.device_ptr() as *const f32,
1346 k as i32,
1347 (&0.0f32) as *const f32,
1348 *c_slice.device_ptr() as *mut f32,
1349 n as i32,
1350 )
1351 }
1352 .map_err(|e| Error::msg(format!("cuBLAS sgemm: {e}")))?;
1353 }
1354
1355 let mut out: CudaSlice<u16> = dev
1357 .dev
1358 .alloc_zeros(total)
1359 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1360 let cast_out = dev.get_func("cast_f32_to_f16")?;
1361 let cfg_out = launch_cfg(total);
1362 unsafe { cast_out.launch(cfg_out, (&out_f32, &mut out, total as u32)) }
1363 .map_err(|e| Error::msg(format!("cast: {e}")))?;
1364
1365 Ok(CudaStorage::F16(out))
1366 }
1367 (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
1368 let a_n = a.len();
1370 let b_n = b.len();
1371 let mn = m * n;
1372 let total = batch_size * mn;
1373
1374 let mut a_f32: CudaSlice<f32> = dev
1375 .dev
1376 .alloc_zeros(a_n)
1377 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1378 let cast_a = dev.get_func("cast_bf16_to_f32")?;
1379 let cfg_a = launch_cfg(a_n);
1380 unsafe { cast_a.launch(cfg_a, (a, &mut a_f32, a_n as u32)) }
1381 .map_err(|e| Error::msg(format!("cast: {e}")))?;
1382
1383 let mut b_f32: CudaSlice<f32> = dev
1384 .dev
1385 .alloc_zeros(b_n)
1386 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1387 let cast_b = dev.get_func("cast_bf16_to_f32")?;
1388 let cfg_b = launch_cfg(b_n);
1389 unsafe { cast_b.launch(cfg_b, (b, &mut b_f32, b_n as u32)) }
1390 .map_err(|e| Error::msg(format!("cast: {e}")))?;
1391
1392 let out_f32: CudaSlice<f32> = dev
1393 .dev
1394 .alloc_zeros(total)
1395 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1396
1397 use cudarc::cublas::sys::cublasOperation_t;
1398 for batch in 0..batch_size {
1399 let a_offset = batch * m * k;
1400 let b_offset = batch * k * n;
1401 let c_offset = batch * mn;
1402 let a_slice = a_f32.slice(a_offset..a_offset + m * k);
1403 let b_slice = b_f32.slice(b_offset..b_offset + k * n);
1404 let c_slice = out_f32.slice(c_offset..c_offset + mn);
1405 unsafe {
1406 cudarc::cublas::result::sgemm(
1407 *dev.blas.handle(),
1408 cublasOperation_t::CUBLAS_OP_N,
1409 cublasOperation_t::CUBLAS_OP_N,
1410 n as i32,
1411 m as i32,
1412 k as i32,
1413 (&1.0f32) as *const f32,
1414 *b_slice.device_ptr() as *const f32,
1415 n as i32,
1416 *a_slice.device_ptr() as *const f32,
1417 k as i32,
1418 (&0.0f32) as *const f32,
1419 *c_slice.device_ptr() as *mut f32,
1420 n as i32,
1421 )
1422 }
1423 .map_err(|e| Error::msg(format!("cuBLAS sgemm: {e}")))?;
1424 }
1425
1426 let mut out: CudaSlice<u16> = dev
1427 .dev
1428 .alloc_zeros(total)
1429 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1430 let cast_out = dev.get_func("cast_f32_to_bf16")?;
1431 let cfg_out = launch_cfg(total);
1432 unsafe { cast_out.launch(cfg_out, (&out_f32, &mut out, total as u32)) }
1433 .map_err(|e| Error::msg(format!("cast: {e}")))?;
1434
1435 Ok(CudaStorage::BF16(out))
1436 }
1437 (CudaStorage::F32(a), CudaStorage::F32(b)) => {
1438 let mn = m * n;
1439 let total = batch_size * mn;
1440 let out: CudaSlice<f32> = dev
1441 .dev
1442 .alloc_zeros(total)
1443 .map_err(|e| Error::msg(format!("alloc matmul: {e}")))?;
1444
1445 use cudarc::cublas::sys::cublasOperation_t;
1446
1447 for batch in 0..batch_size {
1448 let a_offset = batch * m * k;
1449 let b_offset = batch * k * n;
1450 let c_offset = batch * mn;
1451
1452 let a_slice = a.slice(a_offset..a_offset + m * k);
1453 let b_slice = b.slice(b_offset..b_offset + k * n);
1454 let c_slice = out.slice(c_offset..c_offset + mn);
1455
1456 unsafe {
1457 cudarc::cublas::result::sgemm(
1458 *dev.blas.handle(),
1459 cublasOperation_t::CUBLAS_OP_N,
1460 cublasOperation_t::CUBLAS_OP_N,
1461 n as i32,
1462 m as i32,
1463 k as i32,
1464 (&1.0f32) as *const f32,
1465 *b_slice.device_ptr() as *const f32,
1466 n as i32,
1467 *a_slice.device_ptr() as *const f32,
1468 k as i32,
1469 (&0.0f32) as *const f32,
1470 *c_slice.device_ptr() as *mut f32,
1471 n as i32,
1472 )
1473 }
1474 .map_err(|e| Error::msg(format!("cuBLAS sgemm: {e}")))?;
1475 }
1476 Ok(CudaStorage::F32(out))
1477 }
1478 (CudaStorage::F64(a), CudaStorage::F64(b)) => {
1479 let mn = m * n;
1480 let total = batch_size * mn;
1481 let out: CudaSlice<f64> = dev
1482 .dev
1483 .alloc_zeros(total)
1484 .map_err(|e| Error::msg(format!("alloc matmul: {e}")))?;
1485
1486 use cudarc::cublas::sys::cublasOperation_t;
1487
1488 for batch in 0..batch_size {
1489 let a_offset = batch * m * k;
1490 let b_offset = batch * k * n;
1491 let c_offset = batch * mn;
1492
1493 let a_slice = a.slice(a_offset..a_offset + m * k);
1494 let b_slice = b.slice(b_offset..b_offset + k * n);
1495 let c_slice = out.slice(c_offset..c_offset + mn);
1496
1497 unsafe {
1498 cudarc::cublas::result::dgemm(
1499 *dev.blas.handle(),
1500 cublasOperation_t::CUBLAS_OP_N,
1501 cublasOperation_t::CUBLAS_OP_N,
1502 n as i32,
1503 m as i32,
1504 k as i32,
1505 (&1.0f64) as *const f64,
1506 *b_slice.device_ptr() as *const f64,
1507 n as i32,
1508 *a_slice.device_ptr() as *const f64,
1509 k as i32,
1510 (&0.0f64) as *const f64,
1511 *c_slice.device_ptr() as *mut f64,
1512 n as i32,
1513 )
1514 }
1515 .map_err(|e| Error::msg(format!("cuBLAS dgemm: {e}")))?;
1516 }
1517 Ok(CudaStorage::F64(out))
1518 }
1519 _ => Err(Error::msg("matmul: only f16/bf16/f32/f64 supported")),
1520 }
1521 }
1522
1523 fn to_contiguous(input: &CudaStorage, layout: &Layout) -> Result<CudaStorage> {
1526 let dev = dev_from_storage(input)?;
1527 ensure_contiguous(input, layout, &dev)
1528 }
1529
1530 fn to_f64_vec(input: &CudaStorage, layout: &Layout) -> Result<Vec<f64>> {
1533 let dev = dev_from_storage(input)?;
1534
1535 let input_c = ensure_contiguous(input, layout, &dev)?;
1537
1538 match &input_c {
1539 CudaStorage::F16(s) => {
1540 let host = dev
1541 .dev
1542 .dtoh_sync_copy(s)
1543 .map_err(|e| Error::msg(format!("dtoh f16: {e}")))?;
1544 Ok(host
1545 .iter()
1546 .map(|&bits| f16::from_bits(bits).to_f64())
1547 .collect())
1548 }
1549 CudaStorage::BF16(s) => {
1550 let host = dev
1551 .dev
1552 .dtoh_sync_copy(s)
1553 .map_err(|e| Error::msg(format!("dtoh bf16: {e}")))?;
1554 Ok(host
1555 .iter()
1556 .map(|&bits| bf16::from_bits(bits).to_f64())
1557 .collect())
1558 }
1559 CudaStorage::F32(s) => {
1560 let host = dev
1561 .dev
1562 .dtoh_sync_copy(s)
1563 .map_err(|e| Error::msg(format!("dtoh f32: {e}")))?;
1564 Ok(host.iter().map(|&v| v as f64).collect())
1565 }
1566 CudaStorage::F64(s) => {
1567 let host = dev
1568 .dev
1569 .dtoh_sync_copy(s)
1570 .map_err(|e| Error::msg(format!("dtoh f64: {e}")))?;
1571 Ok(host)
1572 }
1573 CudaStorage::U8(s) => {
1574 let host = dev
1575 .dev
1576 .dtoh_sync_copy(s)
1577 .map_err(|e| Error::msg(format!("dtoh u8: {e}")))?;
1578 Ok(host.iter().map(|&v| v as f64).collect())
1579 }
1580 CudaStorage::U32(s) => {
1581 let host = dev
1582 .dev
1583 .dtoh_sync_copy(s)
1584 .map_err(|e| Error::msg(format!("dtoh u32: {e}")))?;
1585 Ok(host.iter().map(|&v| v as f64).collect())
1586 }
1587 CudaStorage::I64(s) => {
1588 let host = dev
1589 .dev
1590 .dtoh_sync_copy(s)
1591 .map_err(|e| Error::msg(format!("dtoh i64: {e}")))?;
1592 Ok(host.iter().map(|&v| v as f64).collect())
1593 }
1594 }
1595 }
1596
1597 fn cmp_op(
1600 op: CmpOp,
1601 lhs: &CudaStorage,
1602 lhs_layout: &Layout,
1603 rhs: &CudaStorage,
1604 rhs_layout: &Layout,
1605 ) -> Result<CudaStorage> {
1606 let dev = dev_from_storage(lhs)?;
1607
1608 let lhs_c = ensure_contiguous(lhs, lhs_layout, &dev)?;
1609 let rhs_c = ensure_contiguous(rhs, rhs_layout, &dev)?;
1610 let n = lhs_layout.elem_count();
1611 let cfg = launch_cfg(n);
1612
1613 let op_name = match op {
1614 CmpOp::Eq => "eq",
1615 CmpOp::Ne => "ne",
1616 CmpOp::Gt => "gt",
1617 CmpOp::Ge => "ge",
1618 CmpOp::Lt => "lt",
1619 CmpOp::Le => "le",
1620 };
1621
1622 let mut out: CudaSlice<u8> = dev
1623 .dev
1624 .alloc_zeros(n)
1625 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1626
1627 match (&lhs_c, &rhs_c) {
1628 (CudaStorage::F16(a), CudaStorage::F16(b)) => {
1629 let func = dev.get_func(&format!("cmp_{op_name}_f16"))?;
1630 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1631 .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1632 }
1633 (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
1634 let func = dev.get_func(&format!("cmp_{op_name}_bf16"))?;
1635 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1636 .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1637 }
1638 (CudaStorage::F32(a), CudaStorage::F32(b)) => {
1639 let func = dev.get_func(&format!("cmp_{op_name}_f32"))?;
1640 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1641 .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1642 }
1643 (CudaStorage::F64(a), CudaStorage::F64(b)) => {
1644 let func = dev.get_func(&format!("cmp_{op_name}_f64"))?;
1645 unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1646 .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1647 }
1648 _ => return Err(Error::msg("cmp_op: dtype mismatch or unsupported")),
1649 }
1650
1651 Ok(CudaStorage::U8(out))
1652 }
1653
1654 fn affine(input: &CudaStorage, layout: &Layout, mul: f64, add: f64) -> Result<CudaStorage> {
1657 let dev = dev_from_storage(input)?;
1658
1659 let input_c = ensure_contiguous(input, layout, &dev)?;
1660 let n = layout.elem_count();
1661 let cfg = launch_cfg(n);
1662
1663 match &input_c {
1664 CudaStorage::F16(inp) => {
1665 let mut out: CudaSlice<u16> = dev
1666 .dev
1667 .alloc_zeros(n)
1668 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1669 let func = dev.get_func("affine_f16")?;
1670 unsafe { func.launch(cfg, (inp, &mut out, mul as f32, add as f32, n as u32)) }
1671 .map_err(|e| Error::msg(format!("affine: {e}")))?;
1672 Ok(CudaStorage::F16(out))
1673 }
1674 CudaStorage::BF16(inp) => {
1675 let mut out: CudaSlice<u16> = dev
1676 .dev
1677 .alloc_zeros(n)
1678 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1679 let func = dev.get_func("affine_bf16")?;
1680 unsafe { func.launch(cfg, (inp, &mut out, mul as f32, add as f32, n as u32)) }
1681 .map_err(|e| Error::msg(format!("affine: {e}")))?;
1682 Ok(CudaStorage::BF16(out))
1683 }
1684 CudaStorage::F32(inp) => {
1685 let mut out: CudaSlice<f32> = dev
1686 .dev
1687 .alloc_zeros(n)
1688 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1689 let func = dev.get_func("affine_f32")?;
1690 unsafe { func.launch(cfg, (inp, &mut out, mul as f32, add as f32, n as u32)) }
1691 .map_err(|e| Error::msg(format!("affine: {e}")))?;
1692 Ok(CudaStorage::F32(out))
1693 }
1694 CudaStorage::F64(inp) => {
1695 let mut out: CudaSlice<f64> = dev
1696 .dev
1697 .alloc_zeros(n)
1698 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1699 let func = dev.get_func("affine_f64")?;
1700 unsafe { func.launch(cfg, (inp, &mut out, mul, add, n as u32)) }
1701 .map_err(|e| Error::msg(format!("affine: {e}")))?;
1702 Ok(CudaStorage::F64(out))
1703 }
1704 _ => Err(Error::msg("affine: only float types supported")),
1705 }
1706 }
1707
1708 fn index_select(
1711 input: &CudaStorage,
1712 input_layout: &Layout,
1713 indices: &CudaStorage,
1714 indices_layout: &Layout,
1715 dim: usize,
1716 ) -> Result<CudaStorage> {
1717 let dev = dev_from_storage(input)?;
1718
1719 let input_c = ensure_contiguous(input, input_layout, &dev)?;
1720 let indices_c = ensure_contiguous(indices, indices_layout, &dev)?;
1721
1722 let input_dims = input_layout.dims();
1723
1724 let pre_dim: usize = input_dims[..dim].iter().product::<usize>().max(1);
1725 let src_dim = input_dims[dim];
1726 let post_dim: usize = input_dims[dim + 1..].iter().product::<usize>().max(1);
1727 let idx_len = indices_layout.elem_count();
1728 let out_n = pre_dim * idx_len * post_dim;
1729 let cfg = launch_cfg(out_n);
1730
1731 let idx_i64 = match &indices_c {
1733 CudaStorage::I64(s) => s.clone(),
1734 CudaStorage::U32(s) => {
1735 let host = dev
1736 .dev
1737 .dtoh_sync_copy(s)
1738 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
1739 let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
1740 dev.dev
1741 .htod_copy(host_i64)
1742 .map_err(|e| Error::msg(format!("htod: {e}")))?
1743 }
1744 _ => return Err(Error::msg("index_select: indices must be integer type")),
1745 };
1746
1747 match &input_c {
1748 CudaStorage::F16(inp) => {
1749 let mut out: CudaSlice<u16> = dev
1750 .dev
1751 .alloc_zeros(out_n)
1752 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1753 let func = dev.get_func("index_select_f16")?;
1754 unsafe {
1755 func.launch(
1756 cfg,
1757 (
1758 inp,
1759 &idx_i64,
1760 &mut out,
1761 pre_dim as u32,
1762 src_dim as u32,
1763 post_dim as u32,
1764 idx_len as u32,
1765 out_n as u32,
1766 ),
1767 )
1768 }
1769 .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1770 Ok(CudaStorage::F16(out))
1771 }
1772 CudaStorage::BF16(inp) => {
1773 let mut out: CudaSlice<u16> = dev
1774 .dev
1775 .alloc_zeros(out_n)
1776 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1777 let func = dev.get_func("index_select_bf16")?;
1778 unsafe {
1779 func.launch(
1780 cfg,
1781 (
1782 inp,
1783 &idx_i64,
1784 &mut out,
1785 pre_dim as u32,
1786 src_dim as u32,
1787 post_dim as u32,
1788 idx_len as u32,
1789 out_n as u32,
1790 ),
1791 )
1792 }
1793 .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1794 Ok(CudaStorage::BF16(out))
1795 }
1796 CudaStorage::F32(inp) => {
1797 let mut out: CudaSlice<f32> = dev
1798 .dev
1799 .alloc_zeros(out_n)
1800 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1801 let func = dev.get_func("index_select_f32")?;
1802 unsafe {
1803 func.launch(
1804 cfg,
1805 (
1806 inp,
1807 &idx_i64,
1808 &mut out,
1809 pre_dim as u32,
1810 src_dim as u32,
1811 post_dim as u32,
1812 idx_len as u32,
1813 out_n as u32,
1814 ),
1815 )
1816 }
1817 .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1818 Ok(CudaStorage::F32(out))
1819 }
1820 CudaStorage::F64(inp) => {
1821 let mut out: CudaSlice<f64> = dev
1822 .dev
1823 .alloc_zeros(out_n)
1824 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1825 let func = dev.get_func("index_select_f64")?;
1826 unsafe {
1827 func.launch(
1828 cfg,
1829 (
1830 inp,
1831 &idx_i64,
1832 &mut out,
1833 pre_dim as u32,
1834 src_dim as u32,
1835 post_dim as u32,
1836 idx_len as u32,
1837 out_n as u32,
1838 ),
1839 )
1840 }
1841 .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1842 Ok(CudaStorage::F64(out))
1843 }
1844 _ => Err(Error::msg("index_select: only float types supported")),
1845 }
1846 }
1847
1848 fn powf(input: &CudaStorage, layout: &Layout, exponent: f64) -> Result<CudaStorage> {
1851 let dev = dev_from_storage(input)?;
1852
1853 let input_c = ensure_contiguous(input, layout, &dev)?;
1854 let n = layout.elem_count();
1855 let cfg = launch_cfg(n);
1856
1857 match &input_c {
1858 CudaStorage::F16(inp) => {
1859 let mut out: CudaSlice<u16> = dev
1860 .dev
1861 .alloc_zeros(n)
1862 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1863 let func = dev.get_func("powf_f16")?;
1864 unsafe { func.launch(cfg, (inp, &mut out, exponent as f32, n as u32)) }
1865 .map_err(|e| Error::msg(format!("powf: {e}")))?;
1866 Ok(CudaStorage::F16(out))
1867 }
1868 CudaStorage::BF16(inp) => {
1869 let mut out: CudaSlice<u16> = dev
1870 .dev
1871 .alloc_zeros(n)
1872 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1873 let func = dev.get_func("powf_bf16")?;
1874 unsafe { func.launch(cfg, (inp, &mut out, exponent as f32, n as u32)) }
1875 .map_err(|e| Error::msg(format!("powf: {e}")))?;
1876 Ok(CudaStorage::BF16(out))
1877 }
1878 CudaStorage::F32(inp) => {
1879 let mut out: CudaSlice<f32> = dev
1880 .dev
1881 .alloc_zeros(n)
1882 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1883 let func = dev.get_func("powf_f32")?;
1884 unsafe { func.launch(cfg, (inp, &mut out, exponent as f32, n as u32)) }
1885 .map_err(|e| Error::msg(format!("powf: {e}")))?;
1886 Ok(CudaStorage::F32(out))
1887 }
1888 CudaStorage::F64(inp) => {
1889 let mut out: CudaSlice<f64> = dev
1890 .dev
1891 .alloc_zeros(n)
1892 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1893 let func = dev.get_func("powf_f64")?;
1894 unsafe { func.launch(cfg, (inp, &mut out, exponent, n as u32)) }
1895 .map_err(|e| Error::msg(format!("powf: {e}")))?;
1896 Ok(CudaStorage::F64(out))
1897 }
1898 _ => Err(Error::msg("powf: only float types supported")),
1899 }
1900 }
1901
1902 fn clamp(input: &CudaStorage, layout: &Layout, min: f64, max: f64) -> Result<CudaStorage> {
1905 let dev = dev_from_storage(input)?;
1906
1907 let input_c = ensure_contiguous(input, layout, &dev)?;
1908 let n = layout.elem_count();
1909 let cfg = launch_cfg(n);
1910
1911 match &input_c {
1912 CudaStorage::F16(inp) => {
1913 let mut out: CudaSlice<u16> = dev
1914 .dev
1915 .alloc_zeros(n)
1916 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1917 let func = dev.get_func("clamp_f16")?;
1918 unsafe { func.launch(cfg, (inp, &mut out, min as f32, max as f32, n as u32)) }
1919 .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1920 Ok(CudaStorage::F16(out))
1921 }
1922 CudaStorage::BF16(inp) => {
1923 let mut out: CudaSlice<u16> = dev
1924 .dev
1925 .alloc_zeros(n)
1926 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1927 let func = dev.get_func("clamp_bf16")?;
1928 unsafe { func.launch(cfg, (inp, &mut out, min as f32, max as f32, n as u32)) }
1929 .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1930 Ok(CudaStorage::BF16(out))
1931 }
1932 CudaStorage::F32(inp) => {
1933 let mut out: CudaSlice<f32> = dev
1934 .dev
1935 .alloc_zeros(n)
1936 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1937 let func = dev.get_func("clamp_f32")?;
1938 unsafe { func.launch(cfg, (inp, &mut out, min as f32, max as f32, n as u32)) }
1939 .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1940 Ok(CudaStorage::F32(out))
1941 }
1942 CudaStorage::F64(inp) => {
1943 let mut out: CudaSlice<f64> = dev
1944 .dev
1945 .alloc_zeros(n)
1946 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1947 let func = dev.get_func("clamp_f64")?;
1948 unsafe { func.launch(cfg, (inp, &mut out, min, max, n as u32)) }
1949 .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1950 Ok(CudaStorage::F64(out))
1951 }
1952 _ => Err(Error::msg("clamp: only float types supported")),
1953 }
1954 }
1955
1956 fn where_cond(
1959 mask: &CudaStorage,
1960 mask_layout: &Layout,
1961 on_true: &CudaStorage,
1962 on_true_layout: &Layout,
1963 on_false: &CudaStorage,
1964 on_false_layout: &Layout,
1965 ) -> Result<CudaStorage> {
1966 let dev = dev_from_storage(mask)?;
1967
1968 let mask_c = ensure_contiguous(mask, mask_layout, &dev)?;
1969 let true_c = ensure_contiguous(on_true, on_true_layout, &dev)?;
1970 let false_c = ensure_contiguous(on_false, on_false_layout, &dev)?;
1971 let n = mask_layout.elem_count();
1972 let cfg = launch_cfg(n);
1973
1974 let mask_u8 = match &mask_c {
1975 CudaStorage::U8(s) => s,
1976 _ => return Err(Error::msg("where_cond: mask must be u8")),
1977 };
1978
1979 match (&true_c, &false_c) {
1980 (CudaStorage::F16(t), CudaStorage::F16(f_vals)) => {
1981 let mut out: CudaSlice<u16> = dev
1982 .dev
1983 .alloc_zeros(n)
1984 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1985 let func = dev.get_func("where_cond_f16")?;
1986 unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
1987 .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
1988 Ok(CudaStorage::F16(out))
1989 }
1990 (CudaStorage::BF16(t), CudaStorage::BF16(f_vals)) => {
1991 let mut out: CudaSlice<u16> = dev
1992 .dev
1993 .alloc_zeros(n)
1994 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1995 let func = dev.get_func("where_cond_bf16")?;
1996 unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
1997 .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
1998 Ok(CudaStorage::BF16(out))
1999 }
2000 (CudaStorage::F32(t), CudaStorage::F32(f_vals)) => {
2001 let mut out: CudaSlice<f32> = dev
2002 .dev
2003 .alloc_zeros(n)
2004 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2005 let func = dev.get_func("where_cond_f32")?;
2006 unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
2007 .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
2008 Ok(CudaStorage::F32(out))
2009 }
2010 (CudaStorage::F64(t), CudaStorage::F64(f_vals)) => {
2011 let mut out: CudaSlice<f64> = dev
2012 .dev
2013 .alloc_zeros(n)
2014 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2015 let func = dev.get_func("where_cond_f64")?;
2016 unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
2017 .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
2018 Ok(CudaStorage::F64(out))
2019 }
2020 _ => Err(Error::msg("where_cond: dtype mismatch")),
2021 }
2022 }
2023
2024 fn gather(
2027 input: &CudaStorage,
2028 input_layout: &Layout,
2029 index: &CudaStorage,
2030 index_layout: &Layout,
2031 dim: usize,
2032 ) -> Result<CudaStorage> {
2033 let dev = dev_from_storage(input)?;
2034
2035 let input_c = ensure_contiguous(input, input_layout, &dev)?;
2036 let index_c = ensure_contiguous(index, index_layout, &dev)?;
2037
2038 let input_dims = input_layout.dims();
2039 let index_dims = index_layout.dims();
2040
2041 let pre: usize = input_dims[..dim].iter().product::<usize>().max(1);
2042 let inp_dim = input_dims[dim];
2043 let idx_dim = index_dims[dim];
2044 let post: usize = input_dims[dim + 1..].iter().product::<usize>().max(1);
2045 let n = index_layout.elem_count();
2046 let cfg = launch_cfg(n);
2047
2048 let idx_i64 = match &index_c {
2050 CudaStorage::I64(s) => s.clone(),
2051 CudaStorage::U32(s) => {
2052 let host = dev
2053 .dev
2054 .dtoh_sync_copy(s)
2055 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2056 let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
2057 dev.dev
2058 .htod_copy(host_i64)
2059 .map_err(|e| Error::msg(format!("htod: {e}")))?
2060 }
2061 CudaStorage::F32(s) => {
2062 let host = dev
2063 .dev
2064 .dtoh_sync_copy(s)
2065 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2066 let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
2067 dev.dev
2068 .htod_copy(host_i64)
2069 .map_err(|e| Error::msg(format!("htod: {e}")))?
2070 }
2071 CudaStorage::F64(s) => {
2072 let host = dev
2073 .dev
2074 .dtoh_sync_copy(s)
2075 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2076 let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
2077 dev.dev
2078 .htod_copy(host_i64)
2079 .map_err(|e| Error::msg(format!("htod: {e}")))?
2080 }
2081 _ => return Err(Error::msg("gather: unsupported index dtype")),
2082 };
2083
2084 match &input_c {
2085 CudaStorage::F16(inp) => {
2086 let mut out: CudaSlice<u16> = dev
2087 .dev
2088 .alloc_zeros(n)
2089 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2090 let func = dev.get_func("gather_f16")?;
2091 unsafe {
2092 func.launch(
2093 cfg,
2094 (
2095 inp,
2096 &idx_i64,
2097 &mut out,
2098 pre as u32,
2099 inp_dim as u32,
2100 idx_dim as u32,
2101 post as u32,
2102 n as u32,
2103 ),
2104 )
2105 }
2106 .map_err(|e| Error::msg(format!("gather: {e}")))?;
2107 Ok(CudaStorage::F16(out))
2108 }
2109 CudaStorage::BF16(inp) => {
2110 let mut out: CudaSlice<u16> = dev
2111 .dev
2112 .alloc_zeros(n)
2113 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2114 let func = dev.get_func("gather_bf16")?;
2115 unsafe {
2116 func.launch(
2117 cfg,
2118 (
2119 inp,
2120 &idx_i64,
2121 &mut out,
2122 pre as u32,
2123 inp_dim as u32,
2124 idx_dim as u32,
2125 post as u32,
2126 n as u32,
2127 ),
2128 )
2129 }
2130 .map_err(|e| Error::msg(format!("gather: {e}")))?;
2131 Ok(CudaStorage::BF16(out))
2132 }
2133 CudaStorage::F32(inp) => {
2134 let mut out: CudaSlice<f32> = dev
2135 .dev
2136 .alloc_zeros(n)
2137 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2138 let func = dev.get_func("gather_f32")?;
2139 unsafe {
2140 func.launch(
2141 cfg,
2142 (
2143 inp,
2144 &idx_i64,
2145 &mut out,
2146 pre as u32,
2147 inp_dim as u32,
2148 idx_dim as u32,
2149 post as u32,
2150 n as u32,
2151 ),
2152 )
2153 }
2154 .map_err(|e| Error::msg(format!("gather: {e}")))?;
2155 Ok(CudaStorage::F32(out))
2156 }
2157 CudaStorage::F64(inp) => {
2158 let mut out: CudaSlice<f64> = dev
2159 .dev
2160 .alloc_zeros(n)
2161 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2162 let func = dev.get_func("gather_f64")?;
2163 unsafe {
2164 func.launch(
2165 cfg,
2166 (
2167 inp,
2168 &idx_i64,
2169 &mut out,
2170 pre as u32,
2171 inp_dim as u32,
2172 idx_dim as u32,
2173 post as u32,
2174 n as u32,
2175 ),
2176 )
2177 }
2178 .map_err(|e| Error::msg(format!("gather: {e}")))?;
2179 Ok(CudaStorage::F64(out))
2180 }
2181 _ => Err(Error::msg("gather: only float types supported")),
2182 }
2183 }
2184
2185 fn cat(
2188 inputs: &[(&CudaStorage, &Layout)],
2189 out_shape: &Shape,
2190 dim: usize,
2191 ) -> Result<CudaStorage> {
2192 if inputs.is_empty() {
2193 return Err(Error::msg("cat: empty input list"));
2194 }
2195
2196 let dev = dev_from_storage(inputs[0].0)?;
2197
2198 let out_dims = out_shape.dims();
2199 let out_n = out_shape.elem_count();
2200
2201 let outer: usize = out_dims[..dim].iter().product::<usize>().max(1);
2202 let total_dim = out_dims[dim];
2203 let inner: usize = out_dims[dim + 1..].iter().product::<usize>().max(1);
2204
2205 match inputs[0].0 {
2206 CudaStorage::F16(_) => {
2207 let mut out: CudaSlice<u16> = dev
2208 .dev
2209 .alloc_zeros(out_n)
2210 .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2211 let mut dim_offset = 0u32;
2212
2213 for &(storage, layout) in inputs {
2214 let storage_c = ensure_contiguous(storage, layout, &dev)?;
2215 let inp = match &storage_c {
2216 CudaStorage::F16(s) => s,
2217 _ => return Err(Error::msg("cat: dtype mismatch")),
2218 };
2219 let t_dims = layout.dims();
2220 let this_dim = t_dims[dim];
2221 let src_n = layout.elem_count();
2222 let cfg = launch_cfg(src_n);
2223
2224 let func = dev.get_func("cat_copy_f16")?;
2225 unsafe {
2226 func.launch(
2227 cfg,
2228 (
2229 inp,
2230 &mut out,
2231 outer as u32,
2232 this_dim as u32,
2233 inner as u32,
2234 total_dim as u32,
2235 dim_offset,
2236 src_n as u32,
2237 ),
2238 )
2239 }
2240 .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2241
2242 dim_offset += this_dim as u32;
2243 }
2244 Ok(CudaStorage::F16(out))
2245 }
2246 CudaStorage::BF16(_) => {
2247 let mut out: CudaSlice<u16> = dev
2248 .dev
2249 .alloc_zeros(out_n)
2250 .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2251 let mut dim_offset = 0u32;
2252
2253 for &(storage, layout) in inputs {
2254 let storage_c = ensure_contiguous(storage, layout, &dev)?;
2255 let inp = match &storage_c {
2256 CudaStorage::BF16(s) => s,
2257 _ => return Err(Error::msg("cat: dtype mismatch")),
2258 };
2259 let t_dims = layout.dims();
2260 let this_dim = t_dims[dim];
2261 let src_n = layout.elem_count();
2262 let cfg = launch_cfg(src_n);
2263
2264 let func = dev.get_func("cat_copy_bf16")?;
2265 unsafe {
2266 func.launch(
2267 cfg,
2268 (
2269 inp,
2270 &mut out,
2271 outer as u32,
2272 this_dim as u32,
2273 inner as u32,
2274 total_dim as u32,
2275 dim_offset,
2276 src_n as u32,
2277 ),
2278 )
2279 }
2280 .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2281
2282 dim_offset += this_dim as u32;
2283 }
2284 Ok(CudaStorage::BF16(out))
2285 }
2286 CudaStorage::F32(_) => {
2287 let mut out: CudaSlice<f32> = dev
2288 .dev
2289 .alloc_zeros(out_n)
2290 .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2291 let mut dim_offset = 0u32;
2292
2293 for &(storage, layout) in inputs {
2294 let storage_c = ensure_contiguous(storage, layout, &dev)?;
2295 let inp = match &storage_c {
2296 CudaStorage::F32(s) => s,
2297 _ => return Err(Error::msg("cat: dtype mismatch")),
2298 };
2299 let t_dims = layout.dims();
2300 let this_dim = t_dims[dim];
2301 let src_n = layout.elem_count();
2302 let cfg = launch_cfg(src_n);
2303
2304 let func = dev.get_func("cat_copy_f32")?;
2305 unsafe {
2306 func.launch(
2307 cfg,
2308 (
2309 inp,
2310 &mut out,
2311 outer as u32,
2312 this_dim as u32,
2313 inner as u32,
2314 total_dim as u32,
2315 dim_offset,
2316 src_n as u32,
2317 ),
2318 )
2319 }
2320 .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2321
2322 dim_offset += this_dim as u32;
2323 }
2324 Ok(CudaStorage::F32(out))
2325 }
2326 CudaStorage::F64(_) => {
2327 let mut out: CudaSlice<f64> = dev
2328 .dev
2329 .alloc_zeros(out_n)
2330 .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2331 let mut dim_offset = 0u32;
2332
2333 for &(storage, layout) in inputs {
2334 let storage_c = ensure_contiguous(storage, layout, &dev)?;
2335 let inp = match &storage_c {
2336 CudaStorage::F64(s) => s,
2337 _ => return Err(Error::msg("cat: dtype mismatch")),
2338 };
2339 let t_dims = layout.dims();
2340 let this_dim = t_dims[dim];
2341 let src_n = layout.elem_count();
2342 let cfg = launch_cfg(src_n);
2343
2344 let func = dev.get_func("cat_copy_f64")?;
2345 unsafe {
2346 func.launch(
2347 cfg,
2348 (
2349 inp,
2350 &mut out,
2351 outer as u32,
2352 this_dim as u32,
2353 inner as u32,
2354 total_dim as u32,
2355 dim_offset,
2356 src_n as u32,
2357 ),
2358 )
2359 }
2360 .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2361
2362 dim_offset += this_dim as u32;
2363 }
2364 Ok(CudaStorage::F64(out))
2365 }
2366 _ => Err(Error::msg("cat: only float types supported")),
2367 }
2368 }
2369
2370 fn cast(
2371 input: &CudaStorage,
2372 layout: &Layout,
2373 dtype: DType,
2374 device: &CudaDevice,
2375 ) -> Result<CudaStorage> {
2376 let src_dtype = input.dtype();
2377 if src_dtype == dtype {
2378 return Ok(input.clone());
2379 }
2380
2381 let dev = dev_from_storage(input)?;
2383 let contig = ensure_contiguous(input, layout, &dev)?;
2384 let n = layout.shape().elem_count();
2385 let cfg = launch_cfg(n);
2386
2387 match (src_dtype, dtype) {
2388 (DType::F16, DType::F32) => {
2389 let src_slice = contig.as_cuda_slice_u16()?;
2390 let out: CudaSlice<f32> = dev
2391 .dev
2392 .alloc_zeros(n)
2393 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2394 let func = dev.get_func("cast_f16_to_f32")?;
2395 unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2396 .map_err(|e| Error::msg(format!("launch cast_f16_to_f32: {e}")))?;
2397 Ok(CudaStorage::F32(out))
2398 }
2399 (DType::F32, DType::F16) => {
2400 let src_slice = contig.as_cuda_slice_f32()?;
2401 let out: CudaSlice<u16> = dev
2402 .dev
2403 .alloc_zeros(n)
2404 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2405 let func = dev.get_func("cast_f32_to_f16")?;
2406 unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2407 .map_err(|e| Error::msg(format!("launch cast_f32_to_f16: {e}")))?;
2408 Ok(CudaStorage::F16(out))
2409 }
2410 (DType::BF16, DType::F32) => {
2411 let src_slice = contig.as_cuda_slice_u16()?;
2412 let out: CudaSlice<f32> = dev
2413 .dev
2414 .alloc_zeros(n)
2415 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2416 let func = dev.get_func("cast_bf16_to_f32")?;
2417 unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2418 .map_err(|e| Error::msg(format!("launch cast_bf16_to_f32: {e}")))?;
2419 Ok(CudaStorage::F32(out))
2420 }
2421 (DType::F32, DType::BF16) => {
2422 let src_slice = contig.as_cuda_slice_f32()?;
2423 let out: CudaSlice<u16> = dev
2424 .dev
2425 .alloc_zeros(n)
2426 .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2427 let func = dev.get_func("cast_f32_to_bf16")?;
2428 unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2429 .map_err(|e| Error::msg(format!("launch cast_f32_to_bf16: {e}")))?;
2430 Ok(CudaStorage::BF16(out))
2431 }
2432 (DType::F16, DType::F64) | (DType::BF16, DType::F64) => {
2435 let f32_storage = Self::cast(
2436 &contig,
2437 &Layout::contiguous(layout.shape().clone()),
2438 DType::F32,
2439 device,
2440 )?;
2441 let f32_layout = Layout::contiguous(layout.shape().clone());
2442 Self::cast(&f32_storage, &f32_layout, DType::F64, device)
2443 }
2444 (DType::F64, DType::F16) | (DType::F64, DType::BF16) => {
2445 let f32_storage = Self::cast(
2446 &contig,
2447 &Layout::contiguous(layout.shape().clone()),
2448 DType::F32,
2449 device,
2450 )?;
2451 let f32_layout = Layout::contiguous(layout.shape().clone());
2452 Self::cast(&f32_storage, &f32_layout, dtype, device)
2453 }
2454 (DType::F16, DType::BF16) | (DType::BF16, DType::F16) => {
2455 let f32_storage = Self::cast(
2456 &contig,
2457 &Layout::contiguous(layout.shape().clone()),
2458 DType::F32,
2459 device,
2460 )?;
2461 let f32_layout = Layout::contiguous(layout.shape().clone());
2462 Self::cast(&f32_storage, &f32_layout, dtype, device)
2463 }
2464 _ => {
2465 let data = Self::to_f64_vec(&contig, &Layout::contiguous(layout.shape().clone()))?;
2467 Self::from_f64_slice(&data, dtype, device)
2468 }
2469 }
2470 }
2471}
2472
2473impl CudaStorage {
2476 pub fn as_cuda_slice_f32(&self) -> Result<&CudaSlice<f32>> {
2478 match self {
2479 CudaStorage::F32(s) => Ok(s),
2480 _ => Err(Error::msg(format!(
2481 "expected F32 storage, got {:?}",
2482 self.dtype()
2483 ))),
2484 }
2485 }
2486
2487 pub fn as_cuda_slice_u16(&self) -> Result<&CudaSlice<u16>> {
2489 match self {
2490 CudaStorage::F16(s) | CudaStorage::BF16(s) => Ok(s),
2491 _ => Err(Error::msg(format!(
2492 "expected F16/BF16 storage, got {:?}",
2493 self.dtype()
2494 ))),
2495 }
2496 }
2497
2498 pub fn from_f32_vec(data: Vec<f32>, device: &CudaDevice) -> Result<Self> {
2500 let s = device
2501 .dev
2502 .htod_copy(data)
2503 .map_err(|e| Error::msg(format!("htod f32: {e}")))?;
2504 Ok(CudaStorage::F32(s))
2505 }
2506
2507 pub fn from_f64_vec(data: Vec<f64>, device: &CudaDevice) -> Result<Self> {
2509 let s = device
2510 .dev
2511 .htod_copy(data)
2512 .map_err(|e| Error::msg(format!("htod f64: {e}")))?;
2513 Ok(CudaStorage::F64(s))
2514 }
2515
2516 pub fn from_f16_vec(data: Vec<f16>, device: &CudaDevice) -> Result<Self> {
2518 let bits: Vec<u16> = data.iter().map(|v| v.to_bits()).collect();
2519 let s = device
2520 .dev
2521 .htod_copy(bits)
2522 .map_err(|e| Error::msg(format!("htod f16: {e}")))?;
2523 Ok(CudaStorage::F16(s))
2524 }
2525
2526 pub fn from_bf16_vec(data: Vec<bf16>, device: &CudaDevice) -> Result<Self> {
2528 let bits: Vec<u16> = data.iter().map(|v| v.to_bits()).collect();
2529 let s = device
2530 .dev
2531 .htod_copy(bits)
2532 .map_err(|e| Error::msg(format!("htod bf16: {e}")))?;
2533 Ok(CudaStorage::BF16(s))
2534 }
2535
2536 pub fn to_host_f64(&self, device: &CudaDevice) -> Result<Vec<f64>> {
2538 match self {
2539 CudaStorage::F16(s) => {
2540 let host = device
2541 .dev
2542 .dtoh_sync_copy(s)
2543 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2544 Ok(host
2545 .iter()
2546 .map(|&bits| f16::from_bits(bits).to_f64())
2547 .collect())
2548 }
2549 CudaStorage::BF16(s) => {
2550 let host = device
2551 .dev
2552 .dtoh_sync_copy(s)
2553 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2554 Ok(host
2555 .iter()
2556 .map(|&bits| bf16::from_bits(bits).to_f64())
2557 .collect())
2558 }
2559 CudaStorage::F32(s) => {
2560 let host = device
2561 .dev
2562 .dtoh_sync_copy(s)
2563 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2564 Ok(host.iter().map(|&v| v as f64).collect())
2565 }
2566 CudaStorage::F64(s) => device
2567 .dev
2568 .dtoh_sync_copy(s)
2569 .map_err(|e| Error::msg(format!("dtoh: {e}"))),
2570 CudaStorage::U8(s) => {
2571 let host = device
2572 .dev
2573 .dtoh_sync_copy(s)
2574 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2575 Ok(host.iter().map(|&v| v as f64).collect())
2576 }
2577 CudaStorage::U32(s) => {
2578 let host = device
2579 .dev
2580 .dtoh_sync_copy(s)
2581 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2582 Ok(host.iter().map(|&v| v as f64).collect())
2583 }
2584 CudaStorage::I64(s) => {
2585 let host = device
2586 .dev
2587 .dtoh_sync_copy(s)
2588 .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2589 Ok(host.iter().map(|&v| v as f64).collect())
2590 }
2591 }
2592 }
2593}
2594
2595pub type CudaTensor = shrew_core::Tensor<CudaBackend>;