shrew/
checkpoint.rs

1// =============================================================================
2// Checkpoint — Save and load model parameters
3// =============================================================================
4//
5// Binary checkpoint format (.shrew):
6//
7//   Header:
8//     magic:   [u8; 4]  = b"SHRW"
9//     version: u32 LE   = 1
10//     count:   u32 LE   = number of tensors
11//
12//   For each tensor:
13//     key_len:  u32 LE
14//     key:      [u8; key_len]  (UTF-8, format: "graph/param")
15//     dtype:    u8             (0=F32, 1=F64, 2=U8, 3=U32, 4=I64)
16//     ndim:     u32 LE
17//     dims:     [u32 LE; ndim]
18//     data_len: u64 LE         (in bytes)
19//     data:     [u8; data_len] (raw little-endian typed data)
20//
21// Usage:
22//   // Save
23//   checkpoint::save("model.shrew", &executor)?;
24//   checkpoint::save_tensors("weights.shrew", &named_tensors)?;
25//
26//   // Load
27//   checkpoint::load("model.shrew", &mut executor)?;
28//   let tensors = checkpoint::load_tensors::<CpuBackend>("weights.shrew", &device)?;
29
30use std::collections::HashMap;
31use std::fs::File;
32use std::io::{BufReader, BufWriter, Read, Write};
33use std::path::Path;
34
35use shrew_core::backend::Backend;
36use shrew_core::tensor::Tensor;
37use shrew_core::DType;
38use shrew_optim::OptimizerState;
39
40use crate::exec::Executor;
41
42// ─────────────────────────────────────────────────────────────────────────────
43// Constants
44// ─────────────────────────────────────────────────────────────────────────────
45
46const MAGIC: &[u8; 4] = b"SHRW";
47const VERSION: u32 = 1;
48
49// ─────────────────────────────────────────────────────────────────────────────
50// DType <-> u8 encoding
51// ─────────────────────────────────────────────────────────────────────────────
52
53fn dtype_to_u8(dtype: DType) -> u8 {
54    match dtype {
55        DType::F32 => 0,
56        DType::F64 => 1,
57        DType::U8 => 2,
58        DType::U32 => 3,
59        DType::I64 => 4,
60        DType::F16 => 5,
61        DType::BF16 => 6,
62    }
63}
64
65fn u8_to_dtype(v: u8) -> shrew_core::Result<DType> {
66    match v {
67        0 => Ok(DType::F32),
68        1 => Ok(DType::F64),
69        2 => Ok(DType::U8),
70        3 => Ok(DType::U32),
71        4 => Ok(DType::I64),
72        5 => Ok(DType::F16),
73        6 => Ok(DType::BF16),
74        _ => Err(shrew_core::Error::msg(format!("Unknown dtype tag: {v}"))),
75    }
76}
77
78// ─────────────────────────────────────────────────────────────────────────────
79// Raw bytes extraction from tensor (via f64 roundtrip for portability)
80// ─────────────────────────────────────────────────────────────────────────────
81
82/// Convert a tensor to raw LE bytes, preserving the original dtype.
83fn tensor_to_bytes<B: Backend>(tensor: &Tensor<B>) -> shrew_core::Result<Vec<u8>> {
84    // Make contiguous first, then extract data as f64 and convert to native bytes
85    let t = tensor.contiguous()?;
86    let data = t.to_f64_vec()?;
87    let dtype = t.dtype();
88
89    Ok(match dtype {
90        DType::F16 => data
91            .iter()
92            .flat_map(|&v| half::f16::from_f64(v).to_le_bytes())
93            .collect(),
94        DType::BF16 => data
95            .iter()
96            .flat_map(|&v| half::bf16::from_f64(v).to_le_bytes())
97            .collect(),
98        DType::F32 => data
99            .iter()
100            .flat_map(|&v| (v as f32).to_le_bytes())
101            .collect(),
102        DType::F64 => data.iter().flat_map(|&v| v.to_le_bytes()).collect(),
103        DType::U8 => data.iter().map(|&v| v as u8).collect(),
104        DType::U32 => data
105            .iter()
106            .flat_map(|&v| (v as u32).to_le_bytes())
107            .collect(),
108        DType::I64 => data
109            .iter()
110            .flat_map(|&v| (v as i64).to_le_bytes())
111            .collect(),
112    })
113}
114
115/// Reconstruct a tensor from raw LE bytes + metadata.
116fn tensor_from_bytes<B: Backend>(
117    bytes: &[u8],
118    shape: Vec<usize>,
119    dtype: DType,
120    device: &B::Device,
121) -> shrew_core::Result<Tensor<B>> {
122    let data_f64: Vec<f64> = match dtype {
123        DType::F16 => bytes
124            .chunks_exact(2)
125            .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f64())
126            .collect(),
127        DType::BF16 => bytes
128            .chunks_exact(2)
129            .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f64())
130            .collect(),
131        DType::F32 => bytes
132            .chunks_exact(4)
133            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
134            .collect(),
135        DType::F64 => bytes
136            .chunks_exact(8)
137            .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
138            .collect(),
139        DType::U8 => bytes.iter().map(|&b| b as f64).collect(),
140        DType::U32 => bytes
141            .chunks_exact(4)
142            .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
143            .collect(),
144        DType::I64 => bytes
145            .chunks_exact(8)
146            .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f64)
147            .collect(),
148    };
149
150    Tensor::<B>::from_f64_slice(&data_f64, shape, dtype, device)
151}
152
153// ─────────────────────────────────────────────────────────────────────────────
154// Low-level IO helpers
155// ─────────────────────────────────────────────────────────────────────────────
156
157fn write_u8(w: &mut impl Write, v: u8) -> std::io::Result<()> {
158    w.write_all(&[v])
159}
160
161fn write_u32(w: &mut impl Write, v: u32) -> std::io::Result<()> {
162    w.write_all(&v.to_le_bytes())
163}
164
165fn write_u64(w: &mut impl Write, v: u64) -> std::io::Result<()> {
166    w.write_all(&v.to_le_bytes())
167}
168
169fn write_bytes(w: &mut impl Write, data: &[u8]) -> std::io::Result<()> {
170    w.write_all(data)
171}
172
173fn read_u8(r: &mut impl Read) -> std::io::Result<u8> {
174    let mut buf = [0u8; 1];
175    r.read_exact(&mut buf)?;
176    Ok(buf[0])
177}
178
179fn read_u32(r: &mut impl Read) -> std::io::Result<u32> {
180    let mut buf = [0u8; 4];
181    r.read_exact(&mut buf)?;
182    Ok(u32::from_le_bytes(buf))
183}
184
185fn read_u64(r: &mut impl Read) -> std::io::Result<u64> {
186    let mut buf = [0u8; 8];
187    r.read_exact(&mut buf)?;
188    Ok(u64::from_le_bytes(buf))
189}
190
191fn read_bytes(r: &mut impl Read, len: usize) -> std::io::Result<Vec<u8>> {
192    let mut buf = vec![0u8; len];
193    r.read_exact(&mut buf)?;
194    Ok(buf)
195}
196
197// ─────────────────────────────────────────────────────────────────────────────
198// Write checkpoint
199// ─────────────────────────────────────────────────────────────────────────────
200
201/// Write a set of named tensors to a writer in the Shrew checkpoint format.
202pub fn write_checkpoint<B: Backend>(
203    writer: &mut impl Write,
204    tensors: &[(String, Tensor<B>)],
205) -> shrew_core::Result<()> {
206    // Header
207    write_bytes(writer, MAGIC).map_err(io_err)?;
208    write_u32(writer, VERSION).map_err(io_err)?;
209    write_u32(writer, tensors.len() as u32).map_err(io_err)?;
210
211    // Each tensor
212    for (key, tensor) in tensors {
213        let key_bytes = key.as_bytes();
214        write_u32(writer, key_bytes.len() as u32).map_err(io_err)?;
215        write_bytes(writer, key_bytes).map_err(io_err)?;
216
217        write_u8(writer, dtype_to_u8(tensor.dtype())).map_err(io_err)?;
218
219        let dims = tensor.dims();
220        write_u32(writer, dims.len() as u32).map_err(io_err)?;
221        for &d in dims {
222            write_u32(writer, d as u32).map_err(io_err)?;
223        }
224
225        let data = tensor_to_bytes(tensor)?;
226        write_u64(writer, data.len() as u64).map_err(io_err)?;
227        write_bytes(writer, &data).map_err(io_err)?;
228    }
229
230    Ok(())
231}
232
233/// Read named tensors from a reader in the Shrew checkpoint format.
234pub fn read_checkpoint<B: Backend>(
235    reader: &mut impl Read,
236    device: &B::Device,
237) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
238    // Header
239    let mut magic = [0u8; 4];
240    reader.read_exact(&mut magic).map_err(io_err)?;
241    if &magic != MAGIC {
242        return Err(shrew_core::Error::msg(format!(
243            "Invalid checkpoint: expected magic {:?}, got {:?}",
244            MAGIC, magic
245        )));
246    }
247
248    let version = read_u32(reader).map_err(io_err)?;
249    if version != VERSION {
250        return Err(shrew_core::Error::msg(format!(
251            "Unsupported checkpoint version: {} (expected {})",
252            version, VERSION
253        )));
254    }
255
256    let count = read_u32(reader).map_err(io_err)? as usize;
257    let mut tensors = Vec::with_capacity(count);
258
259    for _ in 0..count {
260        let key_len = read_u32(reader).map_err(io_err)? as usize;
261        let key_bytes = read_bytes(reader, key_len).map_err(io_err)?;
262        let key = String::from_utf8(key_bytes)
263            .map_err(|e| shrew_core::Error::msg(format!("Invalid UTF-8 key: {e}")))?;
264
265        let dtype = u8_to_dtype(read_u8(reader).map_err(io_err)?)?;
266
267        let ndim = read_u32(reader).map_err(io_err)? as usize;
268        let mut dims = Vec::with_capacity(ndim);
269        for _ in 0..ndim {
270            dims.push(read_u32(reader).map_err(io_err)? as usize);
271        }
272
273        let data_len = read_u64(reader).map_err(io_err)? as usize;
274        let data = read_bytes(reader, data_len).map_err(io_err)?;
275
276        let tensor = tensor_from_bytes::<B>(&data, dims, dtype, device)?;
277        tensors.push((key, tensor));
278    }
279
280    Ok(tensors)
281}
282
283fn io_err(e: std::io::Error) -> shrew_core::Error {
284    shrew_core::Error::msg(format!("IO error: {e}"))
285}
286
287// ─────────────────────────────────────────────────────────────────────────────
288// High-level API — save/load named tensors
289// ─────────────────────────────────────────────────────────────────────────────
290
291/// Save a list of named tensors to a file.
292///
293/// ```rust,no_run
294/// use shrew::checkpoint;
295/// use shrew::prelude::*;
296///
297/// let w1 = Tensor::<CpuBackend>::zeros((2, 3), DType::F32, &CpuDevice).unwrap();
298/// let b1 = Tensor::<CpuBackend>::zeros((2,), DType::F32, &CpuDevice).unwrap();
299/// let tensors = vec![
300///     ("w1".to_string(), w1),
301///     ("b1".to_string(), b1),
302/// ];
303/// checkpoint::save_tensors("weights.shrew", &tensors).unwrap();
304/// ```
305pub fn save_tensors<B: Backend>(
306    path: impl AsRef<Path>,
307    tensors: &[(String, Tensor<B>)],
308) -> shrew_core::Result<()> {
309    let file = File::create(path.as_ref()).map_err(io_err)?;
310    let mut writer = BufWriter::new(file);
311    write_checkpoint(&mut writer, tensors)?;
312    writer.flush().map_err(io_err)?;
313    Ok(())
314}
315
316/// Load named tensors from a file.
317///
318/// ```rust,no_run
319/// use shrew::checkpoint;
320/// use shrew::prelude::*;
321///
322/// let tensors = checkpoint::load_tensors::<CpuBackend>("weights.shrew", &CpuDevice).unwrap();
323/// for (name, tensor) in &tensors {
324///     println!("{name}: {:?}", tensor.dims());
325/// }
326/// ```
327pub fn load_tensors<B: Backend>(
328    path: impl AsRef<Path>,
329    device: &B::Device,
330) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
331    let file = File::open(path.as_ref()).map_err(io_err)?;
332    let mut reader = BufReader::new(file);
333    read_checkpoint(&mut reader, device)
334}
335
336// ─────────────────────────────────────────────────────────────────────────────
337// High-level API — save/load Executor parameters
338// ─────────────────────────────────────────────────────────────────────────────
339
340/// Save all parameters from an Executor to a checkpoint file.
341///
342/// Parameters are stored with keys in the format `"graph_name/param_name"`.
343pub fn save<B: Backend>(path: impl AsRef<Path>, executor: &Executor<B>) -> shrew_core::Result<()> {
344    let named = executor.named_params();
345    save_tensors(path, &named)
346}
347
348/// Load parameters from a checkpoint file into an Executor.
349///
350/// Only parameters present in the checkpoint will be updated.
351/// Parameters not found in the file keep their current values.
352///
353/// Returns the number of parameters loaded.
354pub fn load<B: Backend>(
355    path: impl AsRef<Path>,
356    executor: &mut Executor<B>,
357) -> shrew_core::Result<usize> {
358    let tensors = load_tensors::<B>(path, executor.device())?;
359    let loaded: HashMap<String, Tensor<B>> = tensors.into_iter().collect();
360
361    let mut count = 0;
362    for (key, tensor) in &loaded {
363        if executor.set_param_by_key(key, tensor.clone()) {
364            count += 1;
365        }
366    }
367
368    Ok(count)
369}
370
371/// Save all parameters from a Trainer to a checkpoint file.
372pub fn save_trainer<B: Backend>(
373    path: impl AsRef<Path>,
374    trainer: &crate::exec::Trainer<B>,
375) -> shrew_core::Result<()> {
376    save(path, &trainer.executor)
377}
378
379/// Load parameters from a checkpoint file into a Trainer.
380///
381/// Returns the number of parameters loaded.
382pub fn load_trainer<B: Backend>(
383    path: impl AsRef<Path>,
384    trainer: &mut crate::exec::Trainer<B>,
385) -> shrew_core::Result<usize> {
386    load(path, &mut trainer.executor)
387}
388
389// ─────────────────────────────────────────────────────────────────────────────
390// In-memory checkpoint (for testing and transfer)
391// ─────────────────────────────────────────────────────────────────────────────
392
393/// Serialize named tensors to an in-memory byte vector.
394pub fn to_bytes<B: Backend>(tensors: &[(String, Tensor<B>)]) -> shrew_core::Result<Vec<u8>> {
395    let mut buf = Vec::new();
396    write_checkpoint(&mut buf, tensors)?;
397    Ok(buf)
398}
399
400/// Deserialize named tensors from an in-memory byte slice.
401pub fn from_bytes<B: Backend>(
402    data: &[u8],
403    device: &B::Device,
404) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
405    let mut cursor = std::io::Cursor::new(data);
406    read_checkpoint(&mut cursor, device)
407}
408
409// ─────────────────────────────────────────────────────────────────────────────
410// Training Checkpoint — Full training state (model + optimizer + metadata)
411// ─────────────────────────────────────────────────────────────────────────────
412//
413// Binary training checkpoint format (.shrew v2):
414//
415//   Header:
416//     magic:   [u8; 4]  = b"SHRW"
417//     version: u32 LE   = 2
418//
419//   Section 1: Model Parameters
420//     tag: u8 = 0x01
421//     count: u32 LE
422//     [tensors...]  (same format as v1)
423//
424//   Section 2: Optimizer State
425//     tag: u8 = 0x02
426//     type_len: u32 LE
427//     type_name: [u8; type_len]       (UTF-8, e.g. "Adam")
428//     n_scalars: u32 LE
429//     [key_len: u32, key: [u8], value: f64] × n_scalars
430//     n_buffers: u32 LE
431//     [key_len: u32, key: [u8], buf_len: u64, [f64 LE] × buf_len] × n_buffers
432//
433//   Section 3: Metadata
434//     tag: u8 = 0x03
435//     epoch: u64 LE
436//     global_step: u64 LE
437//     best_loss: f64 LE
438//     n_loss_history: u32 LE
439//     [f64 LE] × n_loss_history
440//
441//   EOF marker: u8 = 0xFF
442
443const TRAINING_VERSION: u32 = 2;
444const TAG_MODEL: u8 = 0x01;
445const TAG_OPTIMIZER: u8 = 0x02;
446const TAG_METADATA: u8 = 0x03;
447const TAG_EOF: u8 = 0xFF;
448
449/// Complete training checkpoint: model weights + optimizer state + training metadata.
450///
451/// Enables full training resume — not just the model parameters, but all the
452/// internal state that would be lost if training were interrupted.
453#[derive(Debug, Clone)]
454pub struct TrainingCheckpoint<B: Backend> {
455    /// Named model parameters (same as standard checkpoint).
456    pub model_params: Vec<(String, Tensor<B>)>,
457    /// Optimizer internal state (momentum buffers, step counters, etc.)
458    pub optimizer_state: Option<OptimizerState>,
459    /// Current training epoch (0-indexed).
460    pub epoch: u64,
461    /// Global optimization step counter.
462    pub global_step: u64,
463    /// Best loss value seen during training.
464    pub best_loss: f64,
465    /// Per-epoch loss history.
466    pub loss_history: Vec<f64>,
467}
468
469impl<B: Backend> Default for TrainingCheckpoint<B> {
470    fn default() -> Self {
471        Self::new()
472    }
473}
474
475impl<B: Backend> TrainingCheckpoint<B> {
476    /// Create an empty checkpoint (no model params).
477    pub fn new() -> Self {
478        TrainingCheckpoint {
479            model_params: Vec::new(),
480            optimizer_state: None,
481            epoch: 0,
482            global_step: 0,
483            best_loss: f64::INFINITY,
484            loss_history: Vec::new(),
485        }
486    }
487
488    /// Create a checkpoint from an executor (model params only).
489    pub fn from_executor(executor: &Executor<B>) -> Self {
490        TrainingCheckpoint {
491            model_params: executor.named_params(),
492            optimizer_state: None,
493            epoch: 0,
494            global_step: 0,
495            best_loss: f64::INFINITY,
496            loss_history: Vec::new(),
497        }
498    }
499
500    /// Set the optimizer state in this checkpoint.
501    pub fn with_optimizer_state(mut self, state: OptimizerState) -> Self {
502        self.optimizer_state = Some(state);
503        self
504    }
505
506    /// Set epoch counter.
507    pub fn with_epoch(mut self, epoch: u64) -> Self {
508        self.epoch = epoch;
509        self
510    }
511
512    /// Set global step counter.
513    pub fn with_global_step(mut self, step: u64) -> Self {
514        self.global_step = step;
515        self
516    }
517
518    /// Set best loss.
519    pub fn with_best_loss(mut self, loss: f64) -> Self {
520        self.best_loss = loss;
521        self
522    }
523
524    /// Set loss history.
525    pub fn with_loss_history(mut self, history: Vec<f64>) -> Self {
526        self.loss_history = history;
527        self
528    }
529}
530
531/// Write an OptimizerState section to the writer.
532fn write_optimizer_state(w: &mut impl Write, state: &OptimizerState) -> std::io::Result<()> {
533    write_u8(w, TAG_OPTIMIZER)?;
534
535    // Optimizer type name
536    let type_bytes = state.optimizer_type.as_bytes();
537    write_u32(w, type_bytes.len() as u32)?;
538    write_bytes(w, type_bytes)?;
539
540    // Scalars
541    let scalars: Vec<_> = state.scalars.iter().collect();
542    write_u32(w, scalars.len() as u32)?;
543    for (key, &value) in &scalars {
544        let key_bytes = key.as_bytes();
545        write_u32(w, key_bytes.len() as u32)?;
546        write_bytes(w, key_bytes)?;
547        write_bytes(w, &value.to_le_bytes())?;
548    }
549
550    // Buffers
551    let buffers: Vec<_> = state.buffers.iter().collect();
552    write_u32(w, buffers.len() as u32)?;
553    for (key, data) in &buffers {
554        let key_bytes = key.as_bytes();
555        write_u32(w, key_bytes.len() as u32)?;
556        write_bytes(w, key_bytes)?;
557        write_u64(w, data.len() as u64)?;
558        for &val in data.iter() {
559            write_bytes(w, &val.to_le_bytes())?;
560        }
561    }
562
563    Ok(())
564}
565
566/// Read an OptimizerState section from the reader (tag already consumed).
567fn read_optimizer_state(r: &mut impl Read) -> std::io::Result<OptimizerState> {
568    // Type name
569    let type_len = read_u32(r)? as usize;
570    let type_bytes = read_bytes(r, type_len)?;
571    let type_name = String::from_utf8(type_bytes)
572        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
573
574    let mut state = OptimizerState::new(type_name);
575
576    // Scalars
577    let n_scalars = read_u32(r)? as usize;
578    for _ in 0..n_scalars {
579        let key_len = read_u32(r)? as usize;
580        let key_bytes = read_bytes(r, key_len)?;
581        let key = String::from_utf8(key_bytes)
582            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
583        let value_bytes = read_bytes(r, 8)?;
584        let value = f64::from_le_bytes([
585            value_bytes[0],
586            value_bytes[1],
587            value_bytes[2],
588            value_bytes[3],
589            value_bytes[4],
590            value_bytes[5],
591            value_bytes[6],
592            value_bytes[7],
593        ]);
594        state.set_scalar(key, value);
595    }
596
597    // Buffers
598    let n_buffers = read_u32(r)? as usize;
599    for _ in 0..n_buffers {
600        let key_len = read_u32(r)? as usize;
601        let key_bytes = read_bytes(r, key_len)?;
602        let key = String::from_utf8(key_bytes)
603            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
604        let buf_len = read_u64(r)? as usize;
605        let mut data = Vec::with_capacity(buf_len);
606        for _ in 0..buf_len {
607            let val_bytes = read_bytes(r, 8)?;
608            data.push(f64::from_le_bytes([
609                val_bytes[0],
610                val_bytes[1],
611                val_bytes[2],
612                val_bytes[3],
613                val_bytes[4],
614                val_bytes[5],
615                val_bytes[6],
616                val_bytes[7],
617            ]));
618        }
619        state.set_buffer(key, data);
620    }
621
622    Ok(state)
623}
624
625/// Write a full training checkpoint to a writer.
626pub fn write_training_checkpoint<B: Backend>(
627    writer: &mut impl Write,
628    checkpoint: &TrainingCheckpoint<B>,
629) -> shrew_core::Result<()> {
630    // Header
631    write_bytes(writer, MAGIC).map_err(io_err)?;
632    write_u32(writer, TRAINING_VERSION).map_err(io_err)?;
633
634    // Section 1: Model parameters
635    write_u8(writer, TAG_MODEL).map_err(io_err)?;
636    write_u32(writer, checkpoint.model_params.len() as u32).map_err(io_err)?;
637    for (key, tensor) in &checkpoint.model_params {
638        let key_bytes = key.as_bytes();
639        write_u32(writer, key_bytes.len() as u32).map_err(io_err)?;
640        write_bytes(writer, key_bytes).map_err(io_err)?;
641        write_u8(writer, dtype_to_u8(tensor.dtype())).map_err(io_err)?;
642        let dims = tensor.dims();
643        write_u32(writer, dims.len() as u32).map_err(io_err)?;
644        for &d in dims {
645            write_u32(writer, d as u32).map_err(io_err)?;
646        }
647        let data = tensor_to_bytes(tensor)?;
648        write_u64(writer, data.len() as u64).map_err(io_err)?;
649        write_bytes(writer, &data).map_err(io_err)?;
650    }
651
652    // Section 2: Optimizer state (optional)
653    if let Some(ref opt_state) = checkpoint.optimizer_state {
654        write_optimizer_state(writer, opt_state).map_err(io_err)?;
655    }
656
657    // Section 3: Metadata
658    write_u8(writer, TAG_METADATA).map_err(io_err)?;
659    write_u64(writer, checkpoint.epoch).map_err(io_err)?;
660    write_u64(writer, checkpoint.global_step).map_err(io_err)?;
661    write_bytes(writer, &checkpoint.best_loss.to_le_bytes()).map_err(io_err)?;
662    write_u32(writer, checkpoint.loss_history.len() as u32).map_err(io_err)?;
663    for &loss in &checkpoint.loss_history {
664        write_bytes(writer, &loss.to_le_bytes()).map_err(io_err)?;
665    }
666
667    // EOF
668    write_u8(writer, TAG_EOF).map_err(io_err)?;
669
670    Ok(())
671}
672
673/// Read a full training checkpoint from a reader.
674pub fn read_training_checkpoint<B: Backend>(
675    reader: &mut impl Read,
676    device: &B::Device,
677) -> shrew_core::Result<TrainingCheckpoint<B>> {
678    // Header
679    let mut magic = [0u8; 4];
680    reader.read_exact(&mut magic).map_err(io_err)?;
681    if &magic != MAGIC {
682        return Err(shrew_core::Error::msg(format!(
683            "Invalid checkpoint: expected magic {:?}, got {:?}",
684            MAGIC, magic
685        )));
686    }
687
688    let version = read_u32(reader).map_err(io_err)?;
689    if version != TRAINING_VERSION {
690        return Err(shrew_core::Error::msg(format!(
691            "Unsupported training checkpoint version: {} (expected {})",
692            version, TRAINING_VERSION
693        )));
694    }
695
696    let mut model_params = Vec::new();
697    let mut optimizer_state = None;
698    let mut epoch = 0u64;
699    let mut global_step = 0u64;
700    let mut best_loss = f64::INFINITY;
701    let mut loss_history = Vec::new();
702
703    loop {
704        let tag = read_u8(reader).map_err(io_err)?;
705
706        match tag {
707            TAG_MODEL => {
708                let count = read_u32(reader).map_err(io_err)? as usize;
709                for _ in 0..count {
710                    let key_len = read_u32(reader).map_err(io_err)? as usize;
711                    let key_bytes = read_bytes(reader, key_len).map_err(io_err)?;
712                    let key = String::from_utf8(key_bytes)
713                        .map_err(|e| shrew_core::Error::msg(format!("Invalid UTF-8 key: {e}")))?;
714
715                    let dtype = u8_to_dtype(read_u8(reader).map_err(io_err)?)?;
716                    let ndim = read_u32(reader).map_err(io_err)? as usize;
717                    let mut dims = Vec::with_capacity(ndim);
718                    for _ in 0..ndim {
719                        dims.push(read_u32(reader).map_err(io_err)? as usize);
720                    }
721                    let data_len = read_u64(reader).map_err(io_err)? as usize;
722                    let data = read_bytes(reader, data_len).map_err(io_err)?;
723                    let tensor = tensor_from_bytes::<B>(&data, dims, dtype, device)?;
724                    model_params.push((key, tensor));
725                }
726            }
727            TAG_OPTIMIZER => {
728                optimizer_state =
729                    Some(read_optimizer_state(reader).map_err(|e| {
730                        shrew_core::Error::msg(format!("Optimizer state error: {e}"))
731                    })?);
732            }
733            TAG_METADATA => {
734                epoch = read_u64(reader).map_err(io_err)?;
735                global_step = read_u64(reader).map_err(io_err)?;
736                let bl_bytes = read_bytes(reader, 8).map_err(io_err)?;
737                best_loss = f64::from_le_bytes([
738                    bl_bytes[0],
739                    bl_bytes[1],
740                    bl_bytes[2],
741                    bl_bytes[3],
742                    bl_bytes[4],
743                    bl_bytes[5],
744                    bl_bytes[6],
745                    bl_bytes[7],
746                ]);
747                let n_losses = read_u32(reader).map_err(io_err)? as usize;
748                loss_history = Vec::with_capacity(n_losses);
749                for _ in 0..n_losses {
750                    let lb = read_bytes(reader, 8).map_err(io_err)?;
751                    loss_history.push(f64::from_le_bytes([
752                        lb[0], lb[1], lb[2], lb[3], lb[4], lb[5], lb[6], lb[7],
753                    ]));
754                }
755            }
756            TAG_EOF => break,
757            other => {
758                return Err(shrew_core::Error::msg(format!(
759                    "Unknown section tag in training checkpoint: 0x{other:02X}"
760                )));
761            }
762        }
763    }
764
765    Ok(TrainingCheckpoint {
766        model_params,
767        optimizer_state,
768        epoch,
769        global_step,
770        best_loss,
771        loss_history,
772    })
773}
774
775// ─────────────────────────────────────────────────────────────────────────────
776// High-level API — save/load training checkpoints
777// ─────────────────────────────────────────────────────────────────────────────
778
779/// Save a complete training checkpoint to a file.
780///
781/// This saves model parameters, optimizer state, epoch, and loss history,
782/// enabling full training resume.
783///
784/// ```rust,no_run
785/// use shrew::checkpoint::{self, TrainingCheckpoint};
786/// use shrew::prelude::*;
787///
788/// // During training:
789/// let ckpt = TrainingCheckpoint::<CpuBackend>::new()
790///     .with_epoch(10)
791///     .with_global_step(5000)
792///     .with_best_loss(0.032)
793///     .with_loss_history(vec![0.5, 0.2, 0.1, 0.05, 0.032]);
794/// checkpoint::save_training("training.shrew", &ckpt).unwrap();
795/// ```
796pub fn save_training<B: Backend>(
797    path: impl AsRef<Path>,
798    checkpoint: &TrainingCheckpoint<B>,
799) -> shrew_core::Result<()> {
800    let file = File::create(path.as_ref()).map_err(io_err)?;
801    let mut writer = BufWriter::new(file);
802    write_training_checkpoint(&mut writer, checkpoint)?;
803    writer.flush().map_err(io_err)?;
804    Ok(())
805}
806
807/// Load a complete training checkpoint from a file.
808///
809/// ```rust,no_run
810/// use shrew::checkpoint;
811/// use shrew::prelude::*;
812///
813/// let ckpt = checkpoint::load_training::<CpuBackend>("training.shrew", &CpuDevice).unwrap();
814/// println!("Resuming from epoch {}, step {}", ckpt.epoch, ckpt.global_step);
815/// println!("Best loss so far: {}", ckpt.best_loss);
816///
817/// // Restore model params into executor...
818/// // Restore optimizer state with optimizer.load_state_dict(ckpt.optimizer_state)...
819/// ```
820pub fn load_training<B: Backend>(
821    path: impl AsRef<Path>,
822    device: &B::Device,
823) -> shrew_core::Result<TrainingCheckpoint<B>> {
824    let file = File::open(path.as_ref()).map_err(io_err)?;
825    let mut reader = BufReader::new(file);
826    read_training_checkpoint(&mut reader, device)
827}
828
829/// Serialize a training checkpoint to an in-memory byte vector.
830pub fn training_to_bytes<B: Backend>(
831    checkpoint: &TrainingCheckpoint<B>,
832) -> shrew_core::Result<Vec<u8>> {
833    let mut buf = Vec::new();
834    write_training_checkpoint(&mut buf, checkpoint)?;
835    Ok(buf)
836}
837
838/// Deserialize a training checkpoint from an in-memory byte slice.
839pub fn training_from_bytes<B: Backend>(
840    data: &[u8],
841    device: &B::Device,
842) -> shrew_core::Result<TrainingCheckpoint<B>> {
843    let mut cursor = std::io::Cursor::new(data);
844    read_training_checkpoint(&mut cursor, device)
845}
846
847// ─────────────────────────────────────────────────────────────────────────────
848// Tests
849// ─────────────────────────────────────────────────────────────────────────────
850
851#[cfg(test)]
852mod tests {
853    use super::*;
854    use shrew_cpu::{CpuBackend, CpuDevice};
855
856    type CpuTensor = Tensor<CpuBackend>;
857
858    #[test]
859    fn test_roundtrip_f32() {
860        let dev = CpuDevice;
861        let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2), DType::F32, &dev).unwrap();
862
863        let tensors = vec![("w".to_string(), t.clone())];
864        let bytes = to_bytes(&tensors).unwrap();
865        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
866
867        assert_eq!(loaded.len(), 1);
868        assert_eq!(loaded[0].0, "w");
869        assert_eq!(loaded[0].1.dims(), &[2, 2]);
870        assert_eq!(loaded[0].1.dtype(), DType::F32);
871
872        let orig = t.to_f64_vec().unwrap();
873        let restored = loaded[0].1.to_f64_vec().unwrap();
874        assert_eq!(orig, restored);
875    }
876
877    #[test]
878    fn test_roundtrip_f64() {
879        let dev = CpuDevice;
880        let vals = vec![std::f64::consts::PI, std::f64::consts::E, 0.0, -1.5];
881        let t = CpuTensor::from_f64_slice(&vals, (4,), DType::F64, &dev).unwrap();
882
883        let tensors = vec![("precision_test".to_string(), t.clone())];
884        let bytes = to_bytes(&tensors).unwrap();
885        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
886
887        let orig = t.to_f64_vec().unwrap();
888        let restored = loaded[0].1.to_f64_vec().unwrap();
889        // F64 should be bit-exact
890        assert_eq!(orig, restored);
891    }
892
893    #[test]
894    fn test_roundtrip_u8() {
895        let dev = CpuDevice;
896        let t = CpuTensor::from_f64_slice(&[0.0, 128.0, 255.0], (3,), DType::U8, &dev).unwrap();
897
898        let tensors = vec![("pixels".to_string(), t.clone())];
899        let bytes = to_bytes(&tensors).unwrap();
900        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
901
902        assert_eq!(loaded[0].1.dtype(), DType::U8);
903        let orig = t.to_f64_vec().unwrap();
904        let restored = loaded[0].1.to_f64_vec().unwrap();
905        assert_eq!(orig, restored);
906    }
907
908    #[test]
909    fn test_roundtrip_multiple_tensors() {
910        let dev = CpuDevice;
911        let w =
912            CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], (2, 3), DType::F64, &dev)
913                .unwrap();
914        let b = CpuTensor::from_f64_slice(&[0.1, 0.2, 0.3], (1, 3), DType::F64, &dev).unwrap();
915
916        let tensors = vec![
917            ("Forward/w1".to_string(), w.clone()),
918            ("Forward/b1".to_string(), b.clone()),
919        ];
920        let bytes = to_bytes(&tensors).unwrap();
921        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
922
923        assert_eq!(loaded.len(), 2);
924        assert_eq!(loaded[0].0, "Forward/w1");
925        assert_eq!(loaded[0].1.dims(), &[2, 3]);
926        assert_eq!(loaded[1].0, "Forward/b1");
927        assert_eq!(loaded[1].1.dims(), &[1, 3]);
928    }
929
930    #[test]
931    fn test_invalid_magic() {
932        let data = b"BADXsomejunk";
933        let result = from_bytes::<CpuBackend>(data, &CpuDevice);
934        assert!(result.is_err());
935        assert!(result
936            .unwrap_err()
937            .to_string()
938            .contains("Invalid checkpoint"));
939    }
940
941    #[test]
942    fn test_file_roundtrip() {
943        let dev = CpuDevice;
944        let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0], (3,), DType::F32, &dev).unwrap();
945        let tensors = vec![("test".to_string(), t.clone())];
946
947        let path = std::env::temp_dir().join("shrew_test_checkpoint.shrew");
948        save_tensors(&path, &tensors).unwrap();
949        let loaded = load_tensors::<CpuBackend>(&path, &dev).unwrap();
950        std::fs::remove_file(&path).ok();
951
952        assert_eq!(loaded.len(), 1);
953        assert_eq!(loaded[0].0, "test");
954        let orig = t.to_f64_vec().unwrap();
955        let restored = loaded[0].1.to_f64_vec().unwrap();
956        assert_eq!(orig, restored);
957    }
958
959    #[test]
960    fn test_empty_checkpoint() {
961        let tensors: Vec<(String, CpuTensor)> = vec![];
962        let bytes = to_bytes(&tensors).unwrap();
963        let loaded = from_bytes::<CpuBackend>(&bytes, &CpuDevice).unwrap();
964        assert_eq!(loaded.len(), 0);
965    }
966
967    #[test]
968    fn test_3d_tensor_roundtrip() {
969        let dev = CpuDevice;
970        let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
971        let t = CpuTensor::from_f64_slice(&data, (2, 3, 4), DType::F32, &dev).unwrap();
972
973        let tensors = vec![("volume".to_string(), t.clone())];
974        let bytes = to_bytes(&tensors).unwrap();
975        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
976
977        assert_eq!(loaded[0].1.dims(), &[2, 3, 4]);
978        let orig = t.to_f64_vec().unwrap();
979        let restored = loaded[0].1.to_f64_vec().unwrap();
980        for (a, b) in orig.iter().zip(restored.iter()) {
981            assert!((a - b).abs() < 1e-6, "mismatch: {a} vs {b}");
982        }
983    }
984
985    #[test]
986    fn test_training_checkpoint_roundtrip() {
987        use shrew_optim::OptimizerState;
988
989        let dev = CpuDevice;
990        let w = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2), DType::F32, &dev).unwrap();
991        let b = CpuTensor::from_f64_slice(&[0.1, 0.2], (2,), DType::F32, &dev).unwrap();
992
993        // Build optimizer state
994        let mut opt_state = OptimizerState::new("Adam");
995        opt_state.set_scalar("t", 100.0);
996        opt_state.set_scalar("lr", 0.001);
997        opt_state.set_scalar("beta1", 0.9);
998        opt_state.set_buffer("m.0", vec![0.01, 0.02, 0.03, 0.04]);
999        opt_state.set_buffer("v.0", vec![0.001, 0.002, 0.003, 0.004]);
1000
1001        let ckpt = TrainingCheckpoint {
1002            model_params: vec![
1003                ("layer1/weight".to_string(), w),
1004                ("layer1/bias".to_string(), b),
1005            ],
1006            optimizer_state: Some(opt_state),
1007            epoch: 10,
1008            global_step: 5000,
1009            best_loss: 0.032,
1010            loss_history: vec![1.5, 0.8, 0.4, 0.1, 0.05, 0.032],
1011        };
1012
1013        let bytes = training_to_bytes(&ckpt).unwrap();
1014        let loaded = training_from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
1015
1016        // Verify model params
1017        assert_eq!(loaded.model_params.len(), 2);
1018        assert_eq!(loaded.model_params[0].0, "layer1/weight");
1019        assert_eq!(loaded.model_params[0].1.dims(), &[2, 2]);
1020        assert_eq!(loaded.model_params[1].0, "layer1/bias");
1021
1022        // Verify metadata
1023        assert_eq!(loaded.epoch, 10);
1024        assert_eq!(loaded.global_step, 5000);
1025        assert!((loaded.best_loss - 0.032).abs() < 1e-10);
1026        assert_eq!(loaded.loss_history.len(), 6);
1027        assert!((loaded.loss_history[0] - 1.5).abs() < 1e-10);
1028        assert!((loaded.loss_history[5] - 0.032).abs() < 1e-10);
1029
1030        // Verify optimizer state
1031        let opt = loaded.optimizer_state.unwrap();
1032        assert_eq!(opt.optimizer_type, "Adam");
1033        assert_eq!(opt.get_scalar("t"), Some(100.0));
1034        assert_eq!(opt.get_scalar("lr"), Some(0.001));
1035        assert_eq!(opt.get_scalar("beta1"), Some(0.9));
1036        let m0 = opt.get_buffer("m.0").unwrap();
1037        assert_eq!(m0.len(), 4);
1038        assert!((m0[2] - 0.03).abs() < 1e-10);
1039    }
1040
1041    #[test]
1042    fn test_training_checkpoint_no_optimizer() {
1043        let dev = CpuDevice;
1044        let t = CpuTensor::from_f64_slice(&[1.0, 2.0], (2,), DType::F64, &dev).unwrap();
1045
1046        let ckpt = TrainingCheckpoint {
1047            model_params: vec![("w".to_string(), t)],
1048            optimizer_state: None,
1049            epoch: 5,
1050            global_step: 250,
1051            best_loss: 0.1,
1052            loss_history: vec![0.5, 0.3, 0.1],
1053        };
1054
1055        let bytes = training_to_bytes(&ckpt).unwrap();
1056        let loaded = training_from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
1057
1058        assert_eq!(loaded.model_params.len(), 1);
1059        assert!(loaded.optimizer_state.is_none());
1060        assert_eq!(loaded.epoch, 5);
1061        assert_eq!(loaded.global_step, 250);
1062    }
1063}