ONNX to Burn Conversion Tool: Development Guide

This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep learning framework written in Rust. It converts both ONNX models to Rust source code and model weights to Burn state files.

For an introduction to ONNX import in Burn, see this section of the Burn book.

Table of Contents

Design Overview

Design Goals

  • Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
  • Convert ONNX model weights to Burn state files.
  • Support ONNX models generated by PyTorch (ONNX Opset 16).
  • Produce easy-to-understand and modifiable models.
  • Ensure the generated models are trainable using Burn APIs.

Design Decisions

  • Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process.
  • Ensure operator behavior consistency across different OpSet versions.
  • Exclude any ONNX/Protobuf-specific logic from the Burn graph.

The conversion process involves three main stages:

  1. Convert ONNX model to Intermediate Representation (IR).
  2. Translate IR to a Burn graph.
  3. Generate Rust source code from the Burn graph.

Adding New Operators

To extend burn-import with support for new ONNX operators, follow these steps:

  1. Create PyTorch Script: Place a PyTorch script using the new operator under crates/burn-import/onnx-tests/tests/<op>/<op>.py. Make sure to print both input and output tensors for end-to-end testing.

  2. Generate ONNX Model: Run the PyTorch script to produce an ONNX model.

  3. Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.

  4. Generate IR and Burn Graph: Navigate to crates/burn-import/ and run:

    cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
    
  5. Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The ./out/my-model.graph.txt should provide relevant information.

  6. Inspect Generated Files: The my-model.graph.txt contains IR details, my-model.rs holds the Burn model in Rust code, and my-model.json includes the model data.

  7. Add End-to-End Test: Include the test in crates/burn-import/onnx-tests/tests/onnx_tests.rs. Further details can be found in the onnx-tests README.

Implementing a New Operator

To extend the capabilities of the Burn library by supporting new operations imported from ONNX graphs, developers must go through a few systematic steps. Here, we detail the process, using the implementation of the Squeeze operation to illustrate points as needed. All file/directory paths are relative to burn/crates/burn-import/.

Step 1: Visibility

To make a new operation accessible to the rest of the Burn project, you need to declare the module within the mod.rs file located in the src/burn/node/ directory.

Step 2: Node Implementation

Within Onnx-IR

If the node type does not exist within the NodeType enum, it will need to be added (support for custom operators is planned). If the node might be provided an input which is a constant or the output of an identity node, it will need to be added to the list of nodeTypes checked for constants. The node will need to be added to dim_inference, and in most cases the work parsing side will be done. If a node requires extra parsing (such as handling an edge case like potentially remapping an unsqueeze to a reshape) the best place for that is after check constants and prior to dim_inference in OnnxGraphBuilder::Build

Within burn-import

Create a new file named <operation_name>.rs in the src/burn/node/ directory.
This file will define the structure and functionality of your new operation. By convention, the necessary information for carrying out an operation is encapsulated within a struct named <operation>Node. For the Squeeze operation, we defined a struct called SqueezeNode that holds necessary information about the input tensor, output tensor, and axes for the operation. If implementing a unary or binary operation, please see note below.

The core of integrating a new operation involves implementing the NodeCodegen trait for your node. This trait defines how the node generates code during the graph compilation process. The implementation must provide methods to define input and output types, to generate the forward pass code, and to encapsulate the node into the more general Node structure. Specifically:

  • output_types and input_types return the tensor (or element) types for the output and inputs of the node, respectively.
  • forward generates the Rust code that performs the operation during the execution phase. The quote! macro is used to generate rust code. Ensure that this is syntactically correct using Burn code.
  • into_node wraps the specific node in a general Node type, facilitating its inclusion in the broader Burn graph structure.

This file is also where you would put test_codegen_nodes(), to make sure that the generated code works within the Burn library.

For unary and binary operations: The implementation of NodeCodegen is mostly implemented in binary.rs and unary.rs, so each new operation only has to define a method to execute the function on the input(s) token stream.

Step 3: Registering New Operations

Register the NodeType::<operation> and create an <operation>_conversion(node: Node) function, both in src/onnx/to_burn.rs.

Registering new operations in the ONNX -> Burn Conversion
To integrate new operations from an ONNX graph into the Burn framework, each operation must be registered within the ONNX graph conversion process. This is done in the src/onnx/to_burn.rs file, where the conversion from ONNX nodes to Burn nodes is orchestrated.

In the into_burn() method of the OnnxGraph struct, operations are matched with their corresponding conversion functions. This method iterates over each node in the ONNX graph and, depending on the node type, calls a specific conversion function that translates the ONNX node into a corresponding Burn node.

#![allow(unused)]
fn main() {
impl OnnxGraph {
    pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
        let mut graph = BurnGraph::<PS>::default();
        let mut unsupported_ops = vec![];

        for node in self.nodes {
            match node.node_type {
                NodeType::Add => graph.register(Self::add_conversion(node)),
                // Other operations...
                NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
                // Add new operations here
            }
        }
    }
}
}

Here, the NodeType::Squeeze matches the ONNX node type with the squeeze_conversion() function that you define to handle the specific attributes and settings of a Squeeze operation.

Define the Conversion Function
Each operation conversion function extracts necessary information from the ONNX node and constructs a corresponding Burn node. The structure of these functions generally includes:

  1. Extracting input and output tensors from the node.
  2. Retrieving and processing operation-specific configurations.
  3. Calling <operation>_config() to parse ONNX node configurations.
  4. Creating an instance of the appropriate Burn node (defined in step 2) using this information.

Step 4: Create a Config Function

Create an <operation>_config(curr: &Node) in src/onnx/op_configuration.rs.

The squeeze_conversion() function in src/onnx/to_burn.rs from the previous step calls the squeeze_config() function in src/onnx/op_configuration.rs in order the parse the ONNX node's attributes to extract parameters specific to the Squeeze operation. In this case, the axes along which the squeeze operation is performed.

πŸ“˜ Info: Understanding Generic config Patterns

The <op>_config() functions follow a similar pattern:

  1. Extract tensor or scalar types for inputs and outputs.
  2. Validate the input structure and types for each node, ensuring they conform to expected formats (panicking if not).
  3. Parse and convert configurations or parameters specific to each operation.
  4. Create and return a node specific to the operation, initialized with extracted values and configurations.

For example, config functions handle specific settings like kernel size for pooling or handling different tensor and scalar types for power operations.

These functions translate the more varied and flexible structure of ONNX nodes into the more structured and type-safe environment of Rust and the Burn framework. Spec compliance is dealt with here.

Step 5: Dimension Inference

If needed, create a dimension inference function, called <operation>_update_output(node: &mut Node) in src/onnx/dim_inference.rs. If dimensions remain unchanged, use the same_as_input() function, for example NodeType::AveragePool1d => same_as_input(node). Match the NodeType to the function in the dim_inference() match block.

Dimension inference is an important step in the conversion process where Burn determines the dimensions of each output tensor based on the operation. The dim_inference() function is responsible for determining the dimensions of the output tensors for each node in the graph. It does this by:

  1. Matching the Node Type: The function uses a match statement on the node_type of each node to apply the correct dimension inference logic depending on the operation.
  2. Applying Operation Specific Logic: For each operation, a specific inference function is called that encapsulates the rules for how output dimensions should be derived from the inputs.

For the Squeeze operation, the dimension inference is handled by the squeeze_update_output() function, which is specifically tailored to handle the nuances of the squeeze operation, which is currently not that nuanced. The output tensor should be (dimensions of input tensor) - 1.

πŸ“˜ Info: How squeeze_update_output() Works

  1. Validation of axes input: We first check if the second input of the node contains a list of integers, which represent the axes along which the squeeze operation is applied. The function also validates that only one axis is specified for squeezing, ensuring that the operation's requirements within Burn are followed.
  2. Extracting input dimensions: The input tensor's dimension is extracted from the first input.
  3. Configuring output dimensions: The output tensor's dimensions are then set to be one less than the input tensor’s dimensions, reflecting the reduction in dimensions caused by the squeeze operation.
  4. The function includes several checks that throw errors (panics) if the inputs do not meet the expected types or configurations, such as when the axes are not provided as an integer list or if the input type is not a tensor.

By invoking this function within the dim_inference() match block, the output dimensions of each node are updated before the graph is finalized. This ensures that all subsequent operations within the graph can rely on correct tensor sizes, which is critical for both compiling the graph and for runtime execution efficiency.

If something is amiss (ie weird panics are happening), after doing this step and the dimensions of your output tensor differs from the dimensions of your input, see the warning at the very end.

Step 6: Integrate into the Graph Building Process

When a new node type is introduced, it must be added to the Node<PS: PrecisionSettings> enum and match_all! macro in src/burn/node/base.rs.

The Node enum abstracts over different types of operations (nodes) within a network graph. Each variant of the enum corresponds to a specific type of operation, and it encapsulates the operation-specific data structures (like SqueezeNode1) that was defined in step 2.

Step 7: Add Newly Supported Op!

As a reward, add an extra check to SUPPORTED-ONNX-OPS.md!

Misc:

🚧 Warning: Dimension Changes

If your operation changes the dimensions of the input tensor, you may need to modify the LIFT_CONSTANTS_FOR_NODE_TYPES enum in src/onnx/from_onnx.rs by adding the NodeType of your operation to it.

Testing

  • Unit tests for the Burn graph to Rust source code conversion are mandatory.
  • End-to-end tests should include a test ONNX model and its expected output for each operator.

Resources

  1. PyTorch to ONNX
  2. ONNX to PyTorch
  3. ONNX Introduction
  4. ONNX Operators
  5. ONNX Protos
  6. ONNX Optimizer
  7. Netron