Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

ONNX to Burn: Development Guide

This guide offers in-depth design insights and step-by-step procedures for developers working on the ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep learning framework written in Rust. It converts both ONNX models to Rust source code and model weights to Burn state files.

For an introduction to ONNX import in Burn, see this section of the Burn book.

Design Overview

Design Goals

  • Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
  • Convert ONNX model weights to Burn state files.
  • Support ONNX models generated by PyTorch (ONNX Opset 16+).
  • Produce easy-to-understand and modifiable models.
  • Ensure the generated models are trainable using Burn APIs.

Design Decisions

  • Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process.
  • Ensure operator behavior consistency across different OpSet versions.
  • Exclude any ONNX/Protobuf-specific logic from the Burn graph.

The conversion process involves three main stages:

  1. Convert ONNX model to Intermediate Representation (IR).
  2. Translate IR to a Burn graph.
  3. Generate Rust source code from the Burn graph.

Adding New Operators

To extend burn-import with support for new ONNX operators, follow these steps:

  1. Create PyTorch Script: Place a PyTorch script using the new operator under crates/burn-import/onnx-tests/tests/<op>/<op>.py. Make sure to print both input and output tensors for end-to-end testing.

  2. Generate ONNX Model: Run the PyTorch script to produce an ONNX model.

  3. Visualize ONNX Model: Use Netron to verify the ONNX model contains the expected operators.

  4. Generate IR and Burn Graph: Navigate to crates/burn-import/ and run:

    cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
    
  5. Implement Missing Operators: If you encounter an error stating that an operator is unsupported, implement it. The ./out/my-model.graph.txt should provide relevant information.

  6. Inspect Generated Files: The my-model.graph.txt contains IR details, my-model.rs holds the Burn model in Rust code, and my-model.json includes the model data.

  7. Add End-to-End Test: Include the test in crates/burn-import/onnx-tests/tests/test_onnx.rs. Further details can be found in the onnx-tests README.

Implementing a New Operator

To extend the capabilities of the Burn library by supporting new operations imported from ONNX graphs, developers must go through a few systematic steps. Here, we detail the process, using the implementation of the Squeeze operation to illustrate points as needed. All file/directory paths are relative to the root of the burn repository.

Step 1: Node Implementation in onnx-ir

The onnx-ir crate handles the Intermediate Representation (IR) of ONNX models. For each operation:

  1. Add the operation to the NodeType enum in crates/onnx-ir/src/ir.rs.

  2. Create a new module file in crates/onnx-ir/src/node/<operation_name>.rs. This file should include:

    • A <operation_name>_config function to extract operation parameters
    • A <operation_name>_update_output function for dimension inference
  3. Make the module visible in crates/onnx-ir/src/node/mod.rs.

  4. If the operation might work with constants, add it to the list of node types checked for constants in crates/onnx-ir/src/from_onnx.rs.

For example, the squeeze operation is defined in crates/onnx-ir/src/node/squeeze.rs and contains:

  • A squeeze_config function that extracts axes from node attributes
  • A squeeze_update_output function that updates output dimensions by reducing input rank

Step 2: Node Implementation in burn-import

  1. Create a new file named <operation_name>.rs in the crates/burn-import/src/burn/node/ directory. This file will define the structure and functionality of your new operation. By convention, the necessary information for carrying out an operation is encapsulated within a struct named <operation>Node. For the Squeeze operation, we defined a struct called SqueezeNode that holds necessary information about the input tensor, output tensor, and axes for the operation.

  2. Implement the OnnxIntoNode trait for your node. This trait has a single method from_onnx that converts an ONNX IR node into your Burn node type:

    #![allow(unused)]
    fn main() {
    impl OnnxIntoNode for SqueezeNode {
        fn from_onnx(node: onnx_ir::Node) -> Self {
            let input = TensorType::from(node.inputs.first().unwrap());
            let output = TensorType::from(node.outputs.first().unwrap());
            let axes = squeeze_config(&node);
    
            SqueezeNode::new(input, output, axes)
        }
    }
    }
  3. The core of integrating a new operation involves implementing the NodeCodegen trait for your node. This trait defines how the node generates code during the graph compilation process. The implementation must provide methods to define input and output types, to generate the forward pass code, and to encapsulate the node into the more general Node structure. Specifically:

    • input_types() - Returns the types of all input arguments
    • output_types() - Returns the types of all output values
    • forward() - Generates the Rust code that performs the operation during execution. Use the quote! macro to generate code and ensure it's syntactically correct Burn code
    • into_node() - Wraps the node into the general Node<PS> enum

    Example implementation:

    #![allow(unused)]
    fn main() {
    impl<PS: PrecisionSettings> NodeCodegen<PS> for SqueezeNode {
        fn input_types(&self) -> Vec<Type> {
            vec![self.input.clone()]
        }
    
        fn output_types(&self) -> Vec<Type> {
            vec![self.output.clone()]
        }
    
        fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
            // Simplified example - actual implementation handles more cases
            let input = scope.tensor_use_owned(&self.input, node_position);
            let output = &self.output.name();
    
            match &self.axes {
                Some(axes) => {
                    let axes_tokens = axes.to_tokens();
                    quote! {
                        let #output = #input.squeeze_dims(&#axes_tokens);
                    }
                }
                None => {
                    let output_rank = self.output.rank();
                    quote! {
                        let #output = #input.squeeze::<#output_rank>();
                    }
                }
            }
        }
    
        fn into_node(self) -> Node<PS> {
            Node::Squeeze(self)
        }
    }
    }
  4. Add unit tests in the same file to verify the generated code compiles and works correctly. These tests typically call a helper function like assert_tokens() to validate the generated code against expected output.

Step 3: Register in Module System

Add the module declaration to crates/burn-import/src/burn/node/mod.rs:

#![allow(unused)]
fn main() {
// ... other node modules
pub(crate) mod squeeze;
// ... more node modules
}

The modules are automatically made visible through re-exports in the same file.

Step 4: Register in Node Registry

Add your operation to the node registry in crates/burn-import/src/burn/node_registry.rs. This is the single source of truth for all ONNX node conversions.

For a single ONNX operation mapping to a single node type:

#![allow(unused)]
fn main() {
node_registry! {
    // ... other operations
    Squeeze => squeeze as SqueezeNode,
    // ... more operations
}
}

For multiple ONNX operations mapping to the same node type (e.g., various reduce operations):

#![allow(unused)]
fn main() {
node_registry! {
    // ... other operations
    [ReduceMax, ReduceMin, ReduceMean, ReduceProd, ReduceSum]
        => ReduceMax: reduce as ReduceNode,
    // ... more operations
}
}

That's it! The registry automatically generates:

  • The Node<PS> enum with your operation as a variant
  • The match_all! macro for dispatching
  • The ONNX to Burn conversion logic
  • All necessary imports

Step 5: Rank Inference

In crates/onnx-ir/src/node/<operation_name>.rs, implement a rank inference function that updates the output rank based on the operation:

#![allow(unused)]
fn main() {
pub fn squeeze_update_output(node: &mut Node) {
    // Extract axes information
    let axes = /* ... */;
    let input_rank = /* ... */;
    let output_rank = input_rank - axes.len();

    // Update output rank
    node.outputs[0].ty = ArgType::Tensor(TensorType {
        elem_type: node.inputs[0].ty.elem_type().clone(),
        rank: output_rank,
        static_shape: None,
    });
}
}

Then register this function in crates/onnx-ir/src/rank_inference.rs by adding it to the match statement:

#![allow(unused)]
fn main() {
pub fn rank_inference(node: &mut Node) {
    match node.node_type {
        // ...
        NodeType::Squeeze => squeeze_update_output(node),
        // Add your new operation here
    }
}
}

The rank_inference.rs file is responsible for determining the output tensor rank for each node in the graph.

If the rank remains unchanged, you can use helper functions like same_as_input() or same_as_input_broadcast() instead of writing a custom update function.

Step 6: Add Newly Supported Op!

As a reward, add an extra check to crates/burn-import/SUPPORTED-ONNX-OPS.md!

Lifting Constant Nodes

If your operation takes inputs from constant nodes (such as weights in Conv1d, shape tensors in Reshape, etc.), you need to add your operation's NodeType to the LIFT_CONSTANTS_FOR_NODE_TYPES array in crates/onnx-ir/src/from_onnx.rs.

#![allow(unused)]
fn main() {
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 16] = [
    NodeType::BatchNormalization,
    // other operations...
    NodeType::Squeeze,
    NodeType::Unsqueeze,
    // Add your operation here if it needs constants to be processed
];
}

"Lifting" constants means converting Constant nodes into direct input values. This is similar to how ONNX initializers work. For example, instead of having a separate Constant node providing weights to a Convolution operation, the weights are directly embedded as values in the Convolution node's inputs.

This transformation makes it easier to:

  1. Access the constant values during node configuration
  2. Process operations like Conv1d that expect weights as direct inputs
  3. Handle shape-defining inputs needed for operations like Reshape

Without this, operations that need to extract configuration from constant inputs (such as shapes, weights, or other parameters) would not work correctly because they wouldn't have direct access to those constant values.

Architecture Overview

The burn-import crate is organized into several key modules:

Core Modules

  • burn/node_registry.rs: Master registry containing all ONNX node mappings. This is a declarative macro that auto-generates the Node enum, conversion functions, and dispatch logic.

  • burn/node_codegen.rs: Contains the NodeCodegen and OnnxIntoNode traits that all nodes must implement. Also includes code generation utilities.

  • burn/node/: Directory containing individual node implementations. Each file implements a specific ONNX operation.

  • burn/graph.rs: Burn graph representation and code generation.

  • burn/ty.rs: Type system for tensors, scalars, and shapes, including conversions from ONNX types.

  • onnx/model_gen.rs: Public API (ModelGen) for converting ONNX models to Burn code.

Testing

When implementing a new operator, there are several levels of testing to consider:

Unit Testing

  • Node Configuration: Write unit tests for the <operation_name>_config function in crates/onnx-ir/src/node/<operation_name>.rs to verify that it correctly extracts parameters from ONNX nodes.

  • Rank Inference: Test the <operation_name>_update_output function to ensure it correctly computes output ranks.

  • Code Generation: Test the Node implementation in burn-import to verify that it generates correct Rust code. Each node file typically includes a test_codegen_nodes() function.

Integration Testing

  • Create small ONNX models that use your operator and test the end-to-end conversion process
  • Ensure the generated Rust code compiles and produces the expected outputs
  • Add these tests to crates/burn-import/onnx-tests/tests/test_onnx.rs

End-to-End Testing

  • Test with realistic ONNX models that use your operator in conjunction with others
  • Verify that inputs and outputs match between the original ONNX model and the converted Burn model
  • Include models that test edge cases (e.g., different input shapes, parameter combinations)

Testing both the rank inference and node configuration is particularly important as these components directly affect the correctness of the conversion process. Incorrect rank inference can lead to mismatched tensor shapes, while incorrect configuration can cause runtime errors or incorrect results.

Resources

  1. PyTorch to ONNX
  2. ONNX to PyTorch
  3. ONNX Introduction
  4. ONNX Operators
  5. ONNX Protos
  6. ONNX Optimizer
  7. Netron