shrew_cuda/
lib.rs

1// CUDA Backend — GPU-accelerated compute backend using cudarc
2//
3// This module provides a full CUDA implementation of the Shrew Backend trait.
4// All tensor operations run on NVIDIA GPUs via custom CUDA kernels (compiled
5// at device creation via NVRTC) and cuBLAS for matrix multiplication.
6//
7// ARCHITECTURE:
8// - CudaDevice wraps cudarc's device handle + cuBLAS handle
9// - CudaStorage is an enum over CudaSlice<T> for each supported dtype
10// - All kernels operate on contiguous data; non-contiguous inputs are
11//   first copied to contiguous layout using a strided-copy kernel
12// - Random number generation happens on host and is transferred to device
13// - F16 and BF16 are stored as CudaSlice<u16> and computed via promote-to-F32
14//   CUDA kernels (portable across all GPU architectures)
15//
16// USAGE:
17//   let device = CudaDevice::new(0)?;  // GPU ordinal 0
18//   let tensor = Tensor::<CudaBackend>::zeros(&[2, 3], DType::F32, &device)?;
19
20mod kernels;
21pub mod pool;
22
23use cudarc::cublas::CudaBlas;
24use cudarc::driver::{CudaSlice, DevicePtr, DeviceSlice, LaunchAsync, LaunchConfig};
25use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions};
26use half::{bf16, f16};
27use pool::CudaMemPool;
28use std::fmt;
29use std::sync::Arc;
30
31use shrew_core::backend::{
32    Backend, BackendDevice, BackendStorage, BinaryOp, CmpOp, ReduceOp, UnaryOp,
33};
34use shrew_core::dtype::DType;
35use shrew_core::error::{Error, Result};
36use shrew_core::layout::Layout;
37use shrew_core::shape::Shape;
38
39// CudaDevice — Wraps a cudarc CUDA device + cuBLAS handle
40
41/// A CUDA device handle. Contains the cudarc device and a cuBLAS handle
42/// for matrix multiplication. Clonable (uses Arc internally).
43pub struct CudaDevice {
44    dev: Arc<cudarc::driver::CudaDevice>,
45    blas: Arc<CudaBlas>,
46    pool: Arc<CudaMemPool>,
47    ordinal: usize,
48}
49
50impl CudaDevice {
51    /// Create a new CUDA device for the given GPU ordinal (0, 1, ...).
52    /// Compiles all Shrew CUDA kernels on first creation.
53    pub fn new(ordinal: usize) -> Result<Self> {
54        let dev = cudarc::driver::CudaDevice::new(ordinal)
55            .map_err(|e| Error::msg(format!("CUDA device creation failed: {e}")))?;
56
57        let blas = CudaBlas::new(dev.clone())
58            .map_err(|e| Error::msg(format!("cuBLAS init failed: {e}")))?;
59
60        // Compile and load all kernels
61        // Query the device compute capability and target it with NVRTC.
62        // Use sm_XX (native SASS) instead of compute_XX (PTX) to avoid
63        // PTX version mismatches between toolkit and driver versions.
64        let major = dev.attribute(cudarc::driver::sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR).unwrap_or(8);
65        let minor = dev.attribute(cudarc::driver::sys::CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR).unwrap_or(9);
66        let arch_str: &'static str = Box::leak(format!("sm_{major}{minor}").into_boxed_str());
67        let opts = CompileOptions {
68            arch: Some(arch_str),
69            ..Default::default()
70        };
71        let ptx = compile_ptx_with_opts(kernels::KERNEL_SOURCE, opts)
72            .map_err(|e| Error::msg(format!("NVRTC compilation failed: {e}")))?;
73        dev.load_ptx(ptx, kernels::MODULE_NAME, kernels::KERNEL_NAMES)
74            .map_err(|e| Error::msg(format!("PTX load failed: {e}")))?;
75
76        Ok(CudaDevice {
77            dev,
78            blas: Arc::new(blas),
79            pool: Arc::new(CudaMemPool::new()),
80            ordinal,
81        })
82    }
83
84    /// Get the underlying cudarc device handle.
85    pub fn device(&self) -> &Arc<cudarc::driver::CudaDevice> {
86        &self.dev
87    }
88
89    /// Get the cuBLAS handle.
90    pub fn blas(&self) -> &CudaBlas {
91        &self.blas
92    }
93
94    /// Get a compiled kernel function by name.
95    fn get_func(&self, name: &str) -> Result<cudarc::driver::CudaFunction> {
96        self.dev
97            .get_func(kernels::MODULE_NAME, name)
98            .ok_or_else(|| Error::msg(format!("CUDA kernel '{name}' not found")))
99    }
100
101    // ── Memory pool helpers ──────────────────────────────────────────────
102
103    /// Get the memory pool.
104    pub fn pool(&self) -> &CudaMemPool {
105        &self.pool
106    }
107
108    /// Release all cached GPU memory back to the CUDA driver.
109    pub fn empty_cache(&self) {
110        self.pool.empty_cache();
111    }
112
113    /// Return pool statistics (cached bytes, hits, misses, etc.).
114    pub fn pool_stats(&self) -> pool::PoolStats {
115        self.pool.stats()
116    }
117
118    /// Reclaim a CudaStorage buffer into the pool for future reuse.
119    pub fn reclaim(&self, storage: CudaStorage) {
120        self.pool.reclaim_storage(storage);
121    }
122
123    // ── Pool-aware allocation helpers ────────────────────────────────────
124
125    /// Allocate `n` elements from the pool (content undefined).
126    pub fn pool_alloc_f32(&self, n: usize) -> Result<CudaSlice<f32>> {
127        self.pool
128            .alloc_f32(&self.dev, n)
129            .map_err(|e| Error::msg(format!("pool alloc f32: {e}")))
130    }
131    pub fn pool_alloc_f64(&self, n: usize) -> Result<CudaSlice<f64>> {
132        self.pool
133            .alloc_f64(&self.dev, n)
134            .map_err(|e| Error::msg(format!("pool alloc f64: {e}")))
135    }
136    pub fn pool_alloc_u16(&self, n: usize) -> Result<CudaSlice<u16>> {
137        self.pool
138            .alloc_u16(&self.dev, n)
139            .map_err(|e| Error::msg(format!("pool alloc u16: {e}")))
140    }
141    pub fn pool_alloc_u8(&self, n: usize) -> Result<CudaSlice<u8>> {
142        self.pool
143            .alloc_u8(&self.dev, n)
144            .map_err(|e| Error::msg(format!("pool alloc u8: {e}")))
145    }
146    pub fn pool_alloc_u32(&self, n: usize) -> Result<CudaSlice<u32>> {
147        self.pool
148            .alloc_u32(&self.dev, n)
149            .map_err(|e| Error::msg(format!("pool alloc u32: {e}")))
150    }
151    pub fn pool_alloc_i64(&self, n: usize) -> Result<CudaSlice<i64>> {
152        self.pool
153            .alloc_i64(&self.dev, n)
154            .map_err(|e| Error::msg(format!("pool alloc i64: {e}")))
155    }
156
157    /// Allocate `n` elements from the pool, zeroed.
158    pub fn pool_alloc_zeros_f32(&self, n: usize) -> Result<CudaSlice<f32>> {
159        self.pool
160            .alloc_zeros_f32(&self.dev, n)
161            .map_err(|e| Error::msg(format!("pool alloc zeros f32: {e}")))
162    }
163    pub fn pool_alloc_zeros_f64(&self, n: usize) -> Result<CudaSlice<f64>> {
164        self.pool
165            .alloc_zeros_f64(&self.dev, n)
166            .map_err(|e| Error::msg(format!("pool alloc zeros f64: {e}")))
167    }
168    pub fn pool_alloc_zeros_u16(&self, n: usize) -> Result<CudaSlice<u16>> {
169        self.pool
170            .alloc_zeros_u16(&self.dev, n)
171            .map_err(|e| Error::msg(format!("pool alloc zeros u16: {e}")))
172    }
173    pub fn pool_alloc_zeros_u8(&self, n: usize) -> Result<CudaSlice<u8>> {
174        self.pool
175            .alloc_zeros_u8(&self.dev, n)
176            .map_err(|e| Error::msg(format!("pool alloc zeros u8: {e}")))
177    }
178    pub fn pool_alloc_zeros_u32(&self, n: usize) -> Result<CudaSlice<u32>> {
179        self.pool
180            .alloc_zeros_u32(&self.dev, n)
181            .map_err(|e| Error::msg(format!("pool alloc zeros u32: {e}")))
182    }
183    pub fn pool_alloc_zeros_i64(&self, n: usize) -> Result<CudaSlice<i64>> {
184        self.pool
185            .alloc_zeros_i64(&self.dev, n)
186            .map_err(|e| Error::msg(format!("pool alloc zeros i64: {e}")))
187    }
188}
189
190impl Clone for CudaDevice {
191    fn clone(&self) -> Self {
192        CudaDevice {
193            dev: self.dev.clone(),
194            blas: self.blas.clone(),
195            pool: self.pool.clone(),
196            ordinal: self.ordinal,
197        }
198    }
199}
200
201impl fmt::Debug for CudaDevice {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        write!(f, "CudaDevice(cuda:{})", self.ordinal)
204    }
205}
206
207// Safety: cudarc's device is thread-safe (CUDA runtime is thread-safe)
208unsafe impl Send for CudaDevice {}
209unsafe impl Sync for CudaDevice {}
210
211impl BackendDevice for CudaDevice {
212    fn name(&self) -> String {
213        format!("cuda:{}", self.ordinal)
214    }
215}
216
217// CudaStorage — Device memory for each supported dtype
218
219/// GPU-side storage. Each variant wraps a cudarc CudaSlice for the corresponding dtype.
220/// F16 and BF16 are stored as CudaSlice<u16> (bit-level representation).
221pub enum CudaStorage {
222    F16(CudaSlice<u16>),
223    BF16(CudaSlice<u16>),
224    F32(CudaSlice<f32>),
225    F64(CudaSlice<f64>),
226    U8(CudaSlice<u8>),
227    U32(CudaSlice<u32>),
228    I64(CudaSlice<i64>),
229}
230
231impl Clone for CudaStorage {
232    fn clone(&self) -> Self {
233        match self {
234            CudaStorage::F16(s) => CudaStorage::F16(s.clone()),
235            CudaStorage::BF16(s) => CudaStorage::BF16(s.clone()),
236            CudaStorage::F32(s) => CudaStorage::F32(s.clone()),
237            CudaStorage::F64(s) => CudaStorage::F64(s.clone()),
238            CudaStorage::U8(s) => CudaStorage::U8(s.clone()),
239            CudaStorage::U32(s) => CudaStorage::U32(s.clone()),
240            CudaStorage::I64(s) => CudaStorage::I64(s.clone()),
241        }
242    }
243}
244
245impl fmt::Debug for CudaStorage {
246    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247        match self {
248            CudaStorage::F16(s) => write!(f, "CudaStorage::F16(len={})", s.len()),
249            CudaStorage::BF16(s) => write!(f, "CudaStorage::BF16(len={})", s.len()),
250            CudaStorage::F32(s) => write!(f, "CudaStorage::F32(len={})", s.len()),
251            CudaStorage::F64(s) => write!(f, "CudaStorage::F64(len={})", s.len()),
252            CudaStorage::U8(s) => write!(f, "CudaStorage::U8(len={})", s.len()),
253            CudaStorage::U32(s) => write!(f, "CudaStorage::U32(len={})", s.len()),
254            CudaStorage::I64(s) => write!(f, "CudaStorage::I64(len={})", s.len()),
255        }
256    }
257}
258
259unsafe impl Send for CudaStorage {}
260unsafe impl Sync for CudaStorage {}
261
262impl BackendStorage for CudaStorage {
263    fn dtype(&self) -> DType {
264        match self {
265            CudaStorage::F16(_) => DType::F16,
266            CudaStorage::BF16(_) => DType::BF16,
267            CudaStorage::F32(_) => DType::F32,
268            CudaStorage::F64(_) => DType::F64,
269            CudaStorage::U8(_) => DType::U8,
270            CudaStorage::U32(_) => DType::U32,
271            CudaStorage::I64(_) => DType::I64,
272        }
273    }
274
275    fn len(&self) -> usize {
276        match self {
277            CudaStorage::F16(s) => s.len(),
278            CudaStorage::BF16(s) => s.len(),
279            CudaStorage::F32(s) => s.len(),
280            CudaStorage::F64(s) => s.len(),
281            CudaStorage::U8(s) => s.len(),
282            CudaStorage::U32(s) => s.len(),
283            CudaStorage::I64(s) => s.len(),
284        }
285    }
286}
287
288// Helpers
289
290/// Standard CUDA launch configuration for N elements.
291fn launch_cfg(n: usize) -> LaunchConfig {
292    const BLOCK: u32 = 256;
293    let grid = (n as u32).div_ceil(BLOCK);
294    LaunchConfig {
295        block_dim: (BLOCK, 1, 1),
296        grid_dim: (grid.max(1), 1, 1),
297        shared_mem_bytes: 0,
298    }
299}
300
301/// Make a CudaStorage contiguous according to the given layout.
302/// If already contiguous with offset 0, returns a clone (cheap Arc bump).
303/// Otherwise launches a strided-copy kernel.
304fn ensure_contiguous(
305    storage: &CudaStorage,
306    layout: &Layout,
307    device: &CudaDevice,
308) -> Result<CudaStorage> {
309    if layout.is_contiguous() && layout.offset() == 0 {
310        return Ok(storage.clone());
311    }
312
313    let n = layout.elem_count();
314    let cfg = launch_cfg(n);
315    let ndim = layout.rank() as i32;
316    let offset = layout.offset() as i32;
317
318    // Upload shape and strides to device
319    let shape_i32: Vec<i32> = layout.dims().iter().map(|&d| d as i32).collect();
320    let strides_i32: Vec<i32> = layout.strides().iter().map(|&s| s as i32).collect();
321    let shape_dev = device
322        .dev
323        .htod_copy(shape_i32)
324        .map_err(|e| Error::msg(format!("htod shape: {e}")))?;
325    let strides_dev = device
326        .dev
327        .htod_copy(strides_i32)
328        .map_err(|e| Error::msg(format!("htod strides: {e}")))?;
329
330    match storage {
331        CudaStorage::F16(src) | CudaStorage::BF16(src) => {
332            let mut dst: CudaSlice<u16> = device
333                .dev
334                .alloc_zeros(n)
335                .map_err(|e| Error::msg(format!("alloc: {e}")))?;
336            let func = device.get_func("to_contiguous_u16")?;
337            unsafe {
338                func.launch(
339                    cfg,
340                    (
341                        src,
342                        &mut dst,
343                        &shape_dev,
344                        &strides_dev,
345                        offset,
346                        ndim,
347                        n as u32,
348                    ),
349                )
350            }
351            .map_err(|e| Error::msg(format!("launch to_contiguous_u16: {e}")))?;
352            match storage {
353                CudaStorage::F16(_) => Ok(CudaStorage::F16(dst)),
354                _ => Ok(CudaStorage::BF16(dst)),
355            }
356        }
357        CudaStorage::F32(src) => {
358            let mut dst: CudaSlice<f32> = device
359                .dev
360                .alloc_zeros(n)
361                .map_err(|e| Error::msg(format!("alloc: {e}")))?;
362            let func = device.get_func("to_contiguous_f32")?;
363            unsafe {
364                func.launch(
365                    cfg,
366                    (
367                        src,
368                        &mut dst,
369                        &shape_dev,
370                        &strides_dev,
371                        offset,
372                        ndim,
373                        n as u32,
374                    ),
375                )
376            }
377            .map_err(|e| Error::msg(format!("launch to_contiguous_f32: {e}")))?;
378            Ok(CudaStorage::F32(dst))
379        }
380        CudaStorage::F64(src) => {
381            let mut dst: CudaSlice<f64> = device
382                .dev
383                .alloc_zeros(n)
384                .map_err(|e| Error::msg(format!("alloc: {e}")))?;
385            let func = device.get_func("to_contiguous_f64")?;
386            unsafe {
387                func.launch(
388                    cfg,
389                    (
390                        src,
391                        &mut dst,
392                        &shape_dev,
393                        &strides_dev,
394                        offset,
395                        ndim,
396                        n as u32,
397                    ),
398                )
399            }
400            .map_err(|e| Error::msg(format!("launch to_contiguous_f64: {e}")))?;
401            Ok(CudaStorage::F64(dst))
402        }
403        CudaStorage::U8(src) => {
404            let mut dst: CudaSlice<u8> = device
405                .dev
406                .alloc_zeros(n)
407                .map_err(|e| Error::msg(format!("alloc: {e}")))?;
408            let func = device.get_func("to_contiguous_u8")?;
409            unsafe {
410                func.launch(
411                    cfg,
412                    (
413                        src,
414                        &mut dst,
415                        &shape_dev,
416                        &strides_dev,
417                        offset,
418                        ndim,
419                        n as u32,
420                    ),
421                )
422            }
423            .map_err(|e| Error::msg(format!("launch to_contiguous_u8: {e}")))?;
424            Ok(CudaStorage::U8(dst))
425        }
426        _ => Err(Error::msg(
427            "to_contiguous not implemented for this dtype on CUDA",
428        )),
429    }
430}
431
432/// Get the CudaDevice from a CudaSlice (needed to find the device for operations).
433fn device_from_storage(storage: &CudaStorage) -> Arc<cudarc::driver::CudaDevice> {
434    match storage {
435        CudaStorage::F16(s) => s.device(),
436        CudaStorage::BF16(s) => s.device(),
437        CudaStorage::F32(s) => s.device(),
438        CudaStorage::F64(s) => s.device(),
439        CudaStorage::U8(s) => s.device(),
440        CudaStorage::U32(s) => s.device(),
441        CudaStorage::I64(s) => s.device(),
442    }
443}
444
445/// Reconstruct a CudaDevice from a storage reference (for Backend trait methods
446/// that don't receive the device explicitly).
447fn dev_from_storage(storage: &CudaStorage) -> Result<CudaDevice> {
448    let raw_dev = device_from_storage(storage);
449    let blas = CudaBlas::new(raw_dev.clone()).map_err(|e| Error::msg(format!("blas: {e}")))?;
450    Ok(CudaDevice {
451        dev: raw_dev,
452        blas: Arc::new(blas),
453        pool: Arc::new(CudaMemPool::new()),
454        ordinal: 0,
455    })
456}
457
458// CudaBackend — The Backend trait implementation
459
460/// The CUDA GPU backend. This is a zero-sized marker type.
461#[derive(Clone, Debug)]
462pub struct CudaBackend;
463
464impl Backend for CudaBackend {
465    type Device = CudaDevice;
466    type Storage = CudaStorage;
467
468    // ---- Creation ----
469
470    fn zeros(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
471        let n = shape.elem_count();
472        match dtype {
473            DType::F16 => {
474                let s: CudaSlice<u16> = device
475                    .dev
476                    .alloc_zeros(n)
477                    .map_err(|e| Error::msg(format!("alloc zeros f16: {e}")))?;
478                Ok(CudaStorage::F16(s))
479            }
480            DType::BF16 => {
481                let s: CudaSlice<u16> = device
482                    .dev
483                    .alloc_zeros(n)
484                    .map_err(|e| Error::msg(format!("alloc zeros bf16: {e}")))?;
485                Ok(CudaStorage::BF16(s))
486            }
487            DType::F32 => {
488                let s: CudaSlice<f32> = device
489                    .dev
490                    .alloc_zeros(n)
491                    .map_err(|e| Error::msg(format!("alloc zeros f32: {e}")))?;
492                Ok(CudaStorage::F32(s))
493            }
494            DType::F64 => {
495                let s: CudaSlice<f64> = device
496                    .dev
497                    .alloc_zeros(n)
498                    .map_err(|e| Error::msg(format!("alloc zeros f64: {e}")))?;
499                Ok(CudaStorage::F64(s))
500            }
501            DType::U8 => {
502                let s: CudaSlice<u8> = device
503                    .dev
504                    .alloc_zeros(n)
505                    .map_err(|e| Error::msg(format!("alloc zeros u8: {e}")))?;
506                Ok(CudaStorage::U8(s))
507            }
508            DType::U32 => {
509                let s: CudaSlice<u32> = device
510                    .dev
511                    .alloc_zeros(n)
512                    .map_err(|e| Error::msg(format!("alloc zeros u32: {e}")))?;
513                Ok(CudaStorage::U32(s))
514            }
515            DType::I64 => {
516                let s: CudaSlice<i64> = device
517                    .dev
518                    .alloc_zeros(n)
519                    .map_err(|e| Error::msg(format!("alloc zeros i64: {e}")))?;
520                Ok(CudaStorage::I64(s))
521            }
522        }
523    }
524
525    fn ones(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
526        Self::full(shape, 1.0, dtype, device)
527    }
528
529    fn full(shape: &Shape, val: f64, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
530        let n = shape.elem_count();
531        let cfg = launch_cfg(n);
532        match dtype {
533            DType::F16 => {
534                let mut s: CudaSlice<u16> = device
535                    .dev
536                    .alloc_zeros(n)
537                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
538                let func = device.get_func("fill_f16")?;
539                unsafe { func.launch(cfg, (&mut s, val as f32, n as u32)) }
540                    .map_err(|e| Error::msg(format!("fill_f16: {e}")))?;
541                Ok(CudaStorage::F16(s))
542            }
543            DType::BF16 => {
544                let mut s: CudaSlice<u16> = device
545                    .dev
546                    .alloc_zeros(n)
547                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
548                let func = device.get_func("fill_bf16")?;
549                unsafe { func.launch(cfg, (&mut s, val as f32, n as u32)) }
550                    .map_err(|e| Error::msg(format!("fill_bf16: {e}")))?;
551                Ok(CudaStorage::BF16(s))
552            }
553            DType::F32 => {
554                let mut s: CudaSlice<f32> = device
555                    .dev
556                    .alloc_zeros(n)
557                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
558                let func = device.get_func("fill_f32")?;
559                unsafe { func.launch(cfg, (&mut s, val as f32, n as u32)) }
560                    .map_err(|e| Error::msg(format!("fill_f32: {e}")))?;
561                Ok(CudaStorage::F32(s))
562            }
563            DType::F64 => {
564                let mut s: CudaSlice<f64> = device
565                    .dev
566                    .alloc_zeros(n)
567                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
568                let func = device.get_func("fill_f64")?;
569                unsafe { func.launch(cfg, (&mut s, val, n as u32)) }
570                    .map_err(|e| Error::msg(format!("fill_f64: {e}")))?;
571                Ok(CudaStorage::F64(s))
572            }
573            DType::U8 => {
574                let mut s: CudaSlice<u8> = device
575                    .dev
576                    .alloc_zeros(n)
577                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
578                let func = device.get_func("fill_u8")?;
579                unsafe { func.launch(cfg, (&mut s, val as u8, n as u32)) }
580                    .map_err(|e| Error::msg(format!("fill_u8: {e}")))?;
581                Ok(CudaStorage::U8(s))
582            }
583            DType::U32 => {
584                let mut s: CudaSlice<u32> = device
585                    .dev
586                    .alloc_zeros(n)
587                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
588                let func = device.get_func("fill_u32")?;
589                unsafe { func.launch(cfg, (&mut s, val as u32, n as u32)) }
590                    .map_err(|e| Error::msg(format!("fill_u32: {e}")))?;
591                Ok(CudaStorage::U32(s))
592            }
593            DType::I64 => {
594                let mut s: CudaSlice<i64> = device
595                    .dev
596                    .alloc_zeros(n)
597                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
598                let func = device.get_func("fill_i64")?;
599                unsafe { func.launch(cfg, (&mut s, val as i64, n as u32)) }
600                    .map_err(|e| Error::msg(format!("fill_i64: {e}")))?;
601                Ok(CudaStorage::I64(s))
602            }
603        }
604    }
605
606    fn from_f64_slice(data: &[f64], dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
607        match dtype {
608            DType::F16 => {
609                let host: Vec<u16> = data.iter().map(|&v| f16::from_f64(v).to_bits()).collect();
610                let s = device
611                    .dev
612                    .htod_copy(host)
613                    .map_err(|e| Error::msg(format!("htod f16: {e}")))?;
614                Ok(CudaStorage::F16(s))
615            }
616            DType::BF16 => {
617                let host: Vec<u16> = data.iter().map(|&v| bf16::from_f64(v).to_bits()).collect();
618                let s = device
619                    .dev
620                    .htod_copy(host)
621                    .map_err(|e| Error::msg(format!("htod bf16: {e}")))?;
622                Ok(CudaStorage::BF16(s))
623            }
624            DType::F32 => {
625                let host: Vec<f32> = data.iter().map(|&v| v as f32).collect();
626                let s = device
627                    .dev
628                    .htod_copy(host)
629                    .map_err(|e| Error::msg(format!("htod f32: {e}")))?;
630                Ok(CudaStorage::F32(s))
631            }
632            DType::F64 => {
633                let s = device
634                    .dev
635                    .htod_copy(data.to_vec())
636                    .map_err(|e| Error::msg(format!("htod f64: {e}")))?;
637                Ok(CudaStorage::F64(s))
638            }
639            DType::U8 => {
640                let host: Vec<u8> = data.iter().map(|&v| v as u8).collect();
641                let s = device
642                    .dev
643                    .htod_copy(host)
644                    .map_err(|e| Error::msg(format!("htod u8: {e}")))?;
645                Ok(CudaStorage::U8(s))
646            }
647            DType::U32 => {
648                let host: Vec<u32> = data.iter().map(|&v| v as u32).collect();
649                let s = device
650                    .dev
651                    .htod_copy(host)
652                    .map_err(|e| Error::msg(format!("htod u32: {e}")))?;
653                Ok(CudaStorage::U32(s))
654            }
655            DType::I64 => {
656                let host: Vec<i64> = data.iter().map(|&v| v as i64).collect();
657                let s = device
658                    .dev
659                    .htod_copy(host)
660                    .map_err(|e| Error::msg(format!("htod i64: {e}")))?;
661                Ok(CudaStorage::I64(s))
662            }
663        }
664    }
665
666    fn rand_uniform(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
667        // Generate on host, transfer to device
668        use rand::Rng;
669        let n = shape.elem_count();
670        let mut rng = rand::thread_rng();
671        match dtype {
672            DType::F16 => {
673                let host: Vec<u16> = (0..n)
674                    .map(|_| f16::from_f32(rng.gen::<f32>()).to_bits())
675                    .collect();
676                let s = device
677                    .dev
678                    .htod_copy(host)
679                    .map_err(|e| Error::msg(format!("htod rand_uniform f16: {e}")))?;
680                Ok(CudaStorage::F16(s))
681            }
682            DType::BF16 => {
683                let host: Vec<u16> = (0..n)
684                    .map(|_| bf16::from_f32(rng.gen::<f32>()).to_bits())
685                    .collect();
686                let s = device
687                    .dev
688                    .htod_copy(host)
689                    .map_err(|e| Error::msg(format!("htod rand_uniform bf16: {e}")))?;
690                Ok(CudaStorage::BF16(s))
691            }
692            DType::F32 => {
693                let host: Vec<f32> = (0..n).map(|_| rng.gen::<f32>()).collect();
694                let s = device
695                    .dev
696                    .htod_copy(host)
697                    .map_err(|e| Error::msg(format!("htod rand_uniform f32: {e}")))?;
698                Ok(CudaStorage::F32(s))
699            }
700            DType::F64 => {
701                let host: Vec<f64> = (0..n).map(|_| rng.gen::<f64>()).collect();
702                let s = device
703                    .dev
704                    .htod_copy(host)
705                    .map_err(|e| Error::msg(format!("htod rand_uniform f64: {e}")))?;
706                Ok(CudaStorage::F64(s))
707            }
708            _ => Err(Error::msg(format!(
709                "rand_uniform not supported for {:?}",
710                dtype
711            ))),
712        }
713    }
714
715    fn rand_normal(shape: &Shape, dtype: DType, device: &CudaDevice) -> Result<CudaStorage> {
716        use rand::Rng;
717        use rand_distr::StandardNormal;
718        let n = shape.elem_count();
719        let mut rng = rand::thread_rng();
720        match dtype {
721            DType::F16 => {
722                let host: Vec<u16> = (0..n)
723                    .map(|_| f16::from_f32(rng.sample::<f32, _>(StandardNormal)).to_bits())
724                    .collect();
725                let s = device
726                    .dev
727                    .htod_copy(host)
728                    .map_err(|e| Error::msg(format!("htod rand_normal f16: {e}")))?;
729                Ok(CudaStorage::F16(s))
730            }
731            DType::BF16 => {
732                let host: Vec<u16> = (0..n)
733                    .map(|_| bf16::from_f32(rng.sample::<f32, _>(StandardNormal)).to_bits())
734                    .collect();
735                let s = device
736                    .dev
737                    .htod_copy(host)
738                    .map_err(|e| Error::msg(format!("htod rand_normal bf16: {e}")))?;
739                Ok(CudaStorage::BF16(s))
740            }
741            DType::F32 => {
742                let host: Vec<f32> = (0..n)
743                    .map(|_| rng.sample::<f32, _>(StandardNormal))
744                    .collect();
745                let s = device
746                    .dev
747                    .htod_copy(host)
748                    .map_err(|e| Error::msg(format!("htod rand_normal f32: {e}")))?;
749                Ok(CudaStorage::F32(s))
750            }
751            DType::F64 => {
752                let host: Vec<f64> = (0..n)
753                    .map(|_| rng.sample::<f64, _>(StandardNormal))
754                    .collect();
755                let s = device
756                    .dev
757                    .htod_copy(host)
758                    .map_err(|e| Error::msg(format!("htod rand_normal f64: {e}")))?;
759                Ok(CudaStorage::F64(s))
760            }
761            _ => Err(Error::msg(format!(
762                "rand_normal not supported for {:?}",
763                dtype
764            ))),
765        }
766    }
767
768    // ---- Binary ops ----
769
770    fn binary_op(
771        op: BinaryOp,
772        lhs: &CudaStorage,
773        lhs_layout: &Layout,
774        rhs: &CudaStorage,
775        rhs_layout: &Layout,
776    ) -> Result<CudaStorage> {
777        let dev = dev_from_storage(lhs)?;
778
779        // Make contiguous
780        let lhs_c = ensure_contiguous(lhs, lhs_layout, &dev)?;
781        let rhs_c = ensure_contiguous(rhs, rhs_layout, &dev)?;
782
783        let lhs_shape = lhs_layout.shape();
784        let rhs_shape = rhs_layout.shape();
785
786        let op_name = match op {
787            BinaryOp::Add => "add",
788            BinaryOp::Sub => "sub",
789            BinaryOp::Mul => "mul",
790            BinaryOp::Div => "div",
791        };
792
793        // Check if broadcasting is needed
794        let needs_broadcast = lhs_shape.dims() != rhs_shape.dims();
795
796        if needs_broadcast {
797            // Broadcast path: compute output shape and per-operand strides
798            let out_shape = Shape::broadcast_shape(lhs_shape, rhs_shape)?;
799            let a_strides = lhs_shape.broadcast_strides(&out_shape);
800            let b_strides = rhs_shape.broadcast_strides(&out_shape);
801            let n = out_shape.elem_count();
802            let cfg = launch_cfg(n);
803
804            // Upload shape & strides to GPU
805            let out_dims_u32: Vec<u32> = out_shape.dims().iter().map(|&d| d as u32).collect();
806            let a_strides_u32: Vec<u32> = a_strides.iter().map(|&s| s as u32).collect();
807            let b_strides_u32: Vec<u32> = b_strides.iter().map(|&s| s as u32).collect();
808            let rank = out_shape.dims().len() as u32;
809
810            let dims_gpu = dev
811                .dev
812                .htod_copy(out_dims_u32)
813                .map_err(|e| Error::msg(format!("alloc dims: {e}")))?;
814            let a_strides_gpu = dev
815                .dev
816                .htod_copy(a_strides_u32)
817                .map_err(|e| Error::msg(format!("alloc a_strides: {e}")))?;
818            let b_strides_gpu = dev
819                .dev
820                .htod_copy(b_strides_u32)
821                .map_err(|e| Error::msg(format!("alloc b_strides: {e}")))?;
822
823            match (&lhs_c, &rhs_c) {
824                (CudaStorage::F32(a), CudaStorage::F32(b)) => {
825                    let mut out: CudaSlice<f32> = dev
826                        .dev
827                        .alloc_zeros(n)
828                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
829                    let func = dev.get_func(&format!("bcast_binary_{op_name}_f32"))?;
830                    unsafe {
831                        func.launch(
832                            cfg,
833                            (
834                                a,
835                                b,
836                                &mut out,
837                                &dims_gpu,
838                                &a_strides_gpu,
839                                &b_strides_gpu,
840                                rank,
841                                n as u32,
842                            ),
843                        )
844                    }
845                    .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
846                    Ok(CudaStorage::F32(out))
847                }
848                (CudaStorage::F64(a), CudaStorage::F64(b)) => {
849                    let mut out: CudaSlice<f64> = dev
850                        .dev
851                        .alloc_zeros(n)
852                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
853                    let func = dev.get_func(&format!("bcast_binary_{op_name}_f64"))?;
854                    unsafe {
855                        func.launch(
856                            cfg,
857                            (
858                                a,
859                                b,
860                                &mut out,
861                                &dims_gpu,
862                                &a_strides_gpu,
863                                &b_strides_gpu,
864                                rank,
865                                n as u32,
866                            ),
867                        )
868                    }
869                    .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
870                    Ok(CudaStorage::F64(out))
871                }
872                (CudaStorage::F16(a), CudaStorage::F16(b)) => {
873                    let mut out: CudaSlice<u16> = dev
874                        .dev
875                        .alloc_zeros(n)
876                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
877                    let func = dev.get_func(&format!("bcast_binary_{op_name}_f16"))?;
878                    unsafe {
879                        func.launch(
880                            cfg,
881                            (
882                                a,
883                                b,
884                                &mut out,
885                                &dims_gpu,
886                                &a_strides_gpu,
887                                &b_strides_gpu,
888                                rank,
889                                n as u32,
890                            ),
891                        )
892                    }
893                    .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
894                    Ok(CudaStorage::F16(out))
895                }
896                (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
897                    let mut out: CudaSlice<u16> = dev
898                        .dev
899                        .alloc_zeros(n)
900                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
901                    let func = dev.get_func(&format!("bcast_binary_{op_name}_bf16"))?;
902                    unsafe {
903                        func.launch(
904                            cfg,
905                            (
906                                a,
907                                b,
908                                &mut out,
909                                &dims_gpu,
910                                &a_strides_gpu,
911                                &b_strides_gpu,
912                                rank,
913                                n as u32,
914                            ),
915                        )
916                    }
917                    .map_err(|e| Error::msg(format!("bcast binary op: {e}")))?;
918                    Ok(CudaStorage::BF16(out))
919                }
920                _ => Err(Error::msg(
921                    "bcast binary_op: dtype mismatch or unsupported dtype",
922                )),
923            }
924        } else {
925            // Fast path: same shape, element-wise
926            let n = lhs_layout.elem_count();
927            let cfg = launch_cfg(n);
928
929            match (&lhs_c, &rhs_c) {
930                (CudaStorage::F16(a), CudaStorage::F16(b)) => {
931                    let mut out: CudaSlice<u16> = dev
932                        .dev
933                        .alloc_zeros(n)
934                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
935                    let func = dev.get_func(&format!("binary_{op_name}_f16"))?;
936                    unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
937                        .map_err(|e| Error::msg(format!("binary op: {e}")))?;
938                    Ok(CudaStorage::F16(out))
939                }
940                (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
941                    let mut out: CudaSlice<u16> = dev
942                        .dev
943                        .alloc_zeros(n)
944                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
945                    let func = dev.get_func(&format!("binary_{op_name}_bf16"))?;
946                    unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
947                        .map_err(|e| Error::msg(format!("binary op: {e}")))?;
948                    Ok(CudaStorage::BF16(out))
949                }
950                (CudaStorage::F32(a), CudaStorage::F32(b)) => {
951                    let mut out: CudaSlice<f32> = dev
952                        .dev
953                        .alloc_zeros(n)
954                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
955                    let func = dev.get_func(&format!("binary_{op_name}_f32"))?;
956                    unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
957                        .map_err(|e| Error::msg(format!("binary op: {e}")))?;
958                    Ok(CudaStorage::F32(out))
959                }
960                (CudaStorage::F64(a), CudaStorage::F64(b)) => {
961                    let mut out: CudaSlice<f64> = dev
962                        .dev
963                        .alloc_zeros(n)
964                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
965                    let func = dev.get_func(&format!("binary_{op_name}_f64"))?;
966                    unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
967                        .map_err(|e| Error::msg(format!("binary op: {e}")))?;
968                    Ok(CudaStorage::F64(out))
969                }
970                _ => Err(Error::msg("binary_op: dtype mismatch or unsupported dtype")),
971            }
972        }
973    }
974
975    // ---- Unary ops ----
976
977    fn unary_op(op: UnaryOp, input: &CudaStorage, layout: &Layout) -> Result<CudaStorage> {
978        let dev = dev_from_storage(input)?;
979
980        let input_c = ensure_contiguous(input, layout, &dev)?;
981        let n = layout.elem_count();
982        let cfg = launch_cfg(n);
983
984        let op_name = match op {
985            UnaryOp::Neg => "neg",
986            UnaryOp::Abs => "abs",
987            UnaryOp::Exp => "exp",
988            UnaryOp::Log => "log",
989            UnaryOp::Sqrt => "sqrt",
990            UnaryOp::Relu => "relu",
991            UnaryOp::Sigmoid => "sigmoid",
992            UnaryOp::Tanh => "tanh",
993            UnaryOp::Gelu => "gelu",
994            UnaryOp::Silu => "silu",
995            UnaryOp::Sin => "sin",
996            UnaryOp::Cos => "cos",
997            UnaryOp::Square => "square",
998            UnaryOp::Floor => "floor",
999            UnaryOp::Ceil => "ceil",
1000            UnaryOp::Round => "round",
1001        };
1002
1003        match &input_c {
1004            CudaStorage::F16(inp) => {
1005                let mut out: CudaSlice<u16> = dev
1006                    .dev
1007                    .alloc_zeros(n)
1008                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1009                let func = dev.get_func(&format!("unary_{op_name}_f16"))?;
1010                unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1011                    .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1012                Ok(CudaStorage::F16(out))
1013            }
1014            CudaStorage::BF16(inp) => {
1015                let mut out: CudaSlice<u16> = dev
1016                    .dev
1017                    .alloc_zeros(n)
1018                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1019                let func = dev.get_func(&format!("unary_{op_name}_bf16"))?;
1020                unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1021                    .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1022                Ok(CudaStorage::BF16(out))
1023            }
1024            CudaStorage::F32(inp) => {
1025                let mut out: CudaSlice<f32> = dev
1026                    .dev
1027                    .alloc_zeros(n)
1028                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1029                let func = dev.get_func(&format!("unary_{op_name}_f32"))?;
1030                unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1031                    .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1032                Ok(CudaStorage::F32(out))
1033            }
1034            CudaStorage::F64(inp) => {
1035                let mut out: CudaSlice<f64> = dev
1036                    .dev
1037                    .alloc_zeros(n)
1038                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1039                let func = dev.get_func(&format!("unary_{op_name}_f64"))?;
1040                unsafe { func.launch(cfg, (inp, &mut out, n as u32)) }
1041                    .map_err(|e| Error::msg(format!("unary op: {e}")))?;
1042                Ok(CudaStorage::F64(out))
1043            }
1044            _ => Err(Error::msg("unary_op: only float types supported")),
1045        }
1046    }
1047
1048    // ---- Reduction ----
1049
1050    fn reduce_op(
1051        op: ReduceOp,
1052        input: &CudaStorage,
1053        layout: &Layout,
1054        dims: &[usize],
1055        _keep_dim: bool,
1056    ) -> Result<CudaStorage> {
1057        let dev = dev_from_storage(input)?;
1058
1059        let input_c = ensure_contiguous(input, layout, &dev)?;
1060        let shape_dims = layout.dims();
1061        let rank = shape_dims.len();
1062
1063        // Determine reduction dimension (or all)
1064        let reduce_dim = if dims.is_empty() {
1065            None
1066        } else if dims.len() == 1 {
1067            Some(dims[0])
1068        } else {
1069            return Err(Error::msg(
1070                "CUDA reduce_op: multi-dim reduction not yet supported",
1071            ));
1072        };
1073
1074        let (outer_size, reduce_size, inner_size) = if let Some(dim) = reduce_dim {
1075            if dim >= rank {
1076                return Err(Error::msg(format!(
1077                    "dim {dim} out of range for rank {rank}"
1078                )));
1079            }
1080            let outer: usize = shape_dims[..dim].iter().product::<usize>().max(1);
1081            let red = shape_dims[dim];
1082            let inner: usize = shape_dims[dim + 1..].iter().product::<usize>().max(1);
1083            (outer, red, inner)
1084        } else {
1085            let total: usize = shape_dims.iter().product();
1086            (1usize, total, 1usize)
1087        };
1088
1089        let out_n = outer_size * inner_size;
1090        let cfg = launch_cfg(out_n);
1091
1092        let is_arg = matches!(op, ReduceOp::ArgMax | ReduceOp::ArgMin);
1093
1094        let op_name = match op {
1095            ReduceOp::Sum => "sum",
1096            ReduceOp::Mean => "mean",
1097            ReduceOp::Max => "max",
1098            ReduceOp::Min => "min",
1099            ReduceOp::ArgMax => "argmax",
1100            ReduceOp::ArgMin => "argmin",
1101        };
1102
1103        if is_arg {
1104            // ArgMax/ArgMin → output I64
1105            let mut out: CudaSlice<i64> = dev
1106                .dev
1107                .alloc_zeros(out_n)
1108                .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1109            match &input_c {
1110                CudaStorage::F16(inp) => {
1111                    let func = dev.get_func(&format!("reduce_{op_name}_f16"))?;
1112                    unsafe {
1113                        func.launch(
1114                            cfg,
1115                            (
1116                                inp,
1117                                &mut out,
1118                                outer_size as u32,
1119                                reduce_size as u32,
1120                                inner_size as u32,
1121                            ),
1122                        )
1123                    }
1124                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1125                }
1126                CudaStorage::BF16(inp) => {
1127                    let func = dev.get_func(&format!("reduce_{op_name}_bf16"))?;
1128                    unsafe {
1129                        func.launch(
1130                            cfg,
1131                            (
1132                                inp,
1133                                &mut out,
1134                                outer_size as u32,
1135                                reduce_size as u32,
1136                                inner_size as u32,
1137                            ),
1138                        )
1139                    }
1140                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1141                }
1142                CudaStorage::F32(inp) => {
1143                    let func = dev.get_func(&format!("reduce_{op_name}_f32"))?;
1144                    unsafe {
1145                        func.launch(
1146                            cfg,
1147                            (
1148                                inp,
1149                                &mut out,
1150                                outer_size as u32,
1151                                reduce_size as u32,
1152                                inner_size as u32,
1153                            ),
1154                        )
1155                    }
1156                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1157                }
1158                CudaStorage::F64(inp) => {
1159                    let func = dev.get_func(&format!("reduce_{op_name}_f64"))?;
1160                    unsafe {
1161                        func.launch(
1162                            cfg,
1163                            (
1164                                inp,
1165                                &mut out,
1166                                outer_size as u32,
1167                                reduce_size as u32,
1168                                inner_size as u32,
1169                            ),
1170                        )
1171                    }
1172                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1173                }
1174                _ => return Err(Error::msg("reduce: only float types supported")),
1175            }
1176            Ok(CudaStorage::I64(out))
1177        } else {
1178            // Sum/Mean/Max/Min → output same type
1179            match &input_c {
1180                CudaStorage::F16(inp) => {
1181                    let mut out: CudaSlice<u16> = dev
1182                        .dev
1183                        .alloc_zeros(out_n)
1184                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1185                    let func = dev.get_func(&format!("reduce_{op_name}_f16"))?;
1186                    unsafe {
1187                        func.launch(
1188                            cfg,
1189                            (
1190                                inp,
1191                                &mut out,
1192                                outer_size as u32,
1193                                reduce_size as u32,
1194                                inner_size as u32,
1195                            ),
1196                        )
1197                    }
1198                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1199                    Ok(CudaStorage::F16(out))
1200                }
1201                CudaStorage::BF16(inp) => {
1202                    let mut out: CudaSlice<u16> = dev
1203                        .dev
1204                        .alloc_zeros(out_n)
1205                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1206                    let func = dev.get_func(&format!("reduce_{op_name}_bf16"))?;
1207                    unsafe {
1208                        func.launch(
1209                            cfg,
1210                            (
1211                                inp,
1212                                &mut out,
1213                                outer_size as u32,
1214                                reduce_size as u32,
1215                                inner_size as u32,
1216                            ),
1217                        )
1218                    }
1219                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1220                    Ok(CudaStorage::BF16(out))
1221                }
1222                CudaStorage::F32(inp) => {
1223                    let mut out: CudaSlice<f32> = dev
1224                        .dev
1225                        .alloc_zeros(out_n)
1226                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1227                    let func = dev.get_func(&format!("reduce_{op_name}_f32"))?;
1228                    unsafe {
1229                        func.launch(
1230                            cfg,
1231                            (
1232                                inp,
1233                                &mut out,
1234                                outer_size as u32,
1235                                reduce_size as u32,
1236                                inner_size as u32,
1237                            ),
1238                        )
1239                    }
1240                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1241                    Ok(CudaStorage::F32(out))
1242                }
1243                CudaStorage::F64(inp) => {
1244                    let mut out: CudaSlice<f64> = dev
1245                        .dev
1246                        .alloc_zeros(out_n)
1247                        .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1248                    let func = dev.get_func(&format!("reduce_{op_name}_f64"))?;
1249                    unsafe {
1250                        func.launch(
1251                            cfg,
1252                            (
1253                                inp,
1254                                &mut out,
1255                                outer_size as u32,
1256                                reduce_size as u32,
1257                                inner_size as u32,
1258                            ),
1259                        )
1260                    }
1261                    .map_err(|e| Error::msg(format!("reduce: {e}")))?;
1262                    Ok(CudaStorage::F64(out))
1263                }
1264                _ => Err(Error::msg("reduce: only float types supported")),
1265            }
1266        }
1267    }
1268
1269    // ---- Matmul (cuBLAS) ----
1270    // F16/BF16: promote to F32 → sgemm → demote back
1271
1272    fn matmul(
1273        lhs: &CudaStorage,
1274        lhs_layout: &Layout,
1275        rhs: &CudaStorage,
1276        rhs_layout: &Layout,
1277    ) -> Result<CudaStorage> {
1278        let dev = dev_from_storage(lhs)?;
1279
1280        // Make contiguous
1281        let lhs_c = ensure_contiguous(lhs, lhs_layout, &dev)?;
1282        let rhs_c = ensure_contiguous(rhs, rhs_layout, &dev)?;
1283
1284        let lhs_dims = lhs_layout.dims();
1285        let rhs_dims = rhs_layout.dims();
1286        let rank = lhs_dims.len();
1287        let m = lhs_dims[rank - 2];
1288        let k = lhs_dims[rank - 1];
1289        let n = rhs_dims[rhs_dims.len() - 1];
1290        let batch_size: usize = lhs_dims[..rank - 2].iter().product::<usize>().max(1);
1291
1292        match (&lhs_c, &rhs_c) {
1293            (CudaStorage::F16(a), CudaStorage::F16(b)) => {
1294                // Promote F16 → F32, matmul, demote back
1295                let a_n = a.len();
1296                let b_n = b.len();
1297                let mn = m * n;
1298                let total = batch_size * mn;
1299
1300                // Cast A to F32
1301                let mut a_f32: CudaSlice<f32> = dev
1302                    .dev
1303                    .alloc_zeros(a_n)
1304                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1305                let cast_a = dev.get_func("cast_f16_to_f32")?;
1306                let cfg_a = launch_cfg(a_n);
1307                unsafe { cast_a.launch(cfg_a, (a, &mut a_f32, a_n as u32)) }
1308                    .map_err(|e| Error::msg(format!("cast: {e}")))?;
1309
1310                // Cast B to F32
1311                let mut b_f32: CudaSlice<f32> = dev
1312                    .dev
1313                    .alloc_zeros(b_n)
1314                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1315                let cast_b = dev.get_func("cast_f16_to_f32")?;
1316                let cfg_b = launch_cfg(b_n);
1317                unsafe { cast_b.launch(cfg_b, (b, &mut b_f32, b_n as u32)) }
1318                    .map_err(|e| Error::msg(format!("cast: {e}")))?;
1319
1320                // sgemm
1321                let out_f32: CudaSlice<f32> = dev
1322                    .dev
1323                    .alloc_zeros(total)
1324                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1325
1326                use cudarc::cublas::sys::cublasOperation_t;
1327                for batch in 0..batch_size {
1328                    let a_offset = batch * m * k;
1329                    let b_offset = batch * k * n;
1330                    let c_offset = batch * mn;
1331                    let a_slice = a_f32.slice(a_offset..a_offset + m * k);
1332                    let b_slice = b_f32.slice(b_offset..b_offset + k * n);
1333                    let c_slice = out_f32.slice(c_offset..c_offset + mn);
1334                    unsafe {
1335                        cudarc::cublas::result::sgemm(
1336                            *dev.blas.handle(),
1337                            cublasOperation_t::CUBLAS_OP_N,
1338                            cublasOperation_t::CUBLAS_OP_N,
1339                            n as i32,
1340                            m as i32,
1341                            k as i32,
1342                            (&1.0f32) as *const f32,
1343                            *b_slice.device_ptr() as *const f32,
1344                            n as i32,
1345                            *a_slice.device_ptr() as *const f32,
1346                            k as i32,
1347                            (&0.0f32) as *const f32,
1348                            *c_slice.device_ptr() as *mut f32,
1349                            n as i32,
1350                        )
1351                    }
1352                    .map_err(|e| Error::msg(format!("cuBLAS sgemm: {e}")))?;
1353                }
1354
1355                // Demote F32 → F16
1356                let mut out: CudaSlice<u16> = dev
1357                    .dev
1358                    .alloc_zeros(total)
1359                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1360                let cast_out = dev.get_func("cast_f32_to_f16")?;
1361                let cfg_out = launch_cfg(total);
1362                unsafe { cast_out.launch(cfg_out, (&out_f32, &mut out, total as u32)) }
1363                    .map_err(|e| Error::msg(format!("cast: {e}")))?;
1364
1365                Ok(CudaStorage::F16(out))
1366            }
1367            (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
1368                // Promote BF16 → F32, matmul, demote back
1369                let a_n = a.len();
1370                let b_n = b.len();
1371                let mn = m * n;
1372                let total = batch_size * mn;
1373
1374                let mut a_f32: CudaSlice<f32> = dev
1375                    .dev
1376                    .alloc_zeros(a_n)
1377                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1378                let cast_a = dev.get_func("cast_bf16_to_f32")?;
1379                let cfg_a = launch_cfg(a_n);
1380                unsafe { cast_a.launch(cfg_a, (a, &mut a_f32, a_n as u32)) }
1381                    .map_err(|e| Error::msg(format!("cast: {e}")))?;
1382
1383                let mut b_f32: CudaSlice<f32> = dev
1384                    .dev
1385                    .alloc_zeros(b_n)
1386                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1387                let cast_b = dev.get_func("cast_bf16_to_f32")?;
1388                let cfg_b = launch_cfg(b_n);
1389                unsafe { cast_b.launch(cfg_b, (b, &mut b_f32, b_n as u32)) }
1390                    .map_err(|e| Error::msg(format!("cast: {e}")))?;
1391
1392                let out_f32: CudaSlice<f32> = dev
1393                    .dev
1394                    .alloc_zeros(total)
1395                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1396
1397                use cudarc::cublas::sys::cublasOperation_t;
1398                for batch in 0..batch_size {
1399                    let a_offset = batch * m * k;
1400                    let b_offset = batch * k * n;
1401                    let c_offset = batch * mn;
1402                    let a_slice = a_f32.slice(a_offset..a_offset + m * k);
1403                    let b_slice = b_f32.slice(b_offset..b_offset + k * n);
1404                    let c_slice = out_f32.slice(c_offset..c_offset + mn);
1405                    unsafe {
1406                        cudarc::cublas::result::sgemm(
1407                            *dev.blas.handle(),
1408                            cublasOperation_t::CUBLAS_OP_N,
1409                            cublasOperation_t::CUBLAS_OP_N,
1410                            n as i32,
1411                            m as i32,
1412                            k as i32,
1413                            (&1.0f32) as *const f32,
1414                            *b_slice.device_ptr() as *const f32,
1415                            n as i32,
1416                            *a_slice.device_ptr() as *const f32,
1417                            k as i32,
1418                            (&0.0f32) as *const f32,
1419                            *c_slice.device_ptr() as *mut f32,
1420                            n as i32,
1421                        )
1422                    }
1423                    .map_err(|e| Error::msg(format!("cuBLAS sgemm: {e}")))?;
1424                }
1425
1426                let mut out: CudaSlice<u16> = dev
1427                    .dev
1428                    .alloc_zeros(total)
1429                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1430                let cast_out = dev.get_func("cast_f32_to_bf16")?;
1431                let cfg_out = launch_cfg(total);
1432                unsafe { cast_out.launch(cfg_out, (&out_f32, &mut out, total as u32)) }
1433                    .map_err(|e| Error::msg(format!("cast: {e}")))?;
1434
1435                Ok(CudaStorage::BF16(out))
1436            }
1437            (CudaStorage::F32(a), CudaStorage::F32(b)) => {
1438                let mn = m * n;
1439                let total = batch_size * mn;
1440                let out: CudaSlice<f32> = dev
1441                    .dev
1442                    .alloc_zeros(total)
1443                    .map_err(|e| Error::msg(format!("alloc matmul: {e}")))?;
1444
1445                use cudarc::cublas::sys::cublasOperation_t;
1446
1447                for batch in 0..batch_size {
1448                    let a_offset = batch * m * k;
1449                    let b_offset = batch * k * n;
1450                    let c_offset = batch * mn;
1451
1452                    let a_slice = a.slice(a_offset..a_offset + m * k);
1453                    let b_slice = b.slice(b_offset..b_offset + k * n);
1454                    let c_slice = out.slice(c_offset..c_offset + mn);
1455
1456                    unsafe {
1457                        cudarc::cublas::result::sgemm(
1458                            *dev.blas.handle(),
1459                            cublasOperation_t::CUBLAS_OP_N,
1460                            cublasOperation_t::CUBLAS_OP_N,
1461                            n as i32,
1462                            m as i32,
1463                            k as i32,
1464                            (&1.0f32) as *const f32,
1465                            *b_slice.device_ptr() as *const f32,
1466                            n as i32,
1467                            *a_slice.device_ptr() as *const f32,
1468                            k as i32,
1469                            (&0.0f32) as *const f32,
1470                            *c_slice.device_ptr() as *mut f32,
1471                            n as i32,
1472                        )
1473                    }
1474                    .map_err(|e| Error::msg(format!("cuBLAS sgemm: {e}")))?;
1475                }
1476                Ok(CudaStorage::F32(out))
1477            }
1478            (CudaStorage::F64(a), CudaStorage::F64(b)) => {
1479                let mn = m * n;
1480                let total = batch_size * mn;
1481                let out: CudaSlice<f64> = dev
1482                    .dev
1483                    .alloc_zeros(total)
1484                    .map_err(|e| Error::msg(format!("alloc matmul: {e}")))?;
1485
1486                use cudarc::cublas::sys::cublasOperation_t;
1487
1488                for batch in 0..batch_size {
1489                    let a_offset = batch * m * k;
1490                    let b_offset = batch * k * n;
1491                    let c_offset = batch * mn;
1492
1493                    let a_slice = a.slice(a_offset..a_offset + m * k);
1494                    let b_slice = b.slice(b_offset..b_offset + k * n);
1495                    let c_slice = out.slice(c_offset..c_offset + mn);
1496
1497                    unsafe {
1498                        cudarc::cublas::result::dgemm(
1499                            *dev.blas.handle(),
1500                            cublasOperation_t::CUBLAS_OP_N,
1501                            cublasOperation_t::CUBLAS_OP_N,
1502                            n as i32,
1503                            m as i32,
1504                            k as i32,
1505                            (&1.0f64) as *const f64,
1506                            *b_slice.device_ptr() as *const f64,
1507                            n as i32,
1508                            *a_slice.device_ptr() as *const f64,
1509                            k as i32,
1510                            (&0.0f64) as *const f64,
1511                            *c_slice.device_ptr() as *mut f64,
1512                            n as i32,
1513                        )
1514                    }
1515                    .map_err(|e| Error::msg(format!("cuBLAS dgemm: {e}")))?;
1516                }
1517                Ok(CudaStorage::F64(out))
1518            }
1519            _ => Err(Error::msg("matmul: only f16/bf16/f32/f64 supported")),
1520        }
1521    }
1522
1523    // ---- to_contiguous ----
1524
1525    fn to_contiguous(input: &CudaStorage, layout: &Layout) -> Result<CudaStorage> {
1526        let dev = dev_from_storage(input)?;
1527        ensure_contiguous(input, layout, &dev)
1528    }
1529
1530    // ---- to_f64_vec (device → host) ----
1531
1532    fn to_f64_vec(input: &CudaStorage, layout: &Layout) -> Result<Vec<f64>> {
1533        let dev = dev_from_storage(input)?;
1534
1535        // Make contiguous first
1536        let input_c = ensure_contiguous(input, layout, &dev)?;
1537
1538        match &input_c {
1539            CudaStorage::F16(s) => {
1540                let host = dev
1541                    .dev
1542                    .dtoh_sync_copy(s)
1543                    .map_err(|e| Error::msg(format!("dtoh f16: {e}")))?;
1544                Ok(host
1545                    .iter()
1546                    .map(|&bits| f16::from_bits(bits).to_f64())
1547                    .collect())
1548            }
1549            CudaStorage::BF16(s) => {
1550                let host = dev
1551                    .dev
1552                    .dtoh_sync_copy(s)
1553                    .map_err(|e| Error::msg(format!("dtoh bf16: {e}")))?;
1554                Ok(host
1555                    .iter()
1556                    .map(|&bits| bf16::from_bits(bits).to_f64())
1557                    .collect())
1558            }
1559            CudaStorage::F32(s) => {
1560                let host = dev
1561                    .dev
1562                    .dtoh_sync_copy(s)
1563                    .map_err(|e| Error::msg(format!("dtoh f32: {e}")))?;
1564                Ok(host.iter().map(|&v| v as f64).collect())
1565            }
1566            CudaStorage::F64(s) => {
1567                let host = dev
1568                    .dev
1569                    .dtoh_sync_copy(s)
1570                    .map_err(|e| Error::msg(format!("dtoh f64: {e}")))?;
1571                Ok(host)
1572            }
1573            CudaStorage::U8(s) => {
1574                let host = dev
1575                    .dev
1576                    .dtoh_sync_copy(s)
1577                    .map_err(|e| Error::msg(format!("dtoh u8: {e}")))?;
1578                Ok(host.iter().map(|&v| v as f64).collect())
1579            }
1580            CudaStorage::U32(s) => {
1581                let host = dev
1582                    .dev
1583                    .dtoh_sync_copy(s)
1584                    .map_err(|e| Error::msg(format!("dtoh u32: {e}")))?;
1585                Ok(host.iter().map(|&v| v as f64).collect())
1586            }
1587            CudaStorage::I64(s) => {
1588                let host = dev
1589                    .dev
1590                    .dtoh_sync_copy(s)
1591                    .map_err(|e| Error::msg(format!("dtoh i64: {e}")))?;
1592                Ok(host.iter().map(|&v| v as f64).collect())
1593            }
1594        }
1595    }
1596
1597    // ---- Comparison ops ----
1598
1599    fn cmp_op(
1600        op: CmpOp,
1601        lhs: &CudaStorage,
1602        lhs_layout: &Layout,
1603        rhs: &CudaStorage,
1604        rhs_layout: &Layout,
1605    ) -> Result<CudaStorage> {
1606        let dev = dev_from_storage(lhs)?;
1607
1608        let lhs_c = ensure_contiguous(lhs, lhs_layout, &dev)?;
1609        let rhs_c = ensure_contiguous(rhs, rhs_layout, &dev)?;
1610        let n = lhs_layout.elem_count();
1611        let cfg = launch_cfg(n);
1612
1613        let op_name = match op {
1614            CmpOp::Eq => "eq",
1615            CmpOp::Ne => "ne",
1616            CmpOp::Gt => "gt",
1617            CmpOp::Ge => "ge",
1618            CmpOp::Lt => "lt",
1619            CmpOp::Le => "le",
1620        };
1621
1622        let mut out: CudaSlice<u8> = dev
1623            .dev
1624            .alloc_zeros(n)
1625            .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1626
1627        match (&lhs_c, &rhs_c) {
1628            (CudaStorage::F16(a), CudaStorage::F16(b)) => {
1629                let func = dev.get_func(&format!("cmp_{op_name}_f16"))?;
1630                unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1631                    .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1632            }
1633            (CudaStorage::BF16(a), CudaStorage::BF16(b)) => {
1634                let func = dev.get_func(&format!("cmp_{op_name}_bf16"))?;
1635                unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1636                    .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1637            }
1638            (CudaStorage::F32(a), CudaStorage::F32(b)) => {
1639                let func = dev.get_func(&format!("cmp_{op_name}_f32"))?;
1640                unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1641                    .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1642            }
1643            (CudaStorage::F64(a), CudaStorage::F64(b)) => {
1644                let func = dev.get_func(&format!("cmp_{op_name}_f64"))?;
1645                unsafe { func.launch(cfg, (a, b, &mut out, n as u32)) }
1646                    .map_err(|e| Error::msg(format!("cmp: {e}")))?;
1647            }
1648            _ => return Err(Error::msg("cmp_op: dtype mismatch or unsupported")),
1649        }
1650
1651        Ok(CudaStorage::U8(out))
1652    }
1653
1654    // ---- Affine ----
1655
1656    fn affine(input: &CudaStorage, layout: &Layout, mul: f64, add: f64) -> Result<CudaStorage> {
1657        let dev = dev_from_storage(input)?;
1658
1659        let input_c = ensure_contiguous(input, layout, &dev)?;
1660        let n = layout.elem_count();
1661        let cfg = launch_cfg(n);
1662
1663        match &input_c {
1664            CudaStorage::F16(inp) => {
1665                let mut out: CudaSlice<u16> = dev
1666                    .dev
1667                    .alloc_zeros(n)
1668                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1669                let func = dev.get_func("affine_f16")?;
1670                unsafe { func.launch(cfg, (inp, &mut out, mul as f32, add as f32, n as u32)) }
1671                    .map_err(|e| Error::msg(format!("affine: {e}")))?;
1672                Ok(CudaStorage::F16(out))
1673            }
1674            CudaStorage::BF16(inp) => {
1675                let mut out: CudaSlice<u16> = dev
1676                    .dev
1677                    .alloc_zeros(n)
1678                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1679                let func = dev.get_func("affine_bf16")?;
1680                unsafe { func.launch(cfg, (inp, &mut out, mul as f32, add as f32, n as u32)) }
1681                    .map_err(|e| Error::msg(format!("affine: {e}")))?;
1682                Ok(CudaStorage::BF16(out))
1683            }
1684            CudaStorage::F32(inp) => {
1685                let mut out: CudaSlice<f32> = dev
1686                    .dev
1687                    .alloc_zeros(n)
1688                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1689                let func = dev.get_func("affine_f32")?;
1690                unsafe { func.launch(cfg, (inp, &mut out, mul as f32, add as f32, n as u32)) }
1691                    .map_err(|e| Error::msg(format!("affine: {e}")))?;
1692                Ok(CudaStorage::F32(out))
1693            }
1694            CudaStorage::F64(inp) => {
1695                let mut out: CudaSlice<f64> = dev
1696                    .dev
1697                    .alloc_zeros(n)
1698                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1699                let func = dev.get_func("affine_f64")?;
1700                unsafe { func.launch(cfg, (inp, &mut out, mul, add, n as u32)) }
1701                    .map_err(|e| Error::msg(format!("affine: {e}")))?;
1702                Ok(CudaStorage::F64(out))
1703            }
1704            _ => Err(Error::msg("affine: only float types supported")),
1705        }
1706    }
1707
1708    // ---- Index select ----
1709
1710    fn index_select(
1711        input: &CudaStorage,
1712        input_layout: &Layout,
1713        indices: &CudaStorage,
1714        indices_layout: &Layout,
1715        dim: usize,
1716    ) -> Result<CudaStorage> {
1717        let dev = dev_from_storage(input)?;
1718
1719        let input_c = ensure_contiguous(input, input_layout, &dev)?;
1720        let indices_c = ensure_contiguous(indices, indices_layout, &dev)?;
1721
1722        let input_dims = input_layout.dims();
1723
1724        let pre_dim: usize = input_dims[..dim].iter().product::<usize>().max(1);
1725        let src_dim = input_dims[dim];
1726        let post_dim: usize = input_dims[dim + 1..].iter().product::<usize>().max(1);
1727        let idx_len = indices_layout.elem_count();
1728        let out_n = pre_dim * idx_len * post_dim;
1729        let cfg = launch_cfg(out_n);
1730
1731        // We need indices as i64 on device
1732        let idx_i64 = match &indices_c {
1733            CudaStorage::I64(s) => s.clone(),
1734            CudaStorage::U32(s) => {
1735                let host = dev
1736                    .dev
1737                    .dtoh_sync_copy(s)
1738                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
1739                let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
1740                dev.dev
1741                    .htod_copy(host_i64)
1742                    .map_err(|e| Error::msg(format!("htod: {e}")))?
1743            }
1744            _ => return Err(Error::msg("index_select: indices must be integer type")),
1745        };
1746
1747        match &input_c {
1748            CudaStorage::F16(inp) => {
1749                let mut out: CudaSlice<u16> = dev
1750                    .dev
1751                    .alloc_zeros(out_n)
1752                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1753                let func = dev.get_func("index_select_f16")?;
1754                unsafe {
1755                    func.launch(
1756                        cfg,
1757                        (
1758                            inp,
1759                            &idx_i64,
1760                            &mut out,
1761                            pre_dim as u32,
1762                            src_dim as u32,
1763                            post_dim as u32,
1764                            idx_len as u32,
1765                            out_n as u32,
1766                        ),
1767                    )
1768                }
1769                .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1770                Ok(CudaStorage::F16(out))
1771            }
1772            CudaStorage::BF16(inp) => {
1773                let mut out: CudaSlice<u16> = dev
1774                    .dev
1775                    .alloc_zeros(out_n)
1776                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1777                let func = dev.get_func("index_select_bf16")?;
1778                unsafe {
1779                    func.launch(
1780                        cfg,
1781                        (
1782                            inp,
1783                            &idx_i64,
1784                            &mut out,
1785                            pre_dim as u32,
1786                            src_dim as u32,
1787                            post_dim as u32,
1788                            idx_len as u32,
1789                            out_n as u32,
1790                        ),
1791                    )
1792                }
1793                .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1794                Ok(CudaStorage::BF16(out))
1795            }
1796            CudaStorage::F32(inp) => {
1797                let mut out: CudaSlice<f32> = dev
1798                    .dev
1799                    .alloc_zeros(out_n)
1800                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1801                let func = dev.get_func("index_select_f32")?;
1802                unsafe {
1803                    func.launch(
1804                        cfg,
1805                        (
1806                            inp,
1807                            &idx_i64,
1808                            &mut out,
1809                            pre_dim as u32,
1810                            src_dim as u32,
1811                            post_dim as u32,
1812                            idx_len as u32,
1813                            out_n as u32,
1814                        ),
1815                    )
1816                }
1817                .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1818                Ok(CudaStorage::F32(out))
1819            }
1820            CudaStorage::F64(inp) => {
1821                let mut out: CudaSlice<f64> = dev
1822                    .dev
1823                    .alloc_zeros(out_n)
1824                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1825                let func = dev.get_func("index_select_f64")?;
1826                unsafe {
1827                    func.launch(
1828                        cfg,
1829                        (
1830                            inp,
1831                            &idx_i64,
1832                            &mut out,
1833                            pre_dim as u32,
1834                            src_dim as u32,
1835                            post_dim as u32,
1836                            idx_len as u32,
1837                            out_n as u32,
1838                        ),
1839                    )
1840                }
1841                .map_err(|e| Error::msg(format!("index_select: {e}")))?;
1842                Ok(CudaStorage::F64(out))
1843            }
1844            _ => Err(Error::msg("index_select: only float types supported")),
1845        }
1846    }
1847
1848    // ---- Powf ----
1849
1850    fn powf(input: &CudaStorage, layout: &Layout, exponent: f64) -> Result<CudaStorage> {
1851        let dev = dev_from_storage(input)?;
1852
1853        let input_c = ensure_contiguous(input, layout, &dev)?;
1854        let n = layout.elem_count();
1855        let cfg = launch_cfg(n);
1856
1857        match &input_c {
1858            CudaStorage::F16(inp) => {
1859                let mut out: CudaSlice<u16> = dev
1860                    .dev
1861                    .alloc_zeros(n)
1862                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1863                let func = dev.get_func("powf_f16")?;
1864                unsafe { func.launch(cfg, (inp, &mut out, exponent as f32, n as u32)) }
1865                    .map_err(|e| Error::msg(format!("powf: {e}")))?;
1866                Ok(CudaStorage::F16(out))
1867            }
1868            CudaStorage::BF16(inp) => {
1869                let mut out: CudaSlice<u16> = dev
1870                    .dev
1871                    .alloc_zeros(n)
1872                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1873                let func = dev.get_func("powf_bf16")?;
1874                unsafe { func.launch(cfg, (inp, &mut out, exponent as f32, n as u32)) }
1875                    .map_err(|e| Error::msg(format!("powf: {e}")))?;
1876                Ok(CudaStorage::BF16(out))
1877            }
1878            CudaStorage::F32(inp) => {
1879                let mut out: CudaSlice<f32> = dev
1880                    .dev
1881                    .alloc_zeros(n)
1882                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1883                let func = dev.get_func("powf_f32")?;
1884                unsafe { func.launch(cfg, (inp, &mut out, exponent as f32, n as u32)) }
1885                    .map_err(|e| Error::msg(format!("powf: {e}")))?;
1886                Ok(CudaStorage::F32(out))
1887            }
1888            CudaStorage::F64(inp) => {
1889                let mut out: CudaSlice<f64> = dev
1890                    .dev
1891                    .alloc_zeros(n)
1892                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1893                let func = dev.get_func("powf_f64")?;
1894                unsafe { func.launch(cfg, (inp, &mut out, exponent, n as u32)) }
1895                    .map_err(|e| Error::msg(format!("powf: {e}")))?;
1896                Ok(CudaStorage::F64(out))
1897            }
1898            _ => Err(Error::msg("powf: only float types supported")),
1899        }
1900    }
1901
1902    // ---- Clamp ----
1903
1904    fn clamp(input: &CudaStorage, layout: &Layout, min: f64, max: f64) -> Result<CudaStorage> {
1905        let dev = dev_from_storage(input)?;
1906
1907        let input_c = ensure_contiguous(input, layout, &dev)?;
1908        let n = layout.elem_count();
1909        let cfg = launch_cfg(n);
1910
1911        match &input_c {
1912            CudaStorage::F16(inp) => {
1913                let mut out: CudaSlice<u16> = dev
1914                    .dev
1915                    .alloc_zeros(n)
1916                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1917                let func = dev.get_func("clamp_f16")?;
1918                unsafe { func.launch(cfg, (inp, &mut out, min as f32, max as f32, n as u32)) }
1919                    .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1920                Ok(CudaStorage::F16(out))
1921            }
1922            CudaStorage::BF16(inp) => {
1923                let mut out: CudaSlice<u16> = dev
1924                    .dev
1925                    .alloc_zeros(n)
1926                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1927                let func = dev.get_func("clamp_bf16")?;
1928                unsafe { func.launch(cfg, (inp, &mut out, min as f32, max as f32, n as u32)) }
1929                    .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1930                Ok(CudaStorage::BF16(out))
1931            }
1932            CudaStorage::F32(inp) => {
1933                let mut out: CudaSlice<f32> = dev
1934                    .dev
1935                    .alloc_zeros(n)
1936                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1937                let func = dev.get_func("clamp_f32")?;
1938                unsafe { func.launch(cfg, (inp, &mut out, min as f32, max as f32, n as u32)) }
1939                    .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1940                Ok(CudaStorage::F32(out))
1941            }
1942            CudaStorage::F64(inp) => {
1943                let mut out: CudaSlice<f64> = dev
1944                    .dev
1945                    .alloc_zeros(n)
1946                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1947                let func = dev.get_func("clamp_f64")?;
1948                unsafe { func.launch(cfg, (inp, &mut out, min, max, n as u32)) }
1949                    .map_err(|e| Error::msg(format!("clamp: {e}")))?;
1950                Ok(CudaStorage::F64(out))
1951            }
1952            _ => Err(Error::msg("clamp: only float types supported")),
1953        }
1954    }
1955
1956    // ---- Where / conditional select ----
1957
1958    fn where_cond(
1959        mask: &CudaStorage,
1960        mask_layout: &Layout,
1961        on_true: &CudaStorage,
1962        on_true_layout: &Layout,
1963        on_false: &CudaStorage,
1964        on_false_layout: &Layout,
1965    ) -> Result<CudaStorage> {
1966        let dev = dev_from_storage(mask)?;
1967
1968        let mask_c = ensure_contiguous(mask, mask_layout, &dev)?;
1969        let true_c = ensure_contiguous(on_true, on_true_layout, &dev)?;
1970        let false_c = ensure_contiguous(on_false, on_false_layout, &dev)?;
1971        let n = mask_layout.elem_count();
1972        let cfg = launch_cfg(n);
1973
1974        let mask_u8 = match &mask_c {
1975            CudaStorage::U8(s) => s,
1976            _ => return Err(Error::msg("where_cond: mask must be u8")),
1977        };
1978
1979        match (&true_c, &false_c) {
1980            (CudaStorage::F16(t), CudaStorage::F16(f_vals)) => {
1981                let mut out: CudaSlice<u16> = dev
1982                    .dev
1983                    .alloc_zeros(n)
1984                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1985                let func = dev.get_func("where_cond_f16")?;
1986                unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
1987                    .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
1988                Ok(CudaStorage::F16(out))
1989            }
1990            (CudaStorage::BF16(t), CudaStorage::BF16(f_vals)) => {
1991                let mut out: CudaSlice<u16> = dev
1992                    .dev
1993                    .alloc_zeros(n)
1994                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
1995                let func = dev.get_func("where_cond_bf16")?;
1996                unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
1997                    .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
1998                Ok(CudaStorage::BF16(out))
1999            }
2000            (CudaStorage::F32(t), CudaStorage::F32(f_vals)) => {
2001                let mut out: CudaSlice<f32> = dev
2002                    .dev
2003                    .alloc_zeros(n)
2004                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2005                let func = dev.get_func("where_cond_f32")?;
2006                unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
2007                    .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
2008                Ok(CudaStorage::F32(out))
2009            }
2010            (CudaStorage::F64(t), CudaStorage::F64(f_vals)) => {
2011                let mut out: CudaSlice<f64> = dev
2012                    .dev
2013                    .alloc_zeros(n)
2014                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2015                let func = dev.get_func("where_cond_f64")?;
2016                unsafe { func.launch(cfg, (mask_u8, t, f_vals, &mut out, n as u32)) }
2017                    .map_err(|e| Error::msg(format!("where_cond: {e}")))?;
2018                Ok(CudaStorage::F64(out))
2019            }
2020            _ => Err(Error::msg("where_cond: dtype mismatch")),
2021        }
2022    }
2023
2024    // ---- Gather ----
2025
2026    fn gather(
2027        input: &CudaStorage,
2028        input_layout: &Layout,
2029        index: &CudaStorage,
2030        index_layout: &Layout,
2031        dim: usize,
2032    ) -> Result<CudaStorage> {
2033        let dev = dev_from_storage(input)?;
2034
2035        let input_c = ensure_contiguous(input, input_layout, &dev)?;
2036        let index_c = ensure_contiguous(index, index_layout, &dev)?;
2037
2038        let input_dims = input_layout.dims();
2039        let index_dims = index_layout.dims();
2040
2041        let pre: usize = input_dims[..dim].iter().product::<usize>().max(1);
2042        let inp_dim = input_dims[dim];
2043        let idx_dim = index_dims[dim];
2044        let post: usize = input_dims[dim + 1..].iter().product::<usize>().max(1);
2045        let n = index_layout.elem_count();
2046        let cfg = launch_cfg(n);
2047
2048        // Convert index to i64
2049        let idx_i64 = match &index_c {
2050            CudaStorage::I64(s) => s.clone(),
2051            CudaStorage::U32(s) => {
2052                let host = dev
2053                    .dev
2054                    .dtoh_sync_copy(s)
2055                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2056                let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
2057                dev.dev
2058                    .htod_copy(host_i64)
2059                    .map_err(|e| Error::msg(format!("htod: {e}")))?
2060            }
2061            CudaStorage::F32(s) => {
2062                let host = dev
2063                    .dev
2064                    .dtoh_sync_copy(s)
2065                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2066                let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
2067                dev.dev
2068                    .htod_copy(host_i64)
2069                    .map_err(|e| Error::msg(format!("htod: {e}")))?
2070            }
2071            CudaStorage::F64(s) => {
2072                let host = dev
2073                    .dev
2074                    .dtoh_sync_copy(s)
2075                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2076                let host_i64: Vec<i64> = host.iter().map(|&v| v as i64).collect();
2077                dev.dev
2078                    .htod_copy(host_i64)
2079                    .map_err(|e| Error::msg(format!("htod: {e}")))?
2080            }
2081            _ => return Err(Error::msg("gather: unsupported index dtype")),
2082        };
2083
2084        match &input_c {
2085            CudaStorage::F16(inp) => {
2086                let mut out: CudaSlice<u16> = dev
2087                    .dev
2088                    .alloc_zeros(n)
2089                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2090                let func = dev.get_func("gather_f16")?;
2091                unsafe {
2092                    func.launch(
2093                        cfg,
2094                        (
2095                            inp,
2096                            &idx_i64,
2097                            &mut out,
2098                            pre as u32,
2099                            inp_dim as u32,
2100                            idx_dim as u32,
2101                            post as u32,
2102                            n as u32,
2103                        ),
2104                    )
2105                }
2106                .map_err(|e| Error::msg(format!("gather: {e}")))?;
2107                Ok(CudaStorage::F16(out))
2108            }
2109            CudaStorage::BF16(inp) => {
2110                let mut out: CudaSlice<u16> = dev
2111                    .dev
2112                    .alloc_zeros(n)
2113                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2114                let func = dev.get_func("gather_bf16")?;
2115                unsafe {
2116                    func.launch(
2117                        cfg,
2118                        (
2119                            inp,
2120                            &idx_i64,
2121                            &mut out,
2122                            pre as u32,
2123                            inp_dim as u32,
2124                            idx_dim as u32,
2125                            post as u32,
2126                            n as u32,
2127                        ),
2128                    )
2129                }
2130                .map_err(|e| Error::msg(format!("gather: {e}")))?;
2131                Ok(CudaStorage::BF16(out))
2132            }
2133            CudaStorage::F32(inp) => {
2134                let mut out: CudaSlice<f32> = dev
2135                    .dev
2136                    .alloc_zeros(n)
2137                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2138                let func = dev.get_func("gather_f32")?;
2139                unsafe {
2140                    func.launch(
2141                        cfg,
2142                        (
2143                            inp,
2144                            &idx_i64,
2145                            &mut out,
2146                            pre as u32,
2147                            inp_dim as u32,
2148                            idx_dim as u32,
2149                            post as u32,
2150                            n as u32,
2151                        ),
2152                    )
2153                }
2154                .map_err(|e| Error::msg(format!("gather: {e}")))?;
2155                Ok(CudaStorage::F32(out))
2156            }
2157            CudaStorage::F64(inp) => {
2158                let mut out: CudaSlice<f64> = dev
2159                    .dev
2160                    .alloc_zeros(n)
2161                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2162                let func = dev.get_func("gather_f64")?;
2163                unsafe {
2164                    func.launch(
2165                        cfg,
2166                        (
2167                            inp,
2168                            &idx_i64,
2169                            &mut out,
2170                            pre as u32,
2171                            inp_dim as u32,
2172                            idx_dim as u32,
2173                            post as u32,
2174                            n as u32,
2175                        ),
2176                    )
2177                }
2178                .map_err(|e| Error::msg(format!("gather: {e}")))?;
2179                Ok(CudaStorage::F64(out))
2180            }
2181            _ => Err(Error::msg("gather: only float types supported")),
2182        }
2183    }
2184
2185    // ---- Concatenation ----
2186
2187    fn cat(
2188        inputs: &[(&CudaStorage, &Layout)],
2189        out_shape: &Shape,
2190        dim: usize,
2191    ) -> Result<CudaStorage> {
2192        if inputs.is_empty() {
2193            return Err(Error::msg("cat: empty input list"));
2194        }
2195
2196        let dev = dev_from_storage(inputs[0].0)?;
2197
2198        let out_dims = out_shape.dims();
2199        let out_n = out_shape.elem_count();
2200
2201        let outer: usize = out_dims[..dim].iter().product::<usize>().max(1);
2202        let total_dim = out_dims[dim];
2203        let inner: usize = out_dims[dim + 1..].iter().product::<usize>().max(1);
2204
2205        match inputs[0].0 {
2206            CudaStorage::F16(_) => {
2207                let mut out: CudaSlice<u16> = dev
2208                    .dev
2209                    .alloc_zeros(out_n)
2210                    .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2211                let mut dim_offset = 0u32;
2212
2213                for &(storage, layout) in inputs {
2214                    let storage_c = ensure_contiguous(storage, layout, &dev)?;
2215                    let inp = match &storage_c {
2216                        CudaStorage::F16(s) => s,
2217                        _ => return Err(Error::msg("cat: dtype mismatch")),
2218                    };
2219                    let t_dims = layout.dims();
2220                    let this_dim = t_dims[dim];
2221                    let src_n = layout.elem_count();
2222                    let cfg = launch_cfg(src_n);
2223
2224                    let func = dev.get_func("cat_copy_f16")?;
2225                    unsafe {
2226                        func.launch(
2227                            cfg,
2228                            (
2229                                inp,
2230                                &mut out,
2231                                outer as u32,
2232                                this_dim as u32,
2233                                inner as u32,
2234                                total_dim as u32,
2235                                dim_offset,
2236                                src_n as u32,
2237                            ),
2238                        )
2239                    }
2240                    .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2241
2242                    dim_offset += this_dim as u32;
2243                }
2244                Ok(CudaStorage::F16(out))
2245            }
2246            CudaStorage::BF16(_) => {
2247                let mut out: CudaSlice<u16> = dev
2248                    .dev
2249                    .alloc_zeros(out_n)
2250                    .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2251                let mut dim_offset = 0u32;
2252
2253                for &(storage, layout) in inputs {
2254                    let storage_c = ensure_contiguous(storage, layout, &dev)?;
2255                    let inp = match &storage_c {
2256                        CudaStorage::BF16(s) => s,
2257                        _ => return Err(Error::msg("cat: dtype mismatch")),
2258                    };
2259                    let t_dims = layout.dims();
2260                    let this_dim = t_dims[dim];
2261                    let src_n = layout.elem_count();
2262                    let cfg = launch_cfg(src_n);
2263
2264                    let func = dev.get_func("cat_copy_bf16")?;
2265                    unsafe {
2266                        func.launch(
2267                            cfg,
2268                            (
2269                                inp,
2270                                &mut out,
2271                                outer as u32,
2272                                this_dim as u32,
2273                                inner as u32,
2274                                total_dim as u32,
2275                                dim_offset,
2276                                src_n as u32,
2277                            ),
2278                        )
2279                    }
2280                    .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2281
2282                    dim_offset += this_dim as u32;
2283                }
2284                Ok(CudaStorage::BF16(out))
2285            }
2286            CudaStorage::F32(_) => {
2287                let mut out: CudaSlice<f32> = dev
2288                    .dev
2289                    .alloc_zeros(out_n)
2290                    .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2291                let mut dim_offset = 0u32;
2292
2293                for &(storage, layout) in inputs {
2294                    let storage_c = ensure_contiguous(storage, layout, &dev)?;
2295                    let inp = match &storage_c {
2296                        CudaStorage::F32(s) => s,
2297                        _ => return Err(Error::msg("cat: dtype mismatch")),
2298                    };
2299                    let t_dims = layout.dims();
2300                    let this_dim = t_dims[dim];
2301                    let src_n = layout.elem_count();
2302                    let cfg = launch_cfg(src_n);
2303
2304                    let func = dev.get_func("cat_copy_f32")?;
2305                    unsafe {
2306                        func.launch(
2307                            cfg,
2308                            (
2309                                inp,
2310                                &mut out,
2311                                outer as u32,
2312                                this_dim as u32,
2313                                inner as u32,
2314                                total_dim as u32,
2315                                dim_offset,
2316                                src_n as u32,
2317                            ),
2318                        )
2319                    }
2320                    .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2321
2322                    dim_offset += this_dim as u32;
2323                }
2324                Ok(CudaStorage::F32(out))
2325            }
2326            CudaStorage::F64(_) => {
2327                let mut out: CudaSlice<f64> = dev
2328                    .dev
2329                    .alloc_zeros(out_n)
2330                    .map_err(|e| Error::msg(format!("alloc cat: {e}")))?;
2331                let mut dim_offset = 0u32;
2332
2333                for &(storage, layout) in inputs {
2334                    let storage_c = ensure_contiguous(storage, layout, &dev)?;
2335                    let inp = match &storage_c {
2336                        CudaStorage::F64(s) => s,
2337                        _ => return Err(Error::msg("cat: dtype mismatch")),
2338                    };
2339                    let t_dims = layout.dims();
2340                    let this_dim = t_dims[dim];
2341                    let src_n = layout.elem_count();
2342                    let cfg = launch_cfg(src_n);
2343
2344                    let func = dev.get_func("cat_copy_f64")?;
2345                    unsafe {
2346                        func.launch(
2347                            cfg,
2348                            (
2349                                inp,
2350                                &mut out,
2351                                outer as u32,
2352                                this_dim as u32,
2353                                inner as u32,
2354                                total_dim as u32,
2355                                dim_offset,
2356                                src_n as u32,
2357                            ),
2358                        )
2359                    }
2360                    .map_err(|e| Error::msg(format!("cat_copy: {e}")))?;
2361
2362                    dim_offset += this_dim as u32;
2363                }
2364                Ok(CudaStorage::F64(out))
2365            }
2366            _ => Err(Error::msg("cat: only float types supported")),
2367        }
2368    }
2369
2370    fn cast(
2371        input: &CudaStorage,
2372        layout: &Layout,
2373        dtype: DType,
2374        device: &CudaDevice,
2375    ) -> Result<CudaStorage> {
2376        let src_dtype = input.dtype();
2377        if src_dtype == dtype {
2378            return Ok(input.clone());
2379        }
2380
2381        // For F16↔F32 and BF16↔F32 we have dedicated CUDA kernels.
2382        let dev = dev_from_storage(input)?;
2383        let contig = ensure_contiguous(input, layout, &dev)?;
2384        let n = layout.shape().elem_count();
2385        let cfg = launch_cfg(n);
2386
2387        match (src_dtype, dtype) {
2388            (DType::F16, DType::F32) => {
2389                let src_slice = contig.as_cuda_slice_u16()?;
2390                let out: CudaSlice<f32> = dev
2391                    .dev
2392                    .alloc_zeros(n)
2393                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2394                let func = dev.get_func("cast_f16_to_f32")?;
2395                unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2396                    .map_err(|e| Error::msg(format!("launch cast_f16_to_f32: {e}")))?;
2397                Ok(CudaStorage::F32(out))
2398            }
2399            (DType::F32, DType::F16) => {
2400                let src_slice = contig.as_cuda_slice_f32()?;
2401                let out: CudaSlice<u16> = dev
2402                    .dev
2403                    .alloc_zeros(n)
2404                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2405                let func = dev.get_func("cast_f32_to_f16")?;
2406                unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2407                    .map_err(|e| Error::msg(format!("launch cast_f32_to_f16: {e}")))?;
2408                Ok(CudaStorage::F16(out))
2409            }
2410            (DType::BF16, DType::F32) => {
2411                let src_slice = contig.as_cuda_slice_u16()?;
2412                let out: CudaSlice<f32> = dev
2413                    .dev
2414                    .alloc_zeros(n)
2415                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2416                let func = dev.get_func("cast_bf16_to_f32")?;
2417                unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2418                    .map_err(|e| Error::msg(format!("launch cast_bf16_to_f32: {e}")))?;
2419                Ok(CudaStorage::F32(out))
2420            }
2421            (DType::F32, DType::BF16) => {
2422                let src_slice = contig.as_cuda_slice_f32()?;
2423                let out: CudaSlice<u16> = dev
2424                    .dev
2425                    .alloc_zeros(n)
2426                    .map_err(|e| Error::msg(format!("alloc: {e}")))?;
2427                let func = dev.get_func("cast_f32_to_bf16")?;
2428                unsafe { func.launch(cfg, (src_slice, &out, n as u32)) }
2429                    .map_err(|e| Error::msg(format!("launch cast_f32_to_bf16: {e}")))?;
2430                Ok(CudaStorage::BF16(out))
2431            }
2432            // For F16↔F64, BF16↔F64, F16↔BF16, and integer casts:
2433            // go through F32 intermediate or fall back to host round-trip
2434            (DType::F16, DType::F64) | (DType::BF16, DType::F64) => {
2435                let f32_storage = Self::cast(
2436                    &contig,
2437                    &Layout::contiguous(layout.shape().clone()),
2438                    DType::F32,
2439                    device,
2440                )?;
2441                let f32_layout = Layout::contiguous(layout.shape().clone());
2442                Self::cast(&f32_storage, &f32_layout, DType::F64, device)
2443            }
2444            (DType::F64, DType::F16) | (DType::F64, DType::BF16) => {
2445                let f32_storage = Self::cast(
2446                    &contig,
2447                    &Layout::contiguous(layout.shape().clone()),
2448                    DType::F32,
2449                    device,
2450                )?;
2451                let f32_layout = Layout::contiguous(layout.shape().clone());
2452                Self::cast(&f32_storage, &f32_layout, dtype, device)
2453            }
2454            (DType::F16, DType::BF16) | (DType::BF16, DType::F16) => {
2455                let f32_storage = Self::cast(
2456                    &contig,
2457                    &Layout::contiguous(layout.shape().clone()),
2458                    DType::F32,
2459                    device,
2460                )?;
2461                let f32_layout = Layout::contiguous(layout.shape().clone());
2462                Self::cast(&f32_storage, &f32_layout, dtype, device)
2463            }
2464            _ => {
2465                // Fallback: host round-trip for integer ↔ float, F32↔F64, etc.
2466                let data = Self::to_f64_vec(&contig, &Layout::contiguous(layout.shape().clone()))?;
2467                Self::from_f64_slice(&data, dtype, device)
2468            }
2469        }
2470    }
2471}
2472
2473// Host ↔ Device transfer helpers
2474
2475impl CudaStorage {
2476    /// Get the underlying CudaSlice<f32> (returns error if dtype doesn't match).
2477    pub fn as_cuda_slice_f32(&self) -> Result<&CudaSlice<f32>> {
2478        match self {
2479            CudaStorage::F32(s) => Ok(s),
2480            _ => Err(Error::msg(format!(
2481                "expected F32 storage, got {:?}",
2482                self.dtype()
2483            ))),
2484        }
2485    }
2486
2487    /// Get the underlying CudaSlice<u16> for F16 or BF16 storage.
2488    pub fn as_cuda_slice_u16(&self) -> Result<&CudaSlice<u16>> {
2489        match self {
2490            CudaStorage::F16(s) | CudaStorage::BF16(s) => Ok(s),
2491            _ => Err(Error::msg(format!(
2492                "expected F16/BF16 storage, got {:?}",
2493                self.dtype()
2494            ))),
2495        }
2496    }
2497
2498    /// Transfer data from host Vec<f32> to a new CudaStorage on the given device.
2499    pub fn from_f32_vec(data: Vec<f32>, device: &CudaDevice) -> Result<Self> {
2500        let s = device
2501            .dev
2502            .htod_copy(data)
2503            .map_err(|e| Error::msg(format!("htod f32: {e}")))?;
2504        Ok(CudaStorage::F32(s))
2505    }
2506
2507    /// Transfer data from host Vec<f64> to a new CudaStorage on the given device.
2508    pub fn from_f64_vec(data: Vec<f64>, device: &CudaDevice) -> Result<Self> {
2509        let s = device
2510            .dev
2511            .htod_copy(data)
2512            .map_err(|e| Error::msg(format!("htod f64: {e}")))?;
2513        Ok(CudaStorage::F64(s))
2514    }
2515
2516    /// Transfer data from host Vec<f16> to a new F16 CudaStorage on the given device.
2517    pub fn from_f16_vec(data: Vec<f16>, device: &CudaDevice) -> Result<Self> {
2518        let bits: Vec<u16> = data.iter().map(|v| v.to_bits()).collect();
2519        let s = device
2520            .dev
2521            .htod_copy(bits)
2522            .map_err(|e| Error::msg(format!("htod f16: {e}")))?;
2523        Ok(CudaStorage::F16(s))
2524    }
2525
2526    /// Transfer data from host Vec<bf16> to a new BF16 CudaStorage on the given device.
2527    pub fn from_bf16_vec(data: Vec<bf16>, device: &CudaDevice) -> Result<Self> {
2528        let bits: Vec<u16> = data.iter().map(|v| v.to_bits()).collect();
2529        let s = device
2530            .dev
2531            .htod_copy(bits)
2532            .map_err(|e| Error::msg(format!("htod bf16: {e}")))?;
2533        Ok(CudaStorage::BF16(s))
2534    }
2535
2536    /// Copy all data to host as Vec<f64>.
2537    pub fn to_host_f64(&self, device: &CudaDevice) -> Result<Vec<f64>> {
2538        match self {
2539            CudaStorage::F16(s) => {
2540                let host = device
2541                    .dev
2542                    .dtoh_sync_copy(s)
2543                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2544                Ok(host
2545                    .iter()
2546                    .map(|&bits| f16::from_bits(bits).to_f64())
2547                    .collect())
2548            }
2549            CudaStorage::BF16(s) => {
2550                let host = device
2551                    .dev
2552                    .dtoh_sync_copy(s)
2553                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2554                Ok(host
2555                    .iter()
2556                    .map(|&bits| bf16::from_bits(bits).to_f64())
2557                    .collect())
2558            }
2559            CudaStorage::F32(s) => {
2560                let host = device
2561                    .dev
2562                    .dtoh_sync_copy(s)
2563                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2564                Ok(host.iter().map(|&v| v as f64).collect())
2565            }
2566            CudaStorage::F64(s) => device
2567                .dev
2568                .dtoh_sync_copy(s)
2569                .map_err(|e| Error::msg(format!("dtoh: {e}"))),
2570            CudaStorage::U8(s) => {
2571                let host = device
2572                    .dev
2573                    .dtoh_sync_copy(s)
2574                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2575                Ok(host.iter().map(|&v| v as f64).collect())
2576            }
2577            CudaStorage::U32(s) => {
2578                let host = device
2579                    .dev
2580                    .dtoh_sync_copy(s)
2581                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2582                Ok(host.iter().map(|&v| v as f64).collect())
2583            }
2584            CudaStorage::I64(s) => {
2585                let host = device
2586                    .dev
2587                    .dtoh_sync_copy(s)
2588                    .map_err(|e| Error::msg(format!("dtoh: {e}")))?;
2589                Ok(host.iter().map(|&v| v as f64).collect())
2590            }
2591        }
2592    }
2593}
2594
2595/// Convenience type alias for CUDA tensors.
2596pub type CudaTensor = shrew_core::Tensor<CudaBackend>;