1use std::collections::HashMap;
37use std::fs::File;
38use std::io::{BufReader, BufWriter, Read, Write};
39use std::path::Path;
40
41use shrew_core::backend::Backend;
42use shrew_core::tensor::Tensor;
43use shrew_core::DType;
44
45fn dtype_to_st(dtype: DType) -> &'static str {
50 match dtype {
51 DType::F16 => "F16",
52 DType::BF16 => "BF16",
53 DType::F32 => "F32",
54 DType::F64 => "F64",
55 DType::U8 => "U8",
56 DType::U32 => "U32",
57 DType::I64 => "I64",
58 }
59}
60
61fn st_to_dtype(s: &str) -> shrew_core::Result<DType> {
62 match s {
63 "F16" => Ok(DType::F16),
64 "BF16" => Ok(DType::BF16),
65 "F32" => Ok(DType::F32),
66 "F64" => Ok(DType::F64),
67 "U8" | "BOOL" => Ok(DType::U8),
68 "U32" | "U16" | "I32" | "I16" | "I8" => Ok(DType::U32),
69 "I64" => Ok(DType::I64),
70 _ => Err(shrew_core::Error::msg(format!(
71 "Unsupported safetensors dtype: {s}"
72 ))),
73 }
74}
75
76fn dtype_elem_size(dtype: DType) -> usize {
77 match dtype {
78 DType::F16 => 2,
79 DType::BF16 => 2,
80 DType::F32 => 4,
81 DType::F64 => 8,
82 DType::U8 => 1,
83 DType::U32 => 4,
84 DType::I64 => 8,
85 }
86}
87
88fn tensor_to_bytes<B: Backend>(tensor: &Tensor<B>) -> shrew_core::Result<Vec<u8>> {
93 let t = tensor.contiguous()?;
94 let data = t.to_f64_vec()?;
95 let dtype = t.dtype();
96
97 Ok(match dtype {
98 DType::F16 => data
99 .iter()
100 .flat_map(|&v| half::f16::from_f64(v).to_le_bytes())
101 .collect(),
102 DType::BF16 => data
103 .iter()
104 .flat_map(|&v| half::bf16::from_f64(v).to_le_bytes())
105 .collect(),
106 DType::F32 => data
107 .iter()
108 .flat_map(|&v| (v as f32).to_le_bytes())
109 .collect(),
110 DType::F64 => data.iter().flat_map(|&v| v.to_le_bytes()).collect(),
111 DType::U8 => data.iter().map(|&v| v as u8).collect(),
112 DType::U32 => data
113 .iter()
114 .flat_map(|&v| (v as u32).to_le_bytes())
115 .collect(),
116 DType::I64 => data
117 .iter()
118 .flat_map(|&v| (v as i64).to_le_bytes())
119 .collect(),
120 })
121}
122
123fn tensor_from_bytes<B: Backend>(
124 raw: &[u8],
125 dims: Vec<usize>,
126 dtype: DType,
127 device: &B::Device,
128) -> shrew_core::Result<Tensor<B>> {
129 let elem_size = dtype_elem_size(dtype);
130 let num_elems: usize = dims.iter().product();
131 let expected = num_elems * elem_size;
132 if raw.len() != expected {
133 return Err(shrew_core::Error::msg(format!(
134 "safetensors: expected {expected} bytes for {num_elems} elements of {:?}, got {}",
135 dtype,
136 raw.len()
137 )));
138 }
139
140 let data_f64: Vec<f64> = match dtype {
141 DType::F16 => raw
142 .chunks_exact(2)
143 .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f64())
144 .collect(),
145 DType::BF16 => raw
146 .chunks_exact(2)
147 .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f64())
148 .collect(),
149 DType::F32 => raw
150 .chunks_exact(4)
151 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
152 .collect(),
153 DType::F64 => raw
154 .chunks_exact(8)
155 .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
156 .collect(),
157 DType::U8 => raw.iter().map(|&v| v as f64).collect(),
158 DType::U32 => raw
159 .chunks_exact(4)
160 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
161 .collect(),
162 DType::I64 => raw
163 .chunks_exact(8)
164 .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f64)
165 .collect(),
166 };
167
168 let shape = shrew_core::Shape::new(dims);
169 Tensor::from_f64_slice(&data_f64, shape, dtype, device)
170}
171
172fn json_escape(s: &str) -> String {
178 let mut out = String::with_capacity(s.len() + 2);
179 out.push('"');
180 for ch in s.chars() {
181 match ch {
182 '"' => out.push_str("\\\""),
183 '\\' => out.push_str("\\\\"),
184 '\n' => out.push_str("\\n"),
185 '\r' => out.push_str("\\r"),
186 '\t' => out.push_str("\\t"),
187 c if c.is_control() => {
188 out.push_str(&format!("\\u{:04x}", c as u32));
189 }
190 c => out.push(c),
191 }
192 }
193 out.push('"');
194 out
195}
196
197struct TensorMeta {
198 name: String,
199 dtype: DType,
200 shape: Vec<usize>,
201 data_offset_start: usize,
202 data_offset_end: usize,
203}
204
205fn build_header_json(metas: &[TensorMeta], metadata: Option<&HashMap<String, String>>) -> String {
206 let mut json = String::from("{");
207
208 if let Some(md) = metadata {
210 json.push_str("\"__metadata__\":{");
211 for (i, (k, v)) in md.iter().enumerate() {
212 if i > 0 {
213 json.push(',');
214 }
215 json.push_str(&json_escape(k));
216 json.push(':');
217 json.push_str(&json_escape(v));
218 }
219 json.push('}');
220 if !metas.is_empty() {
221 json.push(',');
222 }
223 }
224
225 for (i, meta) in metas.iter().enumerate() {
227 if i > 0 {
228 json.push(',');
229 }
230 json.push_str(&json_escape(&meta.name));
231 json.push_str(":{\"dtype\":\"");
232 json.push_str(dtype_to_st(meta.dtype));
233 json.push_str("\",\"shape\":[");
234 for (j, &d) in meta.shape.iter().enumerate() {
235 if j > 0 {
236 json.push(',');
237 }
238 json.push_str(&d.to_string());
239 }
240 json.push_str("],\"data_offsets\":[");
241 json.push_str(&meta.data_offset_start.to_string());
242 json.push(',');
243 json.push_str(&meta.data_offset_end.to_string());
244 json.push_str("]}");
245 }
246
247 json.push('}');
248 json
249}
250
251struct ParsedEntry {
257 name: String,
258 dtype_str: String,
259 shape: Vec<usize>,
260 data_offset_start: usize,
261 data_offset_end: usize,
262}
263
264fn parse_header(json_str: &str) -> shrew_core::Result<Vec<ParsedEntry>> {
266 let value: serde_json::Value = serde_json::from_str(json_str)
267 .map_err(|e| shrew_core::Error::msg(format!("safetensors: invalid JSON header: {e}")))?;
268
269 let obj = value
270 .as_object()
271 .ok_or_else(|| shrew_core::Error::msg("safetensors: header is not a JSON object"))?;
272
273 let mut entries = Vec::new();
274
275 for (key, val) in obj {
276 if key == "__metadata__" {
278 continue;
279 }
280
281 let tensor_obj = val.as_object().ok_or_else(|| {
282 shrew_core::Error::msg(format!("safetensors: entry '{key}' is not an object"))
283 })?;
284
285 let dtype_str = tensor_obj
286 .get("dtype")
287 .and_then(|v| v.as_str())
288 .ok_or_else(|| shrew_core::Error::msg(format!("safetensors: '{key}' missing dtype")))?
289 .to_string();
290
291 let shape_arr = tensor_obj
292 .get("shape")
293 .and_then(|v| v.as_array())
294 .ok_or_else(|| shrew_core::Error::msg(format!("safetensors: '{key}' missing shape")))?;
295
296 let shape: Vec<usize> = shape_arr
297 .iter()
298 .map(|v| v.as_u64().unwrap_or(0) as usize)
299 .collect();
300
301 let offsets = tensor_obj
302 .get("data_offsets")
303 .and_then(|v| v.as_array())
304 .ok_or_else(|| {
305 shrew_core::Error::msg(format!("safetensors: '{key}' missing data_offsets"))
306 })?;
307
308 if offsets.len() != 2 {
309 return Err(shrew_core::Error::msg(format!(
310 "safetensors: '{key}' data_offsets must have exactly 2 elements"
311 )));
312 }
313
314 let start = offsets[0].as_u64().unwrap_or(0) as usize;
315 let end = offsets[1].as_u64().unwrap_or(0) as usize;
316
317 entries.push(ParsedEntry {
318 name: key.clone(),
319 dtype_str,
320 shape,
321 data_offset_start: start,
322 data_offset_end: end,
323 });
324 }
325
326 Ok(entries)
327}
328
329pub fn write_safetensors<B: Backend>(
335 writer: &mut impl Write,
336 tensors: &[(String, Tensor<B>)],
337) -> shrew_core::Result<()> {
338 let mut all_data: Vec<u8> = Vec::new();
340 let mut metas: Vec<TensorMeta> = Vec::with_capacity(tensors.len());
341
342 for (name, tensor) in tensors {
343 let bytes = tensor_to_bytes(tensor)?;
344 let start = all_data.len();
345 let end = start + bytes.len();
346 all_data.extend_from_slice(&bytes);
347
348 metas.push(TensorMeta {
349 name: name.clone(),
350 dtype: tensor.dtype(),
351 shape: tensor.dims().to_vec(),
352 data_offset_start: start,
353 data_offset_end: end,
354 });
355 }
356
357 let mut metadata = HashMap::new();
359 metadata.insert("format".to_string(), "shrew".to_string());
360 let header_json = build_header_json(&metas, Some(&metadata));
361 let header_bytes = header_json.as_bytes();
362
363 let header_size = header_bytes.len() as u64;
365 writer
366 .write_all(&header_size.to_le_bytes())
367 .map_err(io_err)?;
368
369 writer.write_all(header_bytes).map_err(io_err)?;
371
372 writer.write_all(&all_data).map_err(io_err)?;
374
375 Ok(())
376}
377
378pub fn read_safetensors<B: Backend>(
380 reader: &mut impl Read,
381 device: &B::Device,
382) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
383 let mut size_buf = [0u8; 8];
385 reader.read_exact(&mut size_buf).map_err(io_err)?;
386 let header_size = u64::from_le_bytes(size_buf) as usize;
387
388 if header_size > 100_000_000 {
390 return Err(shrew_core::Error::msg(format!(
391 "safetensors: header size {header_size} bytes is unreasonably large"
392 )));
393 }
394
395 let mut header_bytes = vec![0u8; header_size];
397 reader.read_exact(&mut header_bytes).map_err(io_err)?;
398 let header_str = std::str::from_utf8(&header_bytes)
399 .map_err(|e| shrew_core::Error::msg(format!("safetensors: invalid UTF-8 header: {e}")))?;
400
401 let entries = parse_header(header_str)?;
403
404 let max_offset = entries.iter().map(|e| e.data_offset_end).max().unwrap_or(0);
406 let mut all_data = vec![0u8; max_offset];
407 if max_offset > 0 {
408 reader.read_exact(&mut all_data).map_err(io_err)?;
409 }
410
411 let mut tensors = Vec::with_capacity(entries.len());
413 for entry in &entries {
414 let dtype = st_to_dtype(&entry.dtype_str)?;
415 let raw = &all_data[entry.data_offset_start..entry.data_offset_end];
416 let tensor = tensor_from_bytes::<B>(raw, entry.shape.clone(), dtype, device)?;
417 tensors.push((entry.name.clone(), tensor));
418 }
419
420 Ok(tensors)
421}
422
423fn io_err(e: std::io::Error) -> shrew_core::Error {
424 shrew_core::Error::msg(format!("IO error: {e}"))
425}
426
427pub fn save<B: Backend>(
442 path: impl AsRef<Path>,
443 tensors: &[(String, Tensor<B>)],
444) -> shrew_core::Result<()> {
445 let file = File::create(path.as_ref()).map_err(io_err)?;
446 let mut writer = BufWriter::new(file);
447 write_safetensors(&mut writer, tensors)?;
448 writer.flush().map_err(io_err)?;
449 Ok(())
450}
451
452pub fn load<B: Backend>(
464 path: impl AsRef<Path>,
465 device: &B::Device,
466) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
467 let file = File::open(path.as_ref()).map_err(io_err)?;
468 let mut reader = BufReader::new(file);
469 read_safetensors(&mut reader, device)
470}
471
472pub fn to_bytes<B: Backend>(tensors: &[(String, Tensor<B>)]) -> shrew_core::Result<Vec<u8>> {
478 let mut buf = Vec::new();
479 write_safetensors(&mut buf, tensors)?;
480 Ok(buf)
481}
482
483pub fn from_bytes<B: Backend>(
485 data: &[u8],
486 device: &B::Device,
487) -> shrew_core::Result<Vec<(String, Tensor<B>)>> {
488 let mut cursor = std::io::Cursor::new(data);
489 read_safetensors(&mut cursor, device)
490}
491
492pub fn save_module<B: Backend>(
507 path: impl AsRef<Path>,
508 module: &dyn shrew_nn::Module<B>,
509) -> shrew_core::Result<()> {
510 let named = module.named_parameters();
511 save(path, &named)
512}
513
514pub fn load_state_dict<B: Backend>(
518 path: impl AsRef<Path>,
519 device: &B::Device,
520) -> shrew_core::Result<HashMap<String, Tensor<B>>> {
521 let tensors = load(path, device)?;
522 Ok(tensors.into_iter().collect())
523}
524
525#[cfg(test)]
530mod tests {
531 use super::*;
532 use shrew_cpu::{CpuBackend, CpuDevice};
533
534 type CpuTensor = Tensor<CpuBackend>;
535
536 #[test]
537 fn test_roundtrip_f32() {
538 let dev = CpuDevice;
539 let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], (2, 2), DType::F32, &dev).unwrap();
540
541 let tensors = vec![("weight".to_string(), t.clone())];
542 let bytes = to_bytes(&tensors).unwrap();
543 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
544
545 assert_eq!(loaded.len(), 1);
546 assert_eq!(loaded[0].0, "weight");
547 assert_eq!(loaded[0].1.dims(), &[2, 2]);
548 assert_eq!(loaded[0].1.dtype(), DType::F32);
549
550 let orig = t.to_f64_vec().unwrap();
551 let restored = loaded[0].1.to_f64_vec().unwrap();
552 assert_eq!(orig, restored);
553 }
554
555 #[test]
556 fn test_roundtrip_f64() {
557 let dev = CpuDevice;
558 let vals = vec![std::f64::consts::PI, std::f64::consts::E, 0.0, -1.5];
559 let t = CpuTensor::from_f64_slice(&vals, (4,), DType::F64, &dev).unwrap();
560
561 let tensors = vec![("precision".to_string(), t.clone())];
562 let bytes = to_bytes(&tensors).unwrap();
563 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
564
565 let orig = t.to_f64_vec().unwrap();
566 let restored = loaded[0].1.to_f64_vec().unwrap();
567 assert_eq!(orig, restored);
568 }
569
570 #[test]
571 fn test_roundtrip_u8() {
572 let dev = CpuDevice;
573 let t = CpuTensor::from_f64_slice(&[0.0, 128.0, 255.0], (3,), DType::U8, &dev).unwrap();
574
575 let tensors = vec![("pixels".to_string(), t.clone())];
576 let bytes = to_bytes(&tensors).unwrap();
577 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
578
579 assert_eq!(loaded[0].1.dtype(), DType::U8);
580 let orig = t.to_f64_vec().unwrap();
581 let restored = loaded[0].1.to_f64_vec().unwrap();
582 assert_eq!(orig, restored);
583 }
584
585 #[test]
586 fn test_roundtrip_multiple_tensors() {
587 let dev = CpuDevice;
588 let w =
589 CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], (2, 3), DType::F32, &dev)
590 .unwrap();
591 let b = CpuTensor::from_f64_slice(&[0.1, 0.2, 0.3], (3,), DType::F32, &dev).unwrap();
592
593 let tensors = vec![
594 ("layer.weight".to_string(), w.clone()),
595 ("layer.bias".to_string(), b.clone()),
596 ];
597 let bytes = to_bytes(&tensors).unwrap();
598 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
599
600 assert_eq!(loaded.len(), 2);
601
602 let map: HashMap<String, CpuTensor> = loaded.into_iter().collect();
604 assert!(map.contains_key("layer.weight"));
605 assert!(map.contains_key("layer.bias"));
606 assert_eq!(map["layer.weight"].dims(), &[2, 3]);
607 assert_eq!(map["layer.bias"].dims(), &[3]);
608 }
609
610 #[test]
611 fn test_3d_tensor_roundtrip() {
612 let dev = CpuDevice;
613 let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
614 let t = CpuTensor::from_f64_slice(&data, (2, 3, 4), DType::F32, &dev).unwrap();
615
616 let tensors = vec![("volume".to_string(), t.clone())];
617 let bytes = to_bytes(&tensors).unwrap();
618 let loaded = from_bytes::<CpuBackend>(&bytes, &dev).unwrap();
619
620 assert_eq!(loaded[0].1.dims(), &[2, 3, 4]);
621 let orig = t.to_f64_vec().unwrap();
622 let restored = loaded[0].1.to_f64_vec().unwrap();
623 for (a, b) in orig.iter().zip(restored.iter()) {
624 assert!((a - b).abs() < 1e-6, "mismatch: {a} vs {b}");
625 }
626 }
627
628 #[test]
629 fn test_empty() {
630 let tensors: Vec<(String, CpuTensor)> = vec![];
631 let bytes = to_bytes(&tensors).unwrap();
632 let loaded = from_bytes::<CpuBackend>(&bytes, &CpuDevice).unwrap();
633 assert_eq!(loaded.len(), 0);
634 }
635
636 #[test]
637 fn test_file_roundtrip() {
638 let dev = CpuDevice;
639 let t = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0], (3,), DType::F32, &dev).unwrap();
640 let tensors = vec![("test_param".to_string(), t.clone())];
641
642 let path = std::env::temp_dir().join("shrew_test_safetensors.safetensors");
643 save(&path, &tensors).unwrap();
644 let loaded = load::<CpuBackend>(&path, &dev).unwrap();
645 std::fs::remove_file(&path).ok();
646
647 assert_eq!(loaded.len(), 1);
648 let orig = t.to_f64_vec().unwrap();
649 let restored = loaded[0].1.to_f64_vec().unwrap();
650 assert_eq!(orig, restored);
651 }
652
653 #[test]
654 fn test_header_format() {
655 let metas = vec![TensorMeta {
657 name: "layer.weight".to_string(),
658 dtype: DType::F32,
659 shape: vec![3, 4],
660 data_offset_start: 0,
661 data_offset_end: 48,
662 }];
663 let json = build_header_json(&metas, None);
664
665 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
667 let entry = parsed.get("layer.weight").unwrap();
668 assert_eq!(entry["dtype"].as_str().unwrap(), "F32");
669 assert_eq!(entry["shape"][0].as_u64().unwrap(), 3);
670 assert_eq!(entry["shape"][1].as_u64().unwrap(), 4);
671 assert_eq!(entry["data_offsets"][0].as_u64().unwrap(), 0);
672 assert_eq!(entry["data_offsets"][1].as_u64().unwrap(), 48);
673 }
674
675 #[test]
676 fn test_json_escape() {
677 assert_eq!(json_escape("hello"), "\"hello\"");
678 assert_eq!(json_escape("a\"b"), "\"a\\\"b\"");
679 assert_eq!(json_escape("a\\b"), "\"a\\\\b\"");
680 assert_eq!(json_escape("a\nb"), "\"a\\nb\"");
681 }
682
683 #[test]
684 fn test_state_dict_roundtrip() {
685 let dev = CpuDevice;
686 let w = CpuTensor::from_f64_slice(&[1.0, 2.0], (1, 2), DType::F32, &dev).unwrap();
687 let b = CpuTensor::from_f64_slice(&[0.5], (1,), DType::F32, &dev).unwrap();
688
689 let tensors = vec![("fc.weight".to_string(), w), ("fc.bias".to_string(), b)];
690
691 let path = std::env::temp_dir().join("shrew_test_state_dict.safetensors");
692 save(&path, &tensors).unwrap();
693 let sd = load_state_dict::<CpuBackend>(&path, &dev).unwrap();
694 std::fs::remove_file(&path).ok();
695
696 assert!(sd.contains_key("fc.weight"));
697 assert!(sd.contains_key("fc.bias"));
698 assert_eq!(sd["fc.weight"].dims(), &[1, 2]);
699 assert_eq!(sd["fc.bias"].dims(), &[1]);
700 }
701
702 #[test]
703 fn test_save_module_linear() {
704 use shrew_nn::{Linear, Module};
705
706 let dev = CpuDevice;
707 let linear = Linear::<CpuBackend>::new(3, 2, true, DType::F32, &dev).unwrap();
708
709 let path = std::env::temp_dir().join("shrew_test_module_save.safetensors");
710 save_module(&path, &linear).unwrap();
711 let loaded = load::<CpuBackend>(&path, &dev).unwrap();
712 std::fs::remove_file(&path).ok();
713
714 let map: HashMap<String, CpuTensor> = loaded.into_iter().collect();
715 assert!(map.contains_key("weight"), "missing 'weight' key");
716 assert!(map.contains_key("bias"), "missing 'bias' key");
717 assert_eq!(map["weight"].dims(), &[2, 3]);
718 assert_eq!(map["bias"].dims(), &[1, 2]);
719 }
720}