1use std::collections::HashMap;
28use std::fs;
29use std::path::Path;
30
31use shrew_core::backend::Backend;
32use shrew_core::dtype::DType;
33use shrew_core::error::Result;
34use shrew_core::tensor::Tensor;
35
36use shrew_nn::Module;
37
38const ONNX_IR_VERSION: i64 = 9;
44const ONNX_OPSET_VERSION: i64 = 17;
46const ONNX_DOMAIN: &str = "";
48
49const ONNX_FLOAT: i32 = 1;
52const ONNX_DOUBLE: i32 = 11;
53const ONNX_FLOAT16: i32 = 10;
54const ONNX_BFLOAT16: i32 = 16;
55const ONNX_INT8: i32 = 3;
56const ONNX_UINT8: i32 = 2;
57const ONNX_INT32: i32 = 6;
58const ONNX_INT64: i32 = 7;
59const ONNX_UINT32: i32 = 12;
60
61struct PbEncoder {
70 buf: Vec<u8>,
71}
72
73impl PbEncoder {
74 fn new() -> Self {
75 Self { buf: Vec::new() }
76 }
77
78 fn into_bytes(self) -> Vec<u8> {
79 self.buf
80 }
81
82 fn write_varint(&mut self, mut val: u64) {
84 loop {
85 let byte = (val & 0x7F) as u8;
86 val >>= 7;
87 if val == 0 {
88 self.buf.push(byte);
89 break;
90 } else {
91 self.buf.push(byte | 0x80);
92 }
93 }
94 }
95
96 fn write_tag(&mut self, field: u32, wire_type: u32) {
98 self.write_varint(((field as u64) << 3) | wire_type as u64);
99 }
100
101 fn write_varint_field(&mut self, field: u32, val: u64) {
103 self.write_tag(field, 0);
104 self.write_varint(val);
105 }
106
107 fn write_sint64_field(&mut self, field: u32, val: i64) {
109 self.write_varint_field(field, val as u64);
110 }
111
112 fn write_bytes_field(&mut self, field: u32, data: &[u8]) {
114 self.write_tag(field, 2);
115 self.write_varint(data.len() as u64);
116 self.buf.extend_from_slice(data);
117 }
118
119 fn write_string_field(&mut self, field: u32, val: &str) {
121 self.write_bytes_field(field, val.as_bytes());
122 }
123
124 fn write_message_field(&mut self, field: u32, encoder: &PbEncoder) {
126 self.write_bytes_field(field, &encoder.buf);
127 }
128
129 #[allow(dead_code)]
131 fn write_float_data(&mut self, field: u32, data: &[f32]) {
132 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
133 self.write_bytes_field(field, &bytes);
134 }
135
136 #[allow(dead_code)]
138 fn write_double_data(&mut self, field: u32, data: &[f64]) {
139 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
140 self.write_bytes_field(field, &bytes);
141 }
142}
143
144struct PbDecoder<'a> {
150 data: &'a [u8],
151 pos: usize,
152}
153
154impl<'a> PbDecoder<'a> {
155 fn new(data: &'a [u8]) -> Self {
156 Self { data, pos: 0 }
157 }
158
159 fn remaining(&self) -> usize {
160 self.data.len() - self.pos
161 }
162
163 fn read_varint(&mut self) -> Result<u64> {
164 let mut result: u64 = 0;
165 let mut shift = 0;
166 loop {
167 if self.pos >= self.data.len() {
168 return Err(shrew_core::Error::msg("protobuf: unexpected end of data"));
169 }
170 let byte = self.data[self.pos];
171 self.pos += 1;
172 result |= ((byte & 0x7F) as u64) << shift;
173 if byte & 0x80 == 0 {
174 break;
175 }
176 shift += 7;
177 if shift > 63 {
178 return Err(shrew_core::Error::msg("protobuf: varint too long"));
179 }
180 }
181 Ok(result)
182 }
183
184 fn read_tag(&mut self) -> Result<(u32, u32)> {
185 let val = self.read_varint()?;
186 let field = (val >> 3) as u32;
187 let wire_type = (val & 0x7) as u32;
188 Ok((field, wire_type))
189 }
190
191 fn read_bytes(&mut self) -> Result<&'a [u8]> {
192 let len = self.read_varint()? as usize;
193 if self.pos + len > self.data.len() {
194 return Err(shrew_core::Error::msg("protobuf: bytes field exceeds data"));
195 }
196 let result = &self.data[self.pos..self.pos + len];
197 self.pos += len;
198 Ok(result)
199 }
200
201 fn read_string(&mut self) -> Result<String> {
202 let bytes = self.read_bytes()?;
203 String::from_utf8(bytes.to_vec())
204 .map_err(|_| shrew_core::Error::msg("protobuf: invalid UTF-8 string"))
205 }
206
207 fn skip_field(&mut self, wire_type: u32) -> Result<()> {
208 match wire_type {
209 0 => {
210 self.read_varint()?;
211 }
212 1 => {
213 self.pos += 8;
214 } 2 => {
216 self.read_bytes()?;
217 }
218 5 => {
219 self.pos += 4;
220 } _ => {
222 return Err(shrew_core::Error::msg(format!(
223 "protobuf: unsupported wire type {wire_type}"
224 )))
225 }
226 }
227 Ok(())
228 }
229}
230
231#[derive(Debug, Clone)]
237pub struct OnnxTensor {
238 pub name: String,
240 pub data_type: i32,
242 pub dims: Vec<i64>,
244 pub float_data: Vec<f32>,
246 pub double_data: Vec<f64>,
248 pub raw_data: Vec<u8>,
250}
251
252impl OnnxTensor {
253 fn new(name: &str) -> Self {
254 Self {
255 name: name.to_string(),
256 data_type: ONNX_FLOAT,
257 dims: Vec::new(),
258 float_data: Vec::new(),
259 double_data: Vec::new(),
260 raw_data: Vec::new(),
261 }
262 }
263
264 fn encode(&self) -> Vec<u8> {
266 let mut enc = PbEncoder::new();
267 for &d in &self.dims {
269 enc.write_sint64_field(1, d);
270 }
271 enc.write_varint_field(2, self.data_type as u64);
273 if !self.name.is_empty() {
275 enc.write_string_field(8, &self.name);
276 }
277 if !self.float_data.is_empty() {
279 let bytes: Vec<u8> = self
281 .float_data
282 .iter()
283 .flat_map(|v| v.to_le_bytes())
284 .collect();
285 enc.write_bytes_field(13, &bytes);
286 } else if !self.double_data.is_empty() {
287 let bytes: Vec<u8> = self
288 .double_data
289 .iter()
290 .flat_map(|v| v.to_le_bytes())
291 .collect();
292 enc.write_bytes_field(13, &bytes);
293 } else if !self.raw_data.is_empty() {
294 enc.write_bytes_field(13, &self.raw_data);
295 }
296 enc.into_bytes()
297 }
298
299 fn decode(data: &[u8]) -> Result<Self> {
301 let mut dec = PbDecoder::new(data);
302 let mut tensor = OnnxTensor::new("");
303 while dec.remaining() > 0 {
304 let (field, wire_type) = dec.read_tag()?;
305 match (field, wire_type) {
306 (1, 0) => {
307 let v = dec.read_varint()? as i64;
309 tensor.dims.push(v);
310 }
311 (1, 2) => {
312 let bytes = dec.read_bytes()?;
314 let mut sub = PbDecoder::new(bytes);
315 while sub.remaining() > 0 {
316 tensor.dims.push(sub.read_varint()? as i64);
317 }
318 }
319 (2, 0) => {
320 tensor.data_type = dec.read_varint()? as i32;
322 }
323 (8, 2) => {
324 tensor.name = dec.read_string()?;
326 }
327 (13, 2) => {
328 tensor.raw_data = dec.read_bytes()?.to_vec();
330 }
331 (4, 2) => {
332 let bytes = dec.read_bytes()?;
334 for chunk in bytes.chunks_exact(4) {
335 let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
336 tensor.float_data.push(val);
337 }
338 }
339 (4, 5) => {
340 let bytes = &dec.data[dec.pos..dec.pos + 4];
342 dec.pos += 4;
343 let val = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
344 tensor.float_data.push(val);
345 }
346 (5, 2) => {
347 let bytes = dec.read_bytes()?;
349 for chunk in bytes.chunks_exact(8) {
350 let val = f64::from_le_bytes([
351 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
352 chunk[7],
353 ]);
354 tensor.double_data.push(val);
355 }
356 }
357 _ => {
358 dec.skip_field(wire_type)?;
359 }
360 }
361 }
362 Ok(tensor)
363 }
364
365 fn to_f64_vec(&self) -> Vec<f64> {
367 if !self.double_data.is_empty() {
368 return self.double_data.clone();
369 }
370 if !self.float_data.is_empty() {
371 return self.float_data.iter().map(|&v| v as f64).collect();
372 }
373 if !self.raw_data.is_empty() {
374 match self.data_type {
375 ONNX_FLOAT => self
376 .raw_data
377 .chunks_exact(4)
378 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
379 .collect(),
380 ONNX_DOUBLE => self
381 .raw_data
382 .chunks_exact(8)
383 .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
384 .collect(),
385 ONNX_FLOAT16 => self
386 .raw_data
387 .chunks_exact(2)
388 .map(|c| {
389 let bits = u16::from_le_bytes([c[0], c[1]]);
390 half::f16::from_bits(bits).to_f64()
391 })
392 .collect(),
393 _ => Vec::new(),
394 }
395 } else {
396 Vec::new()
397 }
398 }
399}
400
401fn dtype_to_onnx(dtype: DType) -> i32 {
403 match dtype {
404 DType::F32 => ONNX_FLOAT,
405 DType::F64 => ONNX_DOUBLE,
406 DType::F16 => ONNX_FLOAT16,
407 DType::BF16 => ONNX_BFLOAT16,
408 DType::U8 => ONNX_UINT8,
409 DType::U32 => ONNX_UINT32,
410 DType::I64 => ONNX_INT64,
411 }
412}
413
414fn onnx_to_dtype(onnx_type: i32) -> Result<DType> {
416 match onnx_type {
417 ONNX_FLOAT => Ok(DType::F32),
418 ONNX_DOUBLE => Ok(DType::F64),
419 ONNX_FLOAT16 => Ok(DType::F16),
420 ONNX_BFLOAT16 => Ok(DType::BF16),
421 ONNX_UINT8 => Ok(DType::U8),
422 ONNX_UINT32 => Ok(DType::U32),
423 ONNX_INT64 => Ok(DType::I64),
424 ONNX_INT8 => Ok(DType::U8), ONNX_INT32 => Ok(DType::I64), _ => Err(shrew_core::Error::msg(format!(
427 "unsupported ONNX data type: {onnx_type}"
428 ))),
429 }
430}
431
432#[derive(Debug, Clone)]
438pub struct OnnxNode {
439 pub inputs: Vec<String>,
441 pub outputs: Vec<String>,
443 pub op_type: String,
445 pub name: String,
447 pub attributes: HashMap<String, OnnxAttribute>,
449}
450
451#[derive(Debug, Clone)]
453pub enum OnnxAttribute {
454 Int(i64),
455 Float(f32),
456 String(String),
457 Ints(Vec<i64>),
458 Floats(Vec<f32>),
459}
460
461impl OnnxNode {
462 fn encode(&self) -> Vec<u8> {
463 let mut enc = PbEncoder::new();
464 for input in &self.inputs {
466 enc.write_string_field(1, input);
467 }
468 for output in &self.outputs {
470 enc.write_string_field(2, output);
471 }
472 if !self.name.is_empty() {
474 enc.write_string_field(3, &self.name);
475 }
476 enc.write_string_field(4, &self.op_type);
478 for (key, val) in &self.attributes {
480 let attr = encode_attribute(key, val);
481 enc.write_message_field(5, &attr);
482 }
483 enc.into_bytes()
484 }
485}
486
487fn encode_attribute(name: &str, val: &OnnxAttribute) -> PbEncoder {
488 let mut enc = PbEncoder::new();
489 enc.write_string_field(1, name); match val {
491 OnnxAttribute::Int(i) => {
492 enc.write_varint_field(2, 2); enc.write_sint64_field(3, *i); }
495 OnnxAttribute::Float(f) => {
496 enc.write_varint_field(2, 1); enc.write_tag(4, 5);
499 enc.buf.extend_from_slice(&f.to_le_bytes());
500 }
501 OnnxAttribute::String(s) => {
502 enc.write_varint_field(2, 3); enc.write_bytes_field(5, s.as_bytes()); }
505 OnnxAttribute::Ints(ints) => {
506 enc.write_varint_field(2, 7); for &i in ints {
508 enc.write_sint64_field(8, i); }
510 }
511 OnnxAttribute::Floats(floats) => {
512 enc.write_varint_field(2, 6); for &f in floats {
514 enc.write_tag(7, 5); enc.buf.extend_from_slice(&f.to_le_bytes());
516 }
517 }
518 }
519 enc
520}
521
522#[derive(Debug, Clone)]
528pub struct OnnxModel {
529 pub producer_name: String,
531 pub producer_version: String,
533 pub graph_name: String,
535 pub nodes: Vec<OnnxNode>,
537 pub initializers: Vec<OnnxTensor>,
539 pub inputs: Vec<(String, Vec<i64>, i32)>,
541 pub outputs: Vec<(String, Vec<i64>, i32)>,
543}
544
545impl OnnxModel {
546 pub fn new(graph_name: &str) -> Self {
548 Self {
549 producer_name: "Shrew".to_string(),
550 producer_version: "0.1.0".to_string(),
551 graph_name: graph_name.to_string(),
552 nodes: Vec::new(),
553 initializers: Vec::new(),
554 inputs: Vec::new(),
555 outputs: Vec::new(),
556 }
557 }
558
559 pub fn to_bytes(&self) -> Vec<u8> {
561 let mut graph = PbEncoder::new();
563
564 for node in &self.nodes {
566 let node_bytes = node.encode();
567 graph.write_bytes_field(1, &node_bytes);
568 }
569
570 graph.write_string_field(2, &self.graph_name);
572
573 for init in &self.initializers {
575 let tensor_bytes = init.encode();
576 graph.write_bytes_field(5, &tensor_bytes);
577 }
578
579 for (name, dims, dtype) in &self.inputs {
581 let vi = encode_value_info(name, dims, *dtype);
582 graph.write_message_field(11, &vi);
583 }
584
585 for (name, dims, dtype) in &self.outputs {
587 let vi = encode_value_info(name, dims, *dtype);
588 graph.write_message_field(12, &vi);
589 }
590
591 let mut model = PbEncoder::new();
593 model.write_varint_field(1, ONNX_IR_VERSION as u64);
595 model.write_string_field(2, &self.producer_name);
597 model.write_string_field(3, &self.producer_version);
599 model.write_message_field(7, &graph);
601 let mut opset = PbEncoder::new();
603 opset.write_string_field(1, ONNX_DOMAIN); opset.write_varint_field(2, ONNX_OPSET_VERSION as u64); model.write_message_field(8, &opset);
606
607 model.into_bytes()
608 }
609
610 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
612 let bytes = self.to_bytes();
613 fs::write(path.as_ref(), &bytes)
614 .map_err(|e| shrew_core::Error::msg(format!("failed to write ONNX file: {e}")))
615 }
616}
617
618fn encode_value_info(name: &str, dims: &[i64], data_type: i32) -> PbEncoder {
620 let mut vi = PbEncoder::new();
621 vi.write_string_field(1, name); let mut type_proto = PbEncoder::new();
625 let mut tensor_type = PbEncoder::new();
627 tensor_type.write_varint_field(1, data_type as u64); let mut shape = PbEncoder::new();
630 for &d in dims {
631 let mut dim = PbEncoder::new();
632 if d >= 0 {
633 dim.write_sint64_field(1, d); } else {
635 dim.write_string_field(2, "dynamic"); }
637 shape.write_message_field(1, &dim);
638 }
639 tensor_type.write_message_field(2, &shape);
640 type_proto.write_message_field(1, &tensor_type);
641 vi.write_message_field(2, &type_proto);
642
643 vi
644}
645
646pub fn export_weights<P, B, M>(
668 path: P,
669 module: &M,
670 model_name: &str,
671 input_shape: &[i64],
672) -> Result<()>
673where
674 P: AsRef<Path>,
675 B: Backend,
676 M: Module<B>,
677{
678 let named = module.named_parameters();
679
680 let mut model = OnnxModel::new(model_name);
681
682 model
684 .inputs
685 .push(("input".to_string(), input_shape.to_vec(), ONNX_FLOAT));
686
687 for (name, tensor) in &named {
689 let data = tensor.to_f64_vec()?;
690 let dims: Vec<i64> = tensor.dims().iter().map(|&d| d as i64).collect();
691
692 let mut onnx_tensor = OnnxTensor::new(name);
693 onnx_tensor.data_type = dtype_to_onnx(tensor.dtype());
694 onnx_tensor.dims = dims;
695
696 match tensor.dtype() {
697 DType::F32 => {
698 onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
699 }
700 DType::F64 => {
701 onnx_tensor.double_data = data;
702 }
703 _ => {
704 onnx_tensor.data_type = ONNX_FLOAT;
706 onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
707 }
708 }
709
710 model.initializers.push(onnx_tensor);
711 }
712
713 model.outputs.push((
716 "output".to_string(),
717 vec![-1], ONNX_FLOAT,
719 ));
720
721 model.save(path)
722}
723
724pub fn export_tensors<P, B>(
728 path: P,
729 tensors: &[(String, Tensor<B>)],
730 model_name: &str,
731) -> Result<()>
732where
733 P: AsRef<Path>,
734 B: Backend,
735{
736 let mut model = OnnxModel::new(model_name);
737
738 for (name, tensor) in tensors {
739 let data = tensor.to_f64_vec()?;
740 let dims: Vec<i64> = tensor.dims().iter().map(|&d| d as i64).collect();
741
742 let mut onnx_tensor = OnnxTensor::new(name);
743 onnx_tensor.data_type = dtype_to_onnx(tensor.dtype());
744 onnx_tensor.dims = dims;
745
746 match tensor.dtype() {
747 DType::F32 => {
748 onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
749 }
750 DType::F64 => {
751 onnx_tensor.double_data = data;
752 }
753 _ => {
754 onnx_tensor.data_type = ONNX_FLOAT;
755 onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
756 }
757 }
758
759 model.initializers.push(onnx_tensor);
760 }
761
762 model.save(path)
763}
764
765pub fn load_onnx_weights<B: Backend>(
781 path: impl AsRef<Path>,
782 device: &B::Device,
783) -> Result<HashMap<String, Tensor<B>>> {
784 let bytes = fs::read(path.as_ref())
785 .map_err(|e| shrew_core::Error::msg(format!("failed to read ONNX file: {e}")))?;
786
787 load_onnx_weights_from_bytes::<B>(&bytes, device)
788}
789
790pub fn load_onnx_weights_from_bytes<B: Backend>(
792 data: &[u8],
793 device: &B::Device,
794) -> Result<HashMap<String, Tensor<B>>> {
795 let mut dec = PbDecoder::new(data);
796 let mut result = HashMap::new();
797
798 while dec.remaining() > 0 {
800 let (field, wire_type) = dec.read_tag()?;
801 match (field, wire_type) {
802 (7, 2) => {
803 let graph_bytes = dec.read_bytes()?;
805 let tensors = parse_graph_initializers::<B>(graph_bytes, device)?;
806 result.extend(tensors);
807 }
808 _ => {
809 dec.skip_field(wire_type)?;
810 }
811 }
812 }
813
814 Ok(result)
815}
816
817fn parse_graph_initializers<B: Backend>(
819 data: &[u8],
820 device: &B::Device,
821) -> Result<HashMap<String, Tensor<B>>> {
822 let mut dec = PbDecoder::new(data);
823 let mut result = HashMap::new();
824
825 while dec.remaining() > 0 {
826 let (field, wire_type) = dec.read_tag()?;
827 match (field, wire_type) {
828 (5, 2) => {
829 let tensor_bytes = dec.read_bytes()?;
831 let onnx_tensor = OnnxTensor::decode(tensor_bytes)?;
832
833 if !onnx_tensor.name.is_empty() {
834 let dtype = onnx_to_dtype(onnx_tensor.data_type)?;
835 let shape: Vec<usize> = onnx_tensor.dims.iter().map(|&d| d as usize).collect();
836 let f64_data = onnx_tensor.to_f64_vec();
837
838 if !f64_data.is_empty() {
839 let tensor = Tensor::<B>::from_f64_slice(&f64_data, shape, dtype, device)?;
840 result.insert(onnx_tensor.name.clone(), tensor);
841 }
842 }
843 }
844 _ => {
845 dec.skip_field(wire_type)?;
846 }
847 }
848 }
849
850 Ok(result)
851}
852
853#[derive(Debug, Clone)]
859pub struct OnnxGraph {
860 pub nodes: Vec<OnnxNode>,
862 pub initializer_protos: Vec<OnnxTensor>,
864 pub input_names: Vec<String>,
866 pub output_names: Vec<String>,
868 pub name: String,
870}
871
872fn decode_node(data: &[u8]) -> Result<OnnxNode> {
874 let mut dec = PbDecoder::new(data);
875 let mut node = OnnxNode {
876 inputs: Vec::new(),
877 outputs: Vec::new(),
878 op_type: String::new(),
879 name: String::new(),
880 attributes: HashMap::new(),
881 };
882 while dec.remaining() > 0 {
883 let (field, wire_type) = dec.read_tag()?;
884 match (field, wire_type) {
885 (1, 2) => node.inputs.push(dec.read_string()?),
886 (2, 2) => node.outputs.push(dec.read_string()?),
887 (3, 2) => node.name = dec.read_string()?,
888 (4, 2) => node.op_type = dec.read_string()?,
889 (5, 2) => {
890 let attr_bytes = dec.read_bytes()?;
891 let (key, val) = decode_attribute(attr_bytes)?;
892 node.attributes.insert(key, val);
893 }
894 _ => dec.skip_field(wire_type)?,
895 }
896 }
897 Ok(node)
898}
899
900fn decode_attribute(data: &[u8]) -> Result<(String, OnnxAttribute)> {
902 let mut dec = PbDecoder::new(data);
903 let mut name = String::new();
904 let mut attr_type: u64 = 0;
905 let mut int_val: i64 = 0;
906 let mut float_val: f32 = 0.0;
907 let mut string_val = Vec::new();
908 let mut ints_val: Vec<i64> = Vec::new();
909 let mut floats_val: Vec<f32> = Vec::new();
910 while dec.remaining() > 0 {
911 let (field, wire_type) = dec.read_tag()?;
912 match (field, wire_type) {
913 (1, 2) => name = dec.read_string()?, (2, 0) => attr_type = dec.read_varint()?, (3, 0) => int_val = dec.read_varint()? as i64, (4, 5) => {
917 if dec.pos + 4 > dec.data.len() {
919 return Err(shrew_core::Error::msg("attribute: unexpected end"));
920 }
921 let b = &dec.data[dec.pos..dec.pos + 4];
922 float_val = f32::from_le_bytes([b[0], b[1], b[2], b[3]]);
923 dec.pos += 4;
924 }
925 (5, 2) => string_val = dec.read_bytes()?.to_vec(), (7, 5) => {
927 if dec.pos + 4 > dec.data.len() {
929 return Err(shrew_core::Error::msg("attribute: unexpected end"));
930 }
931 let b = &dec.data[dec.pos..dec.pos + 4];
932 floats_val.push(f32::from_le_bytes([b[0], b[1], b[2], b[3]]));
933 dec.pos += 4;
934 }
935 (7, 2) => {
936 let bytes = dec.read_bytes()?;
938 for c in bytes.chunks_exact(4) {
939 floats_val.push(f32::from_le_bytes([c[0], c[1], c[2], c[3]]));
940 }
941 }
942 (8, 0) => ints_val.push(dec.read_varint()? as i64), (8, 2) => {
944 let bytes = dec.read_bytes()?;
946 let mut sub = PbDecoder::new(bytes);
947 while sub.remaining() > 0 {
948 ints_val.push(sub.read_varint()? as i64);
949 }
950 }
951 _ => dec.skip_field(wire_type)?,
952 }
953 }
954 let val = match attr_type {
955 1 => OnnxAttribute::Float(float_val),
956 2 => OnnxAttribute::Int(int_val),
957 3 => OnnxAttribute::String(String::from_utf8(string_val).unwrap_or_default()),
958 6 => OnnxAttribute::Floats(floats_val),
959 7 => OnnxAttribute::Ints(ints_val),
960 _ => OnnxAttribute::Int(int_val), };
962 Ok((name, val))
963}
964
965fn parse_graph_proto(data: &[u8]) -> Result<OnnxGraph> {
967 let mut dec = PbDecoder::new(data);
968 let mut graph = OnnxGraph {
969 nodes: Vec::new(),
970 initializer_protos: Vec::new(),
971 input_names: Vec::new(),
972 output_names: Vec::new(),
973 name: String::new(),
974 };
975 while dec.remaining() > 0 {
976 let (field, wire_type) = dec.read_tag()?;
977 match (field, wire_type) {
978 (1, 2) => {
979 let node_bytes = dec.read_bytes()?;
980 graph.nodes.push(decode_node(node_bytes)?);
981 }
982 (2, 2) => graph.name = dec.read_string()?,
983 (5, 2) => {
984 let tensor_bytes = dec.read_bytes()?;
985 graph
986 .initializer_protos
987 .push(OnnxTensor::decode(tensor_bytes)?);
988 }
989 (11, 2) => {
990 let vi_bytes = dec.read_bytes()?;
992 let name = extract_value_info_name(vi_bytes)?;
993 graph.input_names.push(name);
994 }
995 (12, 2) => {
996 let vi_bytes = dec.read_bytes()?;
998 let name = extract_value_info_name(vi_bytes)?;
999 graph.output_names.push(name);
1000 }
1001 _ => dec.skip_field(wire_type)?,
1002 }
1003 }
1004 Ok(graph)
1005}
1006
1007fn extract_value_info_name(data: &[u8]) -> Result<String> {
1009 let mut dec = PbDecoder::new(data);
1010 while dec.remaining() > 0 {
1011 let (field, wire_type) = dec.read_tag()?;
1012 if field == 1 && wire_type == 2 {
1013 return dec.read_string();
1014 }
1015 dec.skip_field(wire_type)?;
1016 }
1017 Ok(String::new())
1018}
1019
1020pub fn load_onnx_graph(path: impl AsRef<Path>) -> Result<OnnxGraph> {
1022 let bytes = fs::read(path.as_ref())
1023 .map_err(|e| shrew_core::Error::msg(format!("failed to read ONNX file: {e}")))?;
1024 load_onnx_graph_from_bytes(&bytes)
1025}
1026
1027pub fn load_onnx_graph_from_bytes(data: &[u8]) -> Result<OnnxGraph> {
1029 let mut dec = PbDecoder::new(data);
1030 while dec.remaining() > 0 {
1031 let (field, wire_type) = dec.read_tag()?;
1032 if field == 7 && wire_type == 2 {
1033 let graph_bytes = dec.read_bytes()?;
1034 return parse_graph_proto(graph_bytes);
1035 }
1036 dec.skip_field(wire_type)?;
1037 }
1038 Err(shrew_core::Error::msg("ONNX file contains no graph"))
1039}
1040
1041pub fn run_onnx_graph<B: Backend>(
1063 graph: &OnnxGraph,
1064 inputs: &HashMap<String, Tensor<B>>,
1065 device: &B::Device,
1066) -> Result<HashMap<String, Tensor<B>>> {
1067 let mut env: HashMap<String, Tensor<B>> = HashMap::new();
1068
1069 for init in &graph.initializer_protos {
1071 if init.name.is_empty() {
1072 continue;
1073 }
1074 let dtype = onnx_to_dtype(init.data_type)?;
1075 let shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
1076 let f64_data = init.to_f64_vec();
1077 if !f64_data.is_empty() {
1078 let tensor = Tensor::<B>::from_f64_slice(&f64_data, shape, dtype, device)?;
1079 env.insert(init.name.clone(), tensor);
1080 }
1081 }
1082
1083 for (name, tensor) in inputs {
1085 env.insert(name.clone(), tensor.clone());
1086 }
1087
1088 for node in &graph.nodes {
1090 execute_node(node, &mut env, device)?;
1091 }
1092
1093 let mut outputs = HashMap::new();
1095 for name in &graph.output_names {
1096 if let Some(t) = env.get(name) {
1097 outputs.insert(name.clone(), t.clone());
1098 }
1099 }
1100 Ok(outputs)
1101}
1102
1103fn get_tensor<'a, B: Backend>(
1105 env: &'a HashMap<String, Tensor<B>>,
1106 name: &str,
1107) -> Result<&'a Tensor<B>> {
1108 env.get(name)
1109 .ok_or_else(|| shrew_core::Error::msg(format!("ONNX runtime: tensor '{name}' not found")))
1110}
1111
1112fn attr_i(node: &OnnxNode, key: &str, default: i64) -> i64 {
1114 match node.attributes.get(key) {
1115 Some(OnnxAttribute::Int(v)) => *v,
1116 _ => default,
1117 }
1118}
1119
1120fn attr_ints(node: &OnnxNode, key: &str) -> Vec<i64> {
1122 match node.attributes.get(key) {
1123 Some(OnnxAttribute::Ints(v)) => v.clone(),
1124 _ => Vec::new(),
1125 }
1126}
1127
1128fn attr_f(node: &OnnxNode, key: &str, default: f32) -> f32 {
1130 match node.attributes.get(key) {
1131 Some(OnnxAttribute::Float(v)) => *v,
1132 _ => default,
1133 }
1134}
1135
1136fn execute_node<B: Backend>(
1138 node: &OnnxNode,
1139 env: &mut HashMap<String, Tensor<B>>,
1140 device: &B::Device,
1141) -> Result<()> {
1142 match node.op_type.as_str() {
1143 "Add" => {
1145 let a = get_tensor(env, &node.inputs[0])?;
1146 let b = get_tensor(env, &node.inputs[1])?;
1147 let out = a.add(b)?;
1148 env.insert(node.outputs[0].clone(), out);
1149 }
1150 "Sub" => {
1151 let a = get_tensor(env, &node.inputs[0])?;
1152 let b = get_tensor(env, &node.inputs[1])?;
1153 let out = a.sub(b)?;
1154 env.insert(node.outputs[0].clone(), out);
1155 }
1156 "Mul" => {
1157 let a = get_tensor(env, &node.inputs[0])?;
1158 let b = get_tensor(env, &node.inputs[1])?;
1159 let out = a.mul(b)?;
1160 env.insert(node.outputs[0].clone(), out);
1161 }
1162 "Div" => {
1163 let a = get_tensor(env, &node.inputs[0])?;
1164 let b = get_tensor(env, &node.inputs[1])?;
1165 let out = a.div(b)?;
1166 env.insert(node.outputs[0].clone(), out);
1167 }
1168 "Pow" => {
1169 let a = get_tensor(env, &node.inputs[0])?;
1170 let b = get_tensor(env, &node.inputs[1])?;
1172 let exp_val = b.to_f64_vec()?;
1173 if exp_val.len() == 1 {
1174 let out = a.powf(exp_val[0])?;
1175 env.insert(node.outputs[0].clone(), out);
1176 } else {
1177 return Err(shrew_core::Error::msg(
1178 "ONNX Pow: only scalar exponent supported",
1179 ));
1180 }
1181 }
1182
1183 "MatMul" => {
1185 let a = get_tensor(env, &node.inputs[0])?;
1186 let b = get_tensor(env, &node.inputs[1])?;
1187 let out = a.matmul(b)?;
1188 env.insert(node.outputs[0].clone(), out);
1189 }
1190 "Gemm" => {
1191 let alpha = attr_f(node, "alpha", 1.0) as f64;
1193 let beta = attr_f(node, "beta", 1.0) as f64;
1194 let trans_a = attr_i(node, "transA", 0) != 0;
1195 let trans_b = attr_i(node, "transB", 0) != 0;
1196
1197 let mut a = get_tensor(env, &node.inputs[0])?.clone();
1198 let mut b = get_tensor(env, &node.inputs[1])?.clone();
1199
1200 if trans_a {
1201 a = a.t()?;
1202 }
1203 if trans_b {
1204 b = b.t()?;
1205 }
1206
1207 let mut out = a.matmul(&b)?;
1208 if (alpha - 1.0).abs() > 1e-7 {
1209 out = out.affine(alpha, 0.0)?;
1210 }
1211 if node.inputs.len() > 2 && !node.inputs[2].is_empty() {
1212 let c = get_tensor(env, &node.inputs[2])?;
1213 if (beta - 1.0).abs() > 1e-7 {
1214 let bc = c.affine(beta, 0.0)?;
1215 out = out.add(&bc)?;
1216 } else {
1217 out = out.add(c)?;
1218 }
1219 }
1220 env.insert(node.outputs[0].clone(), out);
1221 }
1222
1223 "Relu" => {
1225 let x = get_tensor(env, &node.inputs[0])?;
1226 env.insert(node.outputs[0].clone(), x.relu()?);
1227 }
1228 "Sigmoid" => {
1229 let x = get_tensor(env, &node.inputs[0])?;
1230 env.insert(node.outputs[0].clone(), x.sigmoid()?);
1231 }
1232 "Tanh" => {
1233 let x = get_tensor(env, &node.inputs[0])?;
1234 env.insert(node.outputs[0].clone(), x.tanh()?);
1235 }
1236 "Neg" => {
1237 let x = get_tensor(env, &node.inputs[0])?;
1238 env.insert(node.outputs[0].clone(), x.neg()?);
1239 }
1240 "Sqrt" => {
1241 let x = get_tensor(env, &node.inputs[0])?;
1242 env.insert(node.outputs[0].clone(), x.sqrt()?);
1243 }
1244 "Exp" => {
1245 let x = get_tensor(env, &node.inputs[0])?;
1246 env.insert(node.outputs[0].clone(), x.exp()?);
1247 }
1248 "Log" => {
1249 let x = get_tensor(env, &node.inputs[0])?;
1250 env.insert(node.outputs[0].clone(), x.log()?);
1251 }
1252 "Abs" => {
1253 let x = get_tensor(env, &node.inputs[0])?;
1254 env.insert(node.outputs[0].clone(), x.abs()?);
1255 }
1256
1257 "Softmax" => {
1259 let x = get_tensor(env, &node.inputs[0])?;
1260 let axis = attr_i(node, "axis", -1);
1261 let dim = if axis < 0 {
1262 (x.rank() as i64 + axis) as usize
1263 } else {
1264 axis as usize
1265 };
1266 env.insert(node.outputs[0].clone(), x.softmax(dim)?);
1267 }
1268 "LogSoftmax" => {
1269 let x = get_tensor(env, &node.inputs[0])?;
1270 let axis = attr_i(node, "axis", -1);
1271 let dim = if axis < 0 {
1272 (x.rank() as i64 + axis) as usize
1273 } else {
1274 axis as usize
1275 };
1276 env.insert(node.outputs[0].clone(), x.log_softmax(dim)?);
1277 }
1278
1279 "Clip" => {
1281 let x = get_tensor(env, &node.inputs[0])?;
1282 let min_val = if node.inputs.len() > 1 && !node.inputs[1].is_empty() {
1283 get_tensor(env, &node.inputs[1])?.to_f64_vec()?[0]
1284 } else {
1285 f64::NEG_INFINITY
1286 };
1287 let max_val = if node.inputs.len() > 2 && !node.inputs[2].is_empty() {
1288 get_tensor(env, &node.inputs[2])?.to_f64_vec()?[0]
1289 } else {
1290 f64::INFINITY
1291 };
1292 env.insert(node.outputs[0].clone(), x.clamp(min_val, max_val)?);
1293 }
1294
1295 "Reshape" => {
1297 let x = get_tensor(env, &node.inputs[0])?;
1298 let shape_tensor = get_tensor(env, &node.inputs[1])?;
1299 let shape_vals = shape_tensor.to_f64_vec()?;
1300
1301 let total = x.elem_count();
1303 let mut new_shape: Vec<usize> = shape_vals.iter().map(|&v| v as i64 as usize).collect();
1304 let neg_idx = new_shape.iter().position(|&s| s == usize::MAX); if let Some(idx) = neg_idx {
1306 let known: usize = new_shape
1307 .iter()
1308 .enumerate()
1309 .filter(|&(i, _)| i != idx)
1310 .map(|(_, &s)| s)
1311 .product();
1312 if known > 0 {
1313 new_shape[idx] = total / known;
1314 }
1315 }
1316 env.insert(node.outputs[0].clone(), x.reshape(new_shape)?);
1317 }
1318 "Transpose" => {
1319 let x = get_tensor(env, &node.inputs[0])?;
1320 let perm = attr_ints(node, "perm");
1321 if perm.is_empty() {
1322 let rank = x.rank();
1324 let rev: Vec<usize> = (0..rank).rev().collect();
1325 env.insert(node.outputs[0].clone(), x.permute(&rev)?);
1326 } else {
1327 let perm_usize: Vec<usize> = perm.iter().map(|&p| p as usize).collect();
1328 env.insert(node.outputs[0].clone(), x.permute(&perm_usize)?);
1329 }
1330 }
1331 "Flatten" => {
1332 let x = get_tensor(env, &node.inputs[0])?;
1333 let axis = attr_i(node, "axis", 1) as usize;
1334 env.insert(node.outputs[0].clone(), x.flatten(axis, x.rank() - 1)?);
1335 }
1336 "Squeeze" => {
1337 let x = get_tensor(env, &node.inputs[0])?;
1338 let axes = attr_ints(node, "axes");
1339 if axes.is_empty() {
1340 env.insert(node.outputs[0].clone(), x.squeeze_all());
1341 } else {
1342 let mut result = x.clone();
1343 let mut sorted_axes: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1345 sorted_axes.sort_unstable();
1346 sorted_axes.reverse();
1347 for ax in sorted_axes {
1348 result = result.squeeze(ax)?;
1349 }
1350 env.insert(node.outputs[0].clone(), result);
1351 }
1352 }
1353 "Unsqueeze" => {
1354 let x = get_tensor(env, &node.inputs[0])?;
1355 let axes = if node.inputs.len() > 1 && !node.inputs[1].is_empty() {
1356 let axes_t = get_tensor(env, &node.inputs[1])?;
1358 axes_t
1359 .to_f64_vec()?
1360 .iter()
1361 .map(|&v| v as i64)
1362 .collect::<Vec<_>>()
1363 } else {
1364 attr_ints(node, "axes")
1365 };
1366 let mut result = x.clone();
1367 let mut sorted_axes: Vec<usize> = axes
1368 .iter()
1369 .map(|&a| {
1370 if a < 0 {
1371 (result.rank() as i64 + a + 1) as usize
1372 } else {
1373 a as usize
1374 }
1375 })
1376 .collect();
1377 sorted_axes.sort_unstable();
1378 for ax in sorted_axes {
1379 result = result.unsqueeze(ax)?;
1380 }
1381 env.insert(node.outputs[0].clone(), result);
1382 }
1383 "Concat" => {
1384 let axis = attr_i(node, "axis", 0) as usize;
1385 let tensors: Vec<Tensor<B>> = node
1386 .inputs
1387 .iter()
1388 .map(|n| get_tensor(env, n).cloned())
1389 .collect::<Result<Vec<_>>>()?;
1390 let refs: Vec<Tensor<B>> = tensors;
1391 let out = Tensor::<B>::cat(&refs, axis)?;
1392 env.insert(node.outputs[0].clone(), out);
1393 }
1394
1395 "ReduceSum" => {
1397 let x = get_tensor(env, &node.inputs[0])?;
1398 let axes = attr_ints(node, "axes");
1399 let keepdims = attr_i(node, "keepdims", 1) != 0;
1400 let mut result = x.clone();
1401 if axes.is_empty() {
1402 result = result.sum_all()?;
1403 } else {
1404 let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1405 sorted.sort_unstable();
1406 sorted.reverse();
1407 for ax in sorted {
1408 result = result.sum(ax, keepdims)?;
1409 }
1410 }
1411 env.insert(node.outputs[0].clone(), result);
1412 }
1413 "ReduceMean" => {
1414 let x = get_tensor(env, &node.inputs[0])?;
1415 let axes = attr_ints(node, "axes");
1416 let keepdims = attr_i(node, "keepdims", 1) != 0;
1417 let mut result = x.clone();
1418 if axes.is_empty() {
1419 result = result.mean_all()?;
1420 } else {
1421 let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1422 sorted.sort_unstable();
1423 sorted.reverse();
1424 for ax in sorted {
1425 result = result.mean(ax, keepdims)?;
1426 }
1427 }
1428 env.insert(node.outputs[0].clone(), result);
1429 }
1430 "ReduceMax" => {
1431 let x = get_tensor(env, &node.inputs[0])?;
1432 let axes = attr_ints(node, "axes");
1433 let keepdims = attr_i(node, "keepdims", 1) != 0;
1434 let mut result = x.clone();
1435 let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1436 sorted.sort_unstable();
1437 sorted.reverse();
1438 for ax in sorted {
1439 result = result.max(ax, keepdims)?;
1440 }
1441 env.insert(node.outputs[0].clone(), result);
1442 }
1443 "ReduceMin" => {
1444 let x = get_tensor(env, &node.inputs[0])?;
1445 let axes = attr_ints(node, "axes");
1446 let keepdims = attr_i(node, "keepdims", 1) != 0;
1447 let mut result = x.clone();
1448 let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1449 sorted.sort_unstable();
1450 sorted.reverse();
1451 for ax in sorted {
1452 result = result.min(ax, keepdims)?;
1453 }
1454 env.insert(node.outputs[0].clone(), result);
1455 }
1456
1457 "Gather" => {
1459 let x = get_tensor(env, &node.inputs[0])?;
1460 let indices = get_tensor(env, &node.inputs[1])?;
1461 let axis = attr_i(node, "axis", 0) as usize;
1462 env.insert(node.outputs[0].clone(), x.gather(axis, indices)?);
1463 }
1464
1465 "BatchNormalization" => {
1467 let x = get_tensor(env, &node.inputs[0])?;
1469 let scale = get_tensor(env, &node.inputs[1])?;
1470 let bias = get_tensor(env, &node.inputs[2])?;
1471 let mean = get_tensor(env, &node.inputs[3])?;
1472 let var = get_tensor(env, &node.inputs[4])?;
1473 let eps = attr_f(node, "epsilon", 1e-5) as f64;
1474
1475 let x_sub = x.sub(mean)?;
1478 let std_inv = var.affine(1.0, eps)?.sqrt()?.reciprocal()?;
1479 let normed = x_sub.mul(&std_inv)?;
1480 let scaled = normed.mul(scale)?;
1481 let out = scaled.add(bias)?;
1482 env.insert(node.outputs[0].clone(), out);
1483 }
1484
1485 "Dropout" => {
1487 let x = get_tensor(env, &node.inputs[0])?.clone();
1489 env.insert(node.outputs[0].clone(), x.clone());
1490 if node.outputs.len() > 1 && !node.outputs[1].is_empty() {
1492 env.insert(node.outputs[1].clone(), x);
1493 }
1494 }
1495
1496 "Identity" => {
1498 let x = get_tensor(env, &node.inputs[0])?;
1499 env.insert(node.outputs[0].clone(), x.clone());
1500 }
1501
1502 "Shape" => {
1504 let x = get_tensor(env, &node.inputs[0])?;
1505 let shape: Vec<f64> = x.dims().iter().map(|&d| d as f64).collect();
1506 let n = shape.len();
1507 let out = Tensor::<B>::from_f64_slice(&shape, vec![n], DType::I64, device)?;
1508 env.insert(node.outputs[0].clone(), out);
1509 }
1510
1511 "Cast" => {
1513 let x = get_tensor(env, &node.inputs[0])?;
1514 let to = attr_i(node, "to", ONNX_FLOAT as i64) as i32;
1515 let target_dtype = onnx_to_dtype(to)?;
1516 env.insert(node.outputs[0].clone(), x.to_dtype(target_dtype)?);
1517 }
1518
1519 "Constant" => {
1521 if let Some(OnnxAttribute::Float(v)) = node.attributes.get("value_float") {
1523 let out = Tensor::<B>::from_f64_slice(&[*v as f64], vec![1], DType::F32, device)?;
1524 env.insert(node.outputs[0].clone(), out);
1525 } else if let Some(OnnxAttribute::Int(v)) = node.attributes.get("value_int") {
1526 let out = Tensor::<B>::from_f64_slice(&[*v as f64], vec![1], DType::I64, device)?;
1527 env.insert(node.outputs[0].clone(), out);
1528 } else if let Some(OnnxAttribute::Floats(v)) = node.attributes.get("value_floats") {
1529 let data: Vec<f64> = v.iter().map(|f| *f as f64).collect();
1530 let n = data.len();
1531 let out = Tensor::<B>::from_f64_slice(&data, vec![n], DType::F32, device)?;
1532 env.insert(node.outputs[0].clone(), out);
1533 } else if let Some(OnnxAttribute::Ints(v)) = node.attributes.get("value_ints") {
1534 let data: Vec<f64> = v.iter().map(|i| *i as f64).collect();
1535 let n = data.len();
1536 let out = Tensor::<B>::from_f64_slice(&data, vec![n], DType::I64, device)?;
1537 env.insert(node.outputs[0].clone(), out);
1538 } else {
1539 return Err(shrew_core::Error::msg(format!(
1540 "ONNX Constant: unsupported value attribute in node '{}'",
1541 node.name
1542 )));
1543 }
1544 }
1545
1546 other => {
1547 return Err(shrew_core::Error::msg(format!(
1548 "ONNX runtime: unsupported op '{other}' (node '{}')",
1549 node.name
1550 )));
1551 }
1552 }
1553 Ok(())
1554}
1555
1556#[cfg(test)]
1561mod tests {
1562 use super::*;
1563 use shrew_cpu::{CpuBackend, CpuDevice};
1564
1565 type B = CpuBackend;
1566 type T = Tensor<B>;
1567 const DEV: CpuDevice = CpuDevice;
1568
1569 #[test]
1570 fn test_protobuf_varint_roundtrip() {
1571 let mut enc = PbEncoder::new();
1572 enc.write_varint(0);
1573 enc.write_varint(1);
1574 enc.write_varint(127);
1575 enc.write_varint(128);
1576 enc.write_varint(300);
1577 enc.write_varint(16384);
1578
1579 let mut dec = PbDecoder::new(&enc.buf);
1580 assert_eq!(dec.read_varint().unwrap(), 0);
1581 assert_eq!(dec.read_varint().unwrap(), 1);
1582 assert_eq!(dec.read_varint().unwrap(), 127);
1583 assert_eq!(dec.read_varint().unwrap(), 128);
1584 assert_eq!(dec.read_varint().unwrap(), 300);
1585 assert_eq!(dec.read_varint().unwrap(), 16384);
1586 }
1587
1588 #[test]
1589 fn test_onnx_tensor_encode_decode() {
1590 let mut tensor = OnnxTensor::new("test_weight");
1591 tensor.data_type = ONNX_FLOAT;
1592 tensor.dims = vec![2, 3];
1593 tensor.float_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1594
1595 let encoded = tensor.encode();
1596 let decoded = OnnxTensor::decode(&encoded).unwrap();
1597
1598 assert_eq!(decoded.name, "test_weight");
1599 assert_eq!(decoded.data_type, ONNX_FLOAT);
1600 assert_eq!(decoded.dims, vec![2, 3]);
1601 let data = decoded.to_f64_vec();
1603 assert_eq!(data.len(), 6);
1604 for (a, b) in tensor.float_data.iter().zip(data.iter()) {
1605 assert!((*a as f64 - b).abs() < 1e-6);
1606 }
1607 }
1608
1609 #[test]
1610 fn test_export_import_roundtrip() {
1611 let linear = shrew_nn::Linear::<B>::new(4, 3, true, DType::F32, &DEV).unwrap();
1612
1613 let path = std::env::temp_dir().join("shrew_test_onnx.onnx");
1615 export_weights(&path, &linear, "test_model", &[1, 4]).unwrap();
1616
1617 let weights = load_onnx_weights::<B>(&path, &DEV).unwrap();
1619
1620 assert_eq!(weights.len(), 2);
1622
1623 let named = linear.named_parameters();
1625 for (name, original) in &named {
1626 let loaded = weights.get(name).expect(&format!("missing: {name}"));
1627 assert_eq!(original.dims(), loaded.dims());
1628
1629 let orig_data = original.to_f64_vec().unwrap();
1631 let load_data = loaded.to_f64_vec().unwrap();
1632 for (a, b) in orig_data.iter().zip(load_data.iter()) {
1633 assert!((a - b).abs() < 1e-5, "mismatch for {name}: {a} vs {b}");
1634 }
1635 }
1636
1637 let _ = fs::remove_file(&path);
1639 }
1640
1641 #[test]
1642 fn test_export_tensors() {
1643 let t1 = T::randn(vec![2, 3], DType::F32, &DEV).unwrap();
1644 let t2 = T::ones(vec![5], DType::F32, &DEV).unwrap();
1645
1646 let tensors = vec![
1647 ("weight".to_string(), t1.clone()),
1648 ("bias".to_string(), t2.clone()),
1649 ];
1650
1651 let path = std::env::temp_dir().join("shrew_test_tensors.onnx");
1652 export_tensors(&path, &tensors, "tensors_model").unwrap();
1653
1654 let loaded = load_onnx_weights::<B>(&path, &DEV).unwrap();
1655 assert_eq!(loaded.len(), 2);
1656 assert_eq!(loaded.get("weight").unwrap().dims(), &[2, 3]);
1657 assert_eq!(loaded.get("bias").unwrap().dims(), &[5]);
1658
1659 let _ = fs::remove_file(&path);
1660 }
1661
1662 #[test]
1663 fn test_onnx_model_builder() {
1664 let mut model = OnnxModel::new("test_graph");
1665 model
1666 .inputs
1667 .push(("X".to_string(), vec![1, 784], ONNX_FLOAT));
1668 model
1669 .outputs
1670 .push(("Y".to_string(), vec![1, 10], ONNX_FLOAT));
1671
1672 model.nodes.push(OnnxNode {
1673 inputs: vec!["X".to_string(), "weight".to_string()],
1674 outputs: vec!["matmul_out".to_string()],
1675 op_type: "MatMul".to_string(),
1676 name: "matmul_0".to_string(),
1677 attributes: HashMap::new(),
1678 });
1679
1680 let mut attrs = HashMap::new();
1681 attrs.insert("axis".to_string(), OnnxAttribute::Int(1));
1682 model.nodes.push(OnnxNode {
1683 inputs: vec!["matmul_out".to_string()],
1684 outputs: vec!["Y".to_string()],
1685 op_type: "Softmax".to_string(),
1686 name: "softmax_0".to_string(),
1687 attributes: attrs,
1688 });
1689
1690 let bytes = model.to_bytes();
1691 assert!(!bytes.is_empty());
1692 assert!(bytes.len() > 20); }
1694
1695 #[test]
1696 fn test_dtype_conversion() {
1697 assert_eq!(dtype_to_onnx(DType::F32), ONNX_FLOAT);
1698 assert_eq!(dtype_to_onnx(DType::F64), ONNX_DOUBLE);
1699 assert_eq!(dtype_to_onnx(DType::F16), ONNX_FLOAT16);
1700 assert_eq!(onnx_to_dtype(ONNX_FLOAT).unwrap(), DType::F32);
1701 assert_eq!(onnx_to_dtype(ONNX_DOUBLE).unwrap(), DType::F64);
1702 assert_eq!(onnx_to_dtype(ONNX_FLOAT16).unwrap(), DType::F16);
1703 }
1704
1705 #[test]
1706 fn test_double_data_roundtrip() {
1707 let t = T::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2], DType::F64, &DEV).unwrap();
1708
1709 let tensors = vec![("w".to_string(), t.clone())];
1710 let path = std::env::temp_dir().join("shrew_test_f64.onnx");
1711 export_tensors(&path, &tensors, "f64_model").unwrap();
1712
1713 let loaded = load_onnx_weights::<B>(&path, &DEV).unwrap();
1714 let w = loaded.get("w").unwrap();
1715 assert_eq!(w.dims(), &[2, 2]);
1716
1717 let orig = t.to_f64_vec().unwrap();
1718 let load = w.to_f64_vec().unwrap();
1719 for (a, b) in orig.iter().zip(load.iter()) {
1720 assert!((a - b).abs() < 1e-10);
1721 }
1722
1723 let _ = fs::remove_file(&path);
1724 }
1725
1726 fn build_and_reload_graph(model: &OnnxModel) -> OnnxGraph {
1732 let bytes = model.to_bytes();
1733 load_onnx_graph_from_bytes(&bytes).unwrap()
1734 }
1735
1736 #[test]
1737 fn test_graph_add_two_inputs() {
1738 let mut model = OnnxModel::new("add_graph");
1740 model.inputs.push(("A".into(), vec![2, 2], ONNX_FLOAT));
1741 model.inputs.push(("B".into(), vec![2, 2], ONNX_FLOAT));
1742 model.outputs.push(("Y".into(), vec![2, 2], ONNX_FLOAT));
1743 model.nodes.push(OnnxNode {
1744 inputs: vec!["A".into(), "B".into()],
1745 outputs: vec!["Y".into()],
1746 op_type: "Add".into(),
1747 name: "add_0".into(),
1748 attributes: HashMap::new(),
1749 });
1750
1751 let graph = build_and_reload_graph(&model);
1752 assert_eq!(graph.nodes.len(), 1);
1753 assert_eq!(graph.nodes[0].op_type, "Add");
1754 assert_eq!(graph.output_names, vec!["Y"]);
1755
1756 let a = T::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2], DType::F32, &DEV).unwrap();
1757 let b = T::from_f64_slice(&[10.0, 20.0, 30.0, 40.0], vec![2, 2], DType::F32, &DEV).unwrap();
1758
1759 let mut inputs = HashMap::new();
1760 inputs.insert("A".into(), a);
1761 inputs.insert("B".into(), b);
1762
1763 let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1764 let y = outputs.get("Y").unwrap();
1765 let data = y.to_f64_vec().unwrap();
1766 assert_eq!(data, vec![11.0, 22.0, 33.0, 44.0]);
1767 }
1768
1769 #[test]
1770 fn test_graph_linear_relu() {
1771 let mut model = OnnxModel::new("linear_relu");
1776 model.inputs.push(("X".into(), vec![1, 2], ONNX_FLOAT));
1777 model.outputs.push(("Z".into(), vec![1, 3], ONNX_FLOAT));
1778
1779 let mut w = OnnxTensor::new("W");
1781 w.data_type = ONNX_FLOAT;
1782 w.dims = vec![2, 3];
1783 w.float_data = vec![1.0, -1.0, 0.5, 0.0, 2.0, -0.5];
1784 model.initializers.push(w);
1785
1786 let mut b = OnnxTensor::new("B");
1787 b.data_type = ONNX_FLOAT;
1788 b.dims = vec![3];
1789 b.float_data = vec![0.0, 0.0, 0.0];
1790 model.initializers.push(b);
1791
1792 model.nodes.push(OnnxNode {
1793 inputs: vec!["X".into(), "W".into()],
1794 outputs: vec!["matmul_out".into()],
1795 op_type: "MatMul".into(),
1796 name: "matmul_0".into(),
1797 attributes: HashMap::new(),
1798 });
1799 model.nodes.push(OnnxNode {
1800 inputs: vec!["matmul_out".into(), "B".into()],
1801 outputs: vec!["add_out".into()],
1802 op_type: "Add".into(),
1803 name: "add_0".into(),
1804 attributes: HashMap::new(),
1805 });
1806 model.nodes.push(OnnxNode {
1807 inputs: vec!["add_out".into()],
1808 outputs: vec!["Z".into()],
1809 op_type: "Relu".into(),
1810 name: "relu_0".into(),
1811 attributes: HashMap::new(),
1812 });
1813
1814 let graph = build_and_reload_graph(&model);
1815 assert_eq!(graph.nodes.len(), 3);
1816
1817 let x = T::from_f64_slice(&[1.0, -1.0], vec![1, 2], DType::F32, &DEV).unwrap();
1822 let mut inputs = HashMap::new();
1823 inputs.insert("X".into(), x);
1824
1825 let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1826 let z = outputs.get("Z").unwrap();
1827 let data = z.to_f64_vec().unwrap();
1828 assert_eq!(data.len(), 3);
1829 assert!((data[0] - 1.0).abs() < 1e-5);
1830 assert!((data[1] - 0.0).abs() < 1e-5);
1831 assert!((data[2] - 1.0).abs() < 1e-5);
1832 }
1833
1834 #[test]
1835 fn test_graph_identity_and_dropout() {
1836 let mut model = OnnxModel::new("id_drop");
1837 model.inputs.push(("X".into(), vec![3], ONNX_FLOAT));
1838 model.outputs.push(("Y".into(), vec![3], ONNX_FLOAT));
1839
1840 model.nodes.push(OnnxNode {
1841 inputs: vec!["X".into()],
1842 outputs: vec!["id_out".into()],
1843 op_type: "Identity".into(),
1844 name: "id_0".into(),
1845 attributes: HashMap::new(),
1846 });
1847 model.nodes.push(OnnxNode {
1848 inputs: vec!["id_out".into()],
1849 outputs: vec!["Y".into()],
1850 op_type: "Dropout".into(),
1851 name: "drop_0".into(),
1852 attributes: HashMap::new(),
1853 });
1854
1855 let graph = build_and_reload_graph(&model);
1856 let x = T::from_f64_slice(&[5.0, -3.0, 7.0], vec![3], DType::F32, &DEV).unwrap();
1857 let mut inputs = HashMap::new();
1858 inputs.insert("X".into(), x.clone());
1859
1860 let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1861 let y = outputs.get("Y").unwrap();
1862 assert_eq!(y.to_f64_vec().unwrap(), x.to_f64_vec().unwrap());
1863 }
1864
1865 #[test]
1866 fn test_graph_file_roundtrip() {
1867 let mut model = OnnxModel::new("file_rt");
1869 model.inputs.push(("X".into(), vec![2], ONNX_FLOAT));
1870 model.outputs.push(("Y".into(), vec![2], ONNX_FLOAT));
1871
1872 model.nodes.push(OnnxNode {
1873 inputs: vec!["X".into()],
1874 outputs: vec!["Y".into()],
1875 op_type: "Sigmoid".into(),
1876 name: "sig_0".into(),
1877 attributes: HashMap::new(),
1878 });
1879
1880 let path = std::env::temp_dir().join("shrew_test_graph_rt.onnx");
1881 model.save(&path).unwrap();
1882
1883 let graph = load_onnx_graph(&path).unwrap();
1884 assert_eq!(graph.nodes.len(), 1);
1885
1886 let x = T::from_f64_slice(&[0.0, 1000.0], vec![2], DType::F32, &DEV).unwrap();
1887 let mut inputs = HashMap::new();
1888 inputs.insert("X".into(), x);
1889
1890 let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1891 let data = outputs.get("Y").unwrap().to_f64_vec().unwrap();
1892 assert!((data[0] - 0.5).abs() < 1e-5); assert!((data[1] - 1.0).abs() < 1e-3); let _ = fs::remove_file(&path);
1896 }
1897
1898 #[test]
1899 fn test_decode_attribute_roundtrip() {
1900 let attr = OnnxAttribute::Int(42);
1902 let encoded = encode_attribute("axis", &attr);
1903 let (name, decoded) = decode_attribute(&encoded.buf).unwrap();
1904 assert_eq!(name, "axis");
1905 match decoded {
1906 OnnxAttribute::Int(v) => assert_eq!(v, 42),
1907 _ => panic!("expected Int"),
1908 }
1909 }
1910}