ONNX to Burn: Development Guide
This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep learning framework written in Rust. It converts both ONNX models to Rust source code and model weights to Burn state files.
For an introduction to ONNX import in Burn, see this section of the Burn book.
Design Overview
Design Goals
- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
- Convert ONNX model weights to Burn state files.
- Support ONNX models generated by PyTorch (ONNX Opset 16).
- Produce easy-to-understand and modifiable models.
- Ensure the generated models are trainable using Burn APIs.
Design Decisions
- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process.
- Ensure operator behavior consistency across different OpSet versions.
- Exclude any ONNX/Protobuf-specific logic from the Burn graph.
The conversion process involves three main stages:
- Convert ONNX model to Intermediate Representation (IR).
- Translate IR to a Burn graph.
- Generate Rust source code from the Burn graph.
Adding New Operators
To extend burn-import
with support for new ONNX operators, follow these steps:
-
Create PyTorch Script: Place a PyTorch script using the new operator under
crates/burn-import/onnx-tests/tests/<op>/<op>.py
. Make sure to print both input and output tensors for end-to-end testing. -
Generate ONNX Model: Run the PyTorch script to produce an ONNX model.
-
Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.
-
Generate IR and Burn Graph: Navigate to crates/burn-import/ and run:
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
-
Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The
./out/my-model.graph.txt
should provide relevant information. -
Inspect Generated Files: The
my-model.graph.txt
contains IR details,my-model.rs
holds the Burn model in Rust code, andmy-model.json
includes the model data. -
Add End-to-End Test: Include the test in crates/burn-import/onnx-tests/tests/test_onnx.rs. Further details can be found in the onnx-tests README.
Implementing a New Operator
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
implementation of the Squeeze
operation to illustrate points as needed. All file/directory paths
are relative to the root of the burn repository.
Step 1: Visibility
To make a new operation accessible, there are two key modules to update:
- In
crates/onnx-ir/src/node/mod.rs
, add your new operation module to make it visible within the IR - In
crates/burn-import/src/burn/node/mod.rs
, make the corresponding node type visible within burn-import
Step 2: Node Implementation
Within onnx-ir
The onnx-ir
crate handles the Intermediate Representation (IR) of ONNX models. For each operation:
-
Add the operation to the
NodeType
enum incrates/onnx-ir/src/ir.rs
. -
Create a new module file in
crates/onnx-ir/src/node/<operation_name>.rs
. This file should include:- A
<operation_name>_config
function to extract operation parameters - A
<operation_name>_update_output
function for dimension inference
- A
-
If the operation might work with constants, add it to the list of node types checked for constants in
crates/onnx-ir/src/from_onnx.rs
.
For example, the squeeze operation is defined in crates/onnx-ir/src/node/squeeze.rs
and contains:
- A
squeeze_config
function that extracts axes from node attributes - A
squeeze_update_output
function that updates output dimensions by reducing input rank
Within burn-import
-
Create a new file named
<operation_name>.rs
in thecrates/burn-import/src/burn/node/
directory. This file will define the structure and functionality of your new operation. By convention, the necessary information for carrying out an operation is encapsulated within a struct named<operation>Node
. For theSqueeze
operation, we defined a struct calledSqueezeNode
that holds necessary information about the input tensor, output tensor, and axes for the operation. If implementing a unary or binary operation, please see note below. -
The core of integrating a new operation involves implementing the
NodeCodegen
trait for your node. This trait defines how the node generates code during the graph compilation process. The implementation must provide methods to define input and output types, to generate the forward pass code, and to encapsulate the node into the more generalNode
structure. Specifically:output_types
andinput_types
return the tensor (or element) types for the output and inputs of the node, respectively.forward
generates the Rust code that performs the operation during the execution phase. Thequote!
macro is used to generate rust code. Ensure that this is syntactically correct using Burn code.into_node
wraps the specific node in a generalNode
type, facilitating its inclusion in the broader Burn graph structure.
-
This file is also where you would put
test_codegen_nodes()
, to make sure that the generated code works within the Burn library.
For unary and binary operations: The implementation of NodeCodegen
is mostly implemented in
binary.rs and unary.rs, so each new operation only has to define a method to execute the function on
the input(s) token stream.
Step 3: Registering New Operations
- In
crates/burn-import/src/onnx/to_burn.rs
, add the operation to the match statement in theinto_burn()
method:
#![allow(unused)] fn main() { impl ParsedOnnxGraph { pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> { // ... for node in self.0.nodes { match node.node_type { // ... NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), // Add your new operation here } } } } }
- Create a conversion function that creates an instance of your Burn node:
#![allow(unused)] fn main() { fn squeeze_conversion(node: Node) -> SqueezeNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); let axes = squeeze_config(&node); SqueezeNode::new(input, output, axes) } }
This function extracts the necessary information from the ONNX node and passes it to your node's constructor.
Step 4: Create a Config Function
In crates/onnx-ir/src/node/<operation_name>.rs
, create a config function that extracts
operation-specific parameters from the ONNX node:
#![allow(unused)] fn main() { pub fn squeeze_config(curr: &Node) -> Vec<i64> { let axes = curr .attrs .iter() .filter_map(|(key, value)| { if key == "axes" { Some(value.clone().into_i64s()) } else { None } }) .next() .unwrap_or_else(Vec::new); match curr.inputs.first().unwrap().clone().ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor input is valid"), }; axes } }
This config function is responsible for parsing the ONNX node attributes and extracting operation-specific parameters. In this case, it extracts the "axes" attribute from the squeeze operation.
Step 5: Rank Inference
In crates/onnx-ir/src/node/<operation_name>.rs
, implement a rank inference function that updates
the output rank based on the operation:
#![allow(unused)] fn main() { pub fn squeeze_update_output(node: &mut Node) { // Extract axes information let axes = /* ... */; let input_rank = /* ... */; let output_rank = input_rank - axes.len(); // Update output rank node.outputs[0].ty = ArgType::Tensor(TensorType { elem_type: node.inputs[0].ty.elem_type().clone(), rank: output_rank, static_shape: None, }); } }
Then register this function in crates/onnx-ir/src/rank_inference.rs
by adding it to the match
statement:
#![allow(unused)] fn main() { pub fn rank_inference(node: &mut Node) { match node.node_type { // ... NodeType::Squeeze => squeeze_update_output(node), // Add your new operation here } } }
The rank_inference.rs
file is responsible for determining the output tensor rank for each node in
the graph.
If the rank remains unchanged, you can use helper functions like same_as_input()
or
same_as_input_broadcast()
instead of writing a custom update function.
Step 6: Integrate into the Graph Building Process
When a new node type is introduced, it must be added to the Node<PS: PrecisionSettings>
enum in
crates/burn-import/src/burn/node/base.rs
and the match_all!
macro in the same file.
The Node
enum abstracts over different types of operations (nodes) within a network graph. Each
variant of the enum corresponds to a specific type of operation and encapsulates the
operation-specific data structures (like SqueezeNode
) that were defined in step 2.
Step 7: Add Newly Supported Op!
As a reward, add an extra check to crates/burn-import/SUPPORTED-ONNX-OPS.md
!
Lifting Constant Nodes
If your operation takes inputs from constant nodes (such as weights in Conv1d, shape tensors in
Reshape, etc.), you need to add your operation's NodeType
to the LIFT_CONSTANTS_FOR_NODE_TYPES
array in crates/onnx-ir/src/from_onnx.rs
.
#![allow(unused)] fn main() { const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [ NodeType::BatchNormalization, // other operations... NodeType::Squeeze, NodeType::Unsqueeze, // Add your operation here if it needs constants to be processed ]; }
"Lifting" constants means converting Constant nodes into direct input values. This is similar to how ONNX initializers work. For example, instead of having a separate Constant node providing weights to a Convolution operation, the weights are directly embedded as values in the Convolution node's inputs.
This transformation makes it easier to:
- Access the constant values during node configuration
- Process operations like Conv1d that expect weights as direct inputs
- Handle shape-defining inputs needed for operations like Reshape
Without this, operations that need to extract configuration from constant inputs (such as shapes, weights, or other parameters) would not work correctly because they wouldn't have direct access to those constant values.
Testing
When implementing a new operator, there are several levels of testing to consider:
Unit Testing
-
Node Configuration: Write unit tests for the
<operation_name>_config
function incrates/onnx-ir/src/node/<operation_name>.rs
to verify that it correctly extracts parameters from ONNX nodes. -
Rank Inference: Test the
<operation_name>_update_output
function to ensure it correctly computes output ranks. -
Code Generation: Test the Node implementation in
burn-import
to verify that it generates correct Rust code.
Integration Testing
- Create small ONNX models that use your operator and test the end-to-end conversion process
- Ensure the generated Rust code compiles and produces the expected outputs
- Add these tests to
crates/burn-import/onnx-tests/tests/test_onnx.rs
End-to-End Testing
- Test with realistic ONNX models that use your operator in conjunction with others
- Verify that inputs and outputs match between the original ONNX model and the converted Burn model
- Include models that test edge cases (e.g., different input shapes, parameter combinations)
Testing both the rank inference and node configuration is particularly important as these components directly affect the correctness of the conversion process. Incorrect rank inference can lead to mismatched tensor shapes, while incorrect configuration can cause runtime errors or incorrect results.