Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

ONNX to Burn: Development Guide

This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep learning framework written in Rust. It converts ONNX models to Rust source code and model weights to .burnpack files.

For an introduction to ONNX import in Burn, see this section of the Burn book.

Design Overview

Design Goals

  • Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
  • Convert ONNX model weights to Burn state files.
  • Support ONNX models generated by PyTorch (ONNX Opset 16+ recommended for best compatibility).
  • Produce easy-to-understand and modifiable models.
  • Ensure the generated models are trainable using Burn APIs.

Design Decisions

Core Principles:

  • Op/Node-Centric Design: Built around individual operations and nodes for better scalability as more operators are added
  • Opset-Aware Processing: Processors accept opset parameters for flexible behavior across different ONNX versions
  • Constants-First Approach: All ONNX initializers are treated as constant nodes initially, providing a uniform starting point
  • Native Type Integration: Direct use of burn_tensor::TensorData and Dtype for efficiency, consistency, and future mmap support
  • Multi-Phase Pipeline: Explicit transformation phases (initialization → conversion → type inference → post-processing → finalization) for better visibility and maintainability
  • Graph Input Name Preservation: Sanitized ONNX names are preserved for easier development and troubleshooting

Separation of Concerns:

  • Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process
  • Ensure operator behavior consistency across different OpSet versions
  • Exclude any ONNX/Protobuf-specific logic from the Burn graph
  • Feature Support Validation: The onnx-ir crate should extract and preserve all ONNX attributes faithfully, even if Burn does not yet support them. Rejection of unsupported features should happen in burn-import during code generation, not in onnx-ir during configuration extraction. This allows onnx-ir to be reused by other projects that may have different feature support

The conversion process involves three main stages:

  1. Convert ONNX model to Intermediate Representation (IR) via 5-phase pipeline.
  2. Translate IR to a Burn graph.
  3. Generate Rust source code from the Burn graph.

Adding New Operators

To extend burn-import with support for new ONNX operators, follow these steps:

  1. Create PyTorch Script: Place a PyTorch script using the new operator under crates/burn-import/onnx-tests/tests/<op>/<op>.py. Make sure to print both input and output tensors for end-to-end testing.

  2. Generate ONNX Model: Run the PyTorch script to produce an ONNX model.

  3. Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.

  4. Generate IR and Burn Graph: Navigate to crates/burn-import/ and run:

    cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
    
  5. Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The ./out/my-model.graph.txt should provide relevant information.

  6. Inspect Generated Files: The my-model.graph.txt contains IR details, my-model.rs holds the Burn model in Rust code, and my-model.burnpack contains the model weights.

  7. Integration Test: Include the test in the tests/<op_name>/mod.rs file in the crates/burn-import/onnx-tests/tests/ directory. Further details can be found in the onnx-tests README.

Implementing a New Operator

To extend the capabilities of the Burn library by supporting new operations imported from ONNX graphs, developers must go through a few systematic steps. Here, we detail the process, using the implementation of the Squeeze operation to illustrate points as needed. All file/directory paths are relative to the root of the burn repository.

Step 1: Node Processor Implementation in onnx-ir

The onnx-ir crate handles the Intermediate Representation (IR) of ONNX models using a processor-based architecture. For each operation:

  1. Create a node module in crates/onnx-ir/src/node/<operation_name>.rs. This file should contain:

    • Configuration struct: Define operation-specific parameters (e.g., SqueezeConfig)
    • Processor struct: Implement NodeProcessor trait (marked as pub(crate))
    • The processor handles:
      • Input/output specification: Define expected inputs and outputs via NodeSpec
      • Type inference: Infer output types from inputs and configuration
      • Configuration extraction: Extract operation parameters from ONNX attributes
      • Node construction: Build the final Node enum variant with config
  2. Make the module visible in crates/onnx-ir/src/node/mod.rs:

    #![allow(unused)]
    fn main() {
    pub mod squeeze;
    }
  3. Create a node struct in your module file (e.g., squeeze.rs) with the standard fields:

    #![allow(unused)]
    fn main() {
    use onnx_ir_derive::NodeBuilder;
    
    #[derive(Debug, Clone, NodeBuilder)]
    pub struct SqueezeNode {
        pub name: String,
        pub inputs: Vec<Argument>,
        pub outputs: Vec<Argument>,
        pub config: SqueezeConfig,
    }
    }

    The NodeBuilder derive macro generates a test builder (e.g., SqueezeNodeBuilder) with methods for constructing nodes in tests.

  4. Add to the macro invocation in crates/onnx-ir/src/ir/node.rs by adding a mapping to the define_node_enum! macro:

    #![allow(unused)]
    fn main() {
    define_node_enum! {
        // ... other variants
        Squeeze => squeeze::SqueezeNode,
        // ... more variants
    }
    }

    This single macro invocation generates both the NodeType enum (for parsing) and the Node enum (with tuple variants wrapping node structs) from a single source of truth.

  5. Register your processor in crates/onnx-ir/src/registry.rs by adding it to the with_standard_processors() function:

    #![allow(unused)]
    fn main() {
    registry.register("Squeeze", Box::new(squeeze::SqueezeProcessor));
    }

For example, the squeeze operation in crates/onnx-ir/src/node/squeeze.rs contains:

  • A SqueezeConfig struct with operation parameters (axes)
  • A SqueezeProcessor struct (marked pub(crate)) that implements NodeProcessor
  • The node_spec() method defines input/output requirements
  • The process() method extracts config and constructs the Node::Squeeze variant

Step 2: Code Generation in burn-import

  1. Create a new file named <operation_name>.rs in the crates/burn-import/src/burn/node/ directory. This file implements code generation for your operation by implementing the NodeCodegen trait directly on the onnx-ir node type.

  2. Implement the NodeCodegen<PS> trait for the onnx-ir node type. This trait defines how the node generates Rust code during the graph compilation process:

    #![allow(unused)]
    fn main() {
    use super::prelude::*;
    
    impl<PS: PrecisionSettings> NodeCodegen<PS> for onnx_ir::squeeze::SqueezeNode {
        fn inputs(&self) -> &[Argument] {
            &self.inputs
        }
    
        fn outputs(&self) -> &[Argument] {
            &self.outputs
        }
    
        fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
            let input_arg = self.inputs.first().unwrap();
            let output_arg = self.outputs.first().unwrap();
    
            // Use scope.arg() to handle Tensor/Scalar/Shape arguments automatically
            let input = scope.arg(input_arg);
            let output = arg_to_ident(output_arg);
    
            // Access node configuration
            match &self.config.axes {
                Some(axes) => {
                    let axes_values: Vec<_> = axes.iter().map(|&i| {
                        proc_macro2::Literal::i64_suffixed(i)
                    }).collect();
                    quote! {
                        let #output = #input.squeeze_dims(&[#(#axes_values),*]);
                    }
                }
                None => {
                    // Get output rank from type inference
                    let output_rank = match &output_arg.ty {
                        ArgType::Tensor(t) => t.rank,
                        _ => panic!("Expected tensor output"),
                    };
                    quote! {
                        let #output = #input.squeeze::<#output_rank>();
                    }
                }
            }
        }
    }
    }

    Key methods to implement:

    • inputs(&self) - Returns references to input arguments (usually just &self.inputs)
    • outputs(&self) - Returns references to output arguments (usually just &self.outputs)
    • forward(&self, scope) - Generates Rust code for the operation using the quote! macro
    • field(&self) - (Optional) Declares module fields for parameters like weights
    • collect_snapshots(&self, field_name) - (Optional) Collects tensor snapshots for burnpack serialization
  3. Use helper utilities from argument_helpers.rs:

    • scope.arg(argument) - Automatically handles Tensor/Scalar/Shape with proper cloning
    • arg_to_ident(argument) - Converts argument to identifier for code generation
  4. Add unit tests using snapshot testing to verify the generated code. These tests typically use the insta crate and test helper functions to validate the generated code:

    #![allow(unused)]
    fn main() {
    #[cfg(test)]
    mod tests {
        use super::super::test_helpers::*;
        use insta::assert_snapshot;
        use onnx_ir::squeeze::SqueezeNodeBuilder;
    
        #[test]
        fn test_squeeze_forward() {
            let node = SqueezeNodeBuilder::new("squeeze1")
                .input_tensor("input", 3, DType::F32)
                .output_tensor("output", 2, DType::F32)
                .axes(vec![1])
                .build();
            let code = codegen_forward_default(&node);
            assert_snapshot!(code, @"let output = input.squeeze_dims(&[1i64]);");
        }
    }
    }

Step 3: Register in Module System

Add the module declaration to crates/burn-import/src/burn/node/mod.rs:

#![allow(unused)]
fn main() {
// ... other node modules
pub(crate) mod squeeze;
// ... more node modules
}

The modules are automatically made visible through re-exports in the same file.

Step 4: Register in Code Generation Dispatch

Add your operation to the dispatch macro in crates/burn-import/src/burn/node_codegen.rs. The impl_node_codegen_dispatch! macro generates the trait implementation that dispatches to your node-specific code.

Add the node variant name (as defined in onnx-ir's Node enum) to the macro invocation:

#![allow(unused)]
fn main() {
impl_node_codegen_dispatch! {
    // ... other operations
    Squeeze,  // Add your operation here (matches Node::Squeeze variant)
    // ... more operations
}
}

The macro automatically generates:

  • Dispatch implementation for NodeCodegen<PS> on onnx_ir::Node
  • All required trait methods (inputs, outputs, forward, field, etc.)
  • Pattern matching to route to your node-specific implementation

Step 5: Processor Implementation

The NodeProcessor trait defines how operations are processed in onnx-ir. Each processor must implement:

  1. Associated type: type Config - Define your configuration struct (use () if no config)
  2. infer_types() - Infer output types from inputs and config (required)
  3. build_node() - Construct the node struct and wrap it in the Node enum variant (required)
  4. extract_config() - Extract config from attributes/inputs (override if Config != ())
  5. spec() - Define opset and input/output requirements (optional)
  6. lift_constants() - Request constant lifting for inputs (optional)

Example build_node() implementation:

#![allow(unused)]
fn main() {
fn build_node(&self, builder: RawNode, opset: usize) -> Node {
    let config = self.extract_config(&builder, opset).expect("Config extraction failed");
    Node::Squeeze(SqueezeNode {
        name: builder.name,
        inputs: builder.inputs,
        outputs: builder.outputs,
        config,
    })
}
}

Note: RawNode is the intermediate node representation used during processing. The build_node() method converts it into the final typed Node enum variant.

For complete examples, see existing processors:

  • Simple operation: crates/onnx-ir/src/node/softmax.rs
  • With constant inputs: crates/onnx-ir/src/node/squeeze.rs
  • Complex operation: crates/onnx-ir/src/node/conv2d.rs

See NodeProcessor Trait for the complete trait definition.

Step 6: Add Newly Supported Op!

As a reward, add an extra check to crates/burn-import/SUPPORTED-ONNX-OPS.md!

Constant Lifting

The onnx-ir pipeline automatically handles constant lifting during the post-processing phase. "Lifting" constants means making constant values directly accessible on node inputs via Argument::value(), instead of requiring a separate graph traversal to find a Constant node.

When to use: If your operation takes constant inputs (e.g., weights in Conv1d, shape tensors in Reshape, axes in Squeeze), access them via node.inputs[N].value() in your extract_config() method. See the Configuration Extraction example in Step 5.

Optional optimization: Implement lift_constants() to explicitly request constant lifting for specific inputs before extract_config() is called. The pipeline handles this automatically during post-processing.

Architecture Overview

ONNX-IR Pipeline

The onnx-ir crate converts ONNX models to an Intermediate Representation through a 5-phase pipeline:

Phase 1: Initialization

  • Creates GraphState from ONNX proto structures
  • Constants-first approach: Converts all ONNX initializers into Constant nodes, providing a uniform starting point for processing
  • Sets up the value store for tensor data using burn_tensor::TensorData
  • Preserves sanitized graph input names for debugging

Phase 2: Node Conversion

  • Converts ONNX nodes to IR nodes using registered processors
  • Creates RawNode instances from ONNX proto nodes (intermediate representation)
  • Processors extract configuration and construct typed Node enum variants
  • Handles constant nodes specially (extracting values from attributes into tensor store)
  • Each processor is responsible for its own type inference and node construction

Phase 3: Type Inference

  • Type inference happens within each processor's process() method during Phase 2
  • Processors infer output types based on input types and configuration
  • Multi-pass processing handles dependencies between nodes
  • The pipeline may need multiple iterations for complex type dependencies (e.g., control flow)

Phase 4: Post-processing

  • Lifts constants: Makes constant values accessible on downstream node inputs
  • Eliminates Identity nodes: Removes no-op nodes and rewires the graph
  • Re-runs constant lifting after Identity elimination

Phase 5: Finalization

  • Removes unreferenced constant nodes
  • Constructs the final OnnxGraph with inputs, outputs, and nodes

NodeProcessor Trait

The NodeProcessor trait (defined in crates/onnx-ir/src/processor.rs) is the core abstraction for handling ONNX operations. Each processor implements:

Required:

  • type Config - Associated type for configuration (use () if no config needed)
  • infer_types() - Infer output types from inputs and configuration
  • build_node() - Construct the final Node enum variant

Optional (have defaults):

  • spec() - Define opset requirements and input/output count validation (NodeSpec, InputSpec, OutputSpec)
  • extract_config() - Extract configuration from attributes/inputs (default returns Default::default())
  • lift_constants() - Request constant lifting for specific inputs (default does nothing)
  • input_preferences() - Declare preferred input types from producers (default returns None)

Design principles: Each processor is self-contained, handling type inference, config extraction, and node construction. Processors return strongly-typed Node enum variants, ensuring type safety throughout the pipeline.

Testing

When implementing a new operator, there are several levels of testing to consider:

Unit Testing

  • Processor Methods: Write unit tests in crates/onnx-ir/src/node/<operation_name>.rs to verify:

    • extract_config() - Correctly extracts configuration from attributes and inputs
    • infer_types() - Correctly infers output types (element type, rank, static shapes)
    • build_node() - Constructs correct Node enum variant
    • spec() - Defines correct opset and input/output requirements
    • Error handling for invalid inputs or configurations

    See existing tests in crates/onnx-ir/src/node/squeeze.rs for examples.

  • Code Generation: Test the burn-import Node implementation to verify correct Rust code generation. Each node file typically includes unit tests using assert_tokens() to validate generated code against expected output.

Integration Testing

  • Test Path: Write integration tests in crates/burn-import/onnx-tests/tests/<op_name>/mod.rs where <op_name> is the name of the new operator.

  • What to Test:

    • Create ONNX models that use your operator and test the end-to-end conversion process
    • Ensure the generated Rust code compiles
    • Test with realistic ONNX models that use your operator in conjunction with others
    • Include models that test edge cases (e.g., different input shapes, parameter combinations)
    • Verify that inputs and outputs match between the original ONNX model and the converted Burn model
  • Further details can be found in the onnx-tests README.

Testing the processor implementation is particularly important as it directly affects the correctness of the conversion process. Incorrect type inference can lead to mismatched tensor shapes or wrong element types, while incorrect configuration extraction can cause runtime errors or produce incorrect results.

Node Enum Architecture

The ONNX-IR uses an enum-based node representation where each ONNX operation is a variant of the Node enum (defined in crates/onnx-ir/src/ir/node.rs). Each variant wraps an operation-specific node struct (e.g., SoftmaxNode, Conv2dNode) that contains name, inputs, outputs, and optionally a config field.

The define_node_enum! macro generates both enums from a single source using the syntax VariantName => module::NodeStructType:

#![allow(unused)]
fn main() {
define_node_enum! {
    Softmax => softmax::SoftmaxNode,
    Conv2d => conv2d::Conv2dNode,
    Squeeze => squeeze::SqueezeNode,
    // ... 200+ more variants
}
}

This macro generates:

  1. NodeType enum: Simple unit variants for ONNX parsing (Softmax, Conv2d, etc.)
  2. Node enum: Tuple variants wrapping node structs (Softmax(SoftmaxNode), Conv2d(Conv2dNode), etc.)
  3. Accessor methods: name(), inputs(), outputs() automatically generated for the Node enum

This design provides:

  • Type safety: Each operation has its own struct type
  • Trait implementations: Operations can implement specific traits on their node structs
  • Single source of truth: Both enums are guaranteed to stay in sync
  • Pattern matching: Easy to match on specific operations and access their configuration

Resources

  1. PyTorch to ONNX
  2. ONNX to PyTorch
  3. ONNX Introduction
  4. ONNX Operators
  5. ONNX Protos
  6. ONNX Optimizer
  7. Netron