shrew_cuda/
pool.rs

1// CUDA Memory Pool — Caching allocator for GPU buffer reuse
2//
3// Avoids repeated cudaMalloc/cudaFree round-trips by maintaining per-type,
4// per-size free lists of previously allocated CudaSlice buffers.
5//
6// When a buffer is "returned" to the pool it is not freed to the CUDA driver;
7// instead it is cached.  Future allocations of the same element type and count
8// will reuse these cached buffers, eliminating the allocation overhead.
9//
10// This is conceptually similar to PyTorch's CUDA caching allocator.
11//
12// Usage (through CudaDevice helpers):
13//
14//   let buf: CudaSlice<f32> = device.pool_alloc::<f32>(1024)?;   // from pool
15//   device.pool_reclaim_f32(buf);                                  // return
16//   let stats = device.pool_stats();                               // query
17//   device.empty_cache();                                          // release
18
19use std::collections::HashMap;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Mutex;
22
23use cudarc::driver::{CudaSlice, DeviceSlice};
24
25// Pool statistics
26
27/// Snapshot of the pool's allocation statistics.
28#[derive(Debug, Clone, Copy)]
29pub struct PoolStats {
30    /// Total bytes currently held in the cache (not in use by tensors).
31    pub cached_bytes: usize,
32    /// Number of individual buffers currently in the cache.
33    pub cached_buffers: usize,
34    /// Cumulative cache hits (allocations served from the cache).
35    pub hits: u64,
36    /// Cumulative cache misses (allocations that fell through to cudaMalloc).
37    pub misses: u64,
38}
39
40// Typed free-list bucket
41
42/// A per-type free-list: maps element count → stack of free CudaSlice<T>.
43struct TypedPool<T> {
44    buckets: Mutex<HashMap<usize, Vec<CudaSlice<T>>>>,
45}
46
47impl<T> TypedPool<T> {
48    fn new() -> Self {
49        TypedPool {
50            buckets: Mutex::new(HashMap::new()),
51        }
52    }
53
54    /// Try to pop a cached buffer of exactly `n` elements.
55    fn try_pop(&self, n: usize) -> Option<CudaSlice<T>> {
56        let mut map = self.buckets.lock().unwrap();
57        if let Some(stack) = map.get_mut(&n) {
58            stack.pop()
59        } else {
60            None
61        }
62    }
63
64    /// Push a buffer back into the cache.
65    fn push(&self, slice: CudaSlice<T>)
66    where
67        CudaSlice<T>: DeviceSlice<T>,
68    {
69        let n = slice.len();
70        let mut map = self.buckets.lock().unwrap();
71        map.entry(n).or_default().push(slice);
72    }
73
74    /// Drain all cached buffers, returning the count and total elements freed.
75    fn drain(&self) -> (usize, usize) {
76        let mut map = self.buckets.lock().unwrap();
77        let mut count = 0usize;
78        let mut elems = 0usize;
79        for (n, stack) in map.drain() {
80            count += stack.len();
81            elems += n * stack.len();
82        }
83        (count, elems)
84    }
85
86    /// Count of cached buffers and total cached elements.
87    fn stats(&self) -> (usize, usize) {
88        let map = self.buckets.lock().unwrap();
89        let mut count = 0usize;
90        let mut elems = 0usize;
91        for (n, stack) in map.iter() {
92            count += stack.len();
93            elems += *n * stack.len();
94        }
95        (count, elems)
96    }
97}
98
99// CudaMemPool
100
101/// A CUDA memory caching allocator.
102///
103/// Maintains per-dtype free lists keyed by element count. Reuses buffers
104/// when possible, falling back to `cudaMalloc` on cache miss.
105pub struct CudaMemPool {
106    pool_u8: TypedPool<u8>,
107    pool_u16: TypedPool<u16>,
108    pool_u32: TypedPool<u32>,
109    pool_f32: TypedPool<f32>,
110    pool_f64: TypedPool<f64>,
111    pool_i64: TypedPool<i64>,
112
113    // Atomic counters — no lock contention on the hot path
114    hits: AtomicU64,
115    misses: AtomicU64,
116}
117
118impl CudaMemPool {
119    /// Create a new empty memory pool.
120    pub fn new() -> Self {
121        CudaMemPool {
122            pool_u8: TypedPool::new(),
123            pool_u16: TypedPool::new(),
124            pool_u32: TypedPool::new(),
125            pool_f32: TypedPool::new(),
126            pool_f64: TypedPool::new(),
127            pool_i64: TypedPool::new(),
128            hits: AtomicU64::new(0),
129            misses: AtomicU64::new(0),
130        }
131    }
132
133    // Allocation helpers (per type)
134
135    /// Allocate `n` elements of type `f32`, reusing a cached buffer if available.
136    /// The returned buffer content is **undefined** (not zeroed).
137    pub fn alloc_f32(
138        &self,
139        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
140        n: usize,
141    ) -> std::result::Result<CudaSlice<f32>, cudarc::driver::DriverError> {
142        if let Some(buf) = self.pool_f32.try_pop(n) {
143            self.hits.fetch_add(1, Ordering::Relaxed);
144            Ok(buf)
145        } else {
146            self.misses.fetch_add(1, Ordering::Relaxed);
147            unsafe { dev.alloc::<f32>(n) }
148        }
149    }
150
151    /// Allocate `n` elements of `f32` and zero them.
152    pub fn alloc_zeros_f32(
153        &self,
154        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
155        n: usize,
156    ) -> std::result::Result<CudaSlice<f32>, cudarc::driver::DriverError> {
157        if let Some(mut buf) = self.pool_f32.try_pop(n) {
158            self.hits.fetch_add(1, Ordering::Relaxed);
159            dev.memset_zeros(&mut buf)?;
160            Ok(buf)
161        } else {
162            self.misses.fetch_add(1, Ordering::Relaxed);
163            dev.alloc_zeros::<f32>(n)
164        }
165    }
166
167    /// Allocate `n` elements of type `f64`.
168    pub fn alloc_f64(
169        &self,
170        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
171        n: usize,
172    ) -> std::result::Result<CudaSlice<f64>, cudarc::driver::DriverError> {
173        if let Some(buf) = self.pool_f64.try_pop(n) {
174            self.hits.fetch_add(1, Ordering::Relaxed);
175            Ok(buf)
176        } else {
177            self.misses.fetch_add(1, Ordering::Relaxed);
178            unsafe { dev.alloc::<f64>(n) }
179        }
180    }
181
182    pub fn alloc_zeros_f64(
183        &self,
184        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
185        n: usize,
186    ) -> std::result::Result<CudaSlice<f64>, cudarc::driver::DriverError> {
187        if let Some(mut buf) = self.pool_f64.try_pop(n) {
188            self.hits.fetch_add(1, Ordering::Relaxed);
189            dev.memset_zeros(&mut buf)?;
190            Ok(buf)
191        } else {
192            self.misses.fetch_add(1, Ordering::Relaxed);
193            dev.alloc_zeros::<f64>(n)
194        }
195    }
196
197    /// Allocate `n` elements of type `u16` (used for F16/BF16 storage).
198    pub fn alloc_u16(
199        &self,
200        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
201        n: usize,
202    ) -> std::result::Result<CudaSlice<u16>, cudarc::driver::DriverError> {
203        if let Some(buf) = self.pool_u16.try_pop(n) {
204            self.hits.fetch_add(1, Ordering::Relaxed);
205            Ok(buf)
206        } else {
207            self.misses.fetch_add(1, Ordering::Relaxed);
208            unsafe { dev.alloc::<u16>(n) }
209        }
210    }
211
212    pub fn alloc_zeros_u16(
213        &self,
214        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
215        n: usize,
216    ) -> std::result::Result<CudaSlice<u16>, cudarc::driver::DriverError> {
217        if let Some(mut buf) = self.pool_u16.try_pop(n) {
218            self.hits.fetch_add(1, Ordering::Relaxed);
219            dev.memset_zeros(&mut buf)?;
220            Ok(buf)
221        } else {
222            self.misses.fetch_add(1, Ordering::Relaxed);
223            dev.alloc_zeros::<u16>(n)
224        }
225    }
226
227    /// Allocate `n` elements of type `u8`.
228    pub fn alloc_u8(
229        &self,
230        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
231        n: usize,
232    ) -> std::result::Result<CudaSlice<u8>, cudarc::driver::DriverError> {
233        if let Some(buf) = self.pool_u8.try_pop(n) {
234            self.hits.fetch_add(1, Ordering::Relaxed);
235            Ok(buf)
236        } else {
237            self.misses.fetch_add(1, Ordering::Relaxed);
238            unsafe { dev.alloc::<u8>(n) }
239        }
240    }
241
242    pub fn alloc_zeros_u8(
243        &self,
244        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
245        n: usize,
246    ) -> std::result::Result<CudaSlice<u8>, cudarc::driver::DriverError> {
247        if let Some(mut buf) = self.pool_u8.try_pop(n) {
248            self.hits.fetch_add(1, Ordering::Relaxed);
249            dev.memset_zeros(&mut buf)?;
250            Ok(buf)
251        } else {
252            self.misses.fetch_add(1, Ordering::Relaxed);
253            dev.alloc_zeros::<u8>(n)
254        }
255    }
256
257    /// Allocate `n` elements of type `u32`.
258    pub fn alloc_u32(
259        &self,
260        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
261        n: usize,
262    ) -> std::result::Result<CudaSlice<u32>, cudarc::driver::DriverError> {
263        if let Some(buf) = self.pool_u32.try_pop(n) {
264            self.hits.fetch_add(1, Ordering::Relaxed);
265            Ok(buf)
266        } else {
267            self.misses.fetch_add(1, Ordering::Relaxed);
268            unsafe { dev.alloc::<u32>(n) }
269        }
270    }
271
272    pub fn alloc_zeros_u32(
273        &self,
274        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
275        n: usize,
276    ) -> std::result::Result<CudaSlice<u32>, cudarc::driver::DriverError> {
277        if let Some(mut buf) = self.pool_u32.try_pop(n) {
278            self.hits.fetch_add(1, Ordering::Relaxed);
279            dev.memset_zeros(&mut buf)?;
280            Ok(buf)
281        } else {
282            self.misses.fetch_add(1, Ordering::Relaxed);
283            dev.alloc_zeros::<u32>(n)
284        }
285    }
286
287    /// Allocate `n` elements of type `i64`.
288    pub fn alloc_i64(
289        &self,
290        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
291        n: usize,
292    ) -> std::result::Result<CudaSlice<i64>, cudarc::driver::DriverError> {
293        if let Some(buf) = self.pool_i64.try_pop(n) {
294            self.hits.fetch_add(1, Ordering::Relaxed);
295            Ok(buf)
296        } else {
297            self.misses.fetch_add(1, Ordering::Relaxed);
298            unsafe { dev.alloc::<i64>(n) }
299        }
300    }
301
302    pub fn alloc_zeros_i64(
303        &self,
304        dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
305        n: usize,
306    ) -> std::result::Result<CudaSlice<i64>, cudarc::driver::DriverError> {
307        if let Some(mut buf) = self.pool_i64.try_pop(n) {
308            self.hits.fetch_add(1, Ordering::Relaxed);
309            dev.memset_zeros(&mut buf)?;
310            Ok(buf)
311        } else {
312            self.misses.fetch_add(1, Ordering::Relaxed);
313            dev.alloc_zeros::<i64>(n)
314        }
315    }
316
317    // Reclaim (return buffer to pool)
318
319    pub fn reclaim_f32(&self, s: CudaSlice<f32>) {
320        self.pool_f32.push(s);
321    }
322    pub fn reclaim_f64(&self, s: CudaSlice<f64>) {
323        self.pool_f64.push(s);
324    }
325    pub fn reclaim_u16(&self, s: CudaSlice<u16>) {
326        self.pool_u16.push(s);
327    }
328    pub fn reclaim_u8(&self, s: CudaSlice<u8>) {
329        self.pool_u8.push(s);
330    }
331    pub fn reclaim_u32(&self, s: CudaSlice<u32>) {
332        self.pool_u32.push(s);
333    }
334    pub fn reclaim_i64(&self, s: CudaSlice<i64>) {
335        self.pool_i64.push(s);
336    }
337
338    /// Reclaim all buffers inside a `CudaStorage`, returning them to the pool.
339    pub fn reclaim_storage(&self, storage: super::CudaStorage) {
340        match storage {
341            super::CudaStorage::F16(s) => self.pool_u16.push(s),
342            super::CudaStorage::BF16(s) => self.pool_u16.push(s),
343            super::CudaStorage::F32(s) => self.pool_f32.push(s),
344            super::CudaStorage::F64(s) => self.pool_f64.push(s),
345            super::CudaStorage::U8(s) => self.pool_u8.push(s),
346            super::CudaStorage::U32(s) => self.pool_u32.push(s),
347            super::CudaStorage::I64(s) => self.pool_i64.push(s),
348        }
349    }
350
351    // Cache management
352
353    /// Release all cached buffers back to the CUDA driver.
354    /// This actually frees GPU memory.
355    pub fn empty_cache(&self) {
356        self.pool_u8.drain();
357        self.pool_u16.drain();
358        self.pool_u32.drain();
359        self.pool_f32.drain();
360        self.pool_f64.drain();
361        self.pool_i64.drain();
362    }
363
364    /// Return a snapshot of pool statistics.
365    pub fn stats(&self) -> PoolStats {
366        let (c_u8, e_u8) = self.pool_u8.stats();
367        let (c_u16, e_u16) = self.pool_u16.stats();
368        let (c_u32, e_u32) = self.pool_u32.stats();
369        let (c_f32, e_f32) = self.pool_f32.stats();
370        let (c_f64, e_f64) = self.pool_f64.stats();
371        let (c_i64, e_i64) = self.pool_i64.stats();
372
373        let cached_buffers = c_u8 + c_u16 + c_u32 + c_f32 + c_f64 + c_i64;
374        let cached_bytes = e_u8 * std::mem::size_of::<u8>()
375            + e_u16 * std::mem::size_of::<u16>()
376            + e_u32 * std::mem::size_of::<u32>()
377            + e_f32 * std::mem::size_of::<f32>()
378            + e_f64 * std::mem::size_of::<f64>()
379            + e_i64 * std::mem::size_of::<i64>();
380
381        PoolStats {
382            cached_bytes,
383            cached_buffers,
384            hits: self.hits.load(Ordering::Relaxed),
385            misses: self.misses.load(Ordering::Relaxed),
386        }
387    }
388
389    /// Reset hit/miss counters.
390    pub fn reset_stats(&self) {
391        self.hits.store(0, Ordering::Relaxed);
392        self.misses.store(0, Ordering::Relaxed);
393    }
394}
395
396impl Default for CudaMemPool {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402// Safety: All interior mutability is through Mutex + Atomics.
403unsafe impl Send for CudaMemPool {}
404unsafe impl Sync for CudaMemPool {}