ONNX to Burn: Development Guide
This guide offers in-depth design insights and step-by-step procedures for developers working on the
ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep
learning framework written in Rust. It converts ONNX models to Rust source code and model
weights to .burnpack files.
For an introduction to ONNX import in Burn, see this section of the Burn book.
Design Overview
Design Goals
- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
- Convert ONNX model weights to Burn state files.
- Support ONNX models generated by PyTorch (ONNX Opset 16+ recommended for best compatibility).
- Produce easy-to-understand and modifiable models.
- Ensure the generated models are trainable using Burn APIs.
Design Decisions
Core Principles:
- Op/Node-Centric Design: Built around individual operations and nodes for better scalability as more operators are added
- Opset-Aware Processing: Processors accept opset parameters for flexible behavior across different ONNX versions
- Constants-First Approach: All ONNX initializers are treated as constant nodes initially, providing a uniform starting point
- Native Type Integration: Direct use of
burn_tensor::TensorDataandDtypefor efficiency, consistency, and future mmap support - Multi-Phase Pipeline: Explicit transformation phases (initialization → conversion → type inference → post-processing → finalization) for better visibility and maintainability
- Graph Input Name Preservation: Sanitized ONNX names are preserved for easier development and troubleshooting
Separation of Concerns:
- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process
- Ensure operator behavior consistency across different OpSet versions
- Exclude any ONNX/Protobuf-specific logic from the Burn graph
- Feature Support Validation: The
onnx-ircrate should extract and preserve all ONNX attributes faithfully, even if Burn does not yet support them. Rejection of unsupported features should happen inburn-importduring code generation, not inonnx-irduring configuration extraction. This allowsonnx-irto be reused by other projects that may have different feature support
The conversion process involves three main stages:
- Convert ONNX model to Intermediate Representation (IR) via 5-phase pipeline.
- Translate IR to a Burn graph.
- Generate Rust source code from the Burn graph.
Adding New Operators
To extend burn-import with support for new ONNX operators, follow these steps:
-
Create PyTorch Script: Place a PyTorch script using the new operator under
crates/burn-import/onnx-tests/tests/<op>/<op>.py. Make sure to print both input and output tensors for end-to-end testing. -
Generate ONNX Model: Run the PyTorch script to produce an ONNX model.
-
Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.
-
Generate IR and Burn Graph: Navigate to crates/burn-import/ and run:
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out -
Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The
./out/my-model.graph.txtshould provide relevant information. -
Inspect Generated Files: The
my-model.graph.txtcontains IR details,my-model.rsholds the Burn model in Rust code, andmy-model.burnpackcontains the model weights. -
Integration Test: Include the test in the
tests/<op_name>/mod.rsfile in the crates/burn-import/onnx-tests/tests/ directory. Further details can be found in the onnx-tests README.
Implementing a New Operator
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
implementation of the Squeeze operation to illustrate points as needed. All file/directory paths
are relative to the root of the burn repository.
Step 1: Node Processor Implementation in onnx-ir
The onnx-ir crate handles the Intermediate Representation (IR) of ONNX models using a
processor-based architecture. For each operation:
-
Create a node module in
crates/onnx-ir/src/node/<operation_name>.rs. This file should contain:- Configuration struct: Define operation-specific parameters (e.g.,
SqueezeConfig) - Processor struct: Implement
NodeProcessortrait (marked aspub(crate)) - The processor handles:
- Input/output specification: Define expected inputs and outputs via
NodeSpec - Type inference: Infer output types from inputs and configuration
- Configuration extraction: Extract operation parameters from ONNX attributes
- Node construction: Build the final
Nodeenum variant with config
- Input/output specification: Define expected inputs and outputs via
- Configuration struct: Define operation-specific parameters (e.g.,
-
Make the module visible in
crates/onnx-ir/src/node/mod.rs:#![allow(unused)] fn main() { pub mod squeeze; } -
Create a node struct in your module file (e.g.,
squeeze.rs) with the standard fields:#![allow(unused)] fn main() { use onnx_ir_derive::NodeBuilder; #[derive(Debug, Clone, NodeBuilder)] pub struct SqueezeNode { pub name: String, pub inputs: Vec<Argument>, pub outputs: Vec<Argument>, pub config: SqueezeConfig, } }The
NodeBuilderderive macro generates a test builder (e.g.,SqueezeNodeBuilder) with methods for constructing nodes in tests. -
Add to the macro invocation in
crates/onnx-ir/src/ir/node.rsby adding a mapping to thedefine_node_enum!macro:#![allow(unused)] fn main() { define_node_enum! { // ... other variants Squeeze => squeeze::SqueezeNode, // ... more variants } }This single macro invocation generates both the
NodeTypeenum (for parsing) and theNodeenum (with tuple variants wrapping node structs) from a single source of truth. -
Register your processor in
crates/onnx-ir/src/registry.rsby adding it to thewith_standard_processors()function:#![allow(unused)] fn main() { registry.register("Squeeze", Box::new(squeeze::SqueezeProcessor)); }
For example, the squeeze operation in crates/onnx-ir/src/node/squeeze.rs contains:
- A
SqueezeConfigstruct with operation parameters (axes) - A
SqueezeProcessorstruct (markedpub(crate)) that implementsNodeProcessor - The
node_spec()method defines input/output requirements - The
process()method extracts config and constructs theNode::Squeezevariant
Step 2: Code Generation in burn-import
-
Create a new file named
<operation_name>.rsin thecrates/burn-import/src/burn/node/directory. This file implements code generation for your operation by implementing theNodeCodegentrait directly on the onnx-ir node type. -
Implement the
NodeCodegen<PS>trait for the onnx-ir node type. This trait defines how the node generates Rust code during the graph compilation process:#![allow(unused)] fn main() { use super::prelude::*; impl<PS: PrecisionSettings> NodeCodegen<PS> for onnx_ir::squeeze::SqueezeNode { fn inputs(&self) -> &[Argument] { &self.inputs } fn outputs(&self) -> &[Argument] { &self.outputs } fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream { let input_arg = self.inputs.first().unwrap(); let output_arg = self.outputs.first().unwrap(); // Use scope.arg() to handle Tensor/Scalar/Shape arguments automatically let input = scope.arg(input_arg); let output = arg_to_ident(output_arg); // Access node configuration match &self.config.axes { Some(axes) => { let axes_values: Vec<_> = axes.iter().map(|&i| { proc_macro2::Literal::i64_suffixed(i) }).collect(); quote! { let #output = #input.squeeze_dims(&[#(#axes_values),*]); } } None => { // Get output rank from type inference let output_rank = match &output_arg.ty { ArgType::Tensor(t) => t.rank, _ => panic!("Expected tensor output"), }; quote! { let #output = #input.squeeze::<#output_rank>(); } } } } } }Key methods to implement:
inputs(&self)- Returns references to input arguments (usually just&self.inputs)outputs(&self)- Returns references to output arguments (usually just&self.outputs)forward(&self, scope)- Generates Rust code for the operation using thequote!macrofield(&self)- (Optional) Declares module fields for parameters like weightscollect_snapshots(&self, field_name)- (Optional) Collects tensor snapshots for burnpack serialization
-
Use helper utilities from
argument_helpers.rs:scope.arg(argument)- Automatically handles Tensor/Scalar/Shape with proper cloningarg_to_ident(argument)- Converts argument to identifier for code generation
-
Add unit tests using snapshot testing to verify the generated code. These tests typically use the
instacrate and test helper functions to validate the generated code:#![allow(unused)] fn main() { #[cfg(test)] mod tests { use super::super::test_helpers::*; use insta::assert_snapshot; use onnx_ir::squeeze::SqueezeNodeBuilder; #[test] fn test_squeeze_forward() { let node = SqueezeNodeBuilder::new("squeeze1") .input_tensor("input", 3, DType::F32) .output_tensor("output", 2, DType::F32) .axes(vec![1]) .build(); let code = codegen_forward_default(&node); assert_snapshot!(code, @"let output = input.squeeze_dims(&[1i64]);"); } } }
Step 3: Register in Module System
Add the module declaration to crates/burn-import/src/burn/node/mod.rs:
#![allow(unused)] fn main() { // ... other node modules pub(crate) mod squeeze; // ... more node modules }
The modules are automatically made visible through re-exports in the same file.
Step 4: Register in Code Generation Dispatch
Add your operation to the dispatch macro in crates/burn-import/src/burn/node_codegen.rs. The
impl_node_codegen_dispatch! macro generates the trait implementation that dispatches to your
node-specific code.
Add the node variant name (as defined in onnx-ir's Node enum) to the macro invocation:
#![allow(unused)] fn main() { impl_node_codegen_dispatch! { // ... other operations Squeeze, // Add your operation here (matches Node::Squeeze variant) // ... more operations } }
The macro automatically generates:
- Dispatch implementation for
NodeCodegen<PS>ononnx_ir::Node - All required trait methods (
inputs,outputs,forward,field, etc.) - Pattern matching to route to your node-specific implementation
Step 5: Processor Implementation
The NodeProcessor trait defines how operations are processed in onnx-ir. Each processor must
implement:
- Associated type:
type Config- Define your configuration struct (use()if no config) infer_types()- Infer output types from inputs and config (required)build_node()- Construct the node struct and wrap it in theNodeenum variant (required)extract_config()- Extract config from attributes/inputs (override if Config !=())spec()- Define opset and input/output requirements (optional)lift_constants()- Request constant lifting for inputs (optional)
Example build_node() implementation:
#![allow(unused)] fn main() { fn build_node(&self, builder: RawNode, opset: usize) -> Node { let config = self.extract_config(&builder, opset).expect("Config extraction failed"); Node::Squeeze(SqueezeNode { name: builder.name, inputs: builder.inputs, outputs: builder.outputs, config, }) } }
Note: RawNode is the intermediate node representation used during processing. The build_node()
method converts it into the final typed Node enum variant.
For complete examples, see existing processors:
- Simple operation:
crates/onnx-ir/src/node/softmax.rs - With constant inputs:
crates/onnx-ir/src/node/squeeze.rs - Complex operation:
crates/onnx-ir/src/node/conv2d.rs
See NodeProcessor Trait for the complete trait definition.
Step 6: Add Newly Supported Op!
As a reward, add an extra check to crates/burn-import/SUPPORTED-ONNX-OPS.md!
Constant Lifting
The onnx-ir pipeline automatically handles constant lifting during the post-processing phase.
"Lifting" constants means making constant values directly accessible on node inputs via
Argument::value(), instead of requiring a separate graph traversal to find a Constant node.
When to use: If your operation takes constant inputs (e.g., weights in Conv1d, shape tensors in
Reshape, axes in Squeeze), access them via node.inputs[N].value() in your extract_config()
method. See the Configuration Extraction example in Step 5.
Optional optimization: Implement lift_constants() to explicitly request constant lifting for
specific inputs before extract_config() is called. The pipeline handles this automatically during
post-processing.
Architecture Overview
ONNX-IR Pipeline
The onnx-ir crate converts ONNX models to an Intermediate Representation through a 5-phase
pipeline:
Phase 1: Initialization
- Creates
GraphStatefrom ONNX proto structures - Constants-first approach: Converts all ONNX initializers into Constant nodes, providing a uniform starting point for processing
- Sets up the value store for tensor data using
burn_tensor::TensorData - Preserves sanitized graph input names for debugging
Phase 2: Node Conversion
- Converts ONNX nodes to IR nodes using registered processors
- Creates
RawNodeinstances from ONNX proto nodes (intermediate representation) - Processors extract configuration and construct typed
Nodeenum variants - Handles constant nodes specially (extracting values from attributes into tensor store)
- Each processor is responsible for its own type inference and node construction
Phase 3: Type Inference
- Type inference happens within each processor's
process()method during Phase 2 - Processors infer output types based on input types and configuration
- Multi-pass processing handles dependencies between nodes
- The pipeline may need multiple iterations for complex type dependencies (e.g., control flow)
Phase 4: Post-processing
- Lifts constants: Makes constant values accessible on downstream node inputs
- Eliminates Identity nodes: Removes no-op nodes and rewires the graph
- Re-runs constant lifting after Identity elimination
Phase 5: Finalization
- Removes unreferenced constant nodes
- Constructs the final
OnnxGraphwith inputs, outputs, and nodes
NodeProcessor Trait
The NodeProcessor trait (defined in crates/onnx-ir/src/processor.rs) is the core abstraction for
handling ONNX operations. Each processor implements:
Required:
type Config- Associated type for configuration (use()if no config needed)infer_types()- Infer output types from inputs and configurationbuild_node()- Construct the finalNodeenum variant
Optional (have defaults):
spec()- Define opset requirements and input/output count validation (NodeSpec,InputSpec,OutputSpec)extract_config()- Extract configuration from attributes/inputs (default returnsDefault::default())lift_constants()- Request constant lifting for specific inputs (default does nothing)input_preferences()- Declare preferred input types from producers (default returnsNone)
Design principles: Each processor is self-contained, handling type inference, config extraction, and
node construction. Processors return strongly-typed Node enum variants, ensuring type safety
throughout the pipeline.
Testing
When implementing a new operator, there are several levels of testing to consider:
Unit Testing
-
Processor Methods: Write unit tests in
crates/onnx-ir/src/node/<operation_name>.rsto verify:extract_config()- Correctly extracts configuration from attributes and inputsinfer_types()- Correctly infers output types (element type, rank, static shapes)build_node()- Constructs correctNodeenum variantspec()- Defines correct opset and input/output requirements- Error handling for invalid inputs or configurations
See existing tests in
crates/onnx-ir/src/node/squeeze.rsfor examples. -
Code Generation: Test the burn-import Node implementation to verify correct Rust code generation. Each node file typically includes unit tests using
assert_tokens()to validate generated code against expected output.
Integration Testing
-
Test Path: Write integration tests in
crates/burn-import/onnx-tests/tests/<op_name>/mod.rswhere<op_name>is the name of the new operator. -
What to Test:
- Create ONNX models that use your operator and test the end-to-end conversion process
- Ensure the generated Rust code compiles
- Test with realistic ONNX models that use your operator in conjunction with others
- Include models that test edge cases (e.g., different input shapes, parameter combinations)
- Verify that inputs and outputs match between the original ONNX model and the converted Burn model
-
Further details can be found in the onnx-tests README.
Testing the processor implementation is particularly important as it directly affects the correctness of the conversion process. Incorrect type inference can lead to mismatched tensor shapes or wrong element types, while incorrect configuration extraction can cause runtime errors or produce incorrect results.
Node Enum Architecture
The ONNX-IR uses an enum-based node representation where each ONNX operation is a variant of the
Node enum (defined in crates/onnx-ir/src/ir/node.rs). Each variant wraps an operation-specific
node struct (e.g., SoftmaxNode, Conv2dNode) that contains name, inputs, outputs, and
optionally a config field.
The define_node_enum! macro generates both enums from a single source using the syntax
VariantName => module::NodeStructType:
#![allow(unused)] fn main() { define_node_enum! { Softmax => softmax::SoftmaxNode, Conv2d => conv2d::Conv2dNode, Squeeze => squeeze::SqueezeNode, // ... 200+ more variants } }
This macro generates:
NodeTypeenum: Simple unit variants for ONNX parsing (Softmax,Conv2d, etc.)Nodeenum: Tuple variants wrapping node structs (Softmax(SoftmaxNode),Conv2d(Conv2dNode), etc.)- Accessor methods:
name(),inputs(),outputs()automatically generated for theNodeenum
This design provides:
- Type safety: Each operation has its own struct type
- Trait implementations: Operations can implement specific traits on their node structs
- Single source of truth: Both enums are guaranteed to stay in sync
- Pattern matching: Easy to match on specific operations and access their configuration