shrew/
onnx.rs

1// =============================================================================
2// ONNX — Import / Export for interoperability
3// =============================================================================
4//
5// ONNX (Open Neural Network Exchange) is the industry standard for
6// exchanging trained models between frameworks (PyTorch, TensorFlow,
7// CoreML, TensorRT, etc.).
8//
9// This module provides:
10//
11//   - Export: convert a Shrew module's state_dict to ONNX format
12//   - Import: load an ONNX model's weights into Shrew tensors
13//
14// ONNX files use Protocol Buffers encoding. We implement a minimal
15// protobuf encoder/decoder (no external crate needed) that handles
16// the subset of the ONNX spec we need: ModelProto, GraphProto,
17// TensorProto, and NodeProto.
18//
19// SUPPORTED ONNX OPS (for graph export):
20//   MatMul, Add, Relu, Sigmoid, Tanh, Softmax, Gemm, Reshape, Transpose,
21//   Conv, BatchNormalization, Dropout, Concat, Flatten
22//
23// REFERENCE:
24//   https://onnx.ai/onnx/repo-docs/IR.html
25//   https://protobuf.dev/programming-guides/encoding/
26
27use std::collections::HashMap;
28use std::fs;
29use std::path::Path;
30
31use shrew_core::backend::Backend;
32use shrew_core::dtype::DType;
33use shrew_core::error::Result;
34use shrew_core::tensor::Tensor;
35
36use shrew_nn::Module;
37
38// =============================================================================
39// ONNX constants
40// =============================================================================
41
42/// ONNX IR version (we target ONNX IR version 9 / opset 17).
43const ONNX_IR_VERSION: i64 = 9;
44/// Default opset version.
45const ONNX_OPSET_VERSION: i64 = 17;
46/// Magic bytes + version for ONNX protobuf.
47const ONNX_DOMAIN: &str = "";
48
49// ONNX TensorProto data types
50/// See https://onnx.ai/onnx/repo-docs/IR.html#tensor-data-types
51const ONNX_FLOAT: i32 = 1;
52const ONNX_DOUBLE: i32 = 11;
53const ONNX_FLOAT16: i32 = 10;
54const ONNX_BFLOAT16: i32 = 16;
55const ONNX_INT8: i32 = 3;
56const ONNX_UINT8: i32 = 2;
57const ONNX_INT32: i32 = 6;
58const ONNX_INT64: i32 = 7;
59const ONNX_UINT32: i32 = 12;
60
61// =============================================================================
62// Minimal protobuf encoder
63// =============================================================================
64
65/// A minimal protobuf wire-format encoder. Supports:
66/// - Varint (field type 0)
67/// - Length-delimited (field type 2: bytes, strings, nested messages)
68/// - Fixed32/Fixed64 (field types 5 and 1)
69struct PbEncoder {
70    buf: Vec<u8>,
71}
72
73impl PbEncoder {
74    fn new() -> Self {
75        Self { buf: Vec::new() }
76    }
77
78    fn into_bytes(self) -> Vec<u8> {
79        self.buf
80    }
81
82    /// Write a varint.
83    fn write_varint(&mut self, mut val: u64) {
84        loop {
85            let byte = (val & 0x7F) as u8;
86            val >>= 7;
87            if val == 0 {
88                self.buf.push(byte);
89                break;
90            } else {
91                self.buf.push(byte | 0x80);
92            }
93        }
94    }
95
96    /// Write a field tag (field_number << 3 | wire_type).
97    fn write_tag(&mut self, field: u32, wire_type: u32) {
98        self.write_varint(((field as u64) << 3) | wire_type as u64);
99    }
100
101    /// Write a varint field.
102    fn write_varint_field(&mut self, field: u32, val: u64) {
103        self.write_tag(field, 0);
104        self.write_varint(val);
105    }
106
107    /// Write a signed varint field (zigzag encoding for negative values).
108    fn write_sint64_field(&mut self, field: u32, val: i64) {
109        self.write_varint_field(field, val as u64);
110    }
111
112    /// Write a length-delimited bytes field.
113    fn write_bytes_field(&mut self, field: u32, data: &[u8]) {
114        self.write_tag(field, 2);
115        self.write_varint(data.len() as u64);
116        self.buf.extend_from_slice(data);
117    }
118
119    /// Write a string field.
120    fn write_string_field(&mut self, field: u32, val: &str) {
121        self.write_bytes_field(field, val.as_bytes());
122    }
123
124    /// Write a nested message field.
125    fn write_message_field(&mut self, field: u32, encoder: &PbEncoder) {
126        self.write_bytes_field(field, &encoder.buf);
127    }
128
129    /// Write raw float data as bytes.
130    #[allow(dead_code)]
131    fn write_float_data(&mut self, field: u32, data: &[f32]) {
132        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
133        self.write_bytes_field(field, &bytes);
134    }
135
136    /// Write raw double data as bytes.
137    #[allow(dead_code)]
138    fn write_double_data(&mut self, field: u32, data: &[f64]) {
139        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
140        self.write_bytes_field(field, &bytes);
141    }
142}
143
144// =============================================================================
145// Minimal protobuf decoder
146// =============================================================================
147
148/// A minimal protobuf wire-format decoder.
149struct PbDecoder<'a> {
150    data: &'a [u8],
151    pos: usize,
152}
153
154impl<'a> PbDecoder<'a> {
155    fn new(data: &'a [u8]) -> Self {
156        Self { data, pos: 0 }
157    }
158
159    fn remaining(&self) -> usize {
160        self.data.len() - self.pos
161    }
162
163    fn read_varint(&mut self) -> Result<u64> {
164        let mut result: u64 = 0;
165        let mut shift = 0;
166        loop {
167            if self.pos >= self.data.len() {
168                return Err(shrew_core::Error::msg("protobuf: unexpected end of data"));
169            }
170            let byte = self.data[self.pos];
171            self.pos += 1;
172            result |= ((byte & 0x7F) as u64) << shift;
173            if byte & 0x80 == 0 {
174                break;
175            }
176            shift += 7;
177            if shift > 63 {
178                return Err(shrew_core::Error::msg("protobuf: varint too long"));
179            }
180        }
181        Ok(result)
182    }
183
184    fn read_tag(&mut self) -> Result<(u32, u32)> {
185        let val = self.read_varint()?;
186        let field = (val >> 3) as u32;
187        let wire_type = (val & 0x7) as u32;
188        Ok((field, wire_type))
189    }
190
191    fn read_bytes(&mut self) -> Result<&'a [u8]> {
192        let len = self.read_varint()? as usize;
193        if self.pos + len > self.data.len() {
194            return Err(shrew_core::Error::msg("protobuf: bytes field exceeds data"));
195        }
196        let result = &self.data[self.pos..self.pos + len];
197        self.pos += len;
198        Ok(result)
199    }
200
201    fn read_string(&mut self) -> Result<String> {
202        let bytes = self.read_bytes()?;
203        String::from_utf8(bytes.to_vec())
204            .map_err(|_| shrew_core::Error::msg("protobuf: invalid UTF-8 string"))
205    }
206
207    fn skip_field(&mut self, wire_type: u32) -> Result<()> {
208        match wire_type {
209            0 => {
210                self.read_varint()?;
211            }
212            1 => {
213                self.pos += 8;
214            } // fixed64
215            2 => {
216                self.read_bytes()?;
217            }
218            5 => {
219                self.pos += 4;
220            } // fixed32
221            _ => {
222                return Err(shrew_core::Error::msg(format!(
223                    "protobuf: unsupported wire type {wire_type}"
224                )))
225            }
226        }
227        Ok(())
228    }
229}
230
231// =============================================================================
232// ONNX TensorProto
233// =============================================================================
234
235/// Represents an ONNX TensorProto (a named tensor with shape and data).
236#[derive(Debug, Clone)]
237pub struct OnnxTensor {
238    /// Tensor name.
239    pub name: String,
240    /// ONNX data type (ONNX_FLOAT, ONNX_DOUBLE, etc.).
241    pub data_type: i32,
242    /// Shape dimensions.
243    pub dims: Vec<i64>,
244    /// Raw float data (for FLOAT type).
245    pub float_data: Vec<f32>,
246    /// Raw double data (for DOUBLE type).
247    pub double_data: Vec<f64>,
248    /// Raw bytes (for packed formats like FLOAT16).
249    pub raw_data: Vec<u8>,
250}
251
252impl OnnxTensor {
253    fn new(name: &str) -> Self {
254        Self {
255            name: name.to_string(),
256            data_type: ONNX_FLOAT,
257            dims: Vec::new(),
258            float_data: Vec::new(),
259            double_data: Vec::new(),
260            raw_data: Vec::new(),
261        }
262    }
263
264    /// Convert to protobuf bytes.
265    fn encode(&self) -> Vec<u8> {
266        let mut enc = PbEncoder::new();
267        // field 1: dims (repeated int64)
268        for &d in &self.dims {
269            enc.write_sint64_field(1, d);
270        }
271        // field 2: data_type (int32)
272        enc.write_varint_field(2, self.data_type as u64);
273        // field 8: name (string)
274        if !self.name.is_empty() {
275            enc.write_string_field(8, &self.name);
276        }
277        // field 4: float_data (packed repeated float — as raw_data for efficiency)
278        if !self.float_data.is_empty() {
279            // field 13: raw_data (bytes) — more efficient than repeated float
280            let bytes: Vec<u8> = self
281                .float_data
282                .iter()
283                .flat_map(|v| v.to_le_bytes())
284                .collect();
285            enc.write_bytes_field(13, &bytes);
286        } else if !self.double_data.is_empty() {
287            let bytes: Vec<u8> = self
288                .double_data
289                .iter()
290                .flat_map(|v| v.to_le_bytes())
291                .collect();
292            enc.write_bytes_field(13, &bytes);
293        } else if !self.raw_data.is_empty() {
294            enc.write_bytes_field(13, &self.raw_data);
295        }
296        enc.into_bytes()
297    }
298
299    /// Decode from protobuf bytes.
300    fn decode(data: &[u8]) -> Result<Self> {
301        let mut dec = PbDecoder::new(data);
302        let mut tensor = OnnxTensor::new("");
303        while dec.remaining() > 0 {
304            let (field, wire_type) = dec.read_tag()?;
305            match (field, wire_type) {
306                (1, 0) => {
307                    // dims (varint)
308                    let v = dec.read_varint()? as i64;
309                    tensor.dims.push(v);
310                }
311                (1, 2) => {
312                    // dims (packed)
313                    let bytes = dec.read_bytes()?;
314                    let mut sub = PbDecoder::new(bytes);
315                    while sub.remaining() > 0 {
316                        tensor.dims.push(sub.read_varint()? as i64);
317                    }
318                }
319                (2, 0) => {
320                    // data_type
321                    tensor.data_type = dec.read_varint()? as i32;
322                }
323                (8, 2) => {
324                    // name
325                    tensor.name = dec.read_string()?;
326                }
327                (13, 2) => {
328                    // raw_data
329                    tensor.raw_data = dec.read_bytes()?.to_vec();
330                }
331                (4, 2) => {
332                    // float_data (packed)
333                    let bytes = dec.read_bytes()?;
334                    for chunk in bytes.chunks_exact(4) {
335                        let val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
336                        tensor.float_data.push(val);
337                    }
338                }
339                (4, 5) => {
340                    // float_data (repeated fixed32)
341                    let bytes = &dec.data[dec.pos..dec.pos + 4];
342                    dec.pos += 4;
343                    let val = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
344                    tensor.float_data.push(val);
345                }
346                (5, 2) => {
347                    // double_data (packed)
348                    let bytes = dec.read_bytes()?;
349                    for chunk in bytes.chunks_exact(8) {
350                        let val = f64::from_le_bytes([
351                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
352                            chunk[7],
353                        ]);
354                        tensor.double_data.push(val);
355                    }
356                }
357                _ => {
358                    dec.skip_field(wire_type)?;
359                }
360            }
361        }
362        Ok(tensor)
363    }
364
365    /// Get the float data (converting from raw_data if needed).
366    fn to_f64_vec(&self) -> Vec<f64> {
367        if !self.double_data.is_empty() {
368            return self.double_data.clone();
369        }
370        if !self.float_data.is_empty() {
371            return self.float_data.iter().map(|&v| v as f64).collect();
372        }
373        if !self.raw_data.is_empty() {
374            match self.data_type {
375                ONNX_FLOAT => self
376                    .raw_data
377                    .chunks_exact(4)
378                    .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f64)
379                    .collect(),
380                ONNX_DOUBLE => self
381                    .raw_data
382                    .chunks_exact(8)
383                    .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
384                    .collect(),
385                ONNX_FLOAT16 => self
386                    .raw_data
387                    .chunks_exact(2)
388                    .map(|c| {
389                        let bits = u16::from_le_bytes([c[0], c[1]]);
390                        half::f16::from_bits(bits).to_f64()
391                    })
392                    .collect(),
393                _ => Vec::new(),
394            }
395        } else {
396            Vec::new()
397        }
398    }
399}
400
401/// Convert a Shrew DType to ONNX data type integer.
402fn dtype_to_onnx(dtype: DType) -> i32 {
403    match dtype {
404        DType::F32 => ONNX_FLOAT,
405        DType::F64 => ONNX_DOUBLE,
406        DType::F16 => ONNX_FLOAT16,
407        DType::BF16 => ONNX_BFLOAT16,
408        DType::U8 => ONNX_UINT8,
409        DType::U32 => ONNX_UINT32,
410        DType::I64 => ONNX_INT64,
411    }
412}
413
414/// Convert an ONNX data type integer to Shrew DType.
415fn onnx_to_dtype(onnx_type: i32) -> Result<DType> {
416    match onnx_type {
417        ONNX_FLOAT => Ok(DType::F32),
418        ONNX_DOUBLE => Ok(DType::F64),
419        ONNX_FLOAT16 => Ok(DType::F16),
420        ONNX_BFLOAT16 => Ok(DType::BF16),
421        ONNX_UINT8 => Ok(DType::U8),
422        ONNX_UINT32 => Ok(DType::U32),
423        ONNX_INT64 => Ok(DType::I64),
424        ONNX_INT8 => Ok(DType::U8),   // map to u8
425        ONNX_INT32 => Ok(DType::I64), // upcast
426        _ => Err(shrew_core::Error::msg(format!(
427            "unsupported ONNX data type: {onnx_type}"
428        ))),
429    }
430}
431
432// =============================================================================
433// ONNX NodeProto (graph operation)
434// =============================================================================
435
436/// An ONNX graph node (operation).
437#[derive(Debug, Clone)]
438pub struct OnnxNode {
439    /// Input tensor names.
440    pub inputs: Vec<String>,
441    /// Output tensor names.
442    pub outputs: Vec<String>,
443    /// Operation type (e.g., "MatMul", "Relu", "Add").
444    pub op_type: String,
445    /// Node name (for debugging).
446    pub name: String,
447    /// String attributes (key → value).
448    pub attributes: HashMap<String, OnnxAttribute>,
449}
450
451/// An ONNX attribute value.
452#[derive(Debug, Clone)]
453pub enum OnnxAttribute {
454    Int(i64),
455    Float(f32),
456    String(String),
457    Ints(Vec<i64>),
458    Floats(Vec<f32>),
459}
460
461impl OnnxNode {
462    fn encode(&self) -> Vec<u8> {
463        let mut enc = PbEncoder::new();
464        // field 1: inputs (repeated string)
465        for input in &self.inputs {
466            enc.write_string_field(1, input);
467        }
468        // field 2: outputs (repeated string)
469        for output in &self.outputs {
470            enc.write_string_field(2, output);
471        }
472        // field 3: name (string)
473        if !self.name.is_empty() {
474            enc.write_string_field(3, &self.name);
475        }
476        // field 4: op_type (string)
477        enc.write_string_field(4, &self.op_type);
478        // field 5: attributes (repeated AttributeProto)
479        for (key, val) in &self.attributes {
480            let attr = encode_attribute(key, val);
481            enc.write_message_field(5, &attr);
482        }
483        enc.into_bytes()
484    }
485}
486
487fn encode_attribute(name: &str, val: &OnnxAttribute) -> PbEncoder {
488    let mut enc = PbEncoder::new();
489    enc.write_string_field(1, name); // field 1: name
490    match val {
491        OnnxAttribute::Int(i) => {
492            enc.write_varint_field(2, 2); // type = INT
493            enc.write_sint64_field(3, *i); // field 3: i
494        }
495        OnnxAttribute::Float(f) => {
496            enc.write_varint_field(2, 1); // type = FLOAT
497                                          // field 4: f (float, fixed32)
498            enc.write_tag(4, 5);
499            enc.buf.extend_from_slice(&f.to_le_bytes());
500        }
501        OnnxAttribute::String(s) => {
502            enc.write_varint_field(2, 3); // type = STRING
503            enc.write_bytes_field(5, s.as_bytes()); // field 5: s
504        }
505        OnnxAttribute::Ints(ints) => {
506            enc.write_varint_field(2, 7); // type = INTS
507            for &i in ints {
508                enc.write_sint64_field(8, i); // field 8: ints
509            }
510        }
511        OnnxAttribute::Floats(floats) => {
512            enc.write_varint_field(2, 6); // type = FLOATS
513            for &f in floats {
514                enc.write_tag(7, 5); // field 7: floats (fixed32)
515                enc.buf.extend_from_slice(&f.to_le_bytes());
516            }
517        }
518    }
519    enc
520}
521
522// =============================================================================
523// ONNX ModelProto — top-level export
524// =============================================================================
525
526/// An ONNX model with graph, metadata, and opset information.
527#[derive(Debug, Clone)]
528pub struct OnnxModel {
529    /// Model producer name.
530    pub producer_name: String,
531    /// Model producer version.
532    pub producer_version: String,
533    /// Graph name.
534    pub graph_name: String,
535    /// Graph nodes (operations).
536    pub nodes: Vec<OnnxNode>,
537    /// Initializer tensors (weights).
538    pub initializers: Vec<OnnxTensor>,
539    /// Graph inputs (names and shapes).
540    pub inputs: Vec<(String, Vec<i64>, i32)>,
541    /// Graph outputs (names and shapes).
542    pub outputs: Vec<(String, Vec<i64>, i32)>,
543}
544
545impl OnnxModel {
546    /// Create a new empty ONNX model.
547    pub fn new(graph_name: &str) -> Self {
548        Self {
549            producer_name: "Shrew".to_string(),
550            producer_version: "0.1.0".to_string(),
551            graph_name: graph_name.to_string(),
552            nodes: Vec::new(),
553            initializers: Vec::new(),
554            inputs: Vec::new(),
555            outputs: Vec::new(),
556        }
557    }
558
559    /// Encode to ONNX protobuf binary format.
560    pub fn to_bytes(&self) -> Vec<u8> {
561        // Build GraphProto
562        let mut graph = PbEncoder::new();
563
564        // field 1: nodes (repeated NodeProto)
565        for node in &self.nodes {
566            let node_bytes = node.encode();
567            graph.write_bytes_field(1, &node_bytes);
568        }
569
570        // field 2: name
571        graph.write_string_field(2, &self.graph_name);
572
573        // field 5: initializers (repeated TensorProto — the weights)
574        for init in &self.initializers {
575            let tensor_bytes = init.encode();
576            graph.write_bytes_field(5, &tensor_bytes);
577        }
578
579        // field 11: inputs (repeated ValueInfoProto)
580        for (name, dims, dtype) in &self.inputs {
581            let vi = encode_value_info(name, dims, *dtype);
582            graph.write_message_field(11, &vi);
583        }
584
585        // field 12: outputs (repeated ValueInfoProto)
586        for (name, dims, dtype) in &self.outputs {
587            let vi = encode_value_info(name, dims, *dtype);
588            graph.write_message_field(12, &vi);
589        }
590
591        // Build ModelProto
592        let mut model = PbEncoder::new();
593        // field 1: ir_version (int64)
594        model.write_varint_field(1, ONNX_IR_VERSION as u64);
595        // field 2: producer_name
596        model.write_string_field(2, &self.producer_name);
597        // field 3: producer_version
598        model.write_string_field(3, &self.producer_version);
599        // field 7: graph (GraphProto)
600        model.write_message_field(7, &graph);
601        // field 8: opset_import (OperatorSetIdProto)
602        let mut opset = PbEncoder::new();
603        opset.write_string_field(1, ONNX_DOMAIN); // domain
604        opset.write_varint_field(2, ONNX_OPSET_VERSION as u64); // version
605        model.write_message_field(8, &opset);
606
607        model.into_bytes()
608    }
609
610    /// Save ONNX model to a file.
611    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
612        let bytes = self.to_bytes();
613        fs::write(path.as_ref(), &bytes)
614            .map_err(|e| shrew_core::Error::msg(format!("failed to write ONNX file: {e}")))
615    }
616}
617
618/// Encode a ValueInfoProto (input/output description).
619fn encode_value_info(name: &str, dims: &[i64], data_type: i32) -> PbEncoder {
620    let mut vi = PbEncoder::new();
621    vi.write_string_field(1, name); // field 1: name
622
623    // field 2: type (TypeProto)
624    let mut type_proto = PbEncoder::new();
625    // field 1: tensor_type (Tensor_TypeProto)
626    let mut tensor_type = PbEncoder::new();
627    tensor_type.write_varint_field(1, data_type as u64); // elem_type
628                                                         // field 2: shape (TensorShapeProto)
629    let mut shape = PbEncoder::new();
630    for &d in dims {
631        let mut dim = PbEncoder::new();
632        if d >= 0 {
633            dim.write_sint64_field(1, d); // dim_value
634        } else {
635            dim.write_string_field(2, "dynamic"); // dim_param (symbolic)
636        }
637        shape.write_message_field(1, &dim);
638    }
639    tensor_type.write_message_field(2, &shape);
640    type_proto.write_message_field(1, &tensor_type);
641    vi.write_message_field(2, &type_proto);
642
643    vi
644}
645
646// =============================================================================
647// Export API
648// =============================================================================
649
650/// Export a module's weights as an ONNX model file.
651///
652/// This creates a "weight-only" ONNX model: the initializer tensors contain
653/// the model's learned parameters, and the graph describes a simple
654/// sequential pass from input to output.
655///
656/// # Arguments
657/// - `path`: output file path (typically `.onnx`)
658/// - `module`: the trained module to export
659/// - `model_name`: name for the ONNX graph
660/// - `input_shape`: shape of the model's input tensor
661///
662/// # Example
663/// ```ignore
664/// let model = Linear::new(784, 10, true, DType::F32, &dev)?;
665/// export_weights("model.onnx", &model, "classifier", &[1, 784])?;
666/// ```
667pub fn export_weights<P, B, M>(
668    path: P,
669    module: &M,
670    model_name: &str,
671    input_shape: &[i64],
672) -> Result<()>
673where
674    P: AsRef<Path>,
675    B: Backend,
676    M: Module<B>,
677{
678    let named = module.named_parameters();
679
680    let mut model = OnnxModel::new(model_name);
681
682    // Add input
683    model
684        .inputs
685        .push(("input".to_string(), input_shape.to_vec(), ONNX_FLOAT));
686
687    // Add each parameter as an initializer
688    for (name, tensor) in &named {
689        let data = tensor.to_f64_vec()?;
690        let dims: Vec<i64> = tensor.dims().iter().map(|&d| d as i64).collect();
691
692        let mut onnx_tensor = OnnxTensor::new(name);
693        onnx_tensor.data_type = dtype_to_onnx(tensor.dtype());
694        onnx_tensor.dims = dims;
695
696        match tensor.dtype() {
697            DType::F32 => {
698                onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
699            }
700            DType::F64 => {
701                onnx_tensor.double_data = data;
702            }
703            _ => {
704                // Store as F32 for compatibility
705                onnx_tensor.data_type = ONNX_FLOAT;
706                onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
707            }
708        }
709
710        model.initializers.push(onnx_tensor);
711    }
712
713    // Add a simple identity graph: input → output through the params
714    // (The actual computation graph would require op-level tracking)
715    model.outputs.push((
716        "output".to_string(),
717        vec![-1], // dynamic output shape
718        ONNX_FLOAT,
719    ));
720
721    model.save(path)
722}
723
724/// Export named tensors directly to ONNX format.
725///
726/// Lower-level API: saves a set of named tensors as ONNX initializers.
727pub fn export_tensors<P, B>(
728    path: P,
729    tensors: &[(String, Tensor<B>)],
730    model_name: &str,
731) -> Result<()>
732where
733    P: AsRef<Path>,
734    B: Backend,
735{
736    let mut model = OnnxModel::new(model_name);
737
738    for (name, tensor) in tensors {
739        let data = tensor.to_f64_vec()?;
740        let dims: Vec<i64> = tensor.dims().iter().map(|&d| d as i64).collect();
741
742        let mut onnx_tensor = OnnxTensor::new(name);
743        onnx_tensor.data_type = dtype_to_onnx(tensor.dtype());
744        onnx_tensor.dims = dims;
745
746        match tensor.dtype() {
747            DType::F32 => {
748                onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
749            }
750            DType::F64 => {
751                onnx_tensor.double_data = data;
752            }
753            _ => {
754                onnx_tensor.data_type = ONNX_FLOAT;
755                onnx_tensor.float_data = data.iter().map(|&v| v as f32).collect();
756            }
757        }
758
759        model.initializers.push(onnx_tensor);
760    }
761
762    model.save(path)
763}
764
765// =============================================================================
766// Import API
767// =============================================================================
768
769/// Load tensor weights from an ONNX model file.
770///
771/// Returns a map of tensor name → Tensor for all initializers found.
772///
773/// # Example
774/// ```ignore
775/// let weights = load_onnx_weights::<CpuBackend>("model.onnx", &CpuDevice)?;
776/// for (name, tensor) in &weights {
777///     println!("{}: {:?}", name, tensor.dims());
778/// }
779/// ```
780pub fn load_onnx_weights<B: Backend>(
781    path: impl AsRef<Path>,
782    device: &B::Device,
783) -> Result<HashMap<String, Tensor<B>>> {
784    let bytes = fs::read(path.as_ref())
785        .map_err(|e| shrew_core::Error::msg(format!("failed to read ONNX file: {e}")))?;
786
787    load_onnx_weights_from_bytes::<B>(&bytes, device)
788}
789
790/// Load tensor weights from ONNX bytes (in-memory).
791pub fn load_onnx_weights_from_bytes<B: Backend>(
792    data: &[u8],
793    device: &B::Device,
794) -> Result<HashMap<String, Tensor<B>>> {
795    let mut dec = PbDecoder::new(data);
796    let mut result = HashMap::new();
797
798    // Parse ModelProto
799    while dec.remaining() > 0 {
800        let (field, wire_type) = dec.read_tag()?;
801        match (field, wire_type) {
802            (7, 2) => {
803                // GraphProto
804                let graph_bytes = dec.read_bytes()?;
805                let tensors = parse_graph_initializers::<B>(graph_bytes, device)?;
806                result.extend(tensors);
807            }
808            _ => {
809                dec.skip_field(wire_type)?;
810            }
811        }
812    }
813
814    Ok(result)
815}
816
817/// Parse graph initializer tensors from a GraphProto.
818fn parse_graph_initializers<B: Backend>(
819    data: &[u8],
820    device: &B::Device,
821) -> Result<HashMap<String, Tensor<B>>> {
822    let mut dec = PbDecoder::new(data);
823    let mut result = HashMap::new();
824
825    while dec.remaining() > 0 {
826        let (field, wire_type) = dec.read_tag()?;
827        match (field, wire_type) {
828            (5, 2) => {
829                // Initializer (TensorProto)
830                let tensor_bytes = dec.read_bytes()?;
831                let onnx_tensor = OnnxTensor::decode(tensor_bytes)?;
832
833                if !onnx_tensor.name.is_empty() {
834                    let dtype = onnx_to_dtype(onnx_tensor.data_type)?;
835                    let shape: Vec<usize> = onnx_tensor.dims.iter().map(|&d| d as usize).collect();
836                    let f64_data = onnx_tensor.to_f64_vec();
837
838                    if !f64_data.is_empty() {
839                        let tensor = Tensor::<B>::from_f64_slice(&f64_data, shape, dtype, device)?;
840                        result.insert(onnx_tensor.name.clone(), tensor);
841                    }
842                }
843            }
844            _ => {
845                dec.skip_field(wire_type)?;
846            }
847        }
848    }
849
850    Ok(result)
851}
852
853// =============================================================================
854// Graph Import — Full ONNX graph parsing
855// =============================================================================
856
857/// A fully parsed ONNX graph: nodes + initializers + I/O metadata.
858#[derive(Debug, Clone)]
859pub struct OnnxGraph {
860    /// Computation nodes in topological order.
861    pub nodes: Vec<OnnxNode>,
862    /// Initializer tensors (weights / constants).
863    pub initializer_protos: Vec<OnnxTensor>,
864    /// Graph input names (including initializer names).
865    pub input_names: Vec<String>,
866    /// Graph output names.
867    pub output_names: Vec<String>,
868    /// Graph name.
869    pub name: String,
870}
871
872/// Decode an OnnxNode from protobuf bytes.
873fn decode_node(data: &[u8]) -> Result<OnnxNode> {
874    let mut dec = PbDecoder::new(data);
875    let mut node = OnnxNode {
876        inputs: Vec::new(),
877        outputs: Vec::new(),
878        op_type: String::new(),
879        name: String::new(),
880        attributes: HashMap::new(),
881    };
882    while dec.remaining() > 0 {
883        let (field, wire_type) = dec.read_tag()?;
884        match (field, wire_type) {
885            (1, 2) => node.inputs.push(dec.read_string()?),
886            (2, 2) => node.outputs.push(dec.read_string()?),
887            (3, 2) => node.name = dec.read_string()?,
888            (4, 2) => node.op_type = dec.read_string()?,
889            (5, 2) => {
890                let attr_bytes = dec.read_bytes()?;
891                let (key, val) = decode_attribute(attr_bytes)?;
892                node.attributes.insert(key, val);
893            }
894            _ => dec.skip_field(wire_type)?,
895        }
896    }
897    Ok(node)
898}
899
900/// Decode an OnnxAttribute from protobuf bytes.
901fn decode_attribute(data: &[u8]) -> Result<(String, OnnxAttribute)> {
902    let mut dec = PbDecoder::new(data);
903    let mut name = String::new();
904    let mut attr_type: u64 = 0;
905    let mut int_val: i64 = 0;
906    let mut float_val: f32 = 0.0;
907    let mut string_val = Vec::new();
908    let mut ints_val: Vec<i64> = Vec::new();
909    let mut floats_val: Vec<f32> = Vec::new();
910    while dec.remaining() > 0 {
911        let (field, wire_type) = dec.read_tag()?;
912        match (field, wire_type) {
913            (1, 2) => name = dec.read_string()?,           // name
914            (2, 0) => attr_type = dec.read_varint()?,      // type
915            (3, 0) => int_val = dec.read_varint()? as i64, // i
916            (4, 5) => {
917                // f (fixed32)
918                if dec.pos + 4 > dec.data.len() {
919                    return Err(shrew_core::Error::msg("attribute: unexpected end"));
920                }
921                let b = &dec.data[dec.pos..dec.pos + 4];
922                float_val = f32::from_le_bytes([b[0], b[1], b[2], b[3]]);
923                dec.pos += 4;
924            }
925            (5, 2) => string_val = dec.read_bytes()?.to_vec(), // s
926            (7, 5) => {
927                // floats (repeated fixed32)
928                if dec.pos + 4 > dec.data.len() {
929                    return Err(shrew_core::Error::msg("attribute: unexpected end"));
930                }
931                let b = &dec.data[dec.pos..dec.pos + 4];
932                floats_val.push(f32::from_le_bytes([b[0], b[1], b[2], b[3]]));
933                dec.pos += 4;
934            }
935            (7, 2) => {
936                // floats (packed)
937                let bytes = dec.read_bytes()?;
938                for c in bytes.chunks_exact(4) {
939                    floats_val.push(f32::from_le_bytes([c[0], c[1], c[2], c[3]]));
940                }
941            }
942            (8, 0) => ints_val.push(dec.read_varint()? as i64), // ints (repeated varint)
943            (8, 2) => {
944                // ints (packed)
945                let bytes = dec.read_bytes()?;
946                let mut sub = PbDecoder::new(bytes);
947                while sub.remaining() > 0 {
948                    ints_val.push(sub.read_varint()? as i64);
949                }
950            }
951            _ => dec.skip_field(wire_type)?,
952        }
953    }
954    let val = match attr_type {
955        1 => OnnxAttribute::Float(float_val),
956        2 => OnnxAttribute::Int(int_val),
957        3 => OnnxAttribute::String(String::from_utf8(string_val).unwrap_or_default()),
958        6 => OnnxAttribute::Floats(floats_val),
959        7 => OnnxAttribute::Ints(ints_val),
960        _ => OnnxAttribute::Int(int_val), // fallback
961    };
962    Ok((name, val))
963}
964
965/// Parse a full GraphProto: nodes, initializers, inputs, outputs.
966fn parse_graph_proto(data: &[u8]) -> Result<OnnxGraph> {
967    let mut dec = PbDecoder::new(data);
968    let mut graph = OnnxGraph {
969        nodes: Vec::new(),
970        initializer_protos: Vec::new(),
971        input_names: Vec::new(),
972        output_names: Vec::new(),
973        name: String::new(),
974    };
975    while dec.remaining() > 0 {
976        let (field, wire_type) = dec.read_tag()?;
977        match (field, wire_type) {
978            (1, 2) => {
979                let node_bytes = dec.read_bytes()?;
980                graph.nodes.push(decode_node(node_bytes)?);
981            }
982            (2, 2) => graph.name = dec.read_string()?,
983            (5, 2) => {
984                let tensor_bytes = dec.read_bytes()?;
985                graph
986                    .initializer_protos
987                    .push(OnnxTensor::decode(tensor_bytes)?);
988            }
989            (11, 2) => {
990                // input ValueInfoProto — extract name (field 1)
991                let vi_bytes = dec.read_bytes()?;
992                let name = extract_value_info_name(vi_bytes)?;
993                graph.input_names.push(name);
994            }
995            (12, 2) => {
996                // output ValueInfoProto
997                let vi_bytes = dec.read_bytes()?;
998                let name = extract_value_info_name(vi_bytes)?;
999                graph.output_names.push(name);
1000            }
1001            _ => dec.skip_field(wire_type)?,
1002        }
1003    }
1004    Ok(graph)
1005}
1006
1007/// Extract just the name from a ValueInfoProto.
1008fn extract_value_info_name(data: &[u8]) -> Result<String> {
1009    let mut dec = PbDecoder::new(data);
1010    while dec.remaining() > 0 {
1011        let (field, wire_type) = dec.read_tag()?;
1012        if field == 1 && wire_type == 2 {
1013            return dec.read_string();
1014        }
1015        dec.skip_field(wire_type)?;
1016    }
1017    Ok(String::new())
1018}
1019
1020/// Load a full ONNX graph (nodes + initializers) from a file.
1021pub fn load_onnx_graph(path: impl AsRef<Path>) -> Result<OnnxGraph> {
1022    let bytes = fs::read(path.as_ref())
1023        .map_err(|e| shrew_core::Error::msg(format!("failed to read ONNX file: {e}")))?;
1024    load_onnx_graph_from_bytes(&bytes)
1025}
1026
1027/// Load a full ONNX graph from in-memory bytes.
1028pub fn load_onnx_graph_from_bytes(data: &[u8]) -> Result<OnnxGraph> {
1029    let mut dec = PbDecoder::new(data);
1030    while dec.remaining() > 0 {
1031        let (field, wire_type) = dec.read_tag()?;
1032        if field == 7 && wire_type == 2 {
1033            let graph_bytes = dec.read_bytes()?;
1034            return parse_graph_proto(graph_bytes);
1035        }
1036        dec.skip_field(wire_type)?;
1037    }
1038    Err(shrew_core::Error::msg("ONNX file contains no graph"))
1039}
1040
1041// =============================================================================
1042// Graph Execution — Run an ONNX graph with Shrew tensors
1043// =============================================================================
1044
1045/// Execute an ONNX graph on the given backend.
1046///
1047/// Takes a parsed `OnnxGraph` and a map of input tensors. Initializer tensors
1048/// from the graph are materialised on the given device. Each node is executed
1049/// in order (the ONNX spec requires nodes in topological order).
1050///
1051/// Returns a map of output-name → Tensor for all graph outputs.
1052///
1053/// # Supported ops
1054///
1055/// `Add`, `Sub`, `Mul`, `Div`, `MatMul`, `Gemm`, `Relu`, `Sigmoid`, `Tanh`,
1056/// `Softmax`, `LogSoftmax`, `Reshape`, `Transpose`, `Flatten`, `Squeeze`,
1057/// `Unsqueeze`, `Concat`, `Identity`, `Neg`, `Sqrt`, `Exp`, `Log`, `Abs`,
1058/// `Clip`, `ReduceMean`, `ReduceSum`, `ReduceMax`, `ReduceMin`, `Gather`,
1059/// `BatchNormalization`, `Dropout`, `Shape`, `Cast`, `Pow`.
1060///
1061/// Unsupported ops produce an error.
1062pub fn run_onnx_graph<B: Backend>(
1063    graph: &OnnxGraph,
1064    inputs: &HashMap<String, Tensor<B>>,
1065    device: &B::Device,
1066) -> Result<HashMap<String, Tensor<B>>> {
1067    let mut env: HashMap<String, Tensor<B>> = HashMap::new();
1068
1069    // 1. Load initializers
1070    for init in &graph.initializer_protos {
1071        if init.name.is_empty() {
1072            continue;
1073        }
1074        let dtype = onnx_to_dtype(init.data_type)?;
1075        let shape: Vec<usize> = init.dims.iter().map(|&d| d as usize).collect();
1076        let f64_data = init.to_f64_vec();
1077        if !f64_data.is_empty() {
1078            let tensor = Tensor::<B>::from_f64_slice(&f64_data, shape, dtype, device)?;
1079            env.insert(init.name.clone(), tensor);
1080        }
1081    }
1082
1083    // 2. Insert user-provided inputs (overrides initializers if names clash)
1084    for (name, tensor) in inputs {
1085        env.insert(name.clone(), tensor.clone());
1086    }
1087
1088    // 3. Execute nodes in order
1089    for node in &graph.nodes {
1090        execute_node(node, &mut env, device)?;
1091    }
1092
1093    // 4. Collect outputs
1094    let mut outputs = HashMap::new();
1095    for name in &graph.output_names {
1096        if let Some(t) = env.get(name) {
1097            outputs.insert(name.clone(), t.clone());
1098        }
1099    }
1100    Ok(outputs)
1101}
1102
1103/// Helper: get a tensor from the environment by name.
1104fn get_tensor<'a, B: Backend>(
1105    env: &'a HashMap<String, Tensor<B>>,
1106    name: &str,
1107) -> Result<&'a Tensor<B>> {
1108    env.get(name)
1109        .ok_or_else(|| shrew_core::Error::msg(format!("ONNX runtime: tensor '{name}' not found")))
1110}
1111
1112/// Helper: get an integer attribute with a default.
1113fn attr_i(node: &OnnxNode, key: &str, default: i64) -> i64 {
1114    match node.attributes.get(key) {
1115        Some(OnnxAttribute::Int(v)) => *v,
1116        _ => default,
1117    }
1118}
1119
1120/// Helper: get an integer-list attribute.
1121fn attr_ints(node: &OnnxNode, key: &str) -> Vec<i64> {
1122    match node.attributes.get(key) {
1123        Some(OnnxAttribute::Ints(v)) => v.clone(),
1124        _ => Vec::new(),
1125    }
1126}
1127
1128/// Helper: get a float attribute with default.
1129fn attr_f(node: &OnnxNode, key: &str, default: f32) -> f32 {
1130    match node.attributes.get(key) {
1131        Some(OnnxAttribute::Float(v)) => *v,
1132        _ => default,
1133    }
1134}
1135
1136/// Execute a single ONNX node, inserting results into the environment.
1137fn execute_node<B: Backend>(
1138    node: &OnnxNode,
1139    env: &mut HashMap<String, Tensor<B>>,
1140    device: &B::Device,
1141) -> Result<()> {
1142    match node.op_type.as_str() {
1143        // ── Element-wise binary ──────────────────────────────────────────
1144        "Add" => {
1145            let a = get_tensor(env, &node.inputs[0])?;
1146            let b = get_tensor(env, &node.inputs[1])?;
1147            let out = a.add(b)?;
1148            env.insert(node.outputs[0].clone(), out);
1149        }
1150        "Sub" => {
1151            let a = get_tensor(env, &node.inputs[0])?;
1152            let b = get_tensor(env, &node.inputs[1])?;
1153            let out = a.sub(b)?;
1154            env.insert(node.outputs[0].clone(), out);
1155        }
1156        "Mul" => {
1157            let a = get_tensor(env, &node.inputs[0])?;
1158            let b = get_tensor(env, &node.inputs[1])?;
1159            let out = a.mul(b)?;
1160            env.insert(node.outputs[0].clone(), out);
1161        }
1162        "Div" => {
1163            let a = get_tensor(env, &node.inputs[0])?;
1164            let b = get_tensor(env, &node.inputs[1])?;
1165            let out = a.div(b)?;
1166            env.insert(node.outputs[0].clone(), out);
1167        }
1168        "Pow" => {
1169            let a = get_tensor(env, &node.inputs[0])?;
1170            // ONNX Pow has two inputs; exponent is second input
1171            let b = get_tensor(env, &node.inputs[1])?;
1172            let exp_val = b.to_f64_vec()?;
1173            if exp_val.len() == 1 {
1174                let out = a.powf(exp_val[0])?;
1175                env.insert(node.outputs[0].clone(), out);
1176            } else {
1177                return Err(shrew_core::Error::msg(
1178                    "ONNX Pow: only scalar exponent supported",
1179                ));
1180            }
1181        }
1182
1183        // ── MatMul / Gemm ────────────────────────────────────────────────
1184        "MatMul" => {
1185            let a = get_tensor(env, &node.inputs[0])?;
1186            let b = get_tensor(env, &node.inputs[1])?;
1187            let out = a.matmul(b)?;
1188            env.insert(node.outputs[0].clone(), out);
1189        }
1190        "Gemm" => {
1191            // Y = alpha * A' * B' + beta * C
1192            let alpha = attr_f(node, "alpha", 1.0) as f64;
1193            let beta = attr_f(node, "beta", 1.0) as f64;
1194            let trans_a = attr_i(node, "transA", 0) != 0;
1195            let trans_b = attr_i(node, "transB", 0) != 0;
1196
1197            let mut a = get_tensor(env, &node.inputs[0])?.clone();
1198            let mut b = get_tensor(env, &node.inputs[1])?.clone();
1199
1200            if trans_a {
1201                a = a.t()?;
1202            }
1203            if trans_b {
1204                b = b.t()?;
1205            }
1206
1207            let mut out = a.matmul(&b)?;
1208            if (alpha - 1.0).abs() > 1e-7 {
1209                out = out.affine(alpha, 0.0)?;
1210            }
1211            if node.inputs.len() > 2 && !node.inputs[2].is_empty() {
1212                let c = get_tensor(env, &node.inputs[2])?;
1213                if (beta - 1.0).abs() > 1e-7 {
1214                    let bc = c.affine(beta, 0.0)?;
1215                    out = out.add(&bc)?;
1216                } else {
1217                    out = out.add(c)?;
1218                }
1219            }
1220            env.insert(node.outputs[0].clone(), out);
1221        }
1222
1223        // ── Unary activations ────────────────────────────────────────────
1224        "Relu" => {
1225            let x = get_tensor(env, &node.inputs[0])?;
1226            env.insert(node.outputs[0].clone(), x.relu()?);
1227        }
1228        "Sigmoid" => {
1229            let x = get_tensor(env, &node.inputs[0])?;
1230            env.insert(node.outputs[0].clone(), x.sigmoid()?);
1231        }
1232        "Tanh" => {
1233            let x = get_tensor(env, &node.inputs[0])?;
1234            env.insert(node.outputs[0].clone(), x.tanh()?);
1235        }
1236        "Neg" => {
1237            let x = get_tensor(env, &node.inputs[0])?;
1238            env.insert(node.outputs[0].clone(), x.neg()?);
1239        }
1240        "Sqrt" => {
1241            let x = get_tensor(env, &node.inputs[0])?;
1242            env.insert(node.outputs[0].clone(), x.sqrt()?);
1243        }
1244        "Exp" => {
1245            let x = get_tensor(env, &node.inputs[0])?;
1246            env.insert(node.outputs[0].clone(), x.exp()?);
1247        }
1248        "Log" => {
1249            let x = get_tensor(env, &node.inputs[0])?;
1250            env.insert(node.outputs[0].clone(), x.log()?);
1251        }
1252        "Abs" => {
1253            let x = get_tensor(env, &node.inputs[0])?;
1254            env.insert(node.outputs[0].clone(), x.abs()?);
1255        }
1256
1257        // ── Softmax / LogSoftmax ─────────────────────────────────────────
1258        "Softmax" => {
1259            let x = get_tensor(env, &node.inputs[0])?;
1260            let axis = attr_i(node, "axis", -1);
1261            let dim = if axis < 0 {
1262                (x.rank() as i64 + axis) as usize
1263            } else {
1264                axis as usize
1265            };
1266            env.insert(node.outputs[0].clone(), x.softmax(dim)?);
1267        }
1268        "LogSoftmax" => {
1269            let x = get_tensor(env, &node.inputs[0])?;
1270            let axis = attr_i(node, "axis", -1);
1271            let dim = if axis < 0 {
1272                (x.rank() as i64 + axis) as usize
1273            } else {
1274                axis as usize
1275            };
1276            env.insert(node.outputs[0].clone(), x.log_softmax(dim)?);
1277        }
1278
1279        // ── Clip (clamp) ─────────────────────────────────────────────────
1280        "Clip" => {
1281            let x = get_tensor(env, &node.inputs[0])?;
1282            let min_val = if node.inputs.len() > 1 && !node.inputs[1].is_empty() {
1283                get_tensor(env, &node.inputs[1])?.to_f64_vec()?[0]
1284            } else {
1285                f64::NEG_INFINITY
1286            };
1287            let max_val = if node.inputs.len() > 2 && !node.inputs[2].is_empty() {
1288                get_tensor(env, &node.inputs[2])?.to_f64_vec()?[0]
1289            } else {
1290                f64::INFINITY
1291            };
1292            env.insert(node.outputs[0].clone(), x.clamp(min_val, max_val)?);
1293        }
1294
1295        // ── Shape manipulation ───────────────────────────────────────────
1296        "Reshape" => {
1297            let x = get_tensor(env, &node.inputs[0])?;
1298            let shape_tensor = get_tensor(env, &node.inputs[1])?;
1299            let shape_vals = shape_tensor.to_f64_vec()?;
1300
1301            // Resolve -1 dims
1302            let total = x.elem_count();
1303            let mut new_shape: Vec<usize> = shape_vals.iter().map(|&v| v as i64 as usize).collect();
1304            let neg_idx = new_shape.iter().position(|&s| s == usize::MAX); // -1 as usize wraps
1305            if let Some(idx) = neg_idx {
1306                let known: usize = new_shape
1307                    .iter()
1308                    .enumerate()
1309                    .filter(|&(i, _)| i != idx)
1310                    .map(|(_, &s)| s)
1311                    .product();
1312                if known > 0 {
1313                    new_shape[idx] = total / known;
1314                }
1315            }
1316            env.insert(node.outputs[0].clone(), x.reshape(new_shape)?);
1317        }
1318        "Transpose" => {
1319            let x = get_tensor(env, &node.inputs[0])?;
1320            let perm = attr_ints(node, "perm");
1321            if perm.is_empty() {
1322                // Default: reverse all dims
1323                let rank = x.rank();
1324                let rev: Vec<usize> = (0..rank).rev().collect();
1325                env.insert(node.outputs[0].clone(), x.permute(&rev)?);
1326            } else {
1327                let perm_usize: Vec<usize> = perm.iter().map(|&p| p as usize).collect();
1328                env.insert(node.outputs[0].clone(), x.permute(&perm_usize)?);
1329            }
1330        }
1331        "Flatten" => {
1332            let x = get_tensor(env, &node.inputs[0])?;
1333            let axis = attr_i(node, "axis", 1) as usize;
1334            env.insert(node.outputs[0].clone(), x.flatten(axis, x.rank() - 1)?);
1335        }
1336        "Squeeze" => {
1337            let x = get_tensor(env, &node.inputs[0])?;
1338            let axes = attr_ints(node, "axes");
1339            if axes.is_empty() {
1340                env.insert(node.outputs[0].clone(), x.squeeze_all());
1341            } else {
1342                let mut result = x.clone();
1343                // Squeeze from highest axis to lowest to avoid index shifting
1344                let mut sorted_axes: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1345                sorted_axes.sort_unstable();
1346                sorted_axes.reverse();
1347                for ax in sorted_axes {
1348                    result = result.squeeze(ax)?;
1349                }
1350                env.insert(node.outputs[0].clone(), result);
1351            }
1352        }
1353        "Unsqueeze" => {
1354            let x = get_tensor(env, &node.inputs[0])?;
1355            let axes = if node.inputs.len() > 1 && !node.inputs[1].is_empty() {
1356                // ONNX opset >= 13: axes is a tensor input
1357                let axes_t = get_tensor(env, &node.inputs[1])?;
1358                axes_t
1359                    .to_f64_vec()?
1360                    .iter()
1361                    .map(|&v| v as i64)
1362                    .collect::<Vec<_>>()
1363            } else {
1364                attr_ints(node, "axes")
1365            };
1366            let mut result = x.clone();
1367            let mut sorted_axes: Vec<usize> = axes
1368                .iter()
1369                .map(|&a| {
1370                    if a < 0 {
1371                        (result.rank() as i64 + a + 1) as usize
1372                    } else {
1373                        a as usize
1374                    }
1375                })
1376                .collect();
1377            sorted_axes.sort_unstable();
1378            for ax in sorted_axes {
1379                result = result.unsqueeze(ax)?;
1380            }
1381            env.insert(node.outputs[0].clone(), result);
1382        }
1383        "Concat" => {
1384            let axis = attr_i(node, "axis", 0) as usize;
1385            let tensors: Vec<Tensor<B>> = node
1386                .inputs
1387                .iter()
1388                .map(|n| get_tensor(env, n).cloned())
1389                .collect::<Result<Vec<_>>>()?;
1390            let refs: Vec<Tensor<B>> = tensors;
1391            let out = Tensor::<B>::cat(&refs, axis)?;
1392            env.insert(node.outputs[0].clone(), out);
1393        }
1394
1395        // ── Reductions ───────────────────────────────────────────────────
1396        "ReduceSum" => {
1397            let x = get_tensor(env, &node.inputs[0])?;
1398            let axes = attr_ints(node, "axes");
1399            let keepdims = attr_i(node, "keepdims", 1) != 0;
1400            let mut result = x.clone();
1401            if axes.is_empty() {
1402                result = result.sum_all()?;
1403            } else {
1404                let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1405                sorted.sort_unstable();
1406                sorted.reverse();
1407                for ax in sorted {
1408                    result = result.sum(ax, keepdims)?;
1409                }
1410            }
1411            env.insert(node.outputs[0].clone(), result);
1412        }
1413        "ReduceMean" => {
1414            let x = get_tensor(env, &node.inputs[0])?;
1415            let axes = attr_ints(node, "axes");
1416            let keepdims = attr_i(node, "keepdims", 1) != 0;
1417            let mut result = x.clone();
1418            if axes.is_empty() {
1419                result = result.mean_all()?;
1420            } else {
1421                let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1422                sorted.sort_unstable();
1423                sorted.reverse();
1424                for ax in sorted {
1425                    result = result.mean(ax, keepdims)?;
1426                }
1427            }
1428            env.insert(node.outputs[0].clone(), result);
1429        }
1430        "ReduceMax" => {
1431            let x = get_tensor(env, &node.inputs[0])?;
1432            let axes = attr_ints(node, "axes");
1433            let keepdims = attr_i(node, "keepdims", 1) != 0;
1434            let mut result = x.clone();
1435            let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1436            sorted.sort_unstable();
1437            sorted.reverse();
1438            for ax in sorted {
1439                result = result.max(ax, keepdims)?;
1440            }
1441            env.insert(node.outputs[0].clone(), result);
1442        }
1443        "ReduceMin" => {
1444            let x = get_tensor(env, &node.inputs[0])?;
1445            let axes = attr_ints(node, "axes");
1446            let keepdims = attr_i(node, "keepdims", 1) != 0;
1447            let mut result = x.clone();
1448            let mut sorted: Vec<usize> = axes.iter().map(|&a| a as usize).collect();
1449            sorted.sort_unstable();
1450            sorted.reverse();
1451            for ax in sorted {
1452                result = result.min(ax, keepdims)?;
1453            }
1454            env.insert(node.outputs[0].clone(), result);
1455        }
1456
1457        // ── Gather ───────────────────────────────────────────────────────
1458        "Gather" => {
1459            let x = get_tensor(env, &node.inputs[0])?;
1460            let indices = get_tensor(env, &node.inputs[1])?;
1461            let axis = attr_i(node, "axis", 0) as usize;
1462            env.insert(node.outputs[0].clone(), x.gather(axis, indices)?);
1463        }
1464
1465        // ── BatchNormalization ───────────────────────────────────────────
1466        "BatchNormalization" => {
1467            // inputs: X, scale, B, mean, var
1468            let x = get_tensor(env, &node.inputs[0])?;
1469            let scale = get_tensor(env, &node.inputs[1])?;
1470            let bias = get_tensor(env, &node.inputs[2])?;
1471            let mean = get_tensor(env, &node.inputs[3])?;
1472            let var = get_tensor(env, &node.inputs[4])?;
1473            let eps = attr_f(node, "epsilon", 1e-5) as f64;
1474
1475            // y = scale * (x - mean) / sqrt(var + eps) + bias
1476            // Broadcast: mean/var/scale/bias are 1-D [C], x is [N, C, ...]
1477            let x_sub = x.sub(mean)?;
1478            let std_inv = var.affine(1.0, eps)?.sqrt()?.reciprocal()?;
1479            let normed = x_sub.mul(&std_inv)?;
1480            let scaled = normed.mul(scale)?;
1481            let out = scaled.add(bias)?;
1482            env.insert(node.outputs[0].clone(), out);
1483        }
1484
1485        // ── Dropout (inference) ──────────────────────────────────────────
1486        "Dropout" => {
1487            // In inference mode, Dropout is identity
1488            let x = get_tensor(env, &node.inputs[0])?.clone();
1489            env.insert(node.outputs[0].clone(), x.clone());
1490            // Optional second output (mask) — insert copy
1491            if node.outputs.len() > 1 && !node.outputs[1].is_empty() {
1492                env.insert(node.outputs[1].clone(), x);
1493            }
1494        }
1495
1496        // ── Identity ─────────────────────────────────────────────────────
1497        "Identity" => {
1498            let x = get_tensor(env, &node.inputs[0])?;
1499            env.insert(node.outputs[0].clone(), x.clone());
1500        }
1501
1502        // ── Shape ────────────────────────────────────────────────────────
1503        "Shape" => {
1504            let x = get_tensor(env, &node.inputs[0])?;
1505            let shape: Vec<f64> = x.dims().iter().map(|&d| d as f64).collect();
1506            let n = shape.len();
1507            let out = Tensor::<B>::from_f64_slice(&shape, vec![n], DType::I64, device)?;
1508            env.insert(node.outputs[0].clone(), out);
1509        }
1510
1511        // ── Cast ─────────────────────────────────────────────────────────
1512        "Cast" => {
1513            let x = get_tensor(env, &node.inputs[0])?;
1514            let to = attr_i(node, "to", ONNX_FLOAT as i64) as i32;
1515            let target_dtype = onnx_to_dtype(to)?;
1516            env.insert(node.outputs[0].clone(), x.to_dtype(target_dtype)?);
1517        }
1518
1519        // ── Constant ─────────────────────────────────────────────────────
1520        "Constant" => {
1521            // Try to get value from attributes
1522            if let Some(OnnxAttribute::Float(v)) = node.attributes.get("value_float") {
1523                let out = Tensor::<B>::from_f64_slice(&[*v as f64], vec![1], DType::F32, device)?;
1524                env.insert(node.outputs[0].clone(), out);
1525            } else if let Some(OnnxAttribute::Int(v)) = node.attributes.get("value_int") {
1526                let out = Tensor::<B>::from_f64_slice(&[*v as f64], vec![1], DType::I64, device)?;
1527                env.insert(node.outputs[0].clone(), out);
1528            } else if let Some(OnnxAttribute::Floats(v)) = node.attributes.get("value_floats") {
1529                let data: Vec<f64> = v.iter().map(|f| *f as f64).collect();
1530                let n = data.len();
1531                let out = Tensor::<B>::from_f64_slice(&data, vec![n], DType::F32, device)?;
1532                env.insert(node.outputs[0].clone(), out);
1533            } else if let Some(OnnxAttribute::Ints(v)) = node.attributes.get("value_ints") {
1534                let data: Vec<f64> = v.iter().map(|i| *i as f64).collect();
1535                let n = data.len();
1536                let out = Tensor::<B>::from_f64_slice(&data, vec![n], DType::I64, device)?;
1537                env.insert(node.outputs[0].clone(), out);
1538            } else {
1539                return Err(shrew_core::Error::msg(format!(
1540                    "ONNX Constant: unsupported value attribute in node '{}'",
1541                    node.name
1542                )));
1543            }
1544        }
1545
1546        other => {
1547            return Err(shrew_core::Error::msg(format!(
1548                "ONNX runtime: unsupported op '{other}' (node '{}')",
1549                node.name
1550            )));
1551        }
1552    }
1553    Ok(())
1554}
1555
1556// =============================================================================
1557// Tests
1558// =============================================================================
1559
1560#[cfg(test)]
1561mod tests {
1562    use super::*;
1563    use shrew_cpu::{CpuBackend, CpuDevice};
1564
1565    type B = CpuBackend;
1566    type T = Tensor<B>;
1567    const DEV: CpuDevice = CpuDevice;
1568
1569    #[test]
1570    fn test_protobuf_varint_roundtrip() {
1571        let mut enc = PbEncoder::new();
1572        enc.write_varint(0);
1573        enc.write_varint(1);
1574        enc.write_varint(127);
1575        enc.write_varint(128);
1576        enc.write_varint(300);
1577        enc.write_varint(16384);
1578
1579        let mut dec = PbDecoder::new(&enc.buf);
1580        assert_eq!(dec.read_varint().unwrap(), 0);
1581        assert_eq!(dec.read_varint().unwrap(), 1);
1582        assert_eq!(dec.read_varint().unwrap(), 127);
1583        assert_eq!(dec.read_varint().unwrap(), 128);
1584        assert_eq!(dec.read_varint().unwrap(), 300);
1585        assert_eq!(dec.read_varint().unwrap(), 16384);
1586    }
1587
1588    #[test]
1589    fn test_onnx_tensor_encode_decode() {
1590        let mut tensor = OnnxTensor::new("test_weight");
1591        tensor.data_type = ONNX_FLOAT;
1592        tensor.dims = vec![2, 3];
1593        tensor.float_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1594
1595        let encoded = tensor.encode();
1596        let decoded = OnnxTensor::decode(&encoded).unwrap();
1597
1598        assert_eq!(decoded.name, "test_weight");
1599        assert_eq!(decoded.data_type, ONNX_FLOAT);
1600        assert_eq!(decoded.dims, vec![2, 3]);
1601        // Data stored in raw_data (packed float), so check via to_f64_vec
1602        let data = decoded.to_f64_vec();
1603        assert_eq!(data.len(), 6);
1604        for (a, b) in tensor.float_data.iter().zip(data.iter()) {
1605            assert!((*a as f64 - b).abs() < 1e-6);
1606        }
1607    }
1608
1609    #[test]
1610    fn test_export_import_roundtrip() {
1611        let linear = shrew_nn::Linear::<B>::new(4, 3, true, DType::F32, &DEV).unwrap();
1612
1613        // Export
1614        let path = std::env::temp_dir().join("shrew_test_onnx.onnx");
1615        export_weights(&path, &linear, "test_model", &[1, 4]).unwrap();
1616
1617        // Import
1618        let weights = load_onnx_weights::<B>(&path, &DEV).unwrap();
1619
1620        // Should have weight and bias
1621        assert_eq!(weights.len(), 2);
1622
1623        // Verify weight shape
1624        let named = linear.named_parameters();
1625        for (name, original) in &named {
1626            let loaded = weights.get(name).expect(&format!("missing: {name}"));
1627            assert_eq!(original.dims(), loaded.dims());
1628
1629            // Values should match
1630            let orig_data = original.to_f64_vec().unwrap();
1631            let load_data = loaded.to_f64_vec().unwrap();
1632            for (a, b) in orig_data.iter().zip(load_data.iter()) {
1633                assert!((a - b).abs() < 1e-5, "mismatch for {name}: {a} vs {b}");
1634            }
1635        }
1636
1637        // Cleanup
1638        let _ = fs::remove_file(&path);
1639    }
1640
1641    #[test]
1642    fn test_export_tensors() {
1643        let t1 = T::randn(vec![2, 3], DType::F32, &DEV).unwrap();
1644        let t2 = T::ones(vec![5], DType::F32, &DEV).unwrap();
1645
1646        let tensors = vec![
1647            ("weight".to_string(), t1.clone()),
1648            ("bias".to_string(), t2.clone()),
1649        ];
1650
1651        let path = std::env::temp_dir().join("shrew_test_tensors.onnx");
1652        export_tensors(&path, &tensors, "tensors_model").unwrap();
1653
1654        let loaded = load_onnx_weights::<B>(&path, &DEV).unwrap();
1655        assert_eq!(loaded.len(), 2);
1656        assert_eq!(loaded.get("weight").unwrap().dims(), &[2, 3]);
1657        assert_eq!(loaded.get("bias").unwrap().dims(), &[5]);
1658
1659        let _ = fs::remove_file(&path);
1660    }
1661
1662    #[test]
1663    fn test_onnx_model_builder() {
1664        let mut model = OnnxModel::new("test_graph");
1665        model
1666            .inputs
1667            .push(("X".to_string(), vec![1, 784], ONNX_FLOAT));
1668        model
1669            .outputs
1670            .push(("Y".to_string(), vec![1, 10], ONNX_FLOAT));
1671
1672        model.nodes.push(OnnxNode {
1673            inputs: vec!["X".to_string(), "weight".to_string()],
1674            outputs: vec!["matmul_out".to_string()],
1675            op_type: "MatMul".to_string(),
1676            name: "matmul_0".to_string(),
1677            attributes: HashMap::new(),
1678        });
1679
1680        let mut attrs = HashMap::new();
1681        attrs.insert("axis".to_string(), OnnxAttribute::Int(1));
1682        model.nodes.push(OnnxNode {
1683            inputs: vec!["matmul_out".to_string()],
1684            outputs: vec!["Y".to_string()],
1685            op_type: "Softmax".to_string(),
1686            name: "softmax_0".to_string(),
1687            attributes: attrs,
1688        });
1689
1690        let bytes = model.to_bytes();
1691        assert!(!bytes.is_empty());
1692        assert!(bytes.len() > 20); // non-trivial size
1693    }
1694
1695    #[test]
1696    fn test_dtype_conversion() {
1697        assert_eq!(dtype_to_onnx(DType::F32), ONNX_FLOAT);
1698        assert_eq!(dtype_to_onnx(DType::F64), ONNX_DOUBLE);
1699        assert_eq!(dtype_to_onnx(DType::F16), ONNX_FLOAT16);
1700        assert_eq!(onnx_to_dtype(ONNX_FLOAT).unwrap(), DType::F32);
1701        assert_eq!(onnx_to_dtype(ONNX_DOUBLE).unwrap(), DType::F64);
1702        assert_eq!(onnx_to_dtype(ONNX_FLOAT16).unwrap(), DType::F16);
1703    }
1704
1705    #[test]
1706    fn test_double_data_roundtrip() {
1707        let t = T::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2], DType::F64, &DEV).unwrap();
1708
1709        let tensors = vec![("w".to_string(), t.clone())];
1710        let path = std::env::temp_dir().join("shrew_test_f64.onnx");
1711        export_tensors(&path, &tensors, "f64_model").unwrap();
1712
1713        let loaded = load_onnx_weights::<B>(&path, &DEV).unwrap();
1714        let w = loaded.get("w").unwrap();
1715        assert_eq!(w.dims(), &[2, 2]);
1716
1717        let orig = t.to_f64_vec().unwrap();
1718        let load = w.to_f64_vec().unwrap();
1719        for (a, b) in orig.iter().zip(load.iter()) {
1720            assert!((a - b).abs() < 1e-10);
1721        }
1722
1723        let _ = fs::remove_file(&path);
1724    }
1725
1726    // ────────────────────────────────────────────────────────────────────
1727    //  Graph Import + Execution tests
1728    // ────────────────────────────────────────────────────────────────────
1729
1730    /// Helper: build a minimal ONNX model bytes from an OnnxModel.
1731    fn build_and_reload_graph(model: &OnnxModel) -> OnnxGraph {
1732        let bytes = model.to_bytes();
1733        load_onnx_graph_from_bytes(&bytes).unwrap()
1734    }
1735
1736    #[test]
1737    fn test_graph_add_two_inputs() {
1738        // Graph:  Y = A + B
1739        let mut model = OnnxModel::new("add_graph");
1740        model.inputs.push(("A".into(), vec![2, 2], ONNX_FLOAT));
1741        model.inputs.push(("B".into(), vec![2, 2], ONNX_FLOAT));
1742        model.outputs.push(("Y".into(), vec![2, 2], ONNX_FLOAT));
1743        model.nodes.push(OnnxNode {
1744            inputs: vec!["A".into(), "B".into()],
1745            outputs: vec!["Y".into()],
1746            op_type: "Add".into(),
1747            name: "add_0".into(),
1748            attributes: HashMap::new(),
1749        });
1750
1751        let graph = build_and_reload_graph(&model);
1752        assert_eq!(graph.nodes.len(), 1);
1753        assert_eq!(graph.nodes[0].op_type, "Add");
1754        assert_eq!(graph.output_names, vec!["Y"]);
1755
1756        let a = T::from_f64_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2], DType::F32, &DEV).unwrap();
1757        let b = T::from_f64_slice(&[10.0, 20.0, 30.0, 40.0], vec![2, 2], DType::F32, &DEV).unwrap();
1758
1759        let mut inputs = HashMap::new();
1760        inputs.insert("A".into(), a);
1761        inputs.insert("B".into(), b);
1762
1763        let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1764        let y = outputs.get("Y").unwrap();
1765        let data = y.to_f64_vec().unwrap();
1766        assert_eq!(data, vec![11.0, 22.0, 33.0, 44.0]);
1767    }
1768
1769    #[test]
1770    fn test_graph_linear_relu() {
1771        // Graph:  Z = Relu(X * W + B)
1772        //   matmul_out = MatMul(X, W)
1773        //   add_out    = Add(matmul_out, B)
1774        //   Z          = Relu(add_out)
1775        let mut model = OnnxModel::new("linear_relu");
1776        model.inputs.push(("X".into(), vec![1, 2], ONNX_FLOAT));
1777        model.outputs.push(("Z".into(), vec![1, 3], ONNX_FLOAT));
1778
1779        // W and B as initializers
1780        let mut w = OnnxTensor::new("W");
1781        w.data_type = ONNX_FLOAT;
1782        w.dims = vec![2, 3];
1783        w.float_data = vec![1.0, -1.0, 0.5, 0.0, 2.0, -0.5];
1784        model.initializers.push(w);
1785
1786        let mut b = OnnxTensor::new("B");
1787        b.data_type = ONNX_FLOAT;
1788        b.dims = vec![3];
1789        b.float_data = vec![0.0, 0.0, 0.0];
1790        model.initializers.push(b);
1791
1792        model.nodes.push(OnnxNode {
1793            inputs: vec!["X".into(), "W".into()],
1794            outputs: vec!["matmul_out".into()],
1795            op_type: "MatMul".into(),
1796            name: "matmul_0".into(),
1797            attributes: HashMap::new(),
1798        });
1799        model.nodes.push(OnnxNode {
1800            inputs: vec!["matmul_out".into(), "B".into()],
1801            outputs: vec!["add_out".into()],
1802            op_type: "Add".into(),
1803            name: "add_0".into(),
1804            attributes: HashMap::new(),
1805        });
1806        model.nodes.push(OnnxNode {
1807            inputs: vec!["add_out".into()],
1808            outputs: vec!["Z".into()],
1809            op_type: "Relu".into(),
1810            name: "relu_0".into(),
1811            attributes: HashMap::new(),
1812        });
1813
1814        let graph = build_and_reload_graph(&model);
1815        assert_eq!(graph.nodes.len(), 3);
1816
1817        // X = [[1.0, -1.0]]
1818        // X*W = [[1*1+(-1)*0, 1*(-1)+(-1)*2, 1*0.5+(-1)*(-0.5)]]
1819        //     = [[1.0, -3.0, 1.0]]
1820        // Relu => [[1.0, 0.0, 1.0]]
1821        let x = T::from_f64_slice(&[1.0, -1.0], vec![1, 2], DType::F32, &DEV).unwrap();
1822        let mut inputs = HashMap::new();
1823        inputs.insert("X".into(), x);
1824
1825        let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1826        let z = outputs.get("Z").unwrap();
1827        let data = z.to_f64_vec().unwrap();
1828        assert_eq!(data.len(), 3);
1829        assert!((data[0] - 1.0).abs() < 1e-5);
1830        assert!((data[1] - 0.0).abs() < 1e-5);
1831        assert!((data[2] - 1.0).abs() < 1e-5);
1832    }
1833
1834    #[test]
1835    fn test_graph_identity_and_dropout() {
1836        let mut model = OnnxModel::new("id_drop");
1837        model.inputs.push(("X".into(), vec![3], ONNX_FLOAT));
1838        model.outputs.push(("Y".into(), vec![3], ONNX_FLOAT));
1839
1840        model.nodes.push(OnnxNode {
1841            inputs: vec!["X".into()],
1842            outputs: vec!["id_out".into()],
1843            op_type: "Identity".into(),
1844            name: "id_0".into(),
1845            attributes: HashMap::new(),
1846        });
1847        model.nodes.push(OnnxNode {
1848            inputs: vec!["id_out".into()],
1849            outputs: vec!["Y".into()],
1850            op_type: "Dropout".into(),
1851            name: "drop_0".into(),
1852            attributes: HashMap::new(),
1853        });
1854
1855        let graph = build_and_reload_graph(&model);
1856        let x = T::from_f64_slice(&[5.0, -3.0, 7.0], vec![3], DType::F32, &DEV).unwrap();
1857        let mut inputs = HashMap::new();
1858        inputs.insert("X".into(), x.clone());
1859
1860        let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1861        let y = outputs.get("Y").unwrap();
1862        assert_eq!(y.to_f64_vec().unwrap(), x.to_f64_vec().unwrap());
1863    }
1864
1865    #[test]
1866    fn test_graph_file_roundtrip() {
1867        // Build, save, load from file, execute
1868        let mut model = OnnxModel::new("file_rt");
1869        model.inputs.push(("X".into(), vec![2], ONNX_FLOAT));
1870        model.outputs.push(("Y".into(), vec![2], ONNX_FLOAT));
1871
1872        model.nodes.push(OnnxNode {
1873            inputs: vec!["X".into()],
1874            outputs: vec!["Y".into()],
1875            op_type: "Sigmoid".into(),
1876            name: "sig_0".into(),
1877            attributes: HashMap::new(),
1878        });
1879
1880        let path = std::env::temp_dir().join("shrew_test_graph_rt.onnx");
1881        model.save(&path).unwrap();
1882
1883        let graph = load_onnx_graph(&path).unwrap();
1884        assert_eq!(graph.nodes.len(), 1);
1885
1886        let x = T::from_f64_slice(&[0.0, 1000.0], vec![2], DType::F32, &DEV).unwrap();
1887        let mut inputs = HashMap::new();
1888        inputs.insert("X".into(), x);
1889
1890        let outputs = run_onnx_graph::<B>(&graph, &inputs, &DEV).unwrap();
1891        let data = outputs.get("Y").unwrap().to_f64_vec().unwrap();
1892        assert!((data[0] - 0.5).abs() < 1e-5); // sigmoid(0) = 0.5
1893        assert!((data[1] - 1.0).abs() < 1e-3); // sigmoid(1000) ≈ 1.0
1894
1895        let _ = fs::remove_file(&path);
1896    }
1897
1898    #[test]
1899    fn test_decode_attribute_roundtrip() {
1900        // Encode an Int attribute, decode it back
1901        let attr = OnnxAttribute::Int(42);
1902        let encoded = encode_attribute("axis", &attr);
1903        let (name, decoded) = decode_attribute(&encoded.buf).unwrap();
1904        assert_eq!(name, "axis");
1905        match decoded {
1906            OnnxAttribute::Int(v) => assert_eq!(v, 42),
1907            _ => panic!("expected Int"),
1908        }
1909    }
1910}