1use std::collections::HashMap;
31use std::fs::File;
32use std::io::{BufReader, BufWriter, Read, Write};
33use std::path::Path;
34
35use shrew_core::backend::Backend;
36use shrew_core::tensor::Tensor;
37use shrew_core::DType;
38use shrew_optim::OptimizerState;
39
40use crate::exec::Executor;
41
42const MAGIC: &[u8; 4] = b"SHRW";
47const VERSION: u32 = 1;
48
49fn dtype_to_u8(dtype: DType) -> u8 {
54 match dtype {
55 DType::F32 => 0,
56 DType::F64 => 1,
57 DType::U8 => 2,
58 DType::U32 => 3,
59 DType::I64 => 4,
60 DType::F16 => 5,
61 DType::BF16 => 6,
62 }
63}
64
65fn u8_to_dtype(v: u8) -> shrew_core::Result<DType> {
66 match v {
67 0 => Ok(DType::F32),
68 1 => Ok(DType::F64),
69 2 => Ok(DType::U8),
70 3 => Ok(DType::U32),
71 4 => Ok(DType::I64),
72 5 => Ok(DType::F16),
73 6 => Ok(DType::BF16),
74 _ => Err(shrew_core::Error::msg(format!("Unknown dtype tag: {v}"))),
75 }
76}
77
78fn tensor_to_bytes<B: Backend>(tensor: &Tensor<B>) -> shrew_core::Result<Vec<u8>> {
84 let t = tensor.contiguous()?;
86 let data = t.to_f64_vec()?;
87 let dtype = t.dtype();
88
89 Ok(match dtype {
90 DType::F16 => data
91 .iter()
92 .flat_map(|&v| half::f16::from_f64(v).to_le_bytes())
93 .collect(),
94 DType::BF16 => data
95 .iter()
96 .flat_map(|&v| half::bf16::from_f64(v).to_le_bytes())
97 .collect(),
98 DType::F32 => data
99 .iter()
100 .flat_map(|&v| (v as f32).to_le_bytes())
101 .collect(),
102 DType::F64 => data.iter().flat_map(|&v| v.to_le_bytes()).collect(),
103 DType::U8 => data.iter().map(|&v| v as u8).collect(),
104 DType::U32 => data
105 .iter()
106 .flat_map(|&v| (v as u32).to_le_bytes())
107 .collect(),
108 DType::I64 => data
109 .iter()
110 .flat_map(|&v| (v as i64).to_le_bytes())
111 .collect(),
112 })
113}
114
115fn tensor_from_bytes<B: Backend>(
117 bytes: &[u8],
118 shape: Vec<usize>,
119 dtype: DType,
120 device: &B::Device,
121) -> shrew_core::Result<Tensor<B>> {
122 let data_f64: Vec<f64> = match dtype {
123 DType::F16 => bytes
124 .chunks_exact(2)
125 .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f64())
126 .collect(),
127 DType::BF16 => bytes
128 .chunks_exact(2)
129 .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f64())
130 .collect(),
131 DType::F32 => bytes
132 .chunks_exact(4)
133 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
134 .collect(),
135 DType::F64 => bytes
136 .chunks_exact(8)
137 .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
138 .collect(),
139 DType::U8 => bytes.iter().map(|&b| b as f64).collect(),
140 DType::U32 => bytes
141 .chunks_exact(4)
142 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
143 .collect(),
144 DType::I64 => bytes
145 .chunks_exact(8)
146 .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f64)
147 .collect(),
148 };
149
150 Tensor::<B>::from_f64_slice(&data_f64, shape, dtype, device)
151}
152
153fn write_u8(w: &mut impl Write, v: u8) -> std::io::Result<()> {
158 w.write_all(&[v])
159}
160
161fn write_u32(w: &mut impl Write, v: u32) -> std::io::Result<()> {
162 w.write_all(&v.to_le_bytes())
163}
164
165fn write_u64(w: &mut impl Write, v: u64) -> std::io::Result<()> {
166 w.write_all(&v.to_le_bytes())
167}
168
169fn write_bytes(w: &mut impl Write, data: &[u8]) -> std::io::Result<()> {
170 w.write_all(data)
171}
172
173fn read_u8(r: &mut impl Read) -> std::io::Result<u8> {
174 let mut buf = [0u8; 1];
175 r.read_exact(&mut buf)?;
176 Ok(buf[0])
177}
178
179fn read_u32(r: &mut impl Read) -> std::io::Result<u32> {
180 let mut buf = [0u8; 4];
181 r.read_exact(&mut buf)?;
182 Ok(u32::from_le_bytes(buf))
183}
184
185fn read_u64(r: &mut impl Read) -> std::io::Result<u64> {
186 let mut buf = [0u8; 8];
187 r.read_exact(&mut buf)?;
188 Ok(u64::from_le_bytes(buf))
189}
190
191fn read_bytes(r: &mut impl Read, len: usize) -> std::io::Result<Vec<u8>> {
192 let mut buf = vec![0u8; len];
193 r.read_exact(&mut buf)?;
194 Ok(buf)
195}
196
197pub fn write_checkpoint<B: Backend>(
203 writer: &mut impl Write,
204 tensors: &[(String, Tensor<B>)],
205) -> shrew_core::Result<()> {
206 write_bytes(writer, MAGIC).map_err(io_err)?;
208 write_u32(writer, VERSION).map_err(io_err)?;
209 write_u32(writer, tensors.len() as u32).map_err(io_err)?;
210
211 for (key, tensor) in tensors {
213 let key_bytes = key.as_bytes();
214 write_u32(writer, key_bytes.len() as u32).map_err(io_err)?;
215 write_bytes(writer, key_bytes).map_err(io_err)?;
216
217 write_u8(writer, dtype_to_u8(tensor.dtype())).map_err(io_err)?;
218
219 let dims = tensor.dims();
220 write_u32(writer, dims.len() as u32).map_err(io_err)?;
221 for &d in dims {
222 write_u32(writer, d as u32).map_err(io_err)?;
223 }
224
225 let data = tensor_to_bytes(tensor)?;
226 write_u64(writer, data.len() as u64).map_err(io_err)?;
227 write_bytes(writer, &data).map_err(io_err)?;
228 }
229
230 Ok(())
231}
232
233pub fn read_checkpoint<B: Backend>(
235 reader: &mut impl Read,
236 device: &B::Device,
237) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
238 let mut magic = [0u8; 4];
240 reader.read_exact(&mut magic).map_err(io_err)?;
241 if &magic != MAGIC {
242 return Err(shrew_core::Error::msg(format!(
243 "Invalid checkpoint: expected magic {:?}, got {:?}",
244 MAGIC, magic
245 )));
246 }
247
248 let version = read_u32(reader).map_err(io_err)?;
249 if version != VERSION {
250 return Err(shrew_core::Error::msg(format!(
251 "Unsupported checkpoint version: {} (expected {})",
252 version, VERSION
253 )));
254 }
255
256 let count = read_u32(reader).map_err(io_err)? as usize;
257 let mut tensors = Vec::with_capacity(count);
258
259 for _ in 0..count {
260 let key_len = read_u32(reader).map_err(io_err)? as usize;
261 let key_bytes = read_bytes(reader, key_len).map_err(io_err)?;
262 let key = String::from_utf8(key_bytes)
263 .map_err(|e| shrew_core::Error::msg(format!("Invalid UTF-8 key: {e}")))?;
264
265 let dtype = u8_to_dtype(read_u8(reader).map_err(io_err)?)?;
266
267 let ndim = read_u32(reader).map_err(io_err)? as usize;
268 let mut dims = Vec::with_capacity(ndim);
269 for _ in 0..ndim {
270 dims.push(read_u32(reader).map_err(io_err)? as usize);
271 }
272
273 let data_len = read_u64(reader).map_err(io_err)? as usize;
274 let data = read_bytes(reader, data_len).map_err(io_err)?;
275
276 let tensor = tensor_from_bytes::<B>(&data, dims, dtype, device)?;
277 tensors.push((key, tensor));
278 }
279
280 Ok(tensors)
281}
282
283fn io_err(e: std::io::Error) -> shrew_core::Error {
284 shrew_core::Error::msg(format!("IO error: {e}"))
285}
286
287pub fn save_tensors<B: Backend>(
306 path: impl AsRef<Path>,
307 tensors: &[(String, Tensor<B>)],
308) -> shrew_core::Result<()> {
309 let file = File::create(path.as_ref()).map_err(io_err)?;
310 let mut writer = BufWriter::new(file);
311 write_checkpoint(&mut writer, tensors)?;
312 writer.flush().map_err(io_err)?;
313 Ok(())
314}
315
316pub fn load_tensors<B: Backend>(
328 path: impl AsRef<Path>,
329 device: &B::Device,
330) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
331 let file = File::open(path.as_ref()).map_err(io_err)?;
332 let mut reader = BufReader::new(file);
333 read_checkpoint(&mut reader, device)
334}
335
336pub fn save<B: Backend>(path: impl AsRef<Path>, executor: &Executor<B>) -> shrew_core::Result<()> {
344 let named = executor.named_params();
345 save_tensors(path, &named)
346}
347
348pub fn load<B: Backend>(
355 path: impl AsRef<Path>,
356 executor: &mut Executor<B>,
357) -> shrew_core::Result<usize> {
358 let tensors = load_tensors::<B>(path, executor.device())?;
359 let loaded: HashMap<String, Tensor<B>> = tensors.into_iter().collect();
360
361 let mut count = 0;
362 for (key, tensor) in &loaded {
363 if executor.set_param_by_key(key, tensor.clone()) {
364 count += 1;
365 }
366 }
367
368 Ok(count)
369}
370
371pub fn save_trainer<B: Backend>(
373 path: impl AsRef<Path>,
374 trainer: &crate::exec::Trainer<B>,
375) -> shrew_core::Result<()> {
376 save(path, &trainer.executor)
377}
378
379pub fn load_trainer<B: Backend>(
383 path: impl AsRef<Path>,
384 trainer: &mut crate::exec::Trainer<B>,
385) -> shrew_core::Result<usize> {
386 load(path, &mut trainer.executor)
387}
388
389pub fn to_bytes<B: Backend>(tensors: &[(String, Tensor<B>)]) -> shrew_core::Result<Vec<u8>> {
395 let mut buf = Vec::new();
396 write_checkpoint(&mut buf, tensors)?;
397 Ok(buf)
398}
399
400pub fn from_bytes<B: Backend>(
402 data: &[u8],
403 device: &B::Device,
404) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
405 let mut cursor = std::io::Cursor::new(data);
406 read_checkpoint(&mut cursor, device)
407}
408
409const TRAINING_VERSION: u32 = 2;
444const TAG_MODEL: u8 = 0x01;
445const TAG_OPTIMIZER: u8 = 0x02;
446const TAG_METADATA: u8 = 0x03;
447const TAG_EOF: u8 = 0xFF;
448
449#[derive(Debug, Clone)]
454pub struct TrainingCheckpoint<B: Backend> {
455 pub model_params: Vec<(String, Tensor<B>)>,
457 pub optimizer_state: Option<OptimizerState>,
459 pub epoch: u64,
461 pub global_step: u64,
463 pub best_loss: f64,
465 pub loss_history: Vec<f64>,
467}
468
469impl<B: Backend> Default for TrainingCheckpoint<B> {
470 fn default() -> Self {
471 Self::new()
472 }
473}
474
475impl<B: Backend> TrainingCheckpoint<B> {
476 pub fn new() -> Self {
478 TrainingCheckpoint {
479 model_params: Vec::new(),
480 optimizer_state: None,
481 epoch: 0,
482 global_step: 0,
483 best_loss: f64::INFINITY,
484 loss_history: Vec::new(),
485 }
486 }
487
488 pub fn from_executor(executor: &Executor<B>) -> Self {
490 TrainingCheckpoint {
491 model_params: executor.named_params(),
492 optimizer_state: None,
493 epoch: 0,
494 global_step: 0,
495 best_loss: f64::INFINITY,
496 loss_history: Vec::new(),
497 }
498 }
499
500 pub fn with_optimizer_state(mut self, state: OptimizerState) -> Self {
502 self.optimizer_state = Some(state);
503 self
504 }
505
506 pub fn with_epoch(mut self, epoch: u64) -> Self {
508 self.epoch = epoch;
509 self
510 }
511
512 pub fn with_global_step(mut self, step: u64) -> Self {
514 self.global_step = step;
515 self
516 }
517
518 pub fn with_best_loss(mut self, loss: f64) -> Self {
520 self.best_loss = loss;
521 self
522 }
523
524 pub fn with_loss_history(mut self, history: Vec<f64>) -> Self {
526 self.loss_history = history;
527 self
528 }
529}
530
531fn write_optimizer_state(w: &mut impl Write, state: &OptimizerState) -> std::io::Result<()> {
533 write_u8(w, TAG_OPTIMIZER)?;
534
535 let type_bytes = state.optimizer_type.as_bytes();
537 write_u32(w, type_bytes.len() as u32)?;
538 write_bytes(w, type_bytes)?;
539
540 let scalars: Vec<_> = state.scalars.iter().collect();
542 write_u32(w, scalars.len() as u32)?;
543 for (key, &value) in &scalars {
544 let key_bytes = key.as_bytes();
545 write_u32(w, key_bytes.len() as u32)?;
546 write_bytes(w, key_bytes)?;
547 write_bytes(w, &value.to_le_bytes())?;
548 }
549
550 let buffers: Vec<_> = state.buffers.iter().collect();
552 write_u32(w, buffers.len() as u32)?;
553 for (key, data) in &buffers {
554 let key_bytes = key.as_bytes();
555 write_u32(w, key_bytes.len() as u32)?;
556 write_bytes(w, key_bytes)?;
557 write_u64(w, data.len() as u64)?;
558 for &val in data.iter() {
559 write_bytes(w, &val.to_le_bytes())?;
560 }
561 }
562
563 Ok(())
564}
565
566fn read_optimizer_state(r: &mut impl Read) -> std::io::Result<OptimizerState> {
568 let type_len = read_u32(r)? as usize;
570 let type_bytes = read_bytes(r, type_len)?;
571 let type_name = String::from_utf8(type_bytes)
572 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
573
574 let mut state = OptimizerState::new(type_name);
575
576 let n_scalars = read_u32(r)? as usize;
578 for _ in 0..n_scalars {
579 let key_len = read_u32(r)? as usize;
580 let key_bytes = read_bytes(r, key_len)?;
581 let key = String::from_utf8(key_bytes)
582 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
583 let value_bytes = read_bytes(r, 8)?;
584 let value = f64::from_le_bytes([
585 value_bytes[0],
586 value_bytes[1],
587 value_bytes[2],
588 value_bytes[3],
589 value_bytes[4],
590 value_bytes[5],
591 value_bytes[6],
592 value_bytes[7],
593 ]);
594 state.set_scalar(key, value);
595 }
596
597 let n_buffers = read_u32(r)? as usize;
599 for _ in 0..n_buffers {
600 let key_len = read_u32(r)? as usize;
601 let key_bytes = read_bytes(r, key_len)?;
602 let key = String::from_utf8(key_bytes)
603 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
604 let buf_len = read_u64(r)? as usize;
605 let mut data = Vec::with_capacity(buf_len);
606 for _ in 0..buf_len {
607 let val_bytes = read_bytes(r, 8)?;
608 data.push(f64::from_le_bytes([
609 val_bytes[0],
610 val_bytes[1],
611 val_bytes[2],
612 val_bytes[3],
613 val_bytes[4],
614 val_bytes[5],
615 val_bytes[6],
616 val_bytes[7],
617 ]));
618 }
619 state.set_buffer(key, data);
620 }
621
622 Ok(state)
623}
624
625pub fn write_training_checkpoint<B: Backend>(
627 writer: &mut impl Write,
628 checkpoint: &TrainingCheckpoint<B>,
629) -> shrew_core::Result<()> {
630 write_bytes(writer, MAGIC).map_err(io_err)?;
632 write_u32(writer, TRAINING_VERSION).map_err(io_err)?;
633
634 write_u8(writer, TAG_MODEL).map_err(io_err)?;
636 write_u32(writer, checkpoint.model_params.len() as u32).map_err(io_err)?;
637 for (key, tensor) in &checkpoint.model_params {
638 let key_bytes = key.as_bytes();
639 write_u32(writer, key_bytes.len() as u32).map_err(io_err)?;
640 write_bytes(writer, key_bytes).map_err(io_err)?;
641 write_u8(writer, dtype_to_u8(tensor.dtype())).map_err(io_err)?;
642 let dims = tensor.dims();
643 write_u32(writer, dims.len() as u32).map_err(io_err)?;
644 for &d in dims {
645 write_u32(writer, d as u32).map_err(io_err)?;
646 }
647 let data = tensor_to_bytes(tensor)?;
648 write_u64(writer, data.len() as u64).map_err(io_err)?;
649 write_bytes(writer, &data).map_err(io_err)?;
650 }
651
652 if let Some(ref opt_state) = checkpoint.optimizer_state {
654 write_optimizer_state(writer, opt_state).map_err(io_err)?;
655 }
656
657 write_u8(writer, TAG_METADATA).map_err(io_err)?;
659 write_u64(writer, checkpoint.epoch).map_err(io_err)?;
660 write_u64(writer, checkpoint.global_step).map_err(io_err)?;
661 write_bytes(writer, &checkpoint.best_loss.to_le_bytes()).map_err(io_err)?;
662 write_u32(writer, checkpoint.loss_history.len() as u32).map_err(io_err)?;
663 for &loss in &checkpoint.loss_history {
664 write_bytes(writer, &loss.to_le_bytes()).map_err(io_err)?;
665 }
666
667 write_u8(writer, TAG_EOF).map_err(io_err)?;
669
670 Ok(())
671}
672
673pub fn read_training_checkpoint<B: Backend>(
675 reader: &mut impl Read,
676 device: &B::Device,
677) -> shrew_core::Result<TrainingCheckpoint<B>> {
678 let mut magic = [0u8; 4];
680 reader.read_exact(&mut magic).map_err(io_err)?;
681 if &magic != MAGIC {
682 return Err(shrew_core::Error::msg(format!(
683 "Invalid checkpoint: expected magic {:?}, got {:?}",
684 MAGIC, magic
685 )));
686 }
687
688 let version = read_u32(reader).map_err(io_err)?;
689 if version != TRAINING_VERSION {
690 return Err(shrew_core::Error::msg(format!(
691 "Unsupported training checkpoint version: {} (expected {})",
692 version, TRAINING_VERSION
693 )));
694 }
695
696 let mut model_params = Vec::new();
697 let mut optimizer_state = None;
698 let mut epoch = 0u64;
699 let mut global_step = 0u64;
700 let mut best_loss = f64::INFINITY;
701 let mut loss_history = Vec::new();
702
703 loop {
704 let tag = read_u8(reader).map_err(io_err)?;
705
706 match tag {
707 TAG_MODEL => {
708 let count = read_u32(reader).map_err(io_err)? as usize;
709 for _ in 0..count {
710 let key_len = read_u32(reader).map_err(io_err)? as usize;
711 let key_bytes = read_bytes(reader, key_len).map_err(io_err)?;
712 let key = String::from_utf8(key_bytes)
713 .map_err(|e| shrew_core::Error::msg(format!("Invalid UTF-8 key: {e}")))?;
714
715 let dtype = u8_to_dtype(read_u8(reader).map_err(io_err)?)?;
716 let ndim = read_u32(reader).map_err(io_err)? as usize;
717 let mut dims = Vec::with_capacity(ndim);
718 for _ in 0..ndim {
719 dims.push(read_u32(reader).map_err(io_err)? as usize);
720 }
721 let data_len = read_u64(reader).map_err(io_err)? as usize;
722 let data = read_bytes(reader, data_len).map_err(io_err)?;
723 let tensor = tensor_from_bytes::<B>(&data, dims, dtype, device)?;
724 model_params.push((key, tensor));
725 }
726 }
727 TAG_OPTIMIZER => {
728 optimizer_state =
729 Some(read_optimizer_state(reader).map_err(|e| {
730 shrew_core::Error::msg(format!("Optimizer state error: {e}"))
731 })?);
732 }
733 TAG_METADATA => {
734 epoch = read_u64(reader).map_err(io_err)?;
735 global_step = read_u64(reader).map_err(io_err)?;
736 let bl_bytes = read_bytes(reader, 8).map_err(io_err)?;
737 best_loss = f64::from_le_bytes([
738 bl_bytes[0],
739 bl_bytes[1],
740 bl_bytes[2],
741 bl_bytes[3],
742 bl_bytes[4],
743 bl_bytes[5],
744 bl_bytes[6],
745 bl_bytes[7],
746 ]);
747 let n_losses = read_u32(reader).map_err(io_err)? as usize;
748 loss_history = Vec::with_capacity(n_losses);
749 for _ in 0..n_losses {
750 let lb = read_bytes(reader, 8).map_err(io_err)?;
751 loss_history.push(f64::from_le_bytes([
752 lb[0], lb[1], lb[2], lb[3], lb[4], lb[5], lb[6], lb[7],
753 ]));
754 }
755 }
756 TAG_EOF => break,
757 other => {
758 return Err(shrew_core::Error::msg(format!(
759 "Unknown section tag in training checkpoint: 0x{other:02X}"
760 )));
761 }
762 }
763 }
764
765 Ok(TrainingCheckpoint {
766 model_params,
767 optimizer_state,
768 epoch,
769 global_step,
770 best_loss,
771 loss_history,
772 })
773}
774
775pub fn save_training<B: Backend>(
797 path: impl AsRef<Path>,
798 checkpoint: &TrainingCheckpoint<B>,
799) -> shrew_core::Result<()> {
800 let file = File::create(path.as_ref()).map_err(io_err)?;
801 let mut writer = BufWriter::new(file);
802 write_training_checkpoint(&mut writer, checkpoint)?;
803 writer.flush().map_err(io_err)?;
804 Ok(())
805}
806
807pub fn load_training<B: Backend>(
821 path: impl AsRef<Path>,
822 device: &B::Device,
823) -> shrew_core::Result<TrainingCheckpoint<B>> {
824 let file = File::open(path.as_ref()).map_err(io_err)?;
825 let mut reader = BufReader::new(file);
826 read_training_checkpoint(&mut reader, device)
827}
828
829pub fn training_to_bytes<B: Backend>(
831 checkpoint: &TrainingCheckpoint<B>,
832) -> shrew_core::Result<Vec<u8>> {
833 let mut buf = Vec::new();
834 write_training_checkpoint(&mut buf, checkpoint)?;
835 Ok(buf)
836}
837
838pub fn training_from_bytes<B: Backend>(
840 data: &[u8],
841 device: &B::Device,
842) -> shrew_core::Result<TrainingCheckpoint<B>> {
843 let mut cursor = std::io::Cursor::new(data);
844 read_training_checkpoint(&mut cursor, device)
845}
846
847#[cfg(test)]
852mod tests {
853 use super::*;
854 use shrew_cpu::{CpuBackend, CpuDevice};
855
856 type CpuTensor = Tensor<CpuBackend>;
857
858 #[test]
859 fn test_roundtrip_f32() {
860 let dev = CpuDevice;
861 let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2), DType::F32, &dev).unwrap();
862
863 let tensors = vec![("w".to_string(), t.clone())];
864 let bytes = to_bytes(&tensors).unwrap();
865 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
866
867 assert_eq!(loaded.len(), 1);
868 assert_eq!(loaded[0].0, "w");
869 assert_eq!(loaded[0].1.dims(), &[2, 2]);
870 assert_eq!(loaded[0].1.dtype(), DType::F32);
871
872 let orig = t.to_f64_vec().unwrap();
873 let restored = loaded[0].1.to_f64_vec().unwrap();
874 assert_eq!(orig, restored);
875 }
876
877 #[test]
878 fn test_roundtrip_f64() {
879 let dev = CpuDevice;
880 let vals = vec![std::f64::consts::PI, std::f64::consts::E, 0.0, -1.5];
881 let t = CpuTensor::from_f64_slice(&vals, (4,), DType::F64, &dev).unwrap();
882
883 let tensors = vec![("precision_test".to_string(), t.clone())];
884 let bytes = to_bytes(&tensors).unwrap();
885 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
886
887 let orig = t.to_f64_vec().unwrap();
888 let restored = loaded[0].1.to_f64_vec().unwrap();
889 assert_eq!(orig, restored);
891 }
892
893 #[test]
894 fn test_roundtrip_u8() {
895 let dev = CpuDevice;
896 let t = CpuTensor::from_f64_slice(&[0.0, 128.0, 255.0], (3,), DType::U8, &dev).unwrap();
897
898 let tensors = vec![("pixels".to_string(), t.clone())];
899 let bytes = to_bytes(&tensors).unwrap();
900 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
901
902 assert_eq!(loaded[0].1.dtype(), DType::U8);
903 let orig = t.to_f64_vec().unwrap();
904 let restored = loaded[0].1.to_f64_vec().unwrap();
905 assert_eq!(orig, restored);
906 }
907
908 #[test]
909 fn test_roundtrip_multiple_tensors() {
910 let dev = CpuDevice;
911 let w =
912 CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], (2, 3), DType::F64, &dev)
913 .unwrap();
914 let b = CpuTensor::from_f64_slice(&[0.1, 0.2, 0.3], (1, 3), DType::F64, &dev).unwrap();
915
916 let tensors = vec![
917 ("Forward/w1".to_string(), w.clone()),
918 ("Forward/b1".to_string(), b.clone()),
919 ];
920 let bytes = to_bytes(&tensors).unwrap();
921 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
922
923 assert_eq!(loaded.len(), 2);
924 assert_eq!(loaded[0].0, "Forward/w1");
925 assert_eq!(loaded[0].1.dims(), &[2, 3]);
926 assert_eq!(loaded[1].0, "Forward/b1");
927 assert_eq!(loaded[1].1.dims(), &[1, 3]);
928 }
929
930 #[test]
931 fn test_invalid_magic() {
932 let data = b"BADXsomejunk";
933 let result = from_bytes::<CpuBackend>(data, &CpuDevice);
934 assert!(result.is_err());
935 assert!(result
936 .unwrap_err()
937 .to_string()
938 .contains("Invalid checkpoint"));
939 }
940
941 #[test]
942 fn test_file_roundtrip() {
943 let dev = CpuDevice;
944 let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0], (3,), DType::F32, &dev).unwrap();
945 let tensors = vec![("test".to_string(), t.clone())];
946
947 let path = std::env::temp_dir().join("shrew_test_checkpoint.shrew");
948 save_tensors(&path, &tensors).unwrap();
949 let loaded = load_tensors::<CpuBackend>(&path, &dev).unwrap();
950 std::fs::remove_file(&path).ok();
951
952 assert_eq!(loaded.len(), 1);
953 assert_eq!(loaded[0].0, "test");
954 let orig = t.to_f64_vec().unwrap();
955 let restored = loaded[0].1.to_f64_vec().unwrap();
956 assert_eq!(orig, restored);
957 }
958
959 #[test]
960 fn test_empty_checkpoint() {
961 let tensors: Vec<(String, CpuTensor)> = vec![];
962 let bytes = to_bytes(&tensors).unwrap();
963 let loaded = from_bytes::<CpuBackend>(&bytes, &CpuDevice).unwrap();
964 assert_eq!(loaded.len(), 0);
965 }
966
967 #[test]
968 fn test_3d_tensor_roundtrip() {
969 let dev = CpuDevice;
970 let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
971 let t = CpuTensor::from_f64_slice(&data, (2, 3, 4), DType::F32, &dev).unwrap();
972
973 let tensors = vec![("volume".to_string(), t.clone())];
974 let bytes = to_bytes(&tensors).unwrap();
975 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
976
977 assert_eq!(loaded[0].1.dims(), &[2, 3, 4]);
978 let orig = t.to_f64_vec().unwrap();
979 let restored = loaded[0].1.to_f64_vec().unwrap();
980 for (a, b) in orig.iter().zip(restored.iter()) {
981 assert!((a - b).abs() < 1e-6, "mismatch: {a} vs {b}");
982 }
983 }
984
985 #[test]
986 fn test_training_checkpoint_roundtrip() {
987 use shrew_optim::OptimizerState;
988
989 let dev = CpuDevice;
990 let w = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2), DType::F32, &dev).unwrap();
991 let b = CpuTensor::from_f64_slice(&[0.1, 0.2], (2,), DType::F32, &dev).unwrap();
992
993 let mut opt_state = OptimizerState::new("Adam");
995 opt_state.set_scalar("t", 100.0);
996 opt_state.set_scalar("lr", 0.001);
997 opt_state.set_scalar("beta1", 0.9);
998 opt_state.set_buffer("m.0", vec![0.01, 0.02, 0.03, 0.04]);
999 opt_state.set_buffer("v.0", vec![0.001, 0.002, 0.003, 0.004]);
1000
1001 let ckpt = TrainingCheckpoint {
1002 model_params: vec![
1003 ("layer1/weight".to_string(), w),
1004 ("layer1/bias".to_string(), b),
1005 ],
1006 optimizer_state: Some(opt_state),
1007 epoch: 10,
1008 global_step: 5000,
1009 best_loss: 0.032,
1010 loss_history: vec![1.5, 0.8, 0.4, 0.1, 0.05, 0.032],
1011 };
1012
1013 let bytes = training_to_bytes(&ckpt).unwrap();
1014 let loaded = training_from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
1015
1016 assert_eq!(loaded.model_params.len(), 2);
1018 assert_eq!(loaded.model_params[0].0, "layer1/weight");
1019 assert_eq!(loaded.model_params[0].1.dims(), &[2, 2]);
1020 assert_eq!(loaded.model_params[1].0, "layer1/bias");
1021
1022 assert_eq!(loaded.epoch, 10);
1024 assert_eq!(loaded.global_step, 5000);
1025 assert!((loaded.best_loss - 0.032).abs() < 1e-10);
1026 assert_eq!(loaded.loss_history.len(), 6);
1027 assert!((loaded.loss_history[0] - 1.5).abs() < 1e-10);
1028 assert!((loaded.loss_history[5] - 0.032).abs() < 1e-10);
1029
1030 let opt = loaded.optimizer_state.unwrap();
1032 assert_eq!(opt.optimizer_type, "Adam");
1033 assert_eq!(opt.get_scalar("t"), Some(100.0));
1034 assert_eq!(opt.get_scalar("lr"), Some(0.001));
1035 assert_eq!(opt.get_scalar("beta1"), Some(0.9));
1036 let m0 = opt.get_buffer("m.0").unwrap();
1037 assert_eq!(m0.len(), 4);
1038 assert!((m0[2] - 0.03).abs() < 1e-10);
1039 }
1040
1041 #[test]
1042 fn test_training_checkpoint_no_optimizer() {
1043 let dev = CpuDevice;
1044 let t = CpuTensor::from_f64_slice(&[1.0, 2.0], (2,), DType::F64, &dev).unwrap();
1045
1046 let ckpt = TrainingCheckpoint {
1047 model_params: vec![("w".to_string(), t)],
1048 optimizer_state: None,
1049 epoch: 5,
1050 global_step: 250,
1051 best_loss: 0.1,
1052 loss_history: vec![0.5, 0.3, 0.1],
1053 };
1054
1055 let bytes = training_to_bytes(&ckpt).unwrap();
1056 let loaded = training_from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
1057
1058 assert_eq!(loaded.model_params.len(), 1);
1059 assert!(loaded.optimizer_state.is_none());
1060 assert_eq!(loaded.epoch, 5);
1061 assert_eq!(loaded.global_step, 250);
1062 }
1063}