1use std::collections::HashMap;
6use std::fmt;
7use std::time::{Duration, Instant};
8
9use shrew_core::{Backend, Result, Tensor};
10use shrew_nn::Module;
11
12#[derive(Debug, Clone)]
18pub struct ProfileEvent {
19 pub name: String,
21 pub category: String,
23 pub duration: Duration,
25}
26
27pub struct Profiler {
45 events: Vec<ProfileEvent>,
46 pending: HashMap<String, Instant>,
47}
48
49impl Profiler {
50 pub fn new() -> Self {
52 Self {
53 events: Vec::new(),
54 pending: HashMap::new(),
55 }
56 }
57
58 pub fn start_event(&mut self, name: &str, _category: &str) -> Instant {
60 let now = Instant::now();
61 self.pending.insert(name.to_string(), now);
62 now
63 }
64
65 pub fn end_event(&mut self, start: Instant, name: &str, category: &str) {
67 let elapsed = start.elapsed();
68 self.pending.remove(name);
69 self.events.push(ProfileEvent {
70 name: name.to_string(),
71 category: category.to_string(),
72 duration: elapsed,
73 });
74 }
75
76 pub fn measure<F, R>(&mut self, name: &str, category: &str, f: F) -> R
78 where
79 F: FnOnce() -> R,
80 {
81 let start = Instant::now();
82 let result = f();
83 let elapsed = start.elapsed();
84 self.events.push(ProfileEvent {
85 name: name.to_string(),
86 category: category.to_string(),
87 duration: elapsed,
88 });
89 result
90 }
91
92 pub fn events(&self) -> &[ProfileEvent] {
94 &self.events
95 }
96
97 pub fn clear(&mut self) {
99 self.events.clear();
100 self.pending.clear();
101 }
102
103 pub fn total_time(&self) -> Duration {
105 self.events.iter().map(|e| e.duration).sum()
106 }
107
108 pub fn report(&self) -> ProfileReport {
110 let mut by_name: HashMap<String, Vec<Duration>> = HashMap::new();
112 for ev in &self.events {
113 by_name
114 .entry(ev.name.clone())
115 .or_default()
116 .push(ev.duration);
117 }
118
119 let mut entries: Vec<ProfileEntry> = by_name
120 .into_iter()
121 .map(|(name, durations)| {
122 let count = durations.len();
123 let total: Duration = durations.iter().sum();
124 let min = durations.iter().min().copied().unwrap_or_default();
125 let max = durations.iter().max().copied().unwrap_or_default();
126 let avg = total / count as u32;
127 ProfileEntry {
128 name,
129 count,
130 total,
131 min,
132 max,
133 avg,
134 }
135 })
136 .collect();
137
138 entries.sort_by(|a, b| b.total.cmp(&a.total));
140
141 let total = self.total_time();
142
143 ProfileReport { entries, total }
144 }
145}
146
147impl Default for Profiler {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[derive(Debug, Clone)]
159pub struct ProfileEntry {
160 pub name: String,
161 pub count: usize,
162 pub total: Duration,
163 pub min: Duration,
164 pub max: Duration,
165 pub avg: Duration,
166}
167
168#[derive(Debug, Clone)]
174pub struct ProfileReport {
175 pub entries: Vec<ProfileEntry>,
176 pub total: Duration,
177}
178
179impl fmt::Display for ProfileReport {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 writeln!(
182 f,
183 "╔══════════════════════════════════════════════════════════════════════════════╗"
184 )?;
185 writeln!(
186 f,
187 "║ Shrew Profile Report ║"
188 )?;
189 writeln!(
190 f,
191 "╠══════════════════════════════════════════════════════════════════════════════╣"
192 )?;
193 writeln!(
194 f,
195 "║ {:<20} {:>6} {:>12} {:>12} {:>12} {:>8} ║",
196 "Event", "Count", "Total", "Avg", "Min/Max", "%"
197 )?;
198 writeln!(
199 f,
200 "╠══════════════════════════════════════════════════════════════════════════════╣"
201 )?;
202 for entry in &self.entries {
203 let pct = if self.total.as_nanos() > 0 {
204 (entry.total.as_nanos() as f64 / self.total.as_nanos() as f64) * 100.0
205 } else {
206 0.0
207 };
208 let min_max = format!("{:.2?}/{:.2?}", entry.min, entry.max);
209 writeln!(
210 f,
211 "║ {:<20} {:>6} {:>12.2?} {:>12.2?} {:>12} {:>7.1}% ║",
212 truncate_str(&entry.name, 20),
213 entry.count,
214 entry.total,
215 entry.avg,
216 min_max,
217 pct
218 )?;
219 }
220 writeln!(
221 f,
222 "╠══════════════════════════════════════════════════════════════════════════════╣"
223 )?;
224 writeln!(
225 f,
226 "║ Total: {:>12.2?} ║",
227 self.total
228 )?;
229 writeln!(
230 f,
231 "╚══════════════════════════════════════════════════════════════════════════════╝"
232 )?;
233 Ok(())
234 }
235}
236
237fn truncate_str(s: &str, max: usize) -> String {
238 if s.len() > max {
239 format!("{}…", &s[..max - 1])
240 } else {
241 s.to_string()
242 }
243}
244
245pub struct MemoryTracker {
265 allocations: HashMap<String, usize>,
266 current_bytes: usize,
267 peak_bytes: usize,
268 total_allocated: usize,
269 alloc_count: usize,
270 dealloc_count: usize,
271}
272
273impl MemoryTracker {
274 pub fn new() -> Self {
276 Self {
277 allocations: HashMap::new(),
278 current_bytes: 0,
279 peak_bytes: 0,
280 total_allocated: 0,
281 alloc_count: 0,
282 dealloc_count: 0,
283 }
284 }
285
286 pub fn alloc(&mut self, name: &str, bytes: usize) {
288 if let Some(old) = self.allocations.remove(name) {
290 self.current_bytes = self.current_bytes.saturating_sub(old);
291 }
292 self.allocations.insert(name.to_string(), bytes);
293 self.current_bytes += bytes;
294 self.total_allocated += bytes;
295 self.alloc_count += 1;
296 if self.current_bytes > self.peak_bytes {
297 self.peak_bytes = self.current_bytes;
298 }
299 }
300
301 pub fn dealloc(&mut self, name: &str) {
303 if let Some(bytes) = self.allocations.remove(name) {
304 self.current_bytes = self.current_bytes.saturating_sub(bytes);
305 self.dealloc_count += 1;
306 }
307 }
308
309 pub fn current_bytes(&self) -> usize {
311 self.current_bytes
312 }
313
314 pub fn peak_bytes(&self) -> usize {
316 self.peak_bytes
317 }
318
319 pub fn total_allocated(&self) -> usize {
321 self.total_allocated
322 }
323
324 pub fn alloc_count(&self) -> usize {
326 self.alloc_count
327 }
328
329 pub fn dealloc_count(&self) -> usize {
331 self.dealloc_count
332 }
333
334 pub fn live_count(&self) -> usize {
336 self.allocations.len()
337 }
338
339 pub fn reset(&mut self) {
341 self.allocations.clear();
342 self.current_bytes = 0;
343 self.peak_bytes = 0;
344 self.total_allocated = 0;
345 self.alloc_count = 0;
346 self.dealloc_count = 0;
347 }
348
349 pub fn summary(&self) -> String {
351 format!(
352 "Memory: current={}, peak={}, total_allocated={}, allocs={}, deallocs={}, live={}",
353 format_bytes(self.current_bytes),
354 format_bytes(self.peak_bytes),
355 format_bytes(self.total_allocated),
356 self.alloc_count,
357 self.dealloc_count,
358 self.live_count(),
359 )
360 }
361}
362
363impl Default for MemoryTracker {
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369pub fn format_bytes(bytes: usize) -> String {
371 const KB: usize = 1024;
372 const MB: usize = 1024 * KB;
373 const GB: usize = 1024 * MB;
374 if bytes >= GB {
375 format!("{:.2} GB", bytes as f64 / GB as f64)
376 } else if bytes >= MB {
377 format!("{:.2} MB", bytes as f64 / MB as f64)
378 } else if bytes >= KB {
379 format!("{:.2} KB", bytes as f64 / KB as f64)
380 } else {
381 format!("{} B", bytes)
382 }
383}
384
385#[derive(Debug, Clone)]
391pub struct ModelSummary {
392 pub total_params: usize,
394 pub num_tensors: usize,
396 pub param_details: Vec<(String, Vec<usize>, usize)>,
398 pub estimated_bytes: usize,
400}
401
402impl ModelSummary {
403 pub fn from_module<B: Backend>(module: &dyn Module<B>) -> Self {
405 let named = module.named_parameters();
406 let mut total_params = 0usize;
407 let mut param_details = Vec::new();
408
409 for (name, tensor) in &named {
410 let numel = tensor.shape().elem_count();
411 total_params += numel;
412 param_details.push((name.clone(), tensor.shape().dims().to_vec(), numel));
413 }
414
415 let estimated_bytes = total_params * 4; ModelSummary {
418 total_params,
419 num_tensors: named.len(),
420 param_details,
421 estimated_bytes,
422 }
423 }
424}
425
426impl fmt::Display for ModelSummary {
427 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
428 writeln!(
429 f,
430 "┌──────────────────────────────────────────────────────┐"
431 )?;
432 writeln!(
433 f,
434 "│ Model Summary │"
435 )?;
436 writeln!(
437 f,
438 "├──────────────────────────────────────────────────────┤"
439 )?;
440 for (name, shape, numel) in &self.param_details {
441 writeln!(
442 f,
443 "│ {:<30} {:>10?} {:>8} │",
444 truncate_str(name, 30),
445 shape,
446 numel
447 )?;
448 }
449 writeln!(
450 f,
451 "├──────────────────────────────────────────────────────┤"
452 )?;
453 writeln!(
454 f,
455 "│ Total params: {:<12} Tensors: {:<6} Mem: {:<8} │",
456 format_params(self.total_params),
457 self.num_tensors,
458 format_bytes(self.estimated_bytes),
459 )?;
460 writeln!(
461 f,
462 "└──────────────────────────────────────────────────────┘"
463 )?;
464 Ok(())
465 }
466}
467
468fn format_params(n: usize) -> String {
469 if n >= 1_000_000_000 {
470 format!("{:.2}B", n as f64 / 1e9)
471 } else if n >= 1_000_000 {
472 format!("{:.2}M", n as f64 / 1e6)
473 } else if n >= 1_000 {
474 format!("{:.2}K", n as f64 / 1e3)
475 } else {
476 format!("{}", n)
477 }
478}
479
480#[derive(Debug, Clone)]
486pub struct BenchmarkResult {
487 pub iterations: usize,
489 pub batch_size: usize,
491 pub total_time: Duration,
493 pub avg_time: Duration,
495 pub min_time: Duration,
497 pub max_time: Duration,
499 pub throughput: f64,
501}
502
503impl fmt::Display for BenchmarkResult {
504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505 writeln!(
506 f,
507 "Benchmark: {} iters, batch_size={}",
508 self.iterations, self.batch_size
509 )?;
510 writeln!(f, " Total: {:.2?}", self.total_time)?;
511 writeln!(f, " Avg/iter: {:.2?}", self.avg_time)?;
512 writeln!(
513 f,
514 " Min/Max: {:.2?} / {:.2?}",
515 self.min_time, self.max_time
516 )?;
517 writeln!(f, " Throughput: {:.1} samples/sec", self.throughput)?;
518 Ok(())
519 }
520}
521
522pub fn benchmark_forward<B, M, F>(
543 model: &M,
544 input_fn: F,
545 batch_size: usize,
546 warmup: usize,
547 iterations: usize,
548) -> Result<BenchmarkResult>
549where
550 B: Backend,
551 M: Module<B>,
552 F: Fn() -> Tensor<B>,
553{
554 for _ in 0..warmup {
556 let input = input_fn();
557 let _ = model.forward(&input)?;
558 }
559
560 let mut times = Vec::with_capacity(iterations);
561 let total_start = Instant::now();
562
563 for _ in 0..iterations {
564 let input = input_fn();
565 let start = Instant::now();
566 let _ = model.forward(&input)?;
567 times.push(start.elapsed());
568 }
569
570 let total_time = total_start.elapsed();
571 let min_time = times.iter().min().copied().unwrap_or_default();
572 let max_time = times.iter().max().copied().unwrap_or_default();
573 let avg_time = if iterations > 0 {
574 total_time / iterations as u32
575 } else {
576 Duration::ZERO
577 };
578 let throughput = if total_time.as_secs_f64() > 0.0 {
579 (iterations * batch_size) as f64 / total_time.as_secs_f64()
580 } else {
581 0.0
582 };
583
584 Ok(BenchmarkResult {
585 iterations,
586 batch_size,
587 total_time,
588 avg_time,
589 min_time,
590 max_time,
591 throughput,
592 })
593}
594
595pub fn benchmark_forward_backward<B, M, F, L>(
599 model: &M,
600 input_fn: F,
601 loss_fn: L,
602 batch_size: usize,
603 warmup: usize,
604 iterations: usize,
605) -> Result<BenchmarkResult>
606where
607 B: Backend,
608 M: Module<B>,
609 F: Fn() -> Tensor<B>,
610 L: Fn(&Tensor<B>) -> Result<Tensor<B>>,
611{
612 for _ in 0..warmup {
614 let input = input_fn();
615 let out = model.forward(&input)?;
616 let loss = loss_fn(&out)?;
617 let _ = loss.backward()?;
618 }
619
620 let mut times = Vec::with_capacity(iterations);
621 let total_start = Instant::now();
622
623 for _ in 0..iterations {
624 let input = input_fn();
625 let start = Instant::now();
626 let out = model.forward(&input)?;
627 let loss = loss_fn(&out)?;
628 let _ = loss.backward()?;
629 times.push(start.elapsed());
630 }
631
632 let total_time = total_start.elapsed();
633 let min_time = times.iter().min().copied().unwrap_or_default();
634 let max_time = times.iter().max().copied().unwrap_or_default();
635 let avg_time = if iterations > 0 {
636 total_time / iterations as u32
637 } else {
638 Duration::ZERO
639 };
640 let throughput = if total_time.as_secs_f64() > 0.0 {
641 (iterations * batch_size) as f64 / total_time.as_secs_f64()
642 } else {
643 0.0
644 };
645
646 Ok(BenchmarkResult {
647 iterations,
648 batch_size,
649 total_time,
650 avg_time,
651 min_time,
652 max_time,
653 throughput,
654 })
655}
656
657pub struct ScopedTimer {
677 profiler: std::sync::Arc<std::sync::Mutex<Profiler>>,
678 name: String,
679 category: String,
680 start: Instant,
681}
682
683impl ScopedTimer {
684 pub fn new(
685 profiler: std::sync::Arc<std::sync::Mutex<Profiler>>,
686 name: &str,
687 category: &str,
688 ) -> Self {
689 Self {
690 profiler,
691 name: name.to_string(),
692 category: category.to_string(),
693 start: Instant::now(),
694 }
695 }
696}
697
698impl Drop for ScopedTimer {
699 fn drop(&mut self) {
700 let elapsed = self.start.elapsed();
701 if let Ok(mut prof) = self.profiler.lock() {
702 prof.events.push(ProfileEvent {
703 name: self.name.clone(),
704 category: self.category.clone(),
705 duration: elapsed,
706 });
707 }
708 }
709}
710
711pub fn estimate_model_memory<B: Backend>(module: &dyn Module<B>) -> usize {
719 module
720 .named_parameters()
721 .iter()
722 .map(|(_, t)| {
723 let numel = t.shape().elem_count();
724 let bytes_per_elem = match t.dtype() {
725 shrew_core::DType::F16 | shrew_core::DType::BF16 => 2,
726 shrew_core::DType::F32 | shrew_core::DType::U32 => 4,
727 shrew_core::DType::F64 | shrew_core::DType::I64 => 8,
728 shrew_core::DType::U8 => 1,
729 };
730 numel * bytes_per_elem
731 })
732 .sum()
733}
734
735pub struct Stopwatch {
755 start: Option<Instant>,
756 laps: Vec<Duration>,
757}
758
759impl Stopwatch {
760 pub fn new() -> Self {
761 Self {
762 start: None,
763 laps: Vec::new(),
764 }
765 }
766
767 pub fn start(&mut self) {
769 self.start = Some(Instant::now());
770 self.laps.clear();
771 }
772
773 pub fn lap(&mut self) -> Duration {
775 let elapsed = self.start.map(|s| s.elapsed()).unwrap_or_default();
776 self.laps.push(elapsed);
777 elapsed
778 }
779
780 pub fn stop(&mut self) -> Duration {
782 let elapsed = self.start.map(|s| s.elapsed()).unwrap_or_default();
783 self.start = None;
784 elapsed
785 }
786
787 pub fn laps(&self) -> &[Duration] {
789 &self.laps
790 }
791}
792
793impl Default for Stopwatch {
794 fn default() -> Self {
795 Self::new()
796 }
797}
798
799#[cfg(test)]
804mod tests {
805 use super::*;
806 use shrew_cpu::{CpuBackend, CpuDevice};
807 use shrew_nn::{Linear, Module};
808 use std::sync::{Arc, Mutex};
809 use std::thread;
810
811 #[test]
812 fn test_profiler_measure() {
813 let mut prof = Profiler::new();
814 let result = prof.measure("op_a", "compute", || {
815 thread::sleep(Duration::from_millis(5));
816 42
817 });
818 assert_eq!(result, 42);
819 assert_eq!(prof.events().len(), 1);
820 assert_eq!(prof.events()[0].name, "op_a");
821 assert!(prof.events()[0].duration >= Duration::from_millis(4));
822 }
823
824 #[test]
825 fn test_profiler_start_end() {
826 let mut prof = Profiler::new();
827 let start = prof.start_event("forward", "model");
828 thread::sleep(Duration::from_millis(5));
829 prof.end_event(start, "forward", "model");
830
831 assert_eq!(prof.events().len(), 1);
832 assert!(prof.total_time() >= Duration::from_millis(4));
833 }
834
835 #[test]
836 fn test_profiler_report() {
837 let mut prof = Profiler::new();
838 for _ in 0..3 {
839 prof.measure("matmul", "compute", || {
840 thread::sleep(Duration::from_millis(2));
841 });
842 }
843 prof.measure("relu", "compute", || {
844 thread::sleep(Duration::from_millis(1));
845 });
846
847 let report = prof.report();
848 assert_eq!(report.entries.len(), 2);
849 assert_eq!(report.entries[0].name, "matmul");
851 assert_eq!(report.entries[0].count, 3);
852 assert_eq!(report.entries[1].name, "relu");
853 assert_eq!(report.entries[1].count, 1);
854
855 let s = format!("{}", report);
857 assert!(s.contains("matmul"));
858 assert!(s.contains("relu"));
859 }
860
861 #[test]
862 fn test_profiler_clear() {
863 let mut prof = Profiler::new();
864 prof.measure("x", "y", || {});
865 assert_eq!(prof.events().len(), 1);
866 prof.clear();
867 assert_eq!(prof.events().len(), 0);
868 }
869
870 #[test]
871 fn test_memory_tracker() {
872 let mut mem = MemoryTracker::new();
873 mem.alloc("weight", 4096);
874 mem.alloc("bias", 128);
875 assert_eq!(mem.current_bytes(), 4224);
876 assert_eq!(mem.peak_bytes(), 4224);
877 assert_eq!(mem.alloc_count(), 2);
878 assert_eq!(mem.live_count(), 2);
879
880 mem.dealloc("bias");
881 assert_eq!(mem.current_bytes(), 4096);
882 assert_eq!(mem.peak_bytes(), 4224); assert_eq!(mem.dealloc_count(), 1);
884 assert_eq!(mem.live_count(), 1);
885
886 mem.alloc("weight", 8192);
888 assert_eq!(mem.current_bytes(), 8192);
889 assert!(mem.peak_bytes() >= 8192);
890 }
891
892 #[test]
893 fn test_memory_tracker_summary() {
894 let mut mem = MemoryTracker::new();
895 mem.alloc("big", 1024 * 1024);
896 let s = mem.summary();
897 assert!(s.contains("1.00 MB"));
898 }
899
900 #[test]
901 fn test_format_bytes() {
902 assert_eq!(format_bytes(0), "0 B");
903 assert_eq!(format_bytes(512), "512 B");
904 assert_eq!(format_bytes(1024), "1.00 KB");
905 assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
906 assert_eq!(format_bytes(2 * 1024 * 1024 * 1024), "2.00 GB");
907 }
908
909 #[test]
910 fn test_model_summary() {
911 let model = Linear::new(16, 8, true, shrew_core::DType::F32, &CpuDevice).unwrap();
912 let summary = ModelSummary::from_module::<CpuBackend>(&model);
913 assert_eq!(summary.total_params, 128 + 8);
915 assert_eq!(summary.num_tensors, 2);
916 let s = format!("{}", summary);
917 assert!(s.contains("Model Summary"));
918 }
919
920 #[test]
921 fn test_benchmark_forward() {
922 let model =
923 Linear::<CpuBackend>::new(8, 4, true, shrew_core::DType::F32, &CpuDevice).unwrap();
924 let result = benchmark_forward(
925 &model,
926 || Tensor::rand((2, 8), shrew_core::DType::F32, &CpuDevice).unwrap(),
927 2,
928 1,
929 5,
930 )
931 .unwrap();
932 assert_eq!(result.iterations, 5);
933 assert_eq!(result.batch_size, 2);
934 assert!(result.throughput > 0.0);
935 assert!(result.min_time <= result.max_time);
936 let s = format!("{}", result);
937 assert!(s.contains("Benchmark"));
938 }
939
940 #[test]
941 fn test_benchmark_forward_backward() {
942 let model =
943 Linear::<CpuBackend>::new(8, 4, true, shrew_core::DType::F32, &CpuDevice).unwrap();
944 let result = benchmark_forward_backward(
945 &model,
946 || Tensor::rand((2, 8), shrew_core::DType::F32, &CpuDevice).unwrap(),
947 |out| out.mean_all(),
948 2,
949 1,
950 3,
951 )
952 .unwrap();
953 assert_eq!(result.iterations, 3);
954 assert!(result.throughput > 0.0);
955 }
956
957 #[test]
958 fn test_scoped_timer() {
959 let profiler = Arc::new(Mutex::new(Profiler::new()));
960 {
961 let _t = ScopedTimer::new(profiler.clone(), "scoped_op", "test");
962 thread::sleep(Duration::from_millis(3));
963 }
964 let prof = profiler.lock().unwrap();
965 assert_eq!(prof.events().len(), 1);
966 assert_eq!(prof.events()[0].name, "scoped_op");
967 assert!(prof.events()[0].duration >= Duration::from_millis(2));
968 }
969
970 #[test]
971 fn test_estimate_model_memory() {
972 let model = Linear::new(16, 8, true, shrew_core::DType::F32, &CpuDevice).unwrap();
973 let bytes = estimate_model_memory::<CpuBackend>(&model);
974 assert_eq!(bytes, 544);
976 }
977
978 #[test]
979 fn test_stopwatch() {
980 let mut sw = Stopwatch::new();
981 sw.start();
982 thread::sleep(Duration::from_millis(5));
983 let lap = sw.lap();
984 assert!(lap >= Duration::from_millis(4));
985 let total = sw.stop();
986 assert!(total >= lap);
987 assert_eq!(sw.laps().len(), 1);
988 }
989
990 #[test]
991 fn test_format_params() {
992 assert_eq!(format_params(500), "500");
993 assert_eq!(format_params(1_500), "1.50K");
994 assert_eq!(format_params(1_500_000), "1.50M");
995 assert_eq!(format_params(2_500_000_000), "2.50B");
996 }
997}