shrew/
quantize.rs

1// Quantization — INT8 / INT4 post-training quantization
2//
3// Quantization reduces model size and inference latency by representing
4// weights (and optionally activations) with low-precision integers
5// instead of 32-bit floats.
6//
7// SUPPORTED MODES:
8//
9//   - INT8 symmetric: weights in [-127, 127], scale = max(|w|) / 127
10//   - INT8 asymmetric: weights in [0, 255], scale + zero_point
11//   - INT4 symmetric: weights in [-7, 7], scale = max(|w|) / 7
12//   - INT4 packed: two 4-bit values per byte for 4× compression
13//
14// GRANULARITY:
15//
16//   - Per-tensor: one scale for the entire tensor (least accurate, most compact)
17//   - Per-channel: one scale per output channel (best accuracy/size tradeoff)
18//
19// WORKFLOW:
20//
21//   1. Train model normally in FP32
22//   2. Call quantize_model() or quantize_tensor() for post-training quantization
23//   3. Run inference using dequantize_tensor() to recover approximate FP32 values
24//   4. For deployment, use QuantizedLinear for fused quant/dequant inference
25
26use shrew_core::backend::Backend;
27use shrew_core::dtype::DType;
28use shrew_core::error::Result;
29use shrew_core::tensor::Tensor;
30use shrew_nn::Module;
31
32// Quantization configuration
33
34/// Bit-width for quantized values.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum QuantBits {
37    /// 8-bit quantization (range: -128..127 or 0..255).
38    Int8,
39    /// 4-bit quantization (range: -8..7 or 0..15).
40    Int4,
41}
42
43/// Quantization mode (symmetric vs. asymmetric).
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum QuantMode {
46    /// Symmetric: zero_point = 0, range is [-max, +max].
47    /// Simplest and fastest; works well for weights.
48    Symmetric,
49    /// Asymmetric: zero_point can be non-zero, range is [min, max].
50    /// Better accuracy for activations with skewed distributions.
51    Asymmetric,
52}
53
54/// Granularity of quantization parameters (scale / zero_point).
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum QuantGranularity {
57    /// One scale/zero_point for the entire tensor.
58    PerTensor,
59    /// One scale/zero_point per output channel (dim 0).
60    PerChannel,
61}
62
63/// Full quantization configuration.
64#[derive(Debug, Clone)]
65pub struct QuantConfig {
66    /// Bit width: Int8 or Int4.
67    pub bits: QuantBits,
68    /// Symmetric or asymmetric quantization.
69    pub mode: QuantMode,
70    /// Per-tensor or per-channel granularity.
71    pub granularity: QuantGranularity,
72}
73
74impl Default for QuantConfig {
75    fn default() -> Self {
76        Self {
77            bits: QuantBits::Int8,
78            mode: QuantMode::Symmetric,
79            granularity: QuantGranularity::PerTensor,
80        }
81    }
82}
83
84impl QuantConfig {
85    /// Create INT8 symmetric per-tensor config (most common).
86    pub fn int8() -> Self {
87        Self::default()
88    }
89
90    /// Create INT8 symmetric per-channel config (best accuracy for weights).
91    pub fn int8_per_channel() -> Self {
92        Self {
93            bits: QuantBits::Int8,
94            mode: QuantMode::Symmetric,
95            granularity: QuantGranularity::PerChannel,
96        }
97    }
98
99    /// Create INT4 symmetric per-tensor config (maximum compression).
100    pub fn int4() -> Self {
101        Self {
102            bits: QuantBits::Int4,
103            mode: QuantMode::Symmetric,
104            granularity: QuantGranularity::PerTensor,
105        }
106    }
107
108    /// Create INT4 per-channel config.
109    pub fn int4_per_channel() -> Self {
110        Self {
111            bits: QuantBits::Int4,
112            mode: QuantMode::Symmetric,
113            granularity: QuantGranularity::PerChannel,
114        }
115    }
116
117    /// Set mode to asymmetric.
118    pub fn asymmetric(mut self) -> Self {
119        self.mode = QuantMode::Asymmetric;
120        self
121    }
122
123    /// Maximum representable integer for this bit-width.
124    fn qmax(&self) -> f64 {
125        match self.bits {
126            QuantBits::Int8 => 127.0,
127            QuantBits::Int4 => 7.0,
128        }
129    }
130
131    /// Minimum representable integer for this bit-width and mode.
132    fn qmin(&self) -> f64 {
133        match (self.bits, self.mode) {
134            (QuantBits::Int8, QuantMode::Symmetric) => -127.0,
135            (QuantBits::Int8, QuantMode::Asymmetric) => -128.0,
136            (QuantBits::Int4, QuantMode::Symmetric) => -7.0,
137            (QuantBits::Int4, QuantMode::Asymmetric) => -8.0,
138        }
139    }
140}
141
142// QuantizedTensor — holds quantized weights + metadata
143
144/// A quantized tensor storing integer weights with associated scale/zero_point.
145///
146/// The original float value is recovered by: float = (int - zero_point) * scale
147///
148/// For per-channel quantization, `scales` and `zero_points` have one entry
149/// per channel (dimension 0 of the original tensor).
150#[derive(Debug, Clone)]
151pub struct QuantizedTensor {
152    /// Quantized integer values (stored as i8 for INT8, packed for INT4).
153    pub data: Vec<i8>,
154    /// Scale factor(s). Length 1 for per-tensor, N for per-channel.
155    pub scales: Vec<f64>,
156    /// Zero point(s). Length 1 for per-tensor, N for per-channel.
157    pub zero_points: Vec<f64>,
158    /// Original tensor shape.
159    pub shape: Vec<usize>,
160    /// Original dtype (for dequantization target).
161    pub original_dtype: DType,
162    /// Quantization config used.
163    pub config: QuantConfig,
164}
165
166impl QuantizedTensor {
167    /// Total number of elements.
168    pub fn numel(&self) -> usize {
169        self.shape.iter().product()
170    }
171
172    /// Size in bytes of the quantized representation.
173    pub fn size_bytes(&self) -> usize {
174        match self.config.bits {
175            QuantBits::Int8 => self.numel(),
176            QuantBits::Int4 => self.numel().div_ceil(2), // packed: 2 values per byte
177        }
178    }
179
180    /// Compression ratio vs FP32.
181    pub fn compression_ratio(&self) -> f64 {
182        let fp32_bytes = self.numel() * 4;
183        fp32_bytes as f64 / self.size_bytes() as f64
184    }
185}
186
187// Quantize / Dequantize operations
188
189/// Quantize a float tensor to a `QuantizedTensor`.
190///
191/// # Arguments
192/// - `tensor`: the FP32/FP64 tensor to quantize
193/// - `config`: quantization configuration
194pub fn quantize_tensor<B: Backend>(
195    tensor: &Tensor<B>,
196    config: &QuantConfig,
197) -> Result<QuantizedTensor> {
198    let data = tensor.to_f64_vec()?;
199    let shape = tensor.dims().to_vec();
200
201    match config.granularity {
202        QuantGranularity::PerTensor => quantize_per_tensor(&data, &shape, tensor.dtype(), config),
203        QuantGranularity::PerChannel => quantize_per_channel(&data, &shape, tensor.dtype(), config),
204    }
205}
206
207/// Dequantize a `QuantizedTensor` back to a float tensor.
208///
209/// The dequantized values are approximate: float ≈ (int - zero_point) * scale.
210pub fn dequantize_tensor<B: Backend>(
211    qtensor: &QuantizedTensor,
212    device: &B::Device,
213) -> Result<Tensor<B>> {
214    let float_data = match qtensor.config.granularity {
215        QuantGranularity::PerTensor => dequantize_per_tensor(qtensor),
216        QuantGranularity::PerChannel => dequantize_per_channel(qtensor),
217    };
218
219    Tensor::<B>::from_f64_slice(
220        &float_data,
221        qtensor.shape.clone(),
222        qtensor.original_dtype,
223        device,
224    )
225}
226
227// ── Per-tensor quantization ──
228
229fn quantize_per_tensor(
230    data: &[f64],
231    shape: &[usize],
232    dtype: DType,
233    config: &QuantConfig,
234) -> Result<QuantizedTensor> {
235    let (scale, zero_point) = compute_scale_zp(data, config);
236    let inv_scale = if scale.abs() < 1e-30 {
237        0.0
238    } else {
239        1.0 / scale
240    };
241    let qmin = config.qmin();
242    let qmax = config.qmax();
243
244    let quantized: Vec<i8> = data
245        .iter()
246        .map(|&v| {
247            let q = (v * inv_scale + zero_point).round().clamp(qmin, qmax);
248            q as i8
249        })
250        .collect();
251
252    Ok(QuantizedTensor {
253        data: quantized,
254        scales: vec![scale],
255        zero_points: vec![zero_point],
256        shape: shape.to_vec(),
257        original_dtype: dtype,
258        config: config.clone(),
259    })
260}
261
262fn dequantize_per_tensor(qt: &QuantizedTensor) -> Vec<f64> {
263    let scale = qt.scales[0];
264    let zp = qt.zero_points[0];
265    qt.data.iter().map(|&q| (q as f64 - zp) * scale).collect()
266}
267
268// ── Per-channel quantization ──
269
270fn quantize_per_channel(
271    data: &[f64],
272    shape: &[usize],
273    dtype: DType,
274    config: &QuantConfig,
275) -> Result<QuantizedTensor> {
276    if shape.is_empty() {
277        return quantize_per_tensor(data, shape, dtype, config);
278    }
279
280    let n_channels = shape[0];
281    let channel_size: usize = shape[1..].iter().product();
282    let qmin = config.qmin();
283    let qmax = config.qmax();
284
285    let mut scales = Vec::with_capacity(n_channels);
286    let mut zero_points = Vec::with_capacity(n_channels);
287    let mut quantized = vec![0i8; data.len()];
288
289    for ch in 0..n_channels {
290        let start = ch * channel_size;
291        let end = start + channel_size;
292        let channel_data = &data[start..end];
293
294        let (scale, zp) = compute_scale_zp(channel_data, config);
295        let inv_scale = if scale.abs() < 1e-30 {
296            0.0
297        } else {
298            1.0 / scale
299        };
300
301        for (i, &v) in channel_data.iter().enumerate() {
302            let q = (v * inv_scale + zp).round().clamp(qmin, qmax);
303            quantized[start + i] = q as i8;
304        }
305
306        scales.push(scale);
307        zero_points.push(zp);
308    }
309
310    Ok(QuantizedTensor {
311        data: quantized,
312        scales,
313        zero_points,
314        shape: shape.to_vec(),
315        original_dtype: dtype,
316        config: config.clone(),
317    })
318}
319
320fn dequantize_per_channel(qt: &QuantizedTensor) -> Vec<f64> {
321    let n_channels = qt.shape[0];
322    let channel_size: usize = qt.shape[1..].iter().product();
323    let mut result = vec![0.0f64; qt.data.len()];
324
325    for ch in 0..n_channels {
326        let start = ch * channel_size;
327        let scale = qt.scales[ch];
328        let zp = qt.zero_points[ch];
329        for i in 0..channel_size {
330            result[start + i] = (qt.data[start + i] as f64 - zp) * scale;
331        }
332    }
333
334    result
335}
336
337// ── Scale / zero-point computation ──
338
339fn compute_scale_zp(data: &[f64], config: &QuantConfig) -> (f64, f64) {
340    if data.is_empty() {
341        return (1.0, 0.0);
342    }
343
344    let qmin = config.qmin();
345    let qmax = config.qmax();
346
347    match config.mode {
348        QuantMode::Symmetric => {
349            let amax = data.iter().fold(0.0f64, |acc, &v| acc.max(v.abs()));
350            let scale = if amax < 1e-30 { 1.0 } else { amax / qmax };
351            (scale, 0.0)
352        }
353        QuantMode::Asymmetric => {
354            let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
355            let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
356            let range = max_val - min_val;
357            let scale = if range < 1e-30 {
358                1.0
359            } else {
360                range / (qmax - qmin)
361            };
362            let zero_point = (qmin - min_val / scale).round();
363            (scale, zero_point)
364        }
365    }
366}
367
368// QuantizedLinear — Quantized linear layer for inference
369
370/// A quantized linear layer that stores weights in INT8/INT4.
371///
372/// During inference, weights are dequantized on-the-fly for computation.
373/// This saves memory (2-8× depending on bit-width) at the cost of a small
374/// dequantization overhead.
375///
376/// # Example
377/// ```ignore
378/// // Quantize a trained linear layer
379/// let linear = Linear::new(256, 10, true, DType::F32, &dev)?;
380/// // ... train ...
381/// let qlinear = QuantizedLinear::from_linear(&linear, &QuantConfig::int8())?;
382/// let output = qlinear.forward(&input)?;
383/// ```
384pub struct QuantizedLinear<B: Backend> {
385    /// Quantized weight matrix.
386    weight_q: QuantizedTensor,
387    /// Bias (kept in FP32 — small relative to weights).
388    bias: Option<Tensor<B>>,
389    /// Device for dequantization.
390    device: B::Device,
391    /// Input features.
392    pub in_features: usize,
393    /// Output features.
394    pub out_features: usize,
395}
396
397impl<B: Backend> std::fmt::Debug for QuantizedLinear<B> {
398    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399        f.debug_struct("QuantizedLinear")
400            .field("in_features", &self.in_features)
401            .field("out_features", &self.out_features)
402            .field("bits", &self.weight_q.config.bits)
403            .field(
404                "compression",
405                &format!("{:.1}x", self.weight_q.compression_ratio()),
406            )
407            .finish()
408    }
409}
410
411impl<B: Backend> QuantizedLinear<B> {
412    /// Create a `QuantizedLinear` from a trained `Linear` layer.
413    pub fn from_linear(linear: &shrew_nn::Linear<B>, config: &QuantConfig) -> Result<Self> {
414        let params = linear.parameters();
415        let weight = &params[0];
416        let bias = if params.len() > 1 {
417            Some(params[1].clone())
418        } else {
419            None
420        };
421
422        let weight_q = quantize_tensor(weight, config)?;
423
424        let in_features = weight.dims()[1];
425        let out_features = weight.dims()[0];
426
427        Ok(Self {
428            weight_q,
429            bias,
430            device: weight.device().clone(),
431            in_features,
432            out_features,
433        })
434    }
435
436    /// Create from raw quantized data and optional bias tensor.
437    pub fn new(weight_q: QuantizedTensor, bias: Option<Tensor<B>>, device: B::Device) -> Self {
438        let in_features = weight_q.shape[1];
439        let out_features = weight_q.shape[0];
440        Self {
441            weight_q,
442            bias,
443            device,
444            in_features,
445            out_features,
446        }
447    }
448
449    /// Get the quantized weight data.
450    pub fn weight_quantized(&self) -> &QuantizedTensor {
451        &self.weight_q
452    }
453
454    /// Memory saved compared to FP32 weight storage.
455    pub fn memory_savings_bytes(&self) -> usize {
456        let fp32_size = self.weight_q.numel() * 4;
457        let quant_size = self.weight_q.size_bytes();
458        fp32_size - quant_size
459    }
460}
461
462impl<B: Backend> Module<B> for QuantizedLinear<B> {
463    fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
464        // Dequantize weight from INT8/INT4 → FP32 on-the-fly
465        let weight = dequantize_tensor::<B>(&self.weight_q, &self.device)?;
466
467        // x @ weight^T + bias
468        let wt = weight.t()?;
469        let out = x.matmul(&wt)?;
470
471        if let Some(ref bias) = self.bias {
472            out.add(bias)
473        } else {
474            Ok(out)
475        }
476    }
477
478    fn parameters(&self) -> Vec<Tensor<B>> {
479        // Quantized model is for inference — no trainable parameters
480        Vec::new()
481    }
482}
483
484// Model-level quantization
485
486/// Quantize all Linear layers in a model's named_parameters.
487///
488/// Returns a vector of `(name, QuantizedTensor)` for each weight parameter.
489/// Biases are kept in FP32.
490///
491/// # Example
492/// ```ignore
493/// let model = Sequential::new(vec![...]);
494/// let quantized = quantize_named_parameters(&model, &QuantConfig::int8())?;
495/// for (name, qt) in &quantized {
496///     println!("{}: {} → {} bytes ({:.1}x compression)",
497///         name, qt.numel() * 4, qt.size_bytes(), qt.compression_ratio());
498/// }
499/// ```
500pub fn quantize_named_parameters<B: Backend, M: Module<B>>(
501    module: &M,
502    config: &QuantConfig,
503) -> Result<Vec<(String, QuantizedTensor)>> {
504    let named = module.named_parameters();
505    let mut quantized = Vec::new();
506
507    for (name, tensor) in &named {
508        // Only quantize weight tensors (typically 2D+), skip biases (1D)
509        if tensor.rank() >= 2 {
510            let qt = quantize_tensor(tensor, config)?;
511            quantized.push((name.clone(), qt));
512        }
513    }
514
515    Ok(quantized)
516}
517
518/// Compute quantization statistics for a model.
519#[derive(Debug, Clone)]
520pub struct QuantStats {
521    /// Number of parameters quantized.
522    pub num_quantized: usize,
523    /// Number of parameters kept in FP32 (e.g., biases).
524    pub num_fp32: usize,
525    /// Total FP32 size in bytes.
526    pub fp32_bytes: usize,
527    /// Total quantized size in bytes.
528    pub quantized_bytes: usize,
529    /// Overall compression ratio.
530    pub compression_ratio: f64,
531}
532
533/// Compute quantization statistics for a model without actually quantizing.
534pub fn quantization_stats<B: Backend, M: Module<B>>(
535    module: &M,
536    config: &QuantConfig,
537) -> QuantStats {
538    let named = module.named_parameters();
539    let mut num_quantized = 0usize;
540    let mut num_fp32 = 0usize;
541    let mut fp32_bytes = 0usize;
542    let mut quantized_bytes = 0usize;
543
544    for (_, tensor) in &named {
545        let numel = tensor.elem_count();
546        let param_fp32 = numel * 4;
547        fp32_bytes += param_fp32;
548
549        if tensor.rank() >= 2 {
550            num_quantized += 1;
551            let quant_size = match config.bits {
552                QuantBits::Int8 => numel,
553                QuantBits::Int4 => numel.div_ceil(2),
554            };
555            // Add scale/zp overhead
556            let meta_size = match config.granularity {
557                QuantGranularity::PerTensor => 16, // 2 × f64
558                QuantGranularity::PerChannel => {
559                    if tensor.dims().is_empty() {
560                        16
561                    } else {
562                        tensor.dims()[0] * 16
563                    }
564                }
565            };
566            quantized_bytes += quant_size + meta_size;
567        } else {
568            num_fp32 += 1;
569            quantized_bytes += param_fp32; // biases stay FP32
570        }
571    }
572
573    let compression_ratio = if quantized_bytes > 0 {
574        fp32_bytes as f64 / quantized_bytes as f64
575    } else {
576        1.0
577    };
578
579    QuantStats {
580        num_quantized,
581        num_fp32,
582        fp32_bytes,
583        quantized_bytes,
584        compression_ratio,
585    }
586}
587
588// Tests
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use shrew_cpu::{CpuBackend, CpuDevice};
594
595    type B = CpuBackend;
596    type T = Tensor<B>;
597    const DEV: CpuDevice = CpuDevice;
598
599    // ── Basic quantize/dequantize round-trip ──
600
601    #[test]
602    fn test_int8_symmetric_per_tensor() {
603        let data = vec![1.0, -0.5, 0.25, -1.0, 0.0, 0.75];
604        let tensor = T::from_f64_slice(&data, vec![2, 3], DType::F32, &DEV).unwrap();
605
606        let config = QuantConfig::int8();
607        let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
608
609        assert_eq!(qt.shape, vec![2, 3]);
610        assert_eq!(qt.scales.len(), 1);
611        assert_eq!(qt.zero_points, vec![0.0]);
612
613        // Dequantize and check approximate recovery
614        let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
615        let rec_data = recovered.to_f64_vec().unwrap();
616
617        for (orig, rec) in data.iter().zip(rec_data.iter()) {
618            assert!(
619                (orig - rec).abs() < 0.02,
620                "int8 round-trip: {} vs {}",
621                orig,
622                rec
623            );
624        }
625    }
626
627    #[test]
628    fn test_int4_symmetric_per_tensor() {
629        let data = vec![1.0, -0.5, 0.25, -1.0];
630        let tensor = T::from_f64_slice(&data, vec![2, 2], DType::F32, &DEV).unwrap();
631
632        let config = QuantConfig::int4();
633        let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
634
635        // INT4 has max 7, so scale = 1.0/7 ≈ 0.143
636        assert_eq!(qt.scales.len(), 1);
637
638        let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
639        let rec_data = recovered.to_f64_vec().unwrap();
640
641        for (orig, rec) in data.iter().zip(rec_data.iter()) {
642            assert!(
643                (orig - rec).abs() < 0.2,
644                "int4 round-trip: {} vs {}",
645                orig,
646                rec
647            );
648        }
649    }
650
651    #[test]
652    fn test_int8_per_channel() {
653        // 3 channels, 4 elements each
654        let data: Vec<f64> = (0..12).map(|i| (i as f64 - 6.0) * 0.1).collect();
655        let tensor = T::from_f64_slice(&data, vec![3, 4], DType::F32, &DEV).unwrap();
656
657        let config = QuantConfig::int8_per_channel();
658        let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
659
660        assert_eq!(qt.scales.len(), 3); // one per channel
661
662        let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
663        let rec_data = recovered.to_f64_vec().unwrap();
664
665        for (orig, rec) in data.iter().zip(rec_data.iter()) {
666            assert!(
667                (orig - rec).abs() < 0.01,
668                "per-channel round-trip: {} vs {}",
669                orig,
670                rec
671            );
672        }
673    }
674
675    #[test]
676    fn test_int8_asymmetric() {
677        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
678        let tensor = T::from_f64_slice(&data, vec![1, 5], DType::F32, &DEV).unwrap();
679
680        let config = QuantConfig::int8().asymmetric();
681        let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
682
683        // Asymmetric with symmetric data should still produce valid results
684        assert_eq!(qt.scales.len(), 1);
685
686        let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
687        let rec_data = recovered.to_f64_vec().unwrap();
688
689        for (orig, rec) in data.iter().zip(rec_data.iter()) {
690            assert!(
691                (orig - rec).abs() < 0.02,
692                "asymmetric round-trip: {} vs {}",
693                orig,
694                rec
695            );
696        }
697    }
698
699    // ── Compression ratio ──
700
701    #[test]
702    fn test_compression_ratio() {
703        let tensor = T::randn(vec![256, 128], DType::F32, &DEV).unwrap();
704
705        let qt8 = quantize_tensor::<B>(&tensor, &QuantConfig::int8()).unwrap();
706        assert!((qt8.compression_ratio() - 4.0).abs() < 0.01); // 4:1 for INT8
707
708        let qt4 = quantize_tensor::<B>(&tensor, &QuantConfig::int4()).unwrap();
709        assert!((qt4.compression_ratio() - 8.0).abs() < 0.01); // 8:1 for INT4
710    }
711
712    // ── QuantizedLinear ──
713
714    #[test]
715    fn test_quantized_linear_forward() {
716        let linear = shrew_nn::Linear::<B>::new(4, 3, true, DType::F32, &DEV).unwrap();
717        let input = T::randn(vec![2, 4], DType::F32, &DEV).unwrap();
718
719        // FP32 reference output
720        let fp32_output = linear.forward(&input).unwrap();
721
722        // Quantize and run
723        let qlinear = QuantizedLinear::from_linear(&linear, &QuantConfig::int8()).unwrap();
724        let quant_output = qlinear.forward(&input).unwrap();
725
726        // Outputs should be close but not identical
727        let fp32_data = fp32_output.to_f64_vec().unwrap();
728        let quant_data = quant_output.to_f64_vec().unwrap();
729
730        assert_eq!(fp32_data.len(), quant_data.len());
731        for (a, b) in fp32_data.iter().zip(quant_data.iter()) {
732            assert!(
733                (a - b).abs() < 0.5,
734                "quantized output diverged: {} vs {}",
735                a,
736                b
737            );
738        }
739    }
740
741    #[test]
742    fn test_quantized_linear_no_trainable_params() {
743        let linear = shrew_nn::Linear::<B>::new(4, 2, true, DType::F32, &DEV).unwrap();
744        let qlinear = QuantizedLinear::from_linear(&linear, &QuantConfig::int8()).unwrap();
745        assert!(qlinear.parameters().is_empty()); // inference-only
746    }
747
748    #[test]
749    fn test_quantized_linear_memory_savings() {
750        let linear = shrew_nn::Linear::<B>::new(256, 128, true, DType::F32, &DEV).unwrap();
751        let qlinear = QuantizedLinear::from_linear(&linear, &QuantConfig::int8()).unwrap();
752        // 256*128 = 32768 params, fp32 = 131072 bytes, int8 = 32768 bytes
753        // Savings = 131072 - 32768 = 98304
754        assert_eq!(qlinear.memory_savings_bytes(), 98304);
755    }
756
757    // ── Model-level quantization stats ──
758
759    #[test]
760    fn test_quantize_named_parameters() {
761        let linear = shrew_nn::Linear::<B>::new(8, 4, true, DType::F32, &DEV).unwrap();
762        let quantized = quantize_named_parameters(&linear, &QuantConfig::int8()).unwrap();
763        // Linear stores weight as [4,8] (rank 2) and bias as [1,4] (rank 2)
764        // Both have rank >= 2, so both get quantized
765        assert_eq!(quantized.len(), 2);
766    }
767
768    #[test]
769    fn test_quantization_stats() {
770        let linear = shrew_nn::Linear::<B>::new(64, 32, true, DType::F32, &DEV).unwrap();
771        let stats = quantization_stats(&linear, &QuantConfig::int8());
772        // Both weight [32,64] and bias [1,32] are rank 2 → both quantized
773        assert_eq!(stats.num_quantized, 2);
774        assert_eq!(stats.num_fp32, 0);
775        assert!(stats.compression_ratio > 3.0);
776    }
777
778    // ── Edge case: zero tensor ──
779
780    #[test]
781    fn test_quantize_zero_tensor() {
782        let tensor = T::zeros(vec![2, 3], DType::F32, &DEV).unwrap();
783        let qt = quantize_tensor::<B>(&tensor, &QuantConfig::int8()).unwrap();
784        let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
785        let data = recovered.to_f64_vec().unwrap();
786        for &v in &data {
787            assert_eq!(v, 0.0);
788        }
789    }
790}