shrew/
safetensors.rs

1// =============================================================================
2// Safetensors — Interoperable tensor serialization (HuggingFace format)
3// =============================================================================
4//
5// The safetensors format stores tensors in a single flat file:
6//
7//   ┌──────────────┬──────────────────────┬───────────────────────┐
8//   │ 8 bytes      │ N bytes              │ raw data bytes        │
9//   │ header size  │ JSON header (UTF-8)  │ (contiguous, LE)      │
10//   │ (u64 LE)     │                      │                       │
11//   └──────────────┴──────────────────────┴───────────────────────┘
12//
13// JSON header example:
14//   {
15//     "__metadata__": { "format": "shrew" },
16//     "layer.weight": {
17//       "dtype": "F32",
18//       "shape": [64, 128],
19//       "data_offsets": [0, 32768]
20//     }
21//   }
22//
23// Supported dtypes: F32, F64, U8, I64 (U32 stored as I64 for interop).
24//
25// This implementation intentionally avoids external safetensors crates —
26// the format is simple enough to implement from scratch, and this keeps
27// the dependency tree lean.
28//
29// Usage:
30//   safetensors::save("model.safetensors", &named_tensors)?;
31//   let tensors = safetensors::load::<CpuBackend>("model.safetensors", &device)?;
32//
33//   // Module-level convenience:
34//   safetensors::save_module("model.safetensors", &my_module)?;
35
36use std::collections::HashMap;
37use std::fs::File;
38use std::io::{BufReader, BufWriter, Read, Write};
39use std::path::Path;
40
41use shrew_core::backend::Backend;
42use shrew_core::tensor::Tensor;
43use shrew_core::DType;
44
45// ─────────────────────────────────────────────────────────────────────────────
46// DType ↔ safetensors string
47// ─────────────────────────────────────────────────────────────────────────────
48
49fn dtype_to_st(dtype: DType) -> &'static str {
50    match dtype {
51        DType::F16 => "F16",
52        DType::BF16 => "BF16",
53        DType::F32 => "F32",
54        DType::F64 => "F64",
55        DType::U8 => "U8",
56        DType::U32 => "U32",
57        DType::I64 => "I64",
58    }
59}
60
61fn st_to_dtype(s: &str) -> shrew_core::Result<DType> {
62    match s {
63        "F16" => Ok(DType::F16),
64        "BF16" => Ok(DType::BF16),
65        "F32" => Ok(DType::F32),
66        "F64" => Ok(DType::F64),
67        "U8" | "BOOL" => Ok(DType::U8),
68        "U32" | "U16" | "I32" | "I16" | "I8" => Ok(DType::U32),
69        "I64" => Ok(DType::I64),
70        _ => Err(shrew_core::Error::msg(format!(
71            "Unsupported safetensors dtype: {s}"
72        ))),
73    }
74}
75
76fn dtype_elem_size(dtype: DType) -> usize {
77    match dtype {
78        DType::F16 => 2,
79        DType::BF16 => 2,
80        DType::F32 => 4,
81        DType::F64 => 8,
82        DType::U8 => 1,
83        DType::U32 => 4,
84        DType::I64 => 8,
85    }
86}
87
88// ─────────────────────────────────────────────────────────────────────────────
89// Raw bytes extraction / reconstruction
90// ─────────────────────────────────────────────────────────────────────────────
91
92fn tensor_to_bytes<B: Backend>(tensor: &Tensor<B>) -> shrew_core::Result<Vec<u8>> {
93    let t = tensor.contiguous()?;
94    let data = t.to_f64_vec()?;
95    let dtype = t.dtype();
96
97    Ok(match dtype {
98        DType::F16 => data
99            .iter()
100            .flat_map(|&v| half::f16::from_f64(v).to_le_bytes())
101            .collect(),
102        DType::BF16 => data
103            .iter()
104            .flat_map(|&v| half::bf16::from_f64(v).to_le_bytes())
105            .collect(),
106        DType::F32 => data
107            .iter()
108            .flat_map(|&v| (v as f32).to_le_bytes())
109            .collect(),
110        DType::F64 => data.iter().flat_map(|&v| v.to_le_bytes()).collect(),
111        DType::U8 => data.iter().map(|&v| v as u8).collect(),
112        DType::U32 => data
113            .iter()
114            .flat_map(|&v| (v as u32).to_le_bytes())
115            .collect(),
116        DType::I64 => data
117            .iter()
118            .flat_map(|&v| (v as i64).to_le_bytes())
119            .collect(),
120    })
121}
122
123fn tensor_from_bytes<B: Backend>(
124    raw: &[u8],
125    dims: Vec<usize>,
126    dtype: DType,
127    device: &B::Device,
128) -> shrew_core::Result<Tensor<B>> {
129    let elem_size = dtype_elem_size(dtype);
130    let num_elems: usize = dims.iter().product();
131    let expected = num_elems * elem_size;
132    if raw.len() != expected {
133        return Err(shrew_core::Error::msg(format!(
134            "safetensors: expected {expected} bytes for {num_elems} elements of {:?}, got {}",
135            dtype,
136            raw.len()
137        )));
138    }
139
140    let data_f64: Vec<f64> = match dtype {
141        DType::F16 => raw
142            .chunks_exact(2)
143            .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f64())
144            .collect(),
145        DType::BF16 => raw
146            .chunks_exact(2)
147            .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f64())
148            .collect(),
149        DType::F32 => raw
150            .chunks_exact(4)
151            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
152            .collect(),
153        DType::F64 => raw
154            .chunks_exact(8)
155            .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
156            .collect(),
157        DType::U8 => raw.iter().map(|&v| v as f64).collect(),
158        DType::U32 => raw
159            .chunks_exact(4)
160            .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
161            .collect(),
162        DType::I64 => raw
163            .chunks_exact(8)
164            .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f64)
165            .collect(),
166    };
167
168    let shape = shrew_core::Shape::new(dims);
169    Tensor::from_f64_slice(&data_f64, shape, dtype, device)
170}
171
172// ─────────────────────────────────────────────────────────────────────────────
173// JSON header builder (no serde dependency)
174// ─────────────────────────────────────────────────────────────────────────────
175
176/// Escape a string for JSON (handles \, ", and control characters).
177fn json_escape(s: &str) -> String {
178    let mut out = String::with_capacity(s.len() + 2);
179    out.push('"');
180    for ch in s.chars() {
181        match ch {
182            '"' => out.push_str("\\\""),
183            '\\' => out.push_str("\\\\"),
184            '\n' => out.push_str("\\n"),
185            '\r' => out.push_str("\\r"),
186            '\t' => out.push_str("\\t"),
187            c if c.is_control() => {
188                out.push_str(&format!("\\u{:04x}", c as u32));
189            }
190            c => out.push(c),
191        }
192    }
193    out.push('"');
194    out
195}
196
197struct TensorMeta {
198    name: String,
199    dtype: DType,
200    shape: Vec<usize>,
201    data_offset_start: usize,
202    data_offset_end: usize,
203}
204
205fn build_header_json(metas: &[TensorMeta], metadata: Option<&HashMap<String, String>>) -> String {
206    let mut json = String::from("{");
207
208    // __metadata__ (optional)
209    if let Some(md) = metadata {
210        json.push_str("\"__metadata__\":{");
211        for (i, (k, v)) in md.iter().enumerate() {
212            if i > 0 {
213                json.push(',');
214            }
215            json.push_str(&json_escape(k));
216            json.push(':');
217            json.push_str(&json_escape(v));
218        }
219        json.push('}');
220        if !metas.is_empty() {
221            json.push(',');
222        }
223    }
224
225    // Tensor entries
226    for (i, meta) in metas.iter().enumerate() {
227        if i > 0 {
228            json.push(',');
229        }
230        json.push_str(&json_escape(&meta.name));
231        json.push_str(":{\"dtype\":\"");
232        json.push_str(dtype_to_st(meta.dtype));
233        json.push_str("\",\"shape\":[");
234        for (j, &d) in meta.shape.iter().enumerate() {
235            if j > 0 {
236                json.push(',');
237            }
238            json.push_str(&d.to_string());
239        }
240        json.push_str("],\"data_offsets\":[");
241        json.push_str(&meta.data_offset_start.to_string());
242        json.push(',');
243        json.push_str(&meta.data_offset_end.to_string());
244        json.push_str("]}");
245    }
246
247    json.push('}');
248    json
249}
250
251// ─────────────────────────────────────────────────────────────────────────────
252// JSON header parser (minimal, no serde dependency)
253// ─────────────────────────────────────────────────────────────────────────────
254
255/// Parsed tensor entry from safetensors header.
256struct ParsedEntry {
257    name: String,
258    dtype_str: String,
259    shape: Vec<usize>,
260    data_offset_start: usize,
261    data_offset_end: usize,
262}
263
264/// Parse the safetensors JSON header using serde_json.
265fn parse_header(json_str: &str) -> shrew_core::Result<Vec<ParsedEntry>> {
266    let value: serde_json::Value = serde_json::from_str(json_str)
267        .map_err(|e| shrew_core::Error::msg(format!("safetensors: invalid JSON header: {e}")))?;
268
269    let obj = value
270        .as_object()
271        .ok_or_else(|| shrew_core::Error::msg("safetensors: header is not a JSON object"))?;
272
273    let mut entries = Vec::new();
274
275    for (key, val) in obj {
276        // Skip __metadata__
277        if key == "__metadata__" {
278            continue;
279        }
280
281        let tensor_obj = val.as_object().ok_or_else(|| {
282            shrew_core::Error::msg(format!("safetensors: entry '{key}' is not an object"))
283        })?;
284
285        let dtype_str = tensor_obj
286            .get("dtype")
287            .and_then(|v| v.as_str())
288            .ok_or_else(|| shrew_core::Error::msg(format!("safetensors: '{key}' missing dtype")))?
289            .to_string();
290
291        let shape_arr = tensor_obj
292            .get("shape")
293            .and_then(|v| v.as_array())
294            .ok_or_else(|| shrew_core::Error::msg(format!("safetensors: '{key}' missing shape")))?;
295
296        let shape: Vec<usize> = shape_arr
297            .iter()
298            .map(|v| v.as_u64().unwrap_or(0) as usize)
299            .collect();
300
301        let offsets = tensor_obj
302            .get("data_offsets")
303            .and_then(|v| v.as_array())
304            .ok_or_else(|| {
305                shrew_core::Error::msg(format!("safetensors: '{key}' missing data_offsets"))
306            })?;
307
308        if offsets.len() != 2 {
309            return Err(shrew_core::Error::msg(format!(
310                "safetensors: '{key}' data_offsets must have exactly 2 elements"
311            )));
312        }
313
314        let start = offsets[0].as_u64().unwrap_or(0) as usize;
315        let end = offsets[1].as_u64().unwrap_or(0) as usize;
316
317        entries.push(ParsedEntry {
318            name: key.clone(),
319            dtype_str,
320            shape,
321            data_offset_start: start,
322            data_offset_end: end,
323        });
324    }
325
326    Ok(entries)
327}
328
329// ─────────────────────────────────────────────────────────────────────────────
330// Write safetensors
331// ─────────────────────────────────────────────────────────────────────────────
332
333/// Write named tensors in safetensors format to a writer.
334pub fn write_safetensors<B: Backend>(
335    writer: &mut impl Write,
336    tensors: &[(String, Tensor<B>)],
337) -> shrew_core::Result<()> {
338    // Step 1: Serialize all tensor data and compute offsets
339    let mut all_data: Vec<u8> = Vec::new();
340    let mut metas: Vec<TensorMeta> = Vec::with_capacity(tensors.len());
341
342    for (name, tensor) in tensors {
343        let bytes = tensor_to_bytes(tensor)?;
344        let start = all_data.len();
345        let end = start + bytes.len();
346        all_data.extend_from_slice(&bytes);
347
348        metas.push(TensorMeta {
349            name: name.clone(),
350            dtype: tensor.dtype(),
351            shape: tensor.dims().to_vec(),
352            data_offset_start: start,
353            data_offset_end: end,
354        });
355    }
356
357    // Step 2: Build JSON header
358    let mut metadata = HashMap::new();
359    metadata.insert("format".to_string(), "shrew".to_string());
360    let header_json = build_header_json(&metas, Some(&metadata));
361    let header_bytes = header_json.as_bytes();
362
363    // Step 3: Write header size (u64 LE)
364    let header_size = header_bytes.len() as u64;
365    writer
366        .write_all(&header_size.to_le_bytes())
367        .map_err(io_err)?;
368
369    // Step 4: Write JSON header
370    writer.write_all(header_bytes).map_err(io_err)?;
371
372    // Step 5: Write raw tensor data
373    writer.write_all(&all_data).map_err(io_err)?;
374
375    Ok(())
376}
377
378/// Read named tensors from safetensors format.
379pub fn read_safetensors<B: Backend>(
380    reader: &mut impl Read,
381    device: &B::Device,
382) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
383    // Step 1: Read header size
384    let mut size_buf = [0u8; 8];
385    reader.read_exact(&mut size_buf).map_err(io_err)?;
386    let header_size = u64::from_le_bytes(size_buf) as usize;
387
388    // Sanity check: header shouldn't be unreasonably large
389    if header_size > 100_000_000 {
390        return Err(shrew_core::Error::msg(format!(
391            "safetensors: header size {header_size} bytes is unreasonably large"
392        )));
393    }
394
395    // Step 2: Read JSON header
396    let mut header_bytes = vec![0u8; header_size];
397    reader.read_exact(&mut header_bytes).map_err(io_err)?;
398    let header_str = std::str::from_utf8(&header_bytes)
399        .map_err(|e| shrew_core::Error::msg(format!("safetensors: invalid UTF-8 header: {e}")))?;
400
401    // Step 3: Parse header
402    let entries = parse_header(header_str)?;
403
404    // Step 4: Read all raw data
405    let max_offset = entries.iter().map(|e| e.data_offset_end).max().unwrap_or(0);
406    let mut all_data = vec![0u8; max_offset];
407    if max_offset > 0 {
408        reader.read_exact(&mut all_data).map_err(io_err)?;
409    }
410
411    // Step 5: Reconstruct tensors
412    let mut tensors = Vec::with_capacity(entries.len());
413    for entry in &entries {
414        let dtype = st_to_dtype(&entry.dtype_str)?;
415        let raw = &all_data[entry.data_offset_start..entry.data_offset_end];
416        let tensor = tensor_from_bytes::<B>(raw, entry.shape.clone(), dtype, device)?;
417        tensors.push((entry.name.clone(), tensor));
418    }
419
420    Ok(tensors)
421}
422
423fn io_err(e: std::io::Error) -> shrew_core::Error {
424    shrew_core::Error::msg(format!("IO error: {e}"))
425}
426
427// ─────────────────────────────────────────────────────────────────────────────
428// High-level file API
429// ─────────────────────────────────────────────────────────────────────────────
430
431/// Save named tensors to a `.safetensors` file.
432///
433/// ```rust,no_run
434/// use shrew::safetensors;
435/// use shrew::prelude::*;
436///
437/// let w = Tensor::<CpuBackend>::zeros((2, 3), DType::F32, &CpuDevice).unwrap();
438/// let tensors = vec![("weight".to_string(), w)];
439/// safetensors::save("model.safetensors", &tensors).unwrap();
440/// ```
441pub fn save<B: Backend>(
442    path: impl AsRef<Path>,
443    tensors: &[(String, Tensor<B>)],
444) -> shrew_core::Result<()> {
445    let file = File::create(path.as_ref()).map_err(io_err)?;
446    let mut writer = BufWriter::new(file);
447    write_safetensors(&mut writer, tensors)?;
448    writer.flush().map_err(io_err)?;
449    Ok(())
450}
451
452/// Load named tensors from a `.safetensors` file.
453///
454/// ```rust,no_run
455/// use shrew::safetensors;
456/// use shrew::prelude::*;
457///
458/// let tensors = safetensors::load::<CpuBackend>("model.safetensors", &CpuDevice).unwrap();
459/// for (name, tensor) in &tensors {
460///     println!("{name}: {:?}", tensor.dims());
461/// }
462/// ```
463pub fn load<B: Backend>(
464    path: impl AsRef<Path>,
465    device: &B::Device,
466) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
467    let file = File::open(path.as_ref()).map_err(io_err)?;
468    let mut reader = BufReader::new(file);
469    read_safetensors(&mut reader, device)
470}
471
472// ─────────────────────────────────────────────────────────────────────────────
473// In-memory API (for testing)
474// ─────────────────────────────────────────────────────────────────────────────
475
476/// Serialize named tensors to an in-memory byte vector in safetensors format.
477pub fn to_bytes<B: Backend>(tensors: &[(String, Tensor<B>)]) -> shrew_core::Result<Vec<u8>> {
478    let mut buf = Vec::new();
479    write_safetensors(&mut buf, tensors)?;
480    Ok(buf)
481}
482
483/// Deserialize named tensors from an in-memory safetensors byte slice.
484pub fn from_bytes<B: Backend>(
485    data: &[u8],
486    device: &B::Device,
487) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
488    let mut cursor = std::io::Cursor::new(data);
489    read_safetensors(&mut cursor, device)
490}
491
492// ─────────────────────────────────────────────────────────────────────────────
493// Module-level convenience
494// ─────────────────────────────────────────────────────────────────────────────
495
496/// Save a module's parameters to a `.safetensors` file using its
497/// `named_parameters()`.
498///
499/// ```rust,no_run
500/// use shrew::safetensors;
501/// use shrew::prelude::*;
502///
503/// let linear = Linear::<CpuBackend>::new(3, 2, true, DType::F32, &CpuDevice).unwrap();
504/// safetensors::save_module("linear.safetensors", &linear).unwrap();
505/// ```
506pub fn save_module<B: Backend>(
507    path: impl AsRef<Path>,
508    module: &dyn shrew_nn::Module<B>,
509) -> shrew_core::Result<()> {
510    let named = module.named_parameters();
511    save(path, &named)
512}
513
514/// Load parameters from a `.safetensors` file into a state-dict map.
515///
516/// Returns a `HashMap<String, Tensor<B>>` for flexible parameter loading.
517pub fn load_state_dict<B: Backend>(
518    path: impl AsRef<Path>,
519    device: &B::Device,
520) -> shrew_core::Result<HashMap<String, Tensor<B>>> {
521    let tensors = load(path, device)?;
522    Ok(tensors.into_iter().collect())
523}
524
525// ─────────────────────────────────────────────────────────────────────────────
526// Tests
527// ─────────────────────────────────────────────────────────────────────────────
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use shrew_cpu::{CpuBackend, CpuDevice};
533
534    type CpuTensor = Tensor<CpuBackend>;
535
536    #[test]
537    fn test_roundtrip_f32() {
538        let dev = CpuDevice;
539        let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2), DType::F32, &dev).unwrap();
540
541        let tensors = vec![("weight".to_string(), t.clone())];
542        let bytes = to_bytes(&tensors).unwrap();
543        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
544
545        assert_eq!(loaded.len(), 1);
546        assert_eq!(loaded[0].0, "weight");
547        assert_eq!(loaded[0].1.dims(), &[2, 2]);
548        assert_eq!(loaded[0].1.dtype(), DType::F32);
549
550        let orig = t.to_f64_vec().unwrap();
551        let restored = loaded[0].1.to_f64_vec().unwrap();
552        assert_eq!(orig, restored);
553    }
554
555    #[test]
556    fn test_roundtrip_f64() {
557        let dev = CpuDevice;
558        let vals = vec![std::f64::consts::PI, std::f64::consts::E, 0.0, -1.5];
559        let t = CpuTensor::from_f64_slice(&vals, (4,), DType::F64, &dev).unwrap();
560
561        let tensors = vec![("precision".to_string(), t.clone())];
562        let bytes = to_bytes(&tensors).unwrap();
563        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
564
565        let orig = t.to_f64_vec().unwrap();
566        let restored = loaded[0].1.to_f64_vec().unwrap();
567        assert_eq!(orig, restored);
568    }
569
570    #[test]
571    fn test_roundtrip_u8() {
572        let dev = CpuDevice;
573        let t = CpuTensor::from_f64_slice(&[0.0, 128.0, 255.0], (3,), DType::U8, &dev).unwrap();
574
575        let tensors = vec![("pixels".to_string(), t.clone())];
576        let bytes = to_bytes(&tensors).unwrap();
577        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
578
579        assert_eq!(loaded[0].1.dtype(), DType::U8);
580        let orig = t.to_f64_vec().unwrap();
581        let restored = loaded[0].1.to_f64_vec().unwrap();
582        assert_eq!(orig, restored);
583    }
584
585    #[test]
586    fn test_roundtrip_multiple_tensors() {
587        let dev = CpuDevice;
588        let w =
589            CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], (2, 3), DType::F32, &dev)
590                .unwrap();
591        let b = CpuTensor::from_f64_slice(&[0.1, 0.2, 0.3], (3,), DType::F32, &dev).unwrap();
592
593        let tensors = vec![
594            ("layer.weight".to_string(), w.clone()),
595            ("layer.bias".to_string(), b.clone()),
596        ];
597        let bytes = to_bytes(&tensors).unwrap();
598        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
599
600        assert_eq!(loaded.len(), 2);
601
602        // Find by name (JSON object order may vary)
603        let map: HashMap<String, CpuTensor> = loaded.into_iter().collect();
604        assert!(map.contains_key("layer.weight"));
605        assert!(map.contains_key("layer.bias"));
606        assert_eq!(map["layer.weight"].dims(), &[2, 3]);
607        assert_eq!(map["layer.bias"].dims(), &[3]);
608    }
609
610    #[test]
611    fn test_3d_tensor_roundtrip() {
612        let dev = CpuDevice;
613        let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
614        let t = CpuTensor::from_f64_slice(&data, (2, 3, 4), DType::F32, &dev).unwrap();
615
616        let tensors = vec![("volume".to_string(), t.clone())];
617        let bytes = to_bytes(&tensors).unwrap();
618        let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
619
620        assert_eq!(loaded[0].1.dims(), &[2, 3, 4]);
621        let orig = t.to_f64_vec().unwrap();
622        let restored = loaded[0].1.to_f64_vec().unwrap();
623        for (a, b) in orig.iter().zip(restored.iter()) {
624            assert!((a - b).abs() < 1e-6, "mismatch: {a} vs {b}");
625        }
626    }
627
628    #[test]
629    fn test_empty() {
630        let tensors: Vec<(String, CpuTensor)> = vec![];
631        let bytes = to_bytes(&tensors).unwrap();
632        let loaded = from_bytes::<CpuBackend>(&bytes, &CpuDevice).unwrap();
633        assert_eq!(loaded.len(), 0);
634    }
635
636    #[test]
637    fn test_file_roundtrip() {
638        let dev = CpuDevice;
639        let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0], (3,), DType::F32, &dev).unwrap();
640        let tensors = vec![("test_param".to_string(), t.clone())];
641
642        let path = std::env::temp_dir().join("shrew_test_safetensors.safetensors");
643        save(&path, &tensors).unwrap();
644        let loaded = load::<CpuBackend>(&path, &dev).unwrap();
645        std::fs::remove_file(&path).ok();
646
647        assert_eq!(loaded.len(), 1);
648        let orig = t.to_f64_vec().unwrap();
649        let restored = loaded[0].1.to_f64_vec().unwrap();
650        assert_eq!(orig, restored);
651    }
652
653    #[test]
654    fn test_header_format() {
655        // Verify the header can be parsed independently
656        let metas = vec![TensorMeta {
657            name: "layer.weight".to_string(),
658            dtype: DType::F32,
659            shape: vec![3, 4],
660            data_offset_start: 0,
661            data_offset_end: 48,
662        }];
663        let json = build_header_json(&metas, None);
664
665        // Should be valid JSON parsable by serde_json
666        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
667        let entry = parsed.get("layer.weight").unwrap();
668        assert_eq!(entry["dtype"].as_str().unwrap(), "F32");
669        assert_eq!(entry["shape"][0].as_u64().unwrap(), 3);
670        assert_eq!(entry["shape"][1].as_u64().unwrap(), 4);
671        assert_eq!(entry["data_offsets"][0].as_u64().unwrap(), 0);
672        assert_eq!(entry["data_offsets"][1].as_u64().unwrap(), 48);
673    }
674
675    #[test]
676    fn test_json_escape() {
677        assert_eq!(json_escape("hello"), "\"hello\"");
678        assert_eq!(json_escape("a\"b"), "\"a\\\"b\"");
679        assert_eq!(json_escape("a\\b"), "\"a\\\\b\"");
680        assert_eq!(json_escape("a\nb"), "\"a\\nb\"");
681    }
682
683    #[test]
684    fn test_state_dict_roundtrip() {
685        let dev = CpuDevice;
686        let w = CpuTensor::from_f64_slice(&[1.0, 2.0], (1, 2), DType::F32, &dev).unwrap();
687        let b = CpuTensor::from_f64_slice(&[0.5], (1,), DType::F32, &dev).unwrap();
688
689        let tensors = vec![("fc.weight".to_string(), w), ("fc.bias".to_string(), b)];
690
691        let path = std::env::temp_dir().join("shrew_test_state_dict.safetensors");
692        save(&path, &tensors).unwrap();
693        let sd = load_state_dict::<CpuBackend>(&path, &dev).unwrap();
694        std::fs::remove_file(&path).ok();
695
696        assert!(sd.contains_key("fc.weight"));
697        assert!(sd.contains_key("fc.bias"));
698        assert_eq!(sd["fc.weight"].dims(), &[1, 2]);
699        assert_eq!(sd["fc.bias"].dims(), &[1]);
700    }
701
702    #[test]
703    fn test_save_module_linear() {
704        use shrew_nn::{Linear, Module};
705
706        let dev = CpuDevice;
707        let linear = Linear::<CpuBackend>::new(3, 2, true, DType::F32, &dev).unwrap();
708
709        let path = std::env::temp_dir().join("shrew_test_module_save.safetensors");
710        save_module(&path, &linear).unwrap();
711        let loaded = load::<CpuBackend>(&path, &dev).unwrap();
712        std::fs::remove_file(&path).ok();
713
714        let map: HashMap<String, CpuTensor> = loaded.into_iter().collect();
715        assert!(map.contains_key("weight"), "missing 'weight' key");
716        assert!(map.contains_key("bias"), "missing 'bias' key");
717        assert_eq!(map["weight"].dims(), &[2, 3]);
718        assert_eq!(map["bias"].dims(), &[1, 2]);
719    }
720}