ONNX to Burn: Development Guide
This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep learning framework written in Rust. It converts both ONNX models to Rust source code and model weights to Burn state files.
For an introduction to ONNX import in Burn, see this section of the Burn book.
Design Overview
Design Goals
- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
- Convert ONNX model weights to Burn state files.
- Support ONNX models generated by PyTorch (ONNX Opset 16+).
- Produce easy-to-understand and modifiable models.
- Ensure the generated models are trainable using Burn APIs.
Design Decisions
- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process.
- Ensure operator behavior consistency across different OpSet versions.
- Exclude any ONNX/Protobuf-specific logic from the Burn graph.
The conversion process involves three main stages:
- Convert ONNX model to Intermediate Representation (IR).
- Translate IR to a Burn graph.
- Generate Rust source code from the Burn graph.
Adding New Operators
To extend burn-import with support for new ONNX operators, follow these steps:
-
Create PyTorch Script: Place a PyTorch script using the new operator under
crates/burn-import/onnx-tests/tests/<op>/<op>.py. Make sure to print both input and output tensors for end-to-end testing. -
Generate ONNX Model: Run the PyTorch script to produce an ONNX model.
-
Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.
-
Generate IR and Burn Graph: Navigate to crates/burn-import/ and run:
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out -
Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The
./out/my-model.graph.txtshould provide relevant information. -
Inspect Generated Files: The
my-model.graph.txtcontains IR details,my-model.rsholds the Burn model in Rust code, andmy-model.jsonincludes the model data. -
Add End-to-End Test: Include the test in crates/burn-import/onnx-tests/tests/test_onnx.rs. Further details can be found in the onnx-tests README.
Implementing a New Operator
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
implementation of the Squeeze operation to illustrate points as needed. All file/directory paths
are relative to the root of the burn repository.
Step 1: Node Implementation in onnx-ir
The onnx-ir crate handles the Intermediate Representation (IR) of ONNX models. For each operation:
-
Add the operation to the
NodeTypeenum incrates/onnx-ir/src/ir.rs. -
Create a new module file in
crates/onnx-ir/src/node/<operation_name>.rs. This file should include:- A
<operation_name>_configfunction to extract operation parameters - A
<operation_name>_update_outputfunction for dimension inference
- A
-
Make the module visible in
crates/onnx-ir/src/node/mod.rs. -
If the operation might work with constants, add it to the list of node types checked for constants in
crates/onnx-ir/src/from_onnx.rs.
For example, the squeeze operation is defined in crates/onnx-ir/src/node/squeeze.rs and contains:
- A
squeeze_configfunction that extracts axes from node attributes - A
squeeze_update_outputfunction that updates output dimensions by reducing input rank
Step 2: Node Implementation in burn-import
-
Create a new file named
<operation_name>.rsin thecrates/burn-import/src/burn/node/directory. This file will define the structure and functionality of your new operation. By convention, the necessary information for carrying out an operation is encapsulated within a struct named<operation>Node. For theSqueezeoperation, we defined a struct calledSqueezeNodethat holds necessary information about the input tensor, output tensor, and axes for the operation. -
Implement the
OnnxIntoNodetrait for your node. This trait has a single methodfrom_onnxthat converts an ONNX IR node into your Burn node type:#![allow(unused)] fn main() { impl OnnxIntoNode for SqueezeNode { fn from_onnx(node: onnx_ir::Node) -> Self { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); let axes = squeeze_config(&node); SqueezeNode::new(input, output, axes) } } } -
The core of integrating a new operation involves implementing the
NodeCodegentrait for your node. This trait defines how the node generates code during the graph compilation process. The implementation must provide methods to define input and output types, to generate the forward pass code, and to encapsulate the node into the more generalNodestructure. Specifically:input_types()- Returns the types of all input argumentsoutput_types()- Returns the types of all output valuesforward()- Generates the Rust code that performs the operation during execution. Use thequote!macro to generate code and ensure it's syntactically correct Burn codeinto_node()- Wraps the node into the generalNode<PS>enum
Example implementation:
#![allow(unused)] fn main() { impl<PS: PrecisionSettings> NodeCodegen<PS> for SqueezeNode { fn input_types(&self) -> Vec<Type> { vec![self.input.clone()] } fn output_types(&self) -> Vec<Type> { vec![self.output.clone()] } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { // Simplified example - actual implementation handles more cases let input = scope.tensor_use_owned(&self.input, node_position); let output = &self.output.name(); match &self.axes { Some(axes) => { let axes_tokens = axes.to_tokens(); quote! { let #output = #input.squeeze_dims(&#axes_tokens); } } None => { let output_rank = self.output.rank(); quote! { let #output = #input.squeeze::<#output_rank>(); } } } } fn into_node(self) -> Node<PS> { Node::Squeeze(self) } } } -
Add unit tests in the same file to verify the generated code compiles and works correctly. These tests typically call a helper function like
assert_tokens()to validate the generated code against expected output.
Step 3: Register in Module System
Add the module declaration to crates/burn-import/src/burn/node/mod.rs:
#![allow(unused)] fn main() { // ... other node modules pub(crate) mod squeeze; // ... more node modules }
The modules are automatically made visible through re-exports in the same file.
Step 4: Register in Node Registry
Add your operation to the node registry in crates/burn-import/src/burn/node_registry.rs. This is
the single source of truth for all ONNX node conversions.
For a single ONNX operation mapping to a single node type:
#![allow(unused)] fn main() { node_registry! { // ... other operations Squeeze => squeeze as SqueezeNode, // ... more operations } }
For multiple ONNX operations mapping to the same node type (e.g., various reduce operations):
#![allow(unused)] fn main() { node_registry! { // ... other operations [ReduceMax, ReduceMin, ReduceMean, ReduceProd, ReduceSum] => ReduceMax: reduce as ReduceNode, // ... more operations } }
That's it! The registry automatically generates:
- The
Node<PS>enum with your operation as a variant - The
match_all!macro for dispatching - The ONNX to Burn conversion logic
- All necessary imports
Step 5: Rank Inference
In crates/onnx-ir/src/node/<operation_name>.rs, implement a rank inference function that updates
the output rank based on the operation:
#![allow(unused)] fn main() { pub fn squeeze_update_output(node: &mut Node) { // Extract axes information let axes = /* ... */; let input_rank = /* ... */; let output_rank = input_rank - axes.len(); // Update output rank node.outputs[0].ty = ArgType::Tensor(TensorType { elem_type: node.inputs[0].ty.elem_type().clone(), rank: output_rank, static_shape: None, }); } }
Then register this function in crates/onnx-ir/src/rank_inference.rs by adding it to the match
statement:
#![allow(unused)] fn main() { pub fn rank_inference(node: &mut Node) { match node.node_type { // ... NodeType::Squeeze => squeeze_update_output(node), // Add your new operation here } } }
The rank_inference.rs file is responsible for determining the output tensor rank for each node in
the graph.
If the rank remains unchanged, you can use helper functions like same_as_input() or
same_as_input_broadcast() instead of writing a custom update function.
Step 6: Add Newly Supported Op!
As a reward, add an extra check to crates/burn-import/SUPPORTED-ONNX-OPS.md!
Lifting Constant Nodes
If your operation takes inputs from constant nodes (such as weights in Conv1d, shape tensors in
Reshape, etc.), you need to add your operation's NodeType to the LIFT_CONSTANTS_FOR_NODE_TYPES
array in crates/onnx-ir/src/from_onnx.rs.
#![allow(unused)] fn main() { const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [ NodeType::BatchNormalization, // other operations... NodeType::Squeeze, NodeType::Unsqueeze, // Add your operation here if it needs constants to be processed ]; }
"Lifting" constants means converting Constant nodes into direct input values. This is similar to how ONNX initializers work. For example, instead of having a separate Constant node providing weights to a Convolution operation, the weights are directly embedded as values in the Convolution node's inputs.
This transformation makes it easier to:
- Access the constant values during node configuration
- Process operations like Conv1d that expect weights as direct inputs
- Handle shape-defining inputs needed for operations like Reshape
Without this, operations that need to extract configuration from constant inputs (such as shapes, weights, or other parameters) would not work correctly because they wouldn't have direct access to those constant values.
Architecture Overview
The burn-import crate is organized into several key modules:
Core Modules
-
burn/node_registry.rs: Master registry containing all ONNX node mappings. This is a declarative macro that auto-generates theNodeenum, conversion functions, and dispatch logic. -
burn/node_codegen.rs: Contains theNodeCodegenandOnnxIntoNodetraits that all nodes must implement. Also includes code generation utilities. -
burn/node/: Directory containing individual node implementations. Each file implements a specific ONNX operation. -
burn/graph.rs: Burn graph representation and code generation. -
burn/ty.rs: Type system for tensors, scalars, and shapes, including conversions from ONNX types. -
onnx/model_gen.rs: Public API (ModelGen) for converting ONNX models to Burn code.
Testing
When implementing a new operator, there are several levels of testing to consider:
Unit Testing
-
Node Configuration: Write unit tests for the
<operation_name>_configfunction incrates/onnx-ir/src/node/<operation_name>.rsto verify that it correctly extracts parameters from ONNX nodes. -
Rank Inference: Test the
<operation_name>_update_outputfunction to ensure it correctly computes output ranks. -
Code Generation: Test the Node implementation in
burn-importto verify that it generates correct Rust code. Each node file typically includes atest_codegen_nodes()function.
Integration Testing
- Create small ONNX models that use your operator and test the end-to-end conversion process
- Ensure the generated Rust code compiles and produces the expected outputs
- Add these tests to
crates/burn-import/onnx-tests/tests/test_onnx.rs
End-to-End Testing
- Test with realistic ONNX models that use your operator in conjunction with others
- Verify that inputs and outputs match between the original ONNX model and the converted Burn model
- Include models that test edge cases (e.g., different input shapes, parameter combinations)
Testing both the rank inference and node configuration is particularly important as these components directly affect the correctness of the conversion process. Incorrect rank inference can lead to mismatched tensor shapes, while incorrect configuration can cause runtime errors or incorrect results.