1use shrew_core::backend::Backend;
27use shrew_core::dtype::DType;
28use shrew_core::error::Result;
29use shrew_core::tensor::Tensor;
30use shrew_nn::Module;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum QuantBits {
37 Int8,
39 Int4,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum QuantMode {
46 Symmetric,
49 Asymmetric,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum QuantGranularity {
57 PerTensor,
59 PerChannel,
61}
62
63#[derive(Debug, Clone)]
65pub struct QuantConfig {
66 pub bits: QuantBits,
68 pub mode: QuantMode,
70 pub granularity: QuantGranularity,
72}
73
74impl Default for QuantConfig {
75 fn default() -> Self {
76 Self {
77 bits: QuantBits::Int8,
78 mode: QuantMode::Symmetric,
79 granularity: QuantGranularity::PerTensor,
80 }
81 }
82}
83
84impl QuantConfig {
85 pub fn int8() -> Self {
87 Self::default()
88 }
89
90 pub fn int8_per_channel() -> Self {
92 Self {
93 bits: QuantBits::Int8,
94 mode: QuantMode::Symmetric,
95 granularity: QuantGranularity::PerChannel,
96 }
97 }
98
99 pub fn int4() -> Self {
101 Self {
102 bits: QuantBits::Int4,
103 mode: QuantMode::Symmetric,
104 granularity: QuantGranularity::PerTensor,
105 }
106 }
107
108 pub fn int4_per_channel() -> Self {
110 Self {
111 bits: QuantBits::Int4,
112 mode: QuantMode::Symmetric,
113 granularity: QuantGranularity::PerChannel,
114 }
115 }
116
117 pub fn asymmetric(mut self) -> Self {
119 self.mode = QuantMode::Asymmetric;
120 self
121 }
122
123 fn qmax(&self) -> f64 {
125 match self.bits {
126 QuantBits::Int8 => 127.0,
127 QuantBits::Int4 => 7.0,
128 }
129 }
130
131 fn qmin(&self) -> f64 {
133 match (self.bits, self.mode) {
134 (QuantBits::Int8, QuantMode::Symmetric) => -127.0,
135 (QuantBits::Int8, QuantMode::Asymmetric) => -128.0,
136 (QuantBits::Int4, QuantMode::Symmetric) => -7.0,
137 (QuantBits::Int4, QuantMode::Asymmetric) => -8.0,
138 }
139 }
140}
141
142#[derive(Debug, Clone)]
151pub struct QuantizedTensor {
152 pub data: Vec<i8>,
154 pub scales: Vec<f64>,
156 pub zero_points: Vec<f64>,
158 pub shape: Vec<usize>,
160 pub original_dtype: DType,
162 pub config: QuantConfig,
164}
165
166impl QuantizedTensor {
167 pub fn numel(&self) -> usize {
169 self.shape.iter().product()
170 }
171
172 pub fn size_bytes(&self) -> usize {
174 match self.config.bits {
175 QuantBits::Int8 => self.numel(),
176 QuantBits::Int4 => self.numel().div_ceil(2), }
178 }
179
180 pub fn compression_ratio(&self) -> f64 {
182 let fp32_bytes = self.numel() * 4;
183 fp32_bytes as f64 / self.size_bytes() as f64
184 }
185}
186
187pub fn quantize_tensor<B: Backend>(
195 tensor: &Tensor<B>,
196 config: &QuantConfig,
197) -> Result<QuantizedTensor> {
198 let data = tensor.to_f64_vec()?;
199 let shape = tensor.dims().to_vec();
200
201 match config.granularity {
202 QuantGranularity::PerTensor => quantize_per_tensor(&data, &shape, tensor.dtype(), config),
203 QuantGranularity::PerChannel => quantize_per_channel(&data, &shape, tensor.dtype(), config),
204 }
205}
206
207pub fn dequantize_tensor<B: Backend>(
211 qtensor: &QuantizedTensor,
212 device: &B::Device,
213) -> Result<Tensor<B>> {
214 let float_data = match qtensor.config.granularity {
215 QuantGranularity::PerTensor => dequantize_per_tensor(qtensor),
216 QuantGranularity::PerChannel => dequantize_per_channel(qtensor),
217 };
218
219 Tensor::<B>::from_f64_slice(
220 &float_data,
221 qtensor.shape.clone(),
222 qtensor.original_dtype,
223 device,
224 )
225}
226
227fn quantize_per_tensor(
230 data: &[f64],
231 shape: &[usize],
232 dtype: DType,
233 config: &QuantConfig,
234) -> Result<QuantizedTensor> {
235 let (scale, zero_point) = compute_scale_zp(data, config);
236 let inv_scale = if scale.abs() < 1e-30 {
237 0.0
238 } else {
239 1.0 / scale
240 };
241 let qmin = config.qmin();
242 let qmax = config.qmax();
243
244 let quantized: Vec<i8> = data
245 .iter()
246 .map(|&v| {
247 let q = (v * inv_scale + zero_point).round().clamp(qmin, qmax);
248 q as i8
249 })
250 .collect();
251
252 Ok(QuantizedTensor {
253 data: quantized,
254 scales: vec![scale],
255 zero_points: vec![zero_point],
256 shape: shape.to_vec(),
257 original_dtype: dtype,
258 config: config.clone(),
259 })
260}
261
262fn dequantize_per_tensor(qt: &QuantizedTensor) -> Vec<f64> {
263 let scale = qt.scales[0];
264 let zp = qt.zero_points[0];
265 qt.data.iter().map(|&q| (q as f64 - zp) * scale).collect()
266}
267
268fn quantize_per_channel(
271 data: &[f64],
272 shape: &[usize],
273 dtype: DType,
274 config: &QuantConfig,
275) -> Result<QuantizedTensor> {
276 if shape.is_empty() {
277 return quantize_per_tensor(data, shape, dtype, config);
278 }
279
280 let n_channels = shape[0];
281 let channel_size: usize = shape[1..].iter().product();
282 let qmin = config.qmin();
283 let qmax = config.qmax();
284
285 let mut scales = Vec::with_capacity(n_channels);
286 let mut zero_points = Vec::with_capacity(n_channels);
287 let mut quantized = vec![0i8; data.len()];
288
289 for ch in 0..n_channels {
290 let start = ch * channel_size;
291 let end = start + channel_size;
292 let channel_data = &data[start..end];
293
294 let (scale, zp) = compute_scale_zp(channel_data, config);
295 let inv_scale = if scale.abs() < 1e-30 {
296 0.0
297 } else {
298 1.0 / scale
299 };
300
301 for (i, &v) in channel_data.iter().enumerate() {
302 let q = (v * inv_scale + zp).round().clamp(qmin, qmax);
303 quantized[start + i] = q as i8;
304 }
305
306 scales.push(scale);
307 zero_points.push(zp);
308 }
309
310 Ok(QuantizedTensor {
311 data: quantized,
312 scales,
313 zero_points,
314 shape: shape.to_vec(),
315 original_dtype: dtype,
316 config: config.clone(),
317 })
318}
319
320fn dequantize_per_channel(qt: &QuantizedTensor) -> Vec<f64> {
321 let n_channels = qt.shape[0];
322 let channel_size: usize = qt.shape[1..].iter().product();
323 let mut result = vec![0.0f64; qt.data.len()];
324
325 for ch in 0..n_channels {
326 let start = ch * channel_size;
327 let scale = qt.scales[ch];
328 let zp = qt.zero_points[ch];
329 for i in 0..channel_size {
330 result[start + i] = (qt.data[start + i] as f64 - zp) * scale;
331 }
332 }
333
334 result
335}
336
337fn compute_scale_zp(data: &[f64], config: &QuantConfig) -> (f64, f64) {
340 if data.is_empty() {
341 return (1.0, 0.0);
342 }
343
344 let qmin = config.qmin();
345 let qmax = config.qmax();
346
347 match config.mode {
348 QuantMode::Symmetric => {
349 let amax = data.iter().fold(0.0f64, |acc, &v| acc.max(v.abs()));
350 let scale = if amax < 1e-30 { 1.0 } else { amax / qmax };
351 (scale, 0.0)
352 }
353 QuantMode::Asymmetric => {
354 let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
355 let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
356 let range = max_val - min_val;
357 let scale = if range < 1e-30 {
358 1.0
359 } else {
360 range / (qmax - qmin)
361 };
362 let zero_point = (qmin - min_val / scale).round();
363 (scale, zero_point)
364 }
365 }
366}
367
368pub struct QuantizedLinear<B: Backend> {
385 weight_q: QuantizedTensor,
387 bias: Option<Tensor<B>>,
389 device: B::Device,
391 pub in_features: usize,
393 pub out_features: usize,
395}
396
397impl<B: Backend> std::fmt::Debug for QuantizedLinear<B> {
398 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399 f.debug_struct("QuantizedLinear")
400 .field("in_features", &self.in_features)
401 .field("out_features", &self.out_features)
402 .field("bits", &self.weight_q.config.bits)
403 .field(
404 "compression",
405 &format!("{:.1}x", self.weight_q.compression_ratio()),
406 )
407 .finish()
408 }
409}
410
411impl<B: Backend> QuantizedLinear<B> {
412 pub fn from_linear(linear: &shrew_nn::Linear<B>, config: &QuantConfig) -> Result<Self> {
414 let params = linear.parameters();
415 let weight = ¶ms[0];
416 let bias = if params.len() > 1 {
417 Some(params[1].clone())
418 } else {
419 None
420 };
421
422 let weight_q = quantize_tensor(weight, config)?;
423
424 let in_features = weight.dims()[1];
425 let out_features = weight.dims()[0];
426
427 Ok(Self {
428 weight_q,
429 bias,
430 device: weight.device().clone(),
431 in_features,
432 out_features,
433 })
434 }
435
436 pub fn new(weight_q: QuantizedTensor, bias: Option<Tensor<B>>, device: B::Device) -> Self {
438 let in_features = weight_q.shape[1];
439 let out_features = weight_q.shape[0];
440 Self {
441 weight_q,
442 bias,
443 device,
444 in_features,
445 out_features,
446 }
447 }
448
449 pub fn weight_quantized(&self) -> &QuantizedTensor {
451 &self.weight_q
452 }
453
454 pub fn memory_savings_bytes(&self) -> usize {
456 let fp32_size = self.weight_q.numel() * 4;
457 let quant_size = self.weight_q.size_bytes();
458 fp32_size - quant_size
459 }
460}
461
462impl<B: Backend> Module<B> for QuantizedLinear<B> {
463 fn forward(&self, x: &Tensor<B>) -> Result<Tensor<B>> {
464 let weight = dequantize_tensor::<B>(&self.weight_q, &self.device)?;
466
467 let wt = weight.t()?;
469 let out = x.matmul(&wt)?;
470
471 if let Some(ref bias) = self.bias {
472 out.add(bias)
473 } else {
474 Ok(out)
475 }
476 }
477
478 fn parameters(&self) -> Vec<Tensor<B>> {
479 Vec::new()
481 }
482}
483
484pub fn quantize_named_parameters<B: Backend, M: Module<B>>(
501 module: &M,
502 config: &QuantConfig,
503) -> Result<Vec<(String, QuantizedTensor)>> {
504 let named = module.named_parameters();
505 let mut quantized = Vec::new();
506
507 for (name, tensor) in &named {
508 if tensor.rank() >= 2 {
510 let qt = quantize_tensor(tensor, config)?;
511 quantized.push((name.clone(), qt));
512 }
513 }
514
515 Ok(quantized)
516}
517
518#[derive(Debug, Clone)]
520pub struct QuantStats {
521 pub num_quantized: usize,
523 pub num_fp32: usize,
525 pub fp32_bytes: usize,
527 pub quantized_bytes: usize,
529 pub compression_ratio: f64,
531}
532
533pub fn quantization_stats<B: Backend, M: Module<B>>(
535 module: &M,
536 config: &QuantConfig,
537) -> QuantStats {
538 let named = module.named_parameters();
539 let mut num_quantized = 0usize;
540 let mut num_fp32 = 0usize;
541 let mut fp32_bytes = 0usize;
542 let mut quantized_bytes = 0usize;
543
544 for (_, tensor) in &named {
545 let numel = tensor.elem_count();
546 let param_fp32 = numel * 4;
547 fp32_bytes += param_fp32;
548
549 if tensor.rank() >= 2 {
550 num_quantized += 1;
551 let quant_size = match config.bits {
552 QuantBits::Int8 => numel,
553 QuantBits::Int4 => numel.div_ceil(2),
554 };
555 let meta_size = match config.granularity {
557 QuantGranularity::PerTensor => 16, QuantGranularity::PerChannel => {
559 if tensor.dims().is_empty() {
560 16
561 } else {
562 tensor.dims()[0] * 16
563 }
564 }
565 };
566 quantized_bytes += quant_size + meta_size;
567 } else {
568 num_fp32 += 1;
569 quantized_bytes += param_fp32; }
571 }
572
573 let compression_ratio = if quantized_bytes > 0 {
574 fp32_bytes as f64 / quantized_bytes as f64
575 } else {
576 1.0
577 };
578
579 QuantStats {
580 num_quantized,
581 num_fp32,
582 fp32_bytes,
583 quantized_bytes,
584 compression_ratio,
585 }
586}
587
588#[cfg(test)]
591mod tests {
592 use super::*;
593 use shrew_cpu::{CpuBackend, CpuDevice};
594
595 type B = CpuBackend;
596 type T = Tensor<B>;
597 const DEV: CpuDevice = CpuDevice;
598
599 #[test]
602 fn test_int8_symmetric_per_tensor() {
603 let data = vec![1.0, -0.5, 0.25, -1.0, 0.0, 0.75];
604 let tensor = T::from_f64_slice(&data, vec![2, 3], DType::F32, &DEV).unwrap();
605
606 let config = QuantConfig::int8();
607 let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
608
609 assert_eq!(qt.shape, vec![2, 3]);
610 assert_eq!(qt.scales.len(), 1);
611 assert_eq!(qt.zero_points, vec![0.0]);
612
613 let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
615 let rec_data = recovered.to_f64_vec().unwrap();
616
617 for (orig, rec) in data.iter().zip(rec_data.iter()) {
618 assert!(
619 (orig - rec).abs() < 0.02,
620 "int8 round-trip: {} vs {}",
621 orig,
622 rec
623 );
624 }
625 }
626
627 #[test]
628 fn test_int4_symmetric_per_tensor() {
629 let data = vec![1.0, -0.5, 0.25, -1.0];
630 let tensor = T::from_f64_slice(&data, vec![2, 2], DType::F32, &DEV).unwrap();
631
632 let config = QuantConfig::int4();
633 let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
634
635 assert_eq!(qt.scales.len(), 1);
637
638 let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
639 let rec_data = recovered.to_f64_vec().unwrap();
640
641 for (orig, rec) in data.iter().zip(rec_data.iter()) {
642 assert!(
643 (orig - rec).abs() < 0.2,
644 "int4 round-trip: {} vs {}",
645 orig,
646 rec
647 );
648 }
649 }
650
651 #[test]
652 fn test_int8_per_channel() {
653 let data: Vec<f64> = (0..12).map(|i| (i as f64 - 6.0) * 0.1).collect();
655 let tensor = T::from_f64_slice(&data, vec![3, 4], DType::F32, &DEV).unwrap();
656
657 let config = QuantConfig::int8_per_channel();
658 let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
659
660 assert_eq!(qt.scales.len(), 3); let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
663 let rec_data = recovered.to_f64_vec().unwrap();
664
665 for (orig, rec) in data.iter().zip(rec_data.iter()) {
666 assert!(
667 (orig - rec).abs() < 0.01,
668 "per-channel round-trip: {} vs {}",
669 orig,
670 rec
671 );
672 }
673 }
674
675 #[test]
676 fn test_int8_asymmetric() {
677 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
678 let tensor = T::from_f64_slice(&data, vec![1, 5], DType::F32, &DEV).unwrap();
679
680 let config = QuantConfig::int8().asymmetric();
681 let qt = quantize_tensor::<B>(&tensor, &config).unwrap();
682
683 assert_eq!(qt.scales.len(), 1);
685
686 let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
687 let rec_data = recovered.to_f64_vec().unwrap();
688
689 for (orig, rec) in data.iter().zip(rec_data.iter()) {
690 assert!(
691 (orig - rec).abs() < 0.02,
692 "asymmetric round-trip: {} vs {}",
693 orig,
694 rec
695 );
696 }
697 }
698
699 #[test]
702 fn test_compression_ratio() {
703 let tensor = T::randn(vec![256, 128], DType::F32, &DEV).unwrap();
704
705 let qt8 = quantize_tensor::<B>(&tensor, &QuantConfig::int8()).unwrap();
706 assert!((qt8.compression_ratio() - 4.0).abs() < 0.01); let qt4 = quantize_tensor::<B>(&tensor, &QuantConfig::int4()).unwrap();
709 assert!((qt4.compression_ratio() - 8.0).abs() < 0.01); }
711
712 #[test]
715 fn test_quantized_linear_forward() {
716 let linear = shrew_nn::Linear::<B>::new(4, 3, true, DType::F32, &DEV).unwrap();
717 let input = T::randn(vec![2, 4], DType::F32, &DEV).unwrap();
718
719 let fp32_output = linear.forward(&input).unwrap();
721
722 let qlinear = QuantizedLinear::from_linear(&linear, &QuantConfig::int8()).unwrap();
724 let quant_output = qlinear.forward(&input).unwrap();
725
726 let fp32_data = fp32_output.to_f64_vec().unwrap();
728 let quant_data = quant_output.to_f64_vec().unwrap();
729
730 assert_eq!(fp32_data.len(), quant_data.len());
731 for (a, b) in fp32_data.iter().zip(quant_data.iter()) {
732 assert!(
733 (a - b).abs() < 0.5,
734 "quantized output diverged: {} vs {}",
735 a,
736 b
737 );
738 }
739 }
740
741 #[test]
742 fn test_quantized_linear_no_trainable_params() {
743 let linear = shrew_nn::Linear::<B>::new(4, 2, true, DType::F32, &DEV).unwrap();
744 let qlinear = QuantizedLinear::from_linear(&linear, &QuantConfig::int8()).unwrap();
745 assert!(qlinear.parameters().is_empty()); }
747
748 #[test]
749 fn test_quantized_linear_memory_savings() {
750 let linear = shrew_nn::Linear::<B>::new(256, 128, true, DType::F32, &DEV).unwrap();
751 let qlinear = QuantizedLinear::from_linear(&linear, &QuantConfig::int8()).unwrap();
752 assert_eq!(qlinear.memory_savings_bytes(), 98304);
755 }
756
757 #[test]
760 fn test_quantize_named_parameters() {
761 let linear = shrew_nn::Linear::<B>::new(8, 4, true, DType::F32, &DEV).unwrap();
762 let quantized = quantize_named_parameters(&linear, &QuantConfig::int8()).unwrap();
763 assert_eq!(quantized.len(), 2);
766 }
767
768 #[test]
769 fn test_quantization_stats() {
770 let linear = shrew_nn::Linear::<B>::new(64, 32, true, DType::F32, &DEV).unwrap();
771 let stats = quantization_stats(&linear, &QuantConfig::int8());
772 assert_eq!(stats.num_quantized, 2);
774 assert_eq!(stats.num_fp32, 0);
775 assert!(stats.compression_ratio > 3.0);
776 }
777
778 #[test]
781 fn test_quantize_zero_tensor() {
782 let tensor = T::zeros(vec![2, 3], DType::F32, &DEV).unwrap();
783 let qt = quantize_tensor::<B>(&tensor, &QuantConfig::int8()).unwrap();
784 let recovered = dequantize_tensor::<B>(&qt, &DEV).unwrap();
785 let data = recovered.to_f64_vec().unwrap();
786 for &v in &data {
787 assert_eq!(v, 0.0);
788 }
789 }
790}