shrew/
profiler.rs

1// =============================================================================
2// Profiling & Benchmarking — Op-level timing, memory tracking, model benchmarks
3// =============================================================================
4
5use std::collections::HashMap;
6use std::fmt;
7use std::time::{Duration, Instant};
8
9use shrew_core::{Backend, Result, Tensor};
10use shrew_nn::Module;
11
12// ---------------------------------------------------------------------------
13// ProfileEvent — a single recorded timing event
14// ---------------------------------------------------------------------------
15
16/// A single profiling event with a name, duration, and optional metadata.
17#[derive(Debug, Clone)]
18pub struct ProfileEvent {
19    /// Name / label for this event.
20    pub name: String,
21    /// Category (e.g. "forward", "backward", "data", "optimizer").
22    pub category: String,
23    /// Wall-clock duration.
24    pub duration: Duration,
25}
26
27// ---------------------------------------------------------------------------
28// Profiler — collects named timing events
29// ---------------------------------------------------------------------------
30
31/// A lightweight profiler that collects named timing events.
32///
33/// # Example
34/// ```
35/// use shrew::profiler::Profiler;
36///
37/// let mut prof = Profiler::new();
38/// let t = prof.start_event("forward", "compute");
39/// // ... do work ...
40/// prof.end_event(t, "forward", "compute");
41/// let report = prof.report();
42/// println!("{}", report);
43/// ```
44pub struct Profiler {
45    events: Vec<ProfileEvent>,
46    pending: HashMap<String, Instant>,
47}
48
49impl Profiler {
50    /// Create a new empty profiler.
51    pub fn new() -> Self {
52        Self {
53            events: Vec::new(),
54            pending: HashMap::new(),
55        }
56    }
57
58    /// Mark the start of a named event. Returns the [`Instant`].
59    pub fn start_event(&mut self, name: &str, _category: &str) -> Instant {
60        let now = Instant::now();
61        self.pending.insert(name.to_string(), now);
62        now
63    }
64
65    /// End an event started with [`start_event`]. Records elapsed time.
66    pub fn end_event(&mut self, start: Instant, name: &str, category: &str) {
67        let elapsed = start.elapsed();
68        self.pending.remove(name);
69        self.events.push(ProfileEvent {
70            name: name.to_string(),
71            category: category.to_string(),
72            duration: elapsed,
73        });
74    }
75
76    /// Measure a closure and record it as a named event.
77    pub fn measure<F, R>(&mut self, name: &str, category: &str, f: F) -> R
78    where
79        F: FnOnce() -> R,
80    {
81        let start = Instant::now();
82        let result = f();
83        let elapsed = start.elapsed();
84        self.events.push(ProfileEvent {
85            name: name.to_string(),
86            category: category.to_string(),
87            duration: elapsed,
88        });
89        result
90    }
91
92    /// Return all recorded events.
93    pub fn events(&self) -> &[ProfileEvent] {
94        &self.events
95    }
96
97    /// Reset / clear all recorded events.
98    pub fn clear(&mut self) {
99        self.events.clear();
100        self.pending.clear();
101    }
102
103    /// Total wall-clock time across all events.
104    pub fn total_time(&self) -> Duration {
105        self.events.iter().map(|e| e.duration).sum()
106    }
107
108    /// Generate a human-readable [`ProfileReport`].
109    pub fn report(&self) -> ProfileReport {
110        // Aggregate by name
111        let mut by_name: HashMap<String, Vec<Duration>> = HashMap::new();
112        for ev in &self.events {
113            by_name
114                .entry(ev.name.clone())
115                .or_default()
116                .push(ev.duration);
117        }
118
119        let mut entries: Vec<ProfileEntry> = by_name
120            .into_iter()
121            .map(|(name, durations)| {
122                let count = durations.len();
123                let total: Duration = durations.iter().sum();
124                let min = durations.iter().min().copied().unwrap_or_default();
125                let max = durations.iter().max().copied().unwrap_or_default();
126                let avg = total / count as u32;
127                ProfileEntry {
128                    name,
129                    count,
130                    total,
131                    min,
132                    max,
133                    avg,
134                }
135            })
136            .collect();
137
138        // Sort by total time descending
139        entries.sort_by(|a, b| b.total.cmp(&a.total));
140
141        let total = self.total_time();
142
143        ProfileReport { entries, total }
144    }
145}
146
147impl Default for Profiler {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153// ---------------------------------------------------------------------------
154// ProfileEntry — aggregated stats for one event name
155// ---------------------------------------------------------------------------
156
157/// Aggregated statistics for a single event name.
158#[derive(Debug, Clone)]
159pub struct ProfileEntry {
160    pub name: String,
161    pub count: usize,
162    pub total: Duration,
163    pub min: Duration,
164    pub max: Duration,
165    pub avg: Duration,
166}
167
168// ---------------------------------------------------------------------------
169// ProfileReport — pretty-printable summary
170// ---------------------------------------------------------------------------
171
172/// A formatted profiling report, printed with `Display`.
173#[derive(Debug, Clone)]
174pub struct ProfileReport {
175    pub entries: Vec<ProfileEntry>,
176    pub total: Duration,
177}
178
179impl fmt::Display for ProfileReport {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        writeln!(
182            f,
183            "╔══════════════════════════════════════════════════════════════════════════════╗"
184        )?;
185        writeln!(
186            f,
187            "║                          Shrew Profile Report                               ║"
188        )?;
189        writeln!(
190            f,
191            "╠══════════════════════════════════════════════════════════════════════════════╣"
192        )?;
193        writeln!(
194            f,
195            "║ {:<20} {:>6} {:>12} {:>12} {:>12} {:>8} ║",
196            "Event", "Count", "Total", "Avg", "Min/Max", "%"
197        )?;
198        writeln!(
199            f,
200            "╠══════════════════════════════════════════════════════════════════════════════╣"
201        )?;
202        for entry in &self.entries {
203            let pct = if self.total.as_nanos() > 0 {
204                (entry.total.as_nanos() as f64 / self.total.as_nanos() as f64) * 100.0
205            } else {
206                0.0
207            };
208            let min_max = format!("{:.2?}/{:.2?}", entry.min, entry.max);
209            writeln!(
210                f,
211                "║ {:<20} {:>6} {:>12.2?} {:>12.2?} {:>12} {:>7.1}% ║",
212                truncate_str(&entry.name, 20),
213                entry.count,
214                entry.total,
215                entry.avg,
216                min_max,
217                pct
218            )?;
219        }
220        writeln!(
221            f,
222            "╠══════════════════════════════════════════════════════════════════════════════╣"
223        )?;
224        writeln!(
225            f,
226            "║ Total: {:>12.2?}                                                        ║",
227            self.total
228        )?;
229        writeln!(
230            f,
231            "╚══════════════════════════════════════════════════════════════════════════════╝"
232        )?;
233        Ok(())
234    }
235}
236
237fn truncate_str(s: &str, max: usize) -> String {
238    if s.len() > max {
239        format!("{}…", &s[..max - 1])
240    } else {
241        s.to_string()
242    }
243}
244
245// ---------------------------------------------------------------------------
246// MemoryTracker — track tensor allocations and peak memory
247// ---------------------------------------------------------------------------
248
249/// Tracks tensor memory allocations for profiling.
250///
251/// Call [`alloc`] when a tensor is created and [`dealloc`] when freed.
252///
253/// # Example
254/// ```
255/// use shrew::profiler::MemoryTracker;
256///
257/// let mut mem = MemoryTracker::new();
258/// mem.alloc("weight", 1024 * 4);
259/// mem.alloc("bias", 128 * 4);
260/// assert_eq!(mem.current_bytes(), 1024 * 4 + 128 * 4);
261/// mem.dealloc("bias");
262/// assert_eq!(mem.current_bytes(), 1024 * 4);
263/// ```
264pub struct MemoryTracker {
265    allocations: HashMap<String, usize>,
266    current_bytes: usize,
267    peak_bytes: usize,
268    total_allocated: usize,
269    alloc_count: usize,
270    dealloc_count: usize,
271}
272
273impl MemoryTracker {
274    /// Create a new tracker with zero allocations.
275    pub fn new() -> Self {
276        Self {
277            allocations: HashMap::new(),
278            current_bytes: 0,
279            peak_bytes: 0,
280            total_allocated: 0,
281            alloc_count: 0,
282            dealloc_count: 0,
283        }
284    }
285
286    /// Record a tensor allocation.
287    pub fn alloc(&mut self, name: &str, bytes: usize) {
288        // If re-allocating same name, free old first
289        if let Some(old) = self.allocations.remove(name) {
290            self.current_bytes = self.current_bytes.saturating_sub(old);
291        }
292        self.allocations.insert(name.to_string(), bytes);
293        self.current_bytes += bytes;
294        self.total_allocated += bytes;
295        self.alloc_count += 1;
296        if self.current_bytes > self.peak_bytes {
297            self.peak_bytes = self.current_bytes;
298        }
299    }
300
301    /// Record a tensor deallocation.
302    pub fn dealloc(&mut self, name: &str) {
303        if let Some(bytes) = self.allocations.remove(name) {
304            self.current_bytes = self.current_bytes.saturating_sub(bytes);
305            self.dealloc_count += 1;
306        }
307    }
308
309    /// Current live memory in bytes.
310    pub fn current_bytes(&self) -> usize {
311        self.current_bytes
312    }
313
314    /// Peak memory usage in bytes.
315    pub fn peak_bytes(&self) -> usize {
316        self.peak_bytes
317    }
318
319    /// Total bytes allocated over the tracker's lifetime.
320    pub fn total_allocated(&self) -> usize {
321        self.total_allocated
322    }
323
324    /// Number of allocations recorded.
325    pub fn alloc_count(&self) -> usize {
326        self.alloc_count
327    }
328
329    /// Number of deallocations recorded.
330    pub fn dealloc_count(&self) -> usize {
331        self.dealloc_count
332    }
333
334    /// Number of currently live allocations.
335    pub fn live_count(&self) -> usize {
336        self.allocations.len()
337    }
338
339    /// Reset the tracker.
340    pub fn reset(&mut self) {
341        self.allocations.clear();
342        self.current_bytes = 0;
343        self.peak_bytes = 0;
344        self.total_allocated = 0;
345        self.alloc_count = 0;
346        self.dealloc_count = 0;
347    }
348
349    /// Format a human-readable memory summary.
350    pub fn summary(&self) -> String {
351        format!(
352            "Memory: current={}, peak={}, total_allocated={}, allocs={}, deallocs={}, live={}",
353            format_bytes(self.current_bytes),
354            format_bytes(self.peak_bytes),
355            format_bytes(self.total_allocated),
356            self.alloc_count,
357            self.dealloc_count,
358            self.live_count(),
359        )
360    }
361}
362
363impl Default for MemoryTracker {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369/// Format bytes into a human-readable string (B, KB, MB, GB).
370pub fn format_bytes(bytes: usize) -> String {
371    const KB: usize = 1024;
372    const MB: usize = 1024 * KB;
373    const GB: usize = 1024 * MB;
374    if bytes >= GB {
375        format!("{:.2} GB", bytes as f64 / GB as f64)
376    } else if bytes >= MB {
377        format!("{:.2} MB", bytes as f64 / MB as f64)
378    } else if bytes >= KB {
379        format!("{:.2} KB", bytes as f64 / KB as f64)
380    } else {
381        format!("{} B", bytes)
382    }
383}
384
385// ---------------------------------------------------------------------------
386// ModelSummary — compute model stats (params, layers, etc.)
387// ---------------------------------------------------------------------------
388
389/// Summary statistics for a model.
390#[derive(Debug, Clone)]
391pub struct ModelSummary {
392    /// Total number of trainable parameters.
393    pub total_params: usize,
394    /// Number of named parameter tensors.
395    pub num_tensors: usize,
396    /// Parameter count per named parameter.
397    pub param_details: Vec<(String, Vec<usize>, usize)>,
398    /// Estimated memory in bytes (assuming F32).
399    pub estimated_bytes: usize,
400}
401
402impl ModelSummary {
403    /// Compute a summary for the given module.
404    pub fn from_module<B: Backend>(module: &dyn Module<B>) -> Self {
405        let named = module.named_parameters();
406        let mut total_params = 0usize;
407        let mut param_details = Vec::new();
408
409        for (name, tensor) in &named {
410            let numel = tensor.shape().elem_count();
411            total_params += numel;
412            param_details.push((name.clone(), tensor.shape().dims().to_vec(), numel));
413        }
414
415        let estimated_bytes = total_params * 4; // assume F32
416
417        ModelSummary {
418            total_params,
419            num_tensors: named.len(),
420            param_details,
421            estimated_bytes,
422        }
423    }
424}
425
426impl fmt::Display for ModelSummary {
427    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
428        writeln!(
429            f,
430            "┌──────────────────────────────────────────────────────┐"
431        )?;
432        writeln!(
433            f,
434            "│                   Model Summary                      │"
435        )?;
436        writeln!(
437            f,
438            "├──────────────────────────────────────────────────────┤"
439        )?;
440        for (name, shape, numel) in &self.param_details {
441            writeln!(
442                f,
443                "│ {:<30} {:>10?} {:>8} │",
444                truncate_str(name, 30),
445                shape,
446                numel
447            )?;
448        }
449        writeln!(
450            f,
451            "├──────────────────────────────────────────────────────┤"
452        )?;
453        writeln!(
454            f,
455            "│ Total params: {:<12} Tensors: {:<6} Mem: {:<8} │",
456            format_params(self.total_params),
457            self.num_tensors,
458            format_bytes(self.estimated_bytes),
459        )?;
460        writeln!(
461            f,
462            "└──────────────────────────────────────────────────────┘"
463        )?;
464        Ok(())
465    }
466}
467
468fn format_params(n: usize) -> String {
469    if n >= 1_000_000_000 {
470        format!("{:.2}B", n as f64 / 1e9)
471    } else if n >= 1_000_000 {
472        format!("{:.2}M", n as f64 / 1e6)
473    } else if n >= 1_000 {
474        format!("{:.2}K", n as f64 / 1e3)
475    } else {
476        format!("{}", n)
477    }
478}
479
480// ---------------------------------------------------------------------------
481// Benchmark — time a model's forward pass over multiple iterations
482// ---------------------------------------------------------------------------
483
484/// Result of benchmarking a model's forward pass.
485#[derive(Debug, Clone)]
486pub struct BenchmarkResult {
487    /// Number of iterations run.
488    pub iterations: usize,
489    /// Batch size used.
490    pub batch_size: usize,
491    /// Total wall time for all iterations.
492    pub total_time: Duration,
493    /// Average time per iteration.
494    pub avg_time: Duration,
495    /// Minimum time across iterations.
496    pub min_time: Duration,
497    /// Maximum time across iterations.
498    pub max_time: Duration,
499    /// Throughput in samples per second.
500    pub throughput: f64,
501}
502
503impl fmt::Display for BenchmarkResult {
504    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505        writeln!(
506            f,
507            "Benchmark: {} iters, batch_size={}",
508            self.iterations, self.batch_size
509        )?;
510        writeln!(f, "  Total:      {:.2?}", self.total_time)?;
511        writeln!(f, "  Avg/iter:   {:.2?}", self.avg_time)?;
512        writeln!(
513            f,
514            "  Min/Max:    {:.2?} / {:.2?}",
515            self.min_time, self.max_time
516        )?;
517        writeln!(f, "  Throughput: {:.1} samples/sec", self.throughput)?;
518        Ok(())
519    }
520}
521
522/// Benchmark a model's forward pass.
523///
524/// Runs `warmup` untimed iterations, then `iterations` timed iterations,
525/// calling `input_fn` on each iteration to produce the input tensor.
526///
527/// # Example
528/// ```no_run
529/// use shrew::prelude::*;
530/// use shrew::profiler::benchmark_forward;
531///
532/// let model = Linear::<CpuBackend>::new(16, 8, true, DType::F32, &CpuDevice).unwrap();
533/// let result = benchmark_forward(
534///     &model,
535///     || Tensor::<CpuBackend>::rand((4, 16), DType::F32, &CpuDevice).unwrap(),
536///     4, // batch_size (for throughput calc)
537///     3, // warmup
538///     10, // iterations
539/// ).unwrap();
540/// println!("{}", result);
541/// ```
542pub fn benchmark_forward<B, M, F>(
543    model: &M,
544    input_fn: F,
545    batch_size: usize,
546    warmup: usize,
547    iterations: usize,
548) -> Result<BenchmarkResult>
549where
550    B: Backend,
551    M: Module<B>,
552    F: Fn() -> Tensor<B>,
553{
554    // Warmup
555    for _ in 0..warmup {
556        let input = input_fn();
557        let _ = model.forward(&input)?;
558    }
559
560    let mut times = Vec::with_capacity(iterations);
561    let total_start = Instant::now();
562
563    for _ in 0..iterations {
564        let input = input_fn();
565        let start = Instant::now();
566        let _ = model.forward(&input)?;
567        times.push(start.elapsed());
568    }
569
570    let total_time = total_start.elapsed();
571    let min_time = times.iter().min().copied().unwrap_or_default();
572    let max_time = times.iter().max().copied().unwrap_or_default();
573    let avg_time = if iterations > 0 {
574        total_time / iterations as u32
575    } else {
576        Duration::ZERO
577    };
578    let throughput = if total_time.as_secs_f64() > 0.0 {
579        (iterations * batch_size) as f64 / total_time.as_secs_f64()
580    } else {
581        0.0
582    };
583
584    Ok(BenchmarkResult {
585        iterations,
586        batch_size,
587        total_time,
588        avg_time,
589        min_time,
590        max_time,
591        throughput,
592    })
593}
594
595/// Benchmark forward + backward pass.
596///
597/// Same as [`benchmark_forward`] but also runs backpropagation on each iteration.
598pub fn benchmark_forward_backward<B, M, F, L>(
599    model: &M,
600    input_fn: F,
601    loss_fn: L,
602    batch_size: usize,
603    warmup: usize,
604    iterations: usize,
605) -> Result<BenchmarkResult>
606where
607    B: Backend,
608    M: Module<B>,
609    F: Fn() -> Tensor<B>,
610    L: Fn(&Tensor<B>) -> Result<Tensor<B>>,
611{
612    // Warmup
613    for _ in 0..warmup {
614        let input = input_fn();
615        let out = model.forward(&input)?;
616        let loss = loss_fn(&out)?;
617        let _ = loss.backward()?;
618    }
619
620    let mut times = Vec::with_capacity(iterations);
621    let total_start = Instant::now();
622
623    for _ in 0..iterations {
624        let input = input_fn();
625        let start = Instant::now();
626        let out = model.forward(&input)?;
627        let loss = loss_fn(&out)?;
628        let _ = loss.backward()?;
629        times.push(start.elapsed());
630    }
631
632    let total_time = total_start.elapsed();
633    let min_time = times.iter().min().copied().unwrap_or_default();
634    let max_time = times.iter().max().copied().unwrap_or_default();
635    let avg_time = if iterations > 0 {
636        total_time / iterations as u32
637    } else {
638        Duration::ZERO
639    };
640    let throughput = if total_time.as_secs_f64() > 0.0 {
641        (iterations * batch_size) as f64 / total_time.as_secs_f64()
642    } else {
643        0.0
644    };
645
646    Ok(BenchmarkResult {
647        iterations,
648        batch_size,
649        total_time,
650        avg_time,
651        min_time,
652        max_time,
653        throughput,
654    })
655}
656
657// ---------------------------------------------------------------------------
658// ScopedTimer — RAII timer that records to a profiler on drop
659// ---------------------------------------------------------------------------
660
661/// An RAII timer guard. Drops into a profiler automatically.
662///
663/// # Example
664/// ```
665/// use shrew::profiler::{Profiler, ScopedTimer};
666/// use std::sync::{Arc, Mutex};
667///
668/// let profiler = Arc::new(Mutex::new(Profiler::new()));
669/// {
670///     let _t = ScopedTimer::new(profiler.clone(), "my_op", "compute");
671///     // ... do work ...
672/// } // timer records elapsed time on drop
673/// let prof = profiler.lock().unwrap();
674/// assert_eq!(prof.events().len(), 1);
675/// ```
676pub struct ScopedTimer {
677    profiler: std::sync::Arc<std::sync::Mutex<Profiler>>,
678    name: String,
679    category: String,
680    start: Instant,
681}
682
683impl ScopedTimer {
684    pub fn new(
685        profiler: std::sync::Arc<std::sync::Mutex<Profiler>>,
686        name: &str,
687        category: &str,
688    ) -> Self {
689        Self {
690            profiler,
691            name: name.to_string(),
692            category: category.to_string(),
693            start: Instant::now(),
694        }
695    }
696}
697
698impl Drop for ScopedTimer {
699    fn drop(&mut self) {
700        let elapsed = self.start.elapsed();
701        if let Ok(mut prof) = self.profiler.lock() {
702            prof.events.push(ProfileEvent {
703                name: self.name.clone(),
704                category: self.category.clone(),
705                duration: elapsed,
706            });
707        }
708    }
709}
710
711// ---------------------------------------------------------------------------
712// estimate_model_memory — compute memory usage of a module's parameters
713// ---------------------------------------------------------------------------
714
715/// Estimate the memory usage of a module's parameters in bytes.
716///
717/// Accounts for the actual dtype of each parameter tensor.
718pub fn estimate_model_memory<B: Backend>(module: &dyn Module<B>) -> usize {
719    module
720        .named_parameters()
721        .iter()
722        .map(|(_, t)| {
723            let numel = t.shape().elem_count();
724            let bytes_per_elem = match t.dtype() {
725                shrew_core::DType::F16 | shrew_core::DType::BF16 => 2,
726                shrew_core::DType::F32 | shrew_core::DType::U32 => 4,
727                shrew_core::DType::F64 | shrew_core::DType::I64 => 8,
728                shrew_core::DType::U8 => 1,
729            };
730            numel * bytes_per_elem
731        })
732        .sum()
733}
734
735// ---------------------------------------------------------------------------
736// Stopwatch — simple reusable timer
737// ---------------------------------------------------------------------------
738
739/// A simple stopwatch for manual timing.
740///
741/// # Example
742/// ```
743/// use shrew::profiler::Stopwatch;
744///
745/// let mut sw = Stopwatch::new();
746/// sw.start();
747/// // ... do work ...
748/// let elapsed = sw.stop();
749/// assert!(elapsed.as_nanos() > 0);
750/// sw.start();
751/// let lap = sw.lap();
752/// let elapsed2 = sw.stop();
753/// ```
754pub struct Stopwatch {
755    start: Option<Instant>,
756    laps: Vec<Duration>,
757}
758
759impl Stopwatch {
760    pub fn new() -> Self {
761        Self {
762            start: None,
763            laps: Vec::new(),
764        }
765    }
766
767    /// Start (or restart) the stopwatch.
768    pub fn start(&mut self) {
769        self.start = Some(Instant::now());
770        self.laps.clear();
771    }
772
773    /// Record a lap split without stopping.
774    pub fn lap(&mut self) -> Duration {
775        let elapsed = self.start.map(|s| s.elapsed()).unwrap_or_default();
776        self.laps.push(elapsed);
777        elapsed
778    }
779
780    /// Stop and return total elapsed time.
781    pub fn stop(&mut self) -> Duration {
782        let elapsed = self.start.map(|s| s.elapsed()).unwrap_or_default();
783        self.start = None;
784        elapsed
785    }
786
787    /// Get all recorded laps.
788    pub fn laps(&self) -> &[Duration] {
789        &self.laps
790    }
791}
792
793impl Default for Stopwatch {
794    fn default() -> Self {
795        Self::new()
796    }
797}
798
799// =============================================================================
800// Tests
801// =============================================================================
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806    use shrew_cpu::{CpuBackend, CpuDevice};
807    use shrew_nn::{Linear, Module};
808    use std::sync::{Arc, Mutex};
809    use std::thread;
810
811    #[test]
812    fn test_profiler_measure() {
813        let mut prof = Profiler::new();
814        let result = prof.measure("op_a", "compute", || {
815            thread::sleep(Duration::from_millis(5));
816            42
817        });
818        assert_eq!(result, 42);
819        assert_eq!(prof.events().len(), 1);
820        assert_eq!(prof.events()[0].name, "op_a");
821        assert!(prof.events()[0].duration >= Duration::from_millis(4));
822    }
823
824    #[test]
825    fn test_profiler_start_end() {
826        let mut prof = Profiler::new();
827        let start = prof.start_event("forward", "model");
828        thread::sleep(Duration::from_millis(5));
829        prof.end_event(start, "forward", "model");
830
831        assert_eq!(prof.events().len(), 1);
832        assert!(prof.total_time() >= Duration::from_millis(4));
833    }
834
835    #[test]
836    fn test_profiler_report() {
837        let mut prof = Profiler::new();
838        for _ in 0..3 {
839            prof.measure("matmul", "compute", || {
840                thread::sleep(Duration::from_millis(2));
841            });
842        }
843        prof.measure("relu", "compute", || {
844            thread::sleep(Duration::from_millis(1));
845        });
846
847        let report = prof.report();
848        assert_eq!(report.entries.len(), 2);
849        // matmul should be first (more total time)
850        assert_eq!(report.entries[0].name, "matmul");
851        assert_eq!(report.entries[0].count, 3);
852        assert_eq!(report.entries[1].name, "relu");
853        assert_eq!(report.entries[1].count, 1);
854
855        // Test Display works
856        let s = format!("{}", report);
857        assert!(s.contains("matmul"));
858        assert!(s.contains("relu"));
859    }
860
861    #[test]
862    fn test_profiler_clear() {
863        let mut prof = Profiler::new();
864        prof.measure("x", "y", || {});
865        assert_eq!(prof.events().len(), 1);
866        prof.clear();
867        assert_eq!(prof.events().len(), 0);
868    }
869
870    #[test]
871    fn test_memory_tracker() {
872        let mut mem = MemoryTracker::new();
873        mem.alloc("weight", 4096);
874        mem.alloc("bias", 128);
875        assert_eq!(mem.current_bytes(), 4224);
876        assert_eq!(mem.peak_bytes(), 4224);
877        assert_eq!(mem.alloc_count(), 2);
878        assert_eq!(mem.live_count(), 2);
879
880        mem.dealloc("bias");
881        assert_eq!(mem.current_bytes(), 4096);
882        assert_eq!(mem.peak_bytes(), 4224); // peak unchanged
883        assert_eq!(mem.dealloc_count(), 1);
884        assert_eq!(mem.live_count(), 1);
885
886        // Re-alloc same name replaces
887        mem.alloc("weight", 8192);
888        assert_eq!(mem.current_bytes(), 8192);
889        assert!(mem.peak_bytes() >= 8192);
890    }
891
892    #[test]
893    fn test_memory_tracker_summary() {
894        let mut mem = MemoryTracker::new();
895        mem.alloc("big", 1024 * 1024);
896        let s = mem.summary();
897        assert!(s.contains("1.00 MB"));
898    }
899
900    #[test]
901    fn test_format_bytes() {
902        assert_eq!(format_bytes(0), "0 B");
903        assert_eq!(format_bytes(512), "512 B");
904        assert_eq!(format_bytes(1024), "1.00 KB");
905        assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
906        assert_eq!(format_bytes(2 * 1024 * 1024 * 1024), "2.00 GB");
907    }
908
909    #[test]
910    fn test_model_summary() {
911        let model = Linear::new(16, 8, true, shrew_core::DType::F32, &CpuDevice).unwrap();
912        let summary = ModelSummary::from_module::<CpuBackend>(&model);
913        // weight: 16*8=128, bias: 1*8=8
914        assert_eq!(summary.total_params, 128 + 8);
915        assert_eq!(summary.num_tensors, 2);
916        let s = format!("{}", summary);
917        assert!(s.contains("Model Summary"));
918    }
919
920    #[test]
921    fn test_benchmark_forward() {
922        let model =
923            Linear::<CpuBackend>::new(8, 4, true, shrew_core::DType::F32, &CpuDevice).unwrap();
924        let result = benchmark_forward(
925            &model,
926            || Tensor::rand((2, 8), shrew_core::DType::F32, &CpuDevice).unwrap(),
927            2,
928            1,
929            5,
930        )
931        .unwrap();
932        assert_eq!(result.iterations, 5);
933        assert_eq!(result.batch_size, 2);
934        assert!(result.throughput > 0.0);
935        assert!(result.min_time <= result.max_time);
936        let s = format!("{}", result);
937        assert!(s.contains("Benchmark"));
938    }
939
940    #[test]
941    fn test_benchmark_forward_backward() {
942        let model =
943            Linear::<CpuBackend>::new(8, 4, true, shrew_core::DType::F32, &CpuDevice).unwrap();
944        let result = benchmark_forward_backward(
945            &model,
946            || Tensor::rand((2, 8), shrew_core::DType::F32, &CpuDevice).unwrap(),
947            |out| out.mean_all(),
948            2,
949            1,
950            3,
951        )
952        .unwrap();
953        assert_eq!(result.iterations, 3);
954        assert!(result.throughput > 0.0);
955    }
956
957    #[test]
958    fn test_scoped_timer() {
959        let profiler = Arc::new(Mutex::new(Profiler::new()));
960        {
961            let _t = ScopedTimer::new(profiler.clone(), "scoped_op", "test");
962            thread::sleep(Duration::from_millis(3));
963        }
964        let prof = profiler.lock().unwrap();
965        assert_eq!(prof.events().len(), 1);
966        assert_eq!(prof.events()[0].name, "scoped_op");
967        assert!(prof.events()[0].duration >= Duration::from_millis(2));
968    }
969
970    #[test]
971    fn test_estimate_model_memory() {
972        let model = Linear::new(16, 8, true, shrew_core::DType::F32, &CpuDevice).unwrap();
973        let bytes = estimate_model_memory::<CpuBackend>(&model);
974        // weight: 16*8*4=512 bytes, bias: 1*8*4=32 bytes
975        assert_eq!(bytes, 544);
976    }
977
978    #[test]
979    fn test_stopwatch() {
980        let mut sw = Stopwatch::new();
981        sw.start();
982        thread::sleep(Duration::from_millis(5));
983        let lap = sw.lap();
984        assert!(lap >= Duration::from_millis(4));
985        let total = sw.stop();
986        assert!(total >= lap);
987        assert_eq!(sw.laps().len(), 1);
988    }
989
990    #[test]
991    fn test_format_params() {
992        assert_eq!(format_params(500), "500");
993        assert_eq!(format_params(1_500), "1.50K");
994        assert_eq!(format_params(1_500_000), "1.50M");
995        assert_eq!(format_params(2_500_000_000), "2.50B");
996    }
997}