ONNX to Burn Conversion Tool: Development Guide
This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep learning framework written in Rust. It converts both ONNX models to Rust source code and model weights to Burn state files.
For an introduction to ONNX import in Burn, see this section of the Burn book.
Table of Contents
- ONNX to Burn Conversion Tool: Development Guide
Design Overview
Design Goals
- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
- Convert ONNX model weights to Burn state files.
- Support ONNX models generated by PyTorch (ONNX Opset 16).
- Produce easy-to-understand and modifiable models.
- Ensure the generated models are trainable using Burn APIs.
Design Decisions
- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process.
- Ensure operator behavior consistency across different OpSet versions.
- Exclude any ONNX/Protobuf-specific logic from the Burn graph.
The conversion process involves three main stages:
- Convert ONNX model to Intermediate Representation (IR).
- Translate IR to a Burn graph.
- Generate Rust source code from the Burn graph.
Adding New Operators
To extend burn-import
with support for new ONNX operators, follow these steps:
-
Create PyTorch Script: Place a PyTorch script using the new operator under
crates/burn-import/onnx-tests/tests/<op>/<op>.py
. Make sure to print both input and output tensors for end-to-end testing. -
Generate ONNX Model: Run the PyTorch script to produce an ONNX model.
-
Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.
-
Generate IR and Burn Graph: Navigate to crates/burn-import/ and run:
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
-
Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The
./out/my-model.graph.txt
should provide relevant information. -
Inspect Generated Files: The
my-model.graph.txt
contains IR details,my-model.rs
holds the Burn model in Rust code, andmy-model.json
includes the model data. -
Add End-to-End Test: Include the test in crates/burn-import/onnx-tests/tests/onnx_tests.rs. Further details can be found in the onnx-tests README.
Implementing a New Operator
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
implementation of the Squeeze
operation to illustrate points as needed. All file/directory paths
are relative to burn/crates/burn-import/
.
Step 1: Visibility
To make a new operation accessible to the rest of the Burn project, you need to declare the module
within the
mod.rs
file
located in the src/burn/node/
directory.
Step 2: Node Implementation
Within Onnx-IR
If the node type does not exist within the
NodeType
enum,
it will need to be added (support for custom operators is planned). If the node might be provided an
input which is a constant or the output of an identity node, it will need to be added to the list of
nodeTypes
checked for constants.
The node will need to be added to dim_inference
, and in most cases the work parsing side will be
done. If a node requires extra parsing (such as handling an edge case like potentially remapping an
unsqueeze to a reshape) the best place for that is after check constants and prior to dim_inference
in
OnnxGraphBuilder::Build
Within burn-import
Create a new file named <operation_name>.rs
in the src/burn/node/
directory.
This file will define the structure and functionality of your new operation. By convention, the
necessary information for carrying out an operation is encapsulated within a struct named
<operation>Node
. For the Squeeze
operation, we defined a
struct called SqueezeNode
that holds necessary information about the input tensor, output tensor, and axes for the operation.
If implementing a unary or binary operation, please see note below.
The core of integrating a new operation involves implementing the NodeCodegen
trait for your node.
This trait defines how the node generates code during the graph compilation process. The
implementation must provide methods to define input and output types, to generate the forward pass
code, and to encapsulate the node into the more general Node
structure. Specifically:
output_types
andinput_types
return the tensor (or element) types for the output and inputs of the node, respectively.forward
generates the Rust code that performs the operation during the execution phase. Thequote!
macro is used to generate rust code. Ensure that this is syntactically correct using Burn code.into_node
wraps the specific node in a generalNode
type, facilitating its inclusion in the broader Burn graph structure.
This file is also where you would put test_codegen_nodes()
, to make sure that the generated code
works within the Burn library.
For unary and binary operations: The implementation of NodeCodegen
is mostly implemented in
binary.rs
and
unary.rs
,
so each new operation only has to define a method to execute the function on the input(s) token
stream.
Step 3: Registering New Operations
Register the NodeType::<operation>
and
create an <operation>_conversion(node: Node)
function,
both in src/onnx/to_burn.rs
.
Registering new operations in the ONNX -> Burn Conversion
To integrate new operations from an ONNX graph into the Burn framework, each operation must be
registered within the ONNX graph conversion process. This is done in the src/onnx/to_burn.rs
file,
where the conversion from ONNX nodes to Burn nodes is orchestrated.
In the into_burn()
method of the OnnxGraph
struct, operations are matched with their
corresponding conversion functions. This method iterates over each node in the ONNX graph and,
depending on the node type, calls a specific conversion function that translates the ONNX node into
a corresponding Burn node.
#![allow(unused)] fn main() { impl OnnxGraph { pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> { let mut graph = BurnGraph::<PS>::default(); let mut unsupported_ops = vec![]; for node in self.nodes { match node.node_type { NodeType::Add => graph.register(Self::add_conversion(node)), // Other operations... NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), // Add new operations here } } } } }
Here, the NodeType::Squeeze
matches the ONNX node type with the squeeze_conversion()
function
that you define to handle the specific attributes and settings of a Squeeze operation.
Define the Conversion Function
Each operation conversion function extracts necessary information from the ONNX node and constructs
a corresponding Burn node. The structure of these functions generally includes:
- Extracting input and output tensors from the node.
- Retrieving and processing operation-specific configurations.
- Calling
<operation>_config()
to parse ONNX node configurations. - Creating an instance of the appropriate Burn node (defined in step 2) using this information.
Step 4: Create a Config Function
Create an <operation>_config(curr: &Node)
in src/onnx/op_configuration.rs
.
The squeeze_conversion()
function in src/onnx/to_burn.rs
from the previous step calls the
squeeze_config()
function in src/onnx/op_configuration.rs
in order the parse the ONNX node's
attributes to extract parameters specific to the Squeeze operation. In this case, the axes along
which the squeeze operation is performed.
π Info: Understanding Generic
config
PatternsThe
<op>_config()
functions follow a similar pattern:
- Extract tensor or scalar types for inputs and outputs.
- Validate the input structure and types for each node, ensuring they conform to expected formats (panicking if not).
- Parse and convert configurations or parameters specific to each operation.
- Create and return a node specific to the operation, initialized with extracted values and configurations.
For example, config functions handle specific settings like kernel size for pooling or handling different tensor and scalar types for power operations.
These functions translate the more varied and flexible structure of ONNX nodes into the more structured and type-safe environment of Rust and the Burn framework. Spec compliance is dealt with here.
Step 5: Dimension Inference
If needed,
create a dimension inference function,
called <operation>_update_output(node: &mut Node)
in src/onnx/dim_inference.rs
. If dimensions
remain unchanged, use the same_as_input()
function, for example
NodeType::AveragePool1d => same_as_input(node)
. Match the NodeType
to the function in the
dim_inference()
match block.
Dimension inference is an important step in the conversion process where Burn determines the
dimensions of each output tensor based on the operation.
The dim_inference()
function is responsible for determining the dimensions of the output tensors for each node in the
graph. It does this by:
- Matching the Node Type: The function uses a
match
statement on thenode_type
of each node to apply the correct dimension inference logic depending on the operation. - Applying Operation Specific Logic: For each operation, a specific inference function is called that encapsulates the rules for how output dimensions should be derived from the inputs.
For the Squeeze operation, the dimension inference is handled by the squeeze_update_output()
function, which is specifically tailored to handle the nuances of the squeeze operation, which is
currently not that nuanced. The output tensor should be (dimensions of input tensor) - 1.
π Info: How
squeeze_update_output()
Works
- Validation of axes input: We first check if the second input of the node contains a list of integers, which represent the axes along which the squeeze operation is applied. The function also validates that only one axis is specified for squeezing, ensuring that the operation's requirements within Burn are followed.
- Extracting input dimensions: The input tensor's dimension is extracted from the first input.
- Configuring output dimensions: The output tensor's dimensions are then set to be one less than the input tensorβs dimensions, reflecting the reduction in dimensions caused by the squeeze operation.
- The function includes several checks that throw errors (panics) if the inputs do not meet the expected types or configurations, such as when the axes are not provided as an integer list or if the input type is not a tensor.
By invoking this function within the dim_inference()
match block, the output dimensions of each
node are updated before the graph is finalized. This ensures that all subsequent operations within
the graph can rely on correct tensor sizes, which is critical for both compiling the graph and for
runtime execution efficiency.
If something is amiss (ie weird panics are happening), after doing this step and the dimensions of your output tensor differs from the dimensions of your input, see the warning at the very end.
Step 6: Integrate into the Graph Building Process
When a new node type is introduced, it must be added to the
Node<PS: PrecisionSettings>
enum
and
match_all!
macro
in src/burn/node/base.rs
.
The Node
enum abstracts over different types of operations (nodes) within a network graph. Each
variant of the enum corresponds to a specific type of operation, and it encapsulates the
operation-specific data structures (like SqueezeNode1
) that was
defined in step 2.
Step 7: Add Newly Supported Op!
As a reward, add an extra check to SUPPORTED-ONNX-OPS.md!
Misc:
π§ Warning: Dimension Changes
If your operation changes the dimensions of the input tensor, you may need to modify the
LIFT_CONSTANTS_FOR_NODE_TYPES
enum insrc/onnx/from_onnx.rs
by adding theNodeType
of your operation to it.
Testing
- Unit tests for the Burn graph to Rust source code conversion are mandatory.
- End-to-end tests should include a test ONNX model and its expected output for each operator.