1use std::collections::HashMap;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Mutex;
22
23use cudarc::driver::{CudaSlice, DeviceSlice};
24
25#[derive(Debug, Clone, Copy)]
29pub struct PoolStats {
30 pub cached_bytes: usize,
32 pub cached_buffers: usize,
34 pub hits: u64,
36 pub misses: u64,
38}
39
40struct TypedPool<T> {
44 buckets: Mutex<HashMap<usize, Vec<CudaSlice<T>>>>,
45}
46
47impl<T> TypedPool<T> {
48 fn new() -> Self {
49 TypedPool {
50 buckets: Mutex::new(HashMap::new()),
51 }
52 }
53
54 fn try_pop(&self, n: usize) -> Option<CudaSlice<T>> {
56 let mut map = self.buckets.lock().unwrap();
57 if let Some(stack) = map.get_mut(&n) {
58 stack.pop()
59 } else {
60 None
61 }
62 }
63
64 fn push(&self, slice: CudaSlice<T>)
66 where
67 CudaSlice<T>: DeviceSlice<T>,
68 {
69 let n = slice.len();
70 let mut map = self.buckets.lock().unwrap();
71 map.entry(n).or_default().push(slice);
72 }
73
74 fn drain(&self) -> (usize, usize) {
76 let mut map = self.buckets.lock().unwrap();
77 let mut count = 0usize;
78 let mut elems = 0usize;
79 for (n, stack) in map.drain() {
80 count += stack.len();
81 elems += n * stack.len();
82 }
83 (count, elems)
84 }
85
86 fn stats(&self) -> (usize, usize) {
88 let map = self.buckets.lock().unwrap();
89 let mut count = 0usize;
90 let mut elems = 0usize;
91 for (n, stack) in map.iter() {
92 count += stack.len();
93 elems += *n * stack.len();
94 }
95 (count, elems)
96 }
97}
98
99pub struct CudaMemPool {
106 pool_u8: TypedPool<u8>,
107 pool_u16: TypedPool<u16>,
108 pool_u32: TypedPool<u32>,
109 pool_f32: TypedPool<f32>,
110 pool_f64: TypedPool<f64>,
111 pool_i64: TypedPool<i64>,
112
113 hits: AtomicU64,
115 misses: AtomicU64,
116}
117
118impl CudaMemPool {
119 pub fn new() -> Self {
121 CudaMemPool {
122 pool_u8: TypedPool::new(),
123 pool_u16: TypedPool::new(),
124 pool_u32: TypedPool::new(),
125 pool_f32: TypedPool::new(),
126 pool_f64: TypedPool::new(),
127 pool_i64: TypedPool::new(),
128 hits: AtomicU64::new(0),
129 misses: AtomicU64::new(0),
130 }
131 }
132
133 pub fn alloc_f32(
138 &self,
139 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
140 n: usize,
141 ) -> std::result::Result<CudaSlice<f32>, cudarc::driver::DriverError> {
142 if let Some(buf) = self.pool_f32.try_pop(n) {
143 self.hits.fetch_add(1, Ordering::Relaxed);
144 Ok(buf)
145 } else {
146 self.misses.fetch_add(1, Ordering::Relaxed);
147 unsafe { dev.alloc::<f32>(n) }
148 }
149 }
150
151 pub fn alloc_zeros_f32(
153 &self,
154 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
155 n: usize,
156 ) -> std::result::Result<CudaSlice<f32>, cudarc::driver::DriverError> {
157 if let Some(mut buf) = self.pool_f32.try_pop(n) {
158 self.hits.fetch_add(1, Ordering::Relaxed);
159 dev.memset_zeros(&mut buf)?;
160 Ok(buf)
161 } else {
162 self.misses.fetch_add(1, Ordering::Relaxed);
163 dev.alloc_zeros::<f32>(n)
164 }
165 }
166
167 pub fn alloc_f64(
169 &self,
170 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
171 n: usize,
172 ) -> std::result::Result<CudaSlice<f64>, cudarc::driver::DriverError> {
173 if let Some(buf) = self.pool_f64.try_pop(n) {
174 self.hits.fetch_add(1, Ordering::Relaxed);
175 Ok(buf)
176 } else {
177 self.misses.fetch_add(1, Ordering::Relaxed);
178 unsafe { dev.alloc::<f64>(n) }
179 }
180 }
181
182 pub fn alloc_zeros_f64(
183 &self,
184 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
185 n: usize,
186 ) -> std::result::Result<CudaSlice<f64>, cudarc::driver::DriverError> {
187 if let Some(mut buf) = self.pool_f64.try_pop(n) {
188 self.hits.fetch_add(1, Ordering::Relaxed);
189 dev.memset_zeros(&mut buf)?;
190 Ok(buf)
191 } else {
192 self.misses.fetch_add(1, Ordering::Relaxed);
193 dev.alloc_zeros::<f64>(n)
194 }
195 }
196
197 pub fn alloc_u16(
199 &self,
200 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
201 n: usize,
202 ) -> std::result::Result<CudaSlice<u16>, cudarc::driver::DriverError> {
203 if let Some(buf) = self.pool_u16.try_pop(n) {
204 self.hits.fetch_add(1, Ordering::Relaxed);
205 Ok(buf)
206 } else {
207 self.misses.fetch_add(1, Ordering::Relaxed);
208 unsafe { dev.alloc::<u16>(n) }
209 }
210 }
211
212 pub fn alloc_zeros_u16(
213 &self,
214 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
215 n: usize,
216 ) -> std::result::Result<CudaSlice<u16>, cudarc::driver::DriverError> {
217 if let Some(mut buf) = self.pool_u16.try_pop(n) {
218 self.hits.fetch_add(1, Ordering::Relaxed);
219 dev.memset_zeros(&mut buf)?;
220 Ok(buf)
221 } else {
222 self.misses.fetch_add(1, Ordering::Relaxed);
223 dev.alloc_zeros::<u16>(n)
224 }
225 }
226
227 pub fn alloc_u8(
229 &self,
230 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
231 n: usize,
232 ) -> std::result::Result<CudaSlice<u8>, cudarc::driver::DriverError> {
233 if let Some(buf) = self.pool_u8.try_pop(n) {
234 self.hits.fetch_add(1, Ordering::Relaxed);
235 Ok(buf)
236 } else {
237 self.misses.fetch_add(1, Ordering::Relaxed);
238 unsafe { dev.alloc::<u8>(n) }
239 }
240 }
241
242 pub fn alloc_zeros_u8(
243 &self,
244 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
245 n: usize,
246 ) -> std::result::Result<CudaSlice<u8>, cudarc::driver::DriverError> {
247 if let Some(mut buf) = self.pool_u8.try_pop(n) {
248 self.hits.fetch_add(1, Ordering::Relaxed);
249 dev.memset_zeros(&mut buf)?;
250 Ok(buf)
251 } else {
252 self.misses.fetch_add(1, Ordering::Relaxed);
253 dev.alloc_zeros::<u8>(n)
254 }
255 }
256
257 pub fn alloc_u32(
259 &self,
260 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
261 n: usize,
262 ) -> std::result::Result<CudaSlice<u32>, cudarc::driver::DriverError> {
263 if let Some(buf) = self.pool_u32.try_pop(n) {
264 self.hits.fetch_add(1, Ordering::Relaxed);
265 Ok(buf)
266 } else {
267 self.misses.fetch_add(1, Ordering::Relaxed);
268 unsafe { dev.alloc::<u32>(n) }
269 }
270 }
271
272 pub fn alloc_zeros_u32(
273 &self,
274 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
275 n: usize,
276 ) -> std::result::Result<CudaSlice<u32>, cudarc::driver::DriverError> {
277 if let Some(mut buf) = self.pool_u32.try_pop(n) {
278 self.hits.fetch_add(1, Ordering::Relaxed);
279 dev.memset_zeros(&mut buf)?;
280 Ok(buf)
281 } else {
282 self.misses.fetch_add(1, Ordering::Relaxed);
283 dev.alloc_zeros::<u32>(n)
284 }
285 }
286
287 pub fn alloc_i64(
289 &self,
290 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
291 n: usize,
292 ) -> std::result::Result<CudaSlice<i64>, cudarc::driver::DriverError> {
293 if let Some(buf) = self.pool_i64.try_pop(n) {
294 self.hits.fetch_add(1, Ordering::Relaxed);
295 Ok(buf)
296 } else {
297 self.misses.fetch_add(1, Ordering::Relaxed);
298 unsafe { dev.alloc::<i64>(n) }
299 }
300 }
301
302 pub fn alloc_zeros_i64(
303 &self,
304 dev: &std::sync::Arc<cudarc::driver::CudaDevice>,
305 n: usize,
306 ) -> std::result::Result<CudaSlice<i64>, cudarc::driver::DriverError> {
307 if let Some(mut buf) = self.pool_i64.try_pop(n) {
308 self.hits.fetch_add(1, Ordering::Relaxed);
309 dev.memset_zeros(&mut buf)?;
310 Ok(buf)
311 } else {
312 self.misses.fetch_add(1, Ordering::Relaxed);
313 dev.alloc_zeros::<i64>(n)
314 }
315 }
316
317 pub fn reclaim_f32(&self, s: CudaSlice<f32>) {
320 self.pool_f32.push(s);
321 }
322 pub fn reclaim_f64(&self, s: CudaSlice<f64>) {
323 self.pool_f64.push(s);
324 }
325 pub fn reclaim_u16(&self, s: CudaSlice<u16>) {
326 self.pool_u16.push(s);
327 }
328 pub fn reclaim_u8(&self, s: CudaSlice<u8>) {
329 self.pool_u8.push(s);
330 }
331 pub fn reclaim_u32(&self, s: CudaSlice<u32>) {
332 self.pool_u32.push(s);
333 }
334 pub fn reclaim_i64(&self, s: CudaSlice<i64>) {
335 self.pool_i64.push(s);
336 }
337
338 pub fn reclaim_storage(&self, storage: super::CudaStorage) {
340 match storage {
341 super::CudaStorage::F16(s) => self.pool_u16.push(s),
342 super::CudaStorage::BF16(s) => self.pool_u16.push(s),
343 super::CudaStorage::F32(s) => self.pool_f32.push(s),
344 super::CudaStorage::F64(s) => self.pool_f64.push(s),
345 super::CudaStorage::U8(s) => self.pool_u8.push(s),
346 super::CudaStorage::U32(s) => self.pool_u32.push(s),
347 super::CudaStorage::I64(s) => self.pool_i64.push(s),
348 }
349 }
350
351 pub fn empty_cache(&self) {
356 self.pool_u8.drain();
357 self.pool_u16.drain();
358 self.pool_u32.drain();
359 self.pool_f32.drain();
360 self.pool_f64.drain();
361 self.pool_i64.drain();
362 }
363
364 pub fn stats(&self) -> PoolStats {
366 let (c_u8, e_u8) = self.pool_u8.stats();
367 let (c_u16, e_u16) = self.pool_u16.stats();
368 let (c_u32, e_u32) = self.pool_u32.stats();
369 let (c_f32, e_f32) = self.pool_f32.stats();
370 let (c_f64, e_f64) = self.pool_f64.stats();
371 let (c_i64, e_i64) = self.pool_i64.stats();
372
373 let cached_buffers = c_u8 + c_u16 + c_u32 + c_f32 + c_f64 + c_i64;
374 let cached_bytes = e_u8 * std::mem::size_of::<u8>()
375 + e_u16 * std::mem::size_of::<u16>()
376 + e_u32 * std::mem::size_of::<u32>()
377 + e_f32 * std::mem::size_of::<f32>()
378 + e_f64 * std::mem::size_of::<f64>()
379 + e_i64 * std::mem::size_of::<i64>();
380
381 PoolStats {
382 cached_bytes,
383 cached_buffers,
384 hits: self.hits.load(Ordering::Relaxed),
385 misses: self.misses.load(Ordering::Relaxed),
386 }
387 }
388
389 pub fn reset_stats(&self) {
391 self.hits.store(0, Ordering::Relaxed);
392 self.misses.store(0, Ordering::Relaxed);
393 }
394}
395
396impl Default for CudaMemPool {
397 fn default() -> Self {
398 Self::new()
399 }
400}
401
402unsafe impl Send for CudaMemPool {}
404unsafe impl Sync for CudaMemPool {}