Overview

Welcome to The Burn Book πŸ‘‹

This book will help you get started with the Burn deep learning framework, whether you are an advanced user or a beginner. We have crafted some sections for you:

  • Basic Workflow: From Training to Inference: We'll start with the fundamentals, guiding you through the entire workflow, from training your models to deploying them for inference. This section lays the groundwork for your Burn expertise.

  • Building Blocks: Dive deeper into Burn's core components, understanding how they fit together. This knowledge forms the basis for more advanced usage and customization.

  • Saving & Loading Models: Learn how to easily save and load your trained models.

  • Custom Training Loop: Gain the power to customize your training loops, fine-tuning your models to meet your specific requirements. This section empowers you to harness Burn's flexibility to its fullest.

  • Importing Models: Learn how to import ONNX and PyTorch models, expanding your compatibility with other deep learning ecosystems.

  • Advanced: Finally, venture into advanced topics, exploring Burn's capabilities at their peak. This section caters to those who want to push the boundaries of what's possible with Burn.

Throughout the book, we assume a basic understanding of deep learning concepts, but we may refer to additional material when it seems appropriate.

Why Burn?

Why bother with the effort of creating an entirely new deep learning framework from scratch when PyTorch, TensorFlow, and other frameworks already exist? Spoiler alert: Burn isn't merely a replication of PyTorch or TensorFlow in Rust. It represents a novel approach, placing significant emphasis on making the right compromises in the right areas to facilitate exceptional flexibility, high performance, and a seamless developer experience. Burn isn’t a framework specialized for only one type of application, it is designed to serve as a versatile framework suitable for a wide range of research and production uses. The foundation of Burn's design revolves around three key user profiles:

Machine Learning Researchers require tools to construct and execute experiments efficiently. It’s essential for them to iterate quickly on their ideas and design testable experiments which can help them discover new findings. The framework should facilitate the swift implementation of cutting-edge research while ensuring fast execution for testing.

Machine Learning Engineers are another important demographic to keep in mind. Their focus leans less on swift implementation and more on establishing robustness, seamless deployment, and cost-effective operations. They seek dependable, economical models capable of achieving objectives without excessive expense. The whole machine learning workflow β€”from training to inferenceβ€” must be as efficient as possible with minimal unpredictable behavior.

Low level Software Engineers working with hardware vendors want their processing units to run models as fast as possible to gain competitive advantage. This endeavor involves harnessing hardware-specific features such as Tensor Core for Nvidia. Since they are mostly working at a system level, they want to have absolute control over how the computation will be executed.

The goal of Burn is to satisfy all of those personas!

Getting Started

Burn is a deep learning framework in the Rust programming language. Therefore, it goes without saying that one must understand the basic notions of Rust. Reading the first chapters of the Rust Book is recommended, but don't worry if you're just starting out. We'll try to provide as much context and reference to external resources when required. Just look out for the πŸ¦€ Rust Note indicators.

Installing Rust

For installation instructions, please refer to the installation page. It explains in details the most convenient way for you to install Rust on your computer, which is the very first thing to do to start using Burn.

Creating a Burn application

Once Rust is correctly installed, create a new Rust application by using Rust's build system and package manager Cargo. It is automatically installed with Rust.

πŸ¦€ Cargo Cheat Sheet

Cargo is a very useful tool to manage Rust projects because it handles a lot of tasks. More precisely, it is used to compile your code, download the libraries/packages your code depends on, and build said libraries.

Below is a quick cheat sheet of the main cargo commands you might use throughout this guide.

CommandDescription
cargo new pathCreate a new Cargo package in the given directory.
cargo add crateAdd dependencies to the Cargo.toml manifest file.
cargo buildCompile the local package and all of its dependencies (in debug mode, use -r for release).
cargo checkCheck the local package for compilation errors (much faster).
cargo runRun the local package binary.

For more information, check out Hello, Cargo! in the Rust Book.


In the directory of your choice, run the following:

cargo new my_burn_app

This will initialize the my_burn_app project directory with a Cargo.toml file and a src directory with an auto-generated main.rs file inside. Head inside the directory to check:

cd my_burn_app

Then, add Burn as a dependency:

cargo add burn --features wgpu

Finally, compile the local package by executing the following:

cargo build

That's it, you're ready to start! You have a project configured with Burn and the WGPU backend, which allows to execute low-level operations on any platform using the GPU.

Writing a code snippet

The src/main.rs was automatically generated by Cargo, so let's replace its content with the following:

use burn::tensor::Tensor;
use burn::backend::Wgpu;

// Type alias for the backend to use.
type Backend = Wgpu;

fn main() {
    let device = Default::default();
    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
    let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);
    let tensor_2 = Tensor::<Backend, 2>::ones_like(&tensor_1);

    // Print the element-wise addition (done with the WGPU backend) of the two tensors.
    println!("{}", tensor_1 + tensor_2);
}
πŸ¦€ Use Declarations

To bring any of the Burn module or item into scope, a use declaration is added.

In the example above, we wanted bring the Tensor struct and Wgpu backend into scope with the following:

use burn::tensor::Tensor;
use burn::backend::Wgpu;

This is pretty self-explanatory in this case. But, the same declaration could be written as a shortcut to simultaneously binding of multiple paths with a common prefix:

use burn::{tensor::Tensor, backend::backend::Wgpu};

In this example, the common prefix is pretty short and there are only two items to bind locally. Therefore, the first usage with two use declarations might be preferred. But know that both examples are valid. For more details on the use keyword, take a look at this section of the Rust Book or the Rust reference.


πŸ¦€ Generic Data Types

If you're new to Rust, you're probably wondering why we had to use Tensor::<Backend, 2>::.... That's because the Tensor struct is generic over multiple concrete data types. More specifically, a Tensor can be defined using three generic parameters: the backend, the number of dimensions (rank) and the data type (defaults to Float). Here, we only specify the backend and number of dimensions since a Float tensor is used by default. For more details on the Tensor struct, take a look at this section.

Most of the time when generics are involved, the compiler can infer the generic parameters automatically. In this case, the compiler needs a little help. This can usually be done in one of two ways: providing a type annotation or binding the gereneric parameter via the turbofish ::<> syntax. In the example above we used the so-called turbofish syntax, but we could have used type annotations instead like this:

let tensor_1: Tensor<Backend, 2> = Tensor::from_data([[2., 3.], [4., 5.]]);
let tensor_2 = Tensor::ones_like(&tensor_1);

You probably noticed that we provided a type annotation for the first tensor only and yet this example still works. That's because the compiler (correctly) inferred that tensor_2 had the same generic parameters. The same could have been done in the original example, but specifying the parameters for both is more explicit.


By running cargo run, you should now see the result of the addition:

Tensor {
  data:
[[3.0, 4.0],
 [5.0, 6.0]],
  shape:  [2, 2],
  device:  BestAvailable,
  backend:  "wgpu",
  kind:  "Float",
  dtype:  "f32",
}

While the previous example is somewhat trivial, the upcoming basic workflow section will walk you through a much more relevant example for deep learning applications.

Using prelude

Burn comes with a variety of things in its core library. When creating a new model or using an existing one for inference, you may need to import every single component you used, which could be a little verbose.

To address it, a prelude module is provided, allowing you to easily import commonly used structs and macros as a group:

use burn::prelude::*;

which is equal to:

use burn::{
    config::Config,
    module::Module,
    nn,
    tensor::{
        backend::Backend, Bool, Device, ElementConversion, Float, Int, Shape, Tensor,
        TensorData,
    },
};

For the sake of simplicity, the subsequent chapters of this book will all use this form of importing except in the Building Blocks chapter, as explicit importing aids users in grasping the usage of particular structures and macros.

Examples

In the next chapter you'll have the opportunity to implement the whole Burn guide example yourself in a step by step manner.

Many additional Burn examples are available in the examples directory. Burn examples are organized as library crates with one or more examples that are executable binaries. An example can then be executed using the following cargo command line in the root of the Burn repository:

cargo run --example <example name>

To learn more about crates and examples, read the Rust section below.

πŸ¦€ About Rust crates

Each Burn example is a package which are subdirectories of the examples directory. A package is composed of one or more crates.

A package is a bundle of one or more crates that provides a set of functionality. A package contains a Cargo.toml file that describes how to build those crates.

A crate is a compilation unit in Rust. It could be a single file, but it is often easier to split up crates into multiple modules.

A module lets us organize code within a crate for readability and easy reuse. Modules also allow us to control the privacy of items. For instance the pub(crate) keyword is employed to make a module publicly available inside the crate. In the snippet below there are four modules declared, two of them are public and visible to the users of the crates, one of them is public inside the crate only and crate users cannot see it, at last one is private when there is no keyword. These modules can be single files or a directory with a mod.rs file inside.

pub mod data;
pub mod inference;
pub(crate) mod model;
mod training;

A crate can come in one of two forms: a binary crate or a library crate. When compiling a crate, the compiler first looks in the crate root file (src/lib.rs for a library crate and src/main.rs for a binary crate). Any module declared in the crate root file will be inserted in the crate for compilation.

All Burn examples are library crates and they can contain one or more executable examples that uses the library. We even have some Burn examples that uses the library crate of other examples.

The examples are unique files under the examples directory. Each file produces an executable file with the same name. Each example can then be executed with cargo run --example <executable name>.

Below is a file tree of a typical Burn example package:

examples/burn-example
β”œβ”€β”€ Cargo.toml
β”œβ”€β”€ examples
β”‚   β”œβ”€β”€ example1.rs      ---> compiled to example1 binary
β”‚   β”œβ”€β”€ example2.rs      ---> compiled to example2 binary
β”‚   └── ...
└── src
    β”œβ”€β”€ lib.rs           ---> this is the root file for a library
    β”œβ”€β”€ module1.rs
    β”œβ”€β”€ module2.rs
    └── ...

The following additional examples are currently available if you want to check them out:

ExampleDescription
Custom CSV DatasetImplements a dataset to parse CSV data for a regression task.
RegressionTrains a simple MLP on the California Housing dataset to predict the median house value for a district.
Custom Image DatasetTrains a simple CNN on custom image dataset following a simple folder structure.
Custom RendererImplements a custom renderer to display the Learner progress.
Image Classification WebImage classification web browser demo using Burn, WGPU and WebAssembly.
MNIST Inference on WebAn interactive MNIST inference demo in the browser. The demo is available online.
MNIST TrainingDemonstrates how to train a custom Module (MLP) with the Learner configured to log metrics and keep training checkpoints.
Named TensorPerforms operations with the experimental NamedTensor feature.
ONNX Import InferenceImports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn.
PyTorch Import InferenceImports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn.
Text ClassificationTrains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample.
Text GenerationTrains a text generation transformer model on the DbPedia dataset.

For more information on each example, see their respective README.md file. Be sure to check out the examples directory for an up-to-date list.

Note that some examples use the datasets library by HuggingFace to download the datasets required in the examples. This is a Python library, which means that you will need to install Python before running these examples. This requirement will be clearly indicated in the example's README when applicable.

Guide

This guide will walk you through the process of creating a custom model built with Burn. We will train a simple convolutional neural network model on the MNIST dataset and prepare it for inference.

For clarity, we sometimes omit imports in our code snippets. For more details, please refer to the corresponding code in the examples/guide directory. We reproduce this example in a step-by-step fashion, from dataset creation to modeling and training in the following sections. It is recommended to use the capabilities of your IDE or text editor to automatically add the missing imports as you add the code snippets to your code.

Be sure to checkout the git branch corresponding to the version of Burn you are using to follow this guide.

The current version of Burn is 0.15 and the corresponding branch to checkout is release/0.15.

The code for this demo can be executed from Burn's base directory using the command:

cargo run --example guide

Key Learnings

  • Creating a project
  • Creating neural network models
  • Importing and preparing datasets
  • Training models on data
  • Choosing a backend
  • Using a model for inference

Model

The first step is to create a project and add the different Burn dependencies. Start by creating a new project with Cargo:

cargo new guide

As mentioned previously, this will initialize your guide project directory with a Cargo.toml and a src/main.rs file.

In the Cargo.toml file, add the burn dependency with train, wgpu and vision features. Then run cargo build to build the project and import all the dependencies.

[package]
name = "guide"
version = "0.1.0"
edition = "2021"

[dependencies]
burn = { version = "~0.15", features = ["train", "wgpu", "vision"] }

Our goal will be to create a basic convolutional neural network used for image classification. We will keep the model simple by using two convolution layers followed by two linear layers, some pooling and ReLU activations. We will also use dropout to improve training performance.

Let us start by defining our model struct in a new file src/model.rs.

use burn::{
    nn::{
        conv::{Conv2d, Conv2dConfig},
        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
        Dropout, DropoutConfig, Linear, LinearConfig, Relu,
    },
    prelude::*,
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    conv1: Conv2d<B>,
    conv2: Conv2d<B>,
    pool: AdaptiveAvgPool2d,
    dropout: Dropout,
    linear1: Linear<B>,
    linear2: Linear<B>,
    activation: Relu,
}

There are two major things going on in this code sample.

  1. You can create a deep learning module with the #[derive(Module)] attribute on top of a struct. This will generate the necessary code so that the struct implements the Module trait. This trait will make your module both trainable and (de)serializable while adding related functionalities. Like other attributes often used in Rust, such as Clone, PartialEq or Debug, each field within the struct must also implement the Module trait.

    πŸ¦€ Trait

    Traits are a powerful and flexible Rust language feature. They provide a way to define shared behavior for a particular type, which can be shared with other types.

    A type's behavior consists of the methods called on that type. Since all Modules should implement the same functionality, it is defined as a trait. Implementing a trait on a particular type usually requires the user to implement the defined behaviors of the trait for their types, though that is not the case here as explained above with the derive attribute. Check out the explainer below to learn why.

    For more details on traits, take a look at the associated chapter in the Rust Book.


    πŸ¦€ Derive Macro

    The derive attribute allows traits to be implemented easily by generating code that will implement a trait with its own default implementation on the type that was annotated with the derive syntax.

    This is accomplished through a feature of Rust called procedural macros, which allow us to run code at compile time that operates over Rust syntax, both consuming and producing Rust syntax. Using the attribute #[my_macro], you can effectively extend the provided code. You will see that the derive macro is very frequently employed to recursively implement traits, where the implementation consists of the composition of all fields.

    In this example, we want to derive the Module and Debug traits.

    #[derive(Module, Debug)]
    pub struct MyCustomModule<B: Backend> {
        linear1: Linear<B>,
        linear2: Linear<B>,
        activation: Relu,
    }

    The basic Debug implementation is provided by the compiler to format a value using the {:?} formatter. For ease of use, the Module trait implementation is automatically handled by Burn so you don't have to do anything special. It essentially acts as parameter container.

    For more details on derivable traits, take a look at the Rust appendix, reference or example.


  2. Note that the struct is generic over the Backend trait. The backend trait abstracts the underlying low level implementations of tensor operations, allowing your new model to run on any backend. Contrary to other frameworks, the backend abstraction isn't determined by a compilation flag or a device type. This is important because you can extend the functionalities of a specific backend (see backend extension section), and it allows for an innovative autodiff system. You can also change backend during runtime, for instance to compute training metrics on a cpu backend while using a gpu one only to train the model. In our example, the backend in use will be determined later on.

    πŸ¦€ Trait Bounds

    Trait bounds provide a way for generic items to restrict which types are used as their parameters. The trait bounds stipulate what functionality a type implements. Therefore, bounding restricts the generic to types that conform to the bounds. It also allows generic instances to access the methods of traits specified in the bounds.

    For a simple but concrete example, check out the Rust By Example on bounds.

    In Burn, the Backend trait enables you to run tensor operations using different implementations as it abstracts tensor, device and element types. The getting started example illustrates the advantage of having a simple API that works for different backend implementations. While it used the WGPU backend, you could easily swap it with any other supported backend.

    // Choose from any of the supported backends.
    // type Backend = Candle<f32, i64>;
    // type Backend = LibTorch<f32>;
    // type Backend = NdArray<f32>;
    type Backend = Wgpu;
    
    // Creation of two tensors.
    let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);
    let tensor_2 = Tensor::<Backend, 2>::ones_like(&tensor_1);
    
    // Print the element-wise addition (done with the selected backend) of the two tensors.
    println!("{}", tensor_1 + tensor_2);

    For more details on trait bounds, check out the Rust trait bound section or reference.


Note that each time you create a new file in the src directory you also need to explicitly add this module to the main.rs file. For instance after creating the model.rs, you need to add the following at the top of the main file:

mod model;

fn main() {
}

Next, we need to instantiate the model for training.

use burn::{
    nn::{
        conv::{Conv2d, Conv2dConfig},
        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
        Dropout, DropoutConfig, Linear, LinearConfig, Relu,
    },
    prelude::*,
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    conv1: Conv2d<B>,
    conv2: Conv2d<B>,
    pool: AdaptiveAvgPool2d,
    dropout: Dropout,
    linear1: Linear<B>,
    linear2: Linear<B>,
    activation: Relu,
}

#[derive(Config, Debug)]
pub struct ModelConfig {
    num_classes: usize,
    hidden_size: usize,
    #[config(default = "0.5")]
    dropout: f64,
}

impl ModelConfig {
    /// Returns the initialized model.
    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
        Model {
            conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
            conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
            pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
            activation: Relu::new(),
            linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
            linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
            dropout: DropoutConfig::new(self.dropout).init(),
        }
    }
}

At a glance, you can view the model configuration by printing the model instance:

use crate::model::ModelConfig;
use burn::backend::Wgpu;

fn main() {
    type MyBackend = Wgpu<f32, i32>;

    let device = Default::default();
    let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);

    println!("{}", model);
}

Output:

Model {
  conv1: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80}
  conv2: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168}
  pool: AdaptiveAvgPool2d {output_size: [8, 8]}
  dropout: Dropout {prob: 0.5}
  linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800}
  linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130}
  activation: Relu
  params: 531178
}
πŸ¦€ References

In the previous example, the init() method signature uses & to indicate that the parameter types are references: &self, a reference to the current receiver (ModelConfig), and device: &B::Device, a reference to the backend device.

pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
    Model {
        // ...
    }
}

References in Rust allow us to point to a resource to access its data without owning it. The idea of ownership is quite core to Rust and is worth reading up on.

In a language like C, memory management is explicit and up to the programmer, which means it is easy to make mistakes. In a language like Java or Python, memory management is automatic with the help of a garbage collector. This is very safe and straightforward, but also incurs a runtime cost.

In Rust, memory management is rather unique. Aside from simple types that implement Copy (e.g., primitives like integers, floats, booleans and char), every value is owned by some variable called the owner. Ownership can be transferred from one variable to another and sometimes a value can be borrowed. Once the owner variable goes out of scope, the value is dropped, which means that any memory it allocated can be freed safely.

Because the method does not own the self and device variables, the values the references point to will not be dropped when the reference stops being used (i.e., the scope of the method).

For more information on references and borrowing, be sure to read the corresponding chapter in the Rust Book.


When creating a custom neural network module, it is often a good idea to create a config alongside the model struct. This allows you to define default values for your network, thanks to the Config attribute. The benefit of this attribute is that it makes the configuration serializable, enabling you to painlessly save your model hyperparameters, enhancing your experimentation process. Note that a constructor will automatically be generated for your configuration, which will take in as input values the parameters which do not have default values: let config = ModelConfig::new(num_classes, hidden_size);. The default values can be overridden easily with builder-like methods: (e.g config.with_dropout(0.2);)

The first implementation block is related to the initialization method. As we can see, all fields are set using the configuration of the corresponding neural network's underlying module. In this specific case, we have chosen to expand the tensor channels from 1 to 8 with the first layer, then from 8 to 16 with the second layer, using a kernel size of 3 on all dimensions. We also use the adaptive average pooling module to reduce the dimensionality of the images to an 8 by 8 matrix, which we will flatten in the forward pass to have a 1024 (16 _ 8 _ 8) resulting tensor.

Now let's see how the forward pass is defined.

use burn::{
    nn::{
        conv::{Conv2d, Conv2dConfig},
        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
        Dropout, DropoutConfig, Linear, LinearConfig, Relu,
    },
    prelude::*,
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    conv1: Conv2d<B>,
    conv2: Conv2d<B>,
    pool: AdaptiveAvgPool2d,
    dropout: Dropout,
    linear1: Linear<B>,
    linear2: Linear<B>,
    activation: Relu,
}

#[derive(Config, Debug)]
pub struct ModelConfig {
    num_classes: usize,
    hidden_size: usize,
    #[config(default = "0.5")]
    dropout: f64,
}

impl ModelConfig {
    /// Returns the initialized model.
    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
        Model {
            conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
            conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
            pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
            activation: Relu::new(),
            linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
            linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
            dropout: DropoutConfig::new(self.dropout).init(),
        }
    }
}

impl<B: Backend> Model<B> {
    /// # Shapes
    ///   - Images [batch_size, height, width]
    ///   - Output [batch_size, num_classes]
    pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {
        let [batch_size, height, width] = images.dims();

        // Create a channel at the second dimension.
        let x = images.reshape([batch_size, 1, height, width]);


        let x = self.conv1.forward(x); // [batch_size, 8, _, _]
        let x = self.dropout.forward(x);
        let x = self.conv2.forward(x); // [batch_size, 16, _, _]
        let x = self.dropout.forward(x);
        let x = self.activation.forward(x);

        let x = self.pool.forward(x); // [batch_size, 16, 8, 8]
        let x = x.reshape([batch_size, 16 * 8 * 8]);
        let x = self.linear1.forward(x);
        let x = self.dropout.forward(x);
        let x = self.activation.forward(x);

        self.linear2.forward(x) // [batch_size, num_classes]
    }
}

For former PyTorch users, this might feel very intuitive, as each module is directly incorporated into the code using an eager API. Note that no abstraction is imposed for the forward method. You are free to define multiple forward functions with the names of your liking. Most of the neural network modules already built with Burn use the forward nomenclature, simply because it is standard in the field.

Similar to neural network modules, the Tensor struct given as a parameter also takes the Backend trait as a generic argument, alongside its dimensionality. Even if it is not used in this specific example, it is possible to add the kind of the tensor as a third generic argument. For example, a 3-dimensional Tensor of different data types(float, int, bool) would be defined as following:

Tensor<B, 3> // Float tensor (default)
Tensor<B, 3, Float> // Float tensor (explicit)
Tensor<B, 3, Int> // Int tensor
Tensor<B, 3, Bool> // Bool tensor

Note that the specific element type, such as f16, f32 and the likes, will be defined later with the backend.

Data

Typically, one trains a model on some dataset. Burn provides a library of very useful dataset sources and transformations, such as Hugging Face dataset utilities that allow to download and store data into an SQLite database for extremely efficient data streaming and storage. For this guide though, we will use the MNIST dataset from burn::data::dataset::vision which requires no external dependency.

To iterate over a dataset efficiently, we will define a struct which will implement the Batcher trait. The goal of a batcher is to map individual dataset items into a batched tensor that can be used as input to our previously defined model.

Let us start by defining our dataset functionalities in a file src/data.rs. We shall omit some of the imports for brevity, but the full code for following this guide can be found at examples/guide/ directory.

use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
};

#[derive(Clone)]
pub struct MnistBatcher<B: Backend> {
    device: B::Device,
}

impl<B: Backend> MnistBatcher<B> {
    pub fn new(device: B::Device) -> Self {
        Self { device }
    }
}

This codeblock defines a batcher struct with the device in which the tensor should be sent before being passed to the model. Note that the device is an associative type of the Backend trait since not all backends expose the same devices. As an example, the Libtorch-based backend exposes Cuda(gpu_index), Cpu, Vulkan and Metal devices, while the ndarray backend only exposes the Cpu device.

Next, we need to actually implement the batching logic.

use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
};

#[derive(Clone)]
pub struct MnistBatcher<B: Backend> {
    device: B::Device,
}

impl<B: Backend> MnistBatcher<B> {
    pub fn new(device: B::Device) -> Self {
        Self { device }
    }
}

#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
    pub images: Tensor<B, 3>,
    pub targets: Tensor<B, 1, Int>,
}

impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
    fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
        let images = items
            .iter()
            .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
            .map(|data| Tensor::<B, 2>::from_data(data, &self.device))
            .map(|tensor| tensor.reshape([1, 28, 28]))
            // Normalize: make between [0,1] and make the mean=0 and std=1
            // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
            // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
            .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
            .collect();

        let targets = items
            .iter()
            .map(|item| {
                Tensor::<B, 1, Int>::from_data(
                    [(item.label as i64).elem::<B::IntElem>()],
                    &self.device,
                )
            })
            .collect();

        let images = Tensor::cat(images, 0).to_device(&self.device);
        let targets = Tensor::cat(targets, 0).to_device(&self.device);

        MnistBatch { images, targets }
    }
}
πŸ¦€ Iterators and Closures

The iterator pattern allows you to perform some tasks on a sequence of items in turn.

In this example, an iterator is created over the MnistItems in the vector items by calling the iter method.

Iterator adaptors are methods defined on the Iterator trait that produce different iterators by changing some aspect of the original iterator. Here, the map method is called in a chain to transform the original data before consuming the final iterator with collect to obtain the images and targets vectors. Both vectors are then concatenated into a single tensor for the current batch.

You probably noticed that each call to map is different, as it defines a function to execute on the iterator items at each step. These anonymous functions are called closures in Rust. They're easy to recognize due to their syntax which uses vertical bars ||. The vertical bars capture the input variables (if applicable) while the rest of the expression defines the function to execute.

If we go back to the example, we can break down and comment the expression used to process the images.

let images = items                                                       // take items Vec<MnistItem>
    .iter()                                                              // create an iterator over it
    .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())  // for each item, convert the image to float data struct
    .map(|data| Tensor::<B, 2>::from_data(data, &self.device))           // for each data struct, create a tensor on the device
    .map(|tensor| tensor.reshape([1, 28, 28]))                           // for each tensor, reshape to the image dimensions [C, H, W]
    .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)                    // for each image tensor, apply normalization
    .collect();                                                          // consume the resulting iterator & collect the values into a new vector

For more information on iterators and closures, be sure to check out the corresponding chapter in the Rust Book.


In the previous example, we implement the Batcher trait with a list of MnistItem as input and a single MnistBatch as output. The batch contains the images in the form of a 3D tensor, along with a targets tensor that contains the indexes of the correct digit class. The first step is to parse the image array into a TensorData struct. Burn provides the TensorData struct to encapsulate tensor storage information without being specific for a backend. When creating a tensor from data, we often need to convert the data precision to the current backend in use. This can be done with the .convert() method (in this example, the data is converted backend's float element type B::FloatElem). While importing the burn::tensor::ElementConversion trait, you can call .elem() on a specific number to convert it to the current backend element type in use.

Training

We are now ready to write the necessary code to train our model on the MNIST dataset. We shall define the code for this training section in the file: src/training.rs.

Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose responsibility is to apply an optimizer to the model. The output struct is used for all metrics calculated during the training. Therefore it should include all the necessary information to calculate any metric that you want for a task.

Burn provides two basic output types: ClassificationOutput and RegressionOutput. They implement the necessary trait to be used with metrics. It is possible to create your own item, but it is beyond the scope of this guide.

Since the MNIST task is a classification problem, we will use the ClassificationOutput type.

use crate::{
    data::{MnistBatch, MnistBatcher},
    model::{Model, ModelConfig},
};
use burn::{
    data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
    nn::loss::CrossEntropyLossConfig,
    optim::AdamConfig,
    prelude::*,
    record::CompactRecorder,
    tensor::backend::AutodiffBackend,
    train::{
        metric::{AccuracyMetric, LossMetric},
        ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
    },
};

impl<B: Backend> Model<B> {
    pub fn forward_classification(
        &self,
        images: Tensor<B, 3>,
        targets: Tensor<B, 1, Int>,
    ) -> ClassificationOutput<B> {
        let output = self.forward(images);
        let loss = CrossEntropyLossConfig::new()
            .init(&output.device())
            .forward(output.clone(), targets.clone());

        ClassificationOutput::new(loss, output, targets)
    }
}

As evident from the preceding code block, we employ the cross-entropy loss module for loss calculation, without the inclusion of any padding token. We then return the classification output containing the loss, the output tensor with all logits and the targets.

Please take note that tensor operations receive owned tensors as input. For reusing a tensor multiple times, you need to use the clone() function. There's no need to worry; this process won't involve actual copying of the tensor data. Instead, it will simply indicate that the tensor is employed in multiple instances, implying that certain operations won't be performed in place. In summary, our API has been designed with owned tensors to optimize performance.

Moving forward, we will proceed with the implementation of both the training and validation steps for our model.

use crate::{
    data::{MnistBatch, MnistBatcher},
    model::{Model, ModelConfig},
};
use burn::{
    data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
    nn::loss::CrossEntropyLossConfig,
    optim::AdamConfig,
    prelude::*,
    record::CompactRecorder,
    tensor::backend::AutodiffBackend,
    train::{
        metric::{AccuracyMetric, LossMetric},
        ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
    },
};

impl<B: Backend> Model<B> {
    pub fn forward_classification(
        &self,
        images: Tensor<B, 3>,
        targets: Tensor<B, 1, Int>,
    ) -> ClassificationOutput<B> {
        let output = self.forward(images);
        let loss = CrossEntropyLossConfig::new()
            .init(&output.device())
            .forward(output.clone(), targets.clone());

        ClassificationOutput::new(loss, output, targets)
    }
}

impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
    fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
        let item = self.forward_classification(batch.images, batch.targets);

        TrainOutput::new(self, item.loss.backward(), item)
    }
}

impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
    fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
        self.forward_classification(batch.images, batch.targets)
    }
}

Here we define the input and output types as generic arguments in the TrainStep and ValidStep. We will call them MnistBatch and ClassificationOutput. In the training step, the computation of gradients is straightforward, necessitating a simple invocation of backward() on the loss. Note that contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather returned by the backward pass, as such: let gradients = loss.backward();. The gradient of a parameter can be obtained with the grad function: let grad = tensor.grad(&gradients);. Although it is not necessary when using the learner struct and the optimizers, it can prove to be quite useful when debugging or writing custom training loops. One of the differences between the training and the validation steps is that the former requires the backend to implement AutodiffBackend and not just Backend. Otherwise, the backward function is not available, as the backend does not support autodiff. We will see later how to create a backend with autodiff support.

πŸ¦€ Generic Type Constraints in Method Definitions

Although generic data types, trait and trait bounds were already introduced in previous sections of this guide, the previous code snippet might be a lot to take in at first.

In the example above, we implement the TrainStep and ValidStep trait for our Model struct, which is generic over the Backend trait as has been covered before. These traits are provided by burn::train and define a common step method that should be implemented for all structs. Since the trait is generic over the input and output types, the trait implementation must specify the concrete types used. This is where the additional type constraints appear <MnistBatch<B>, ClassificationOutput<B>>. As we saw previously, the concrete input type for the batch is MnistBatch, and the output of the forward pass is ClassificationOutput. The step method signature matches the concrete input and output types.

For more details specific to constraints on generic types when defining methods, take a look at this section of the Rust Book.


Let us move on to establishing the practical training configuration.

use crate::{
    data::{MnistBatch, MnistBatcher},
    model::{Model, ModelConfig},
};
use burn::{
    data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
    nn::loss::CrossEntropyLossConfig,
    optim::AdamConfig,
    prelude::*,
    record::CompactRecorder,
    tensor::backend::AutodiffBackend,
    train::{
        metric::{AccuracyMetric, LossMetric},
        ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
    },
};

impl<B: Backend> Model<B> {
    pub fn forward_classification(
        &self,
        images: Tensor<B, 3>,
        targets: Tensor<B, 1, Int>,
    ) -> ClassificationOutput<B> {
        let output = self.forward(images);
        let loss = CrossEntropyLossConfig::new()
            .init(&output.device())
            .forward(output.clone(), targets.clone());

        ClassificationOutput::new(loss, output, targets)
    }
}

impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
    fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
        let item = self.forward_classification(batch.images, batch.targets);

        TrainOutput::new(self, item.loss.backward(), item)
    }
}

impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {
    fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {
        self.forward_classification(batch.images, batch.targets)
    }
}

#[derive(Config)]
pub struct TrainingConfig {
    pub model: ModelConfig,
    pub optimizer: AdamConfig,
    #[config(default = 10)]
    pub num_epochs: usize,
    #[config(default = 64)]
    pub batch_size: usize,
    #[config(default = 4)]
    pub num_workers: usize,
    #[config(default = 42)]
    pub seed: u64,
    #[config(default = 1.0e-4)]
    pub learning_rate: f64,
}

fn create_artifact_dir(artifact_dir: &str) {
    // Remove existing artifacts before to get an accurate learner summary
    std::fs::remove_dir_all(artifact_dir).ok();
    std::fs::create_dir_all(artifact_dir).ok();
}

pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
    create_artifact_dir(artifact_dir);
    config
        .save(format!("{artifact_dir}/config.json"))
        .expect("Config should be saved successfully");

    B::seed(config.seed);

    let batcher_train = MnistBatcher::<B>::new(device.clone());
    let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());

    let dataloader_train = DataLoaderBuilder::new(batcher_train)
        .batch_size(config.batch_size)
        .shuffle(config.seed)
        .num_workers(config.num_workers)
        .build(MnistDataset::train());

    let dataloader_test = DataLoaderBuilder::new(batcher_valid)
        .batch_size(config.batch_size)
        .shuffle(config.seed)
        .num_workers(config.num_workers)
        .build(MnistDataset::test());

    let learner = LearnerBuilder::new(artifact_dir)
        .metric_train_numeric(AccuracyMetric::new())
        .metric_valid_numeric(AccuracyMetric::new())
        .metric_train_numeric(LossMetric::new())
        .metric_valid_numeric(LossMetric::new())
        .with_file_checkpointer(CompactRecorder::new())
        .devices(vec![device.clone()])
        .num_epochs(config.num_epochs)
        .summary()
        .build(
            config.model.init::<B>(&device),
            config.optimizer.init(),
            config.learning_rate,
        );

    let model_trained = learner.fit(dataloader_train, dataloader_test);

    model_trained
        .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
        .expect("Trained model should be saved successfully");
}

It is a good practice to use the Config derive to create the experiment configuration. In the train function, the first thing we are doing is making sure the artifact_dir exists, using the standard rust library for file manipulation. All checkpoints, logging and metrics will be stored under this directory. We then initialize our dataloaders using our previously created batcher. Since no automatic differentiation is needed during the validation phase, the backend used for the corresponding batcher is B::InnerBackend (see Backend). The autodiff capabilities are available through a type system, making it nearly impossible to forget to deactivate gradient calculation.

Next, we create our learner with the accuracy and loss metric on both training and validation steps along with the device and the epoch. We also configure the checkpointer using the CompactRecorder to indicate how weights should be stored. This struct implements the Recorder trait, which makes it capable of saving records for persistency.

We then build the learner with the model, the optimizer and the learning rate. Notably, the third argument of the build function should actually be a learning rate scheduler. When provided with a float as in our example, it is automatically transformed into a constant learning rate scheduler. The learning rate is not part of the optimizer config as it is often done in other frameworks, but rather passed as a parameter when executing the optimizer step. This avoids having to mutate the state of the optimizer and is therefore more functional. It makes no difference when using the learner struct, but it will be an essential nuance to grasp if you implement your own training loop.

Once the learner is created, we can simply call fit and provide the training and validation dataloaders. For the sake of simplicity in this example, we employ the test set as the validation set; however, we do not recommend this practice for actual usage.

Finally, the trained model is returned by the fit method. The trained weights are then saved using the CompactRecorder. This recorder employs the MessagePack format with half precision, f16 for floats and i16 for integers. Other recorders are available, offering support for various formats, such as BinCode and JSON, with or without compression. Any backend, regardless of precision, can load recorded data of any kind.

Backend

We have effectively written most of the necessary code to train our model. However, we have not explicitly designated the backend to be used at any point. This will be defined in the main entrypoint of our program, namely the main function defined in src/main.rs.

mod data;
mod model;
mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
    backend::{Autodiff, Wgpu},
    data::dataset::Dataset,
    optim::AdamConfig,
};

fn main() {
    type MyBackend = Wgpu<f32, i32>;
    type MyAutodiffBackend = Autodiff<MyBackend>;

    let device = burn::backend::wgpu::WgpuDevice::default();
    let artifact_dir = "/tmp/guide";
    crate::training::train::<MyAutodiffBackend>(
        artifact_dir,
        TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
        device.clone(),
    );
}

In this code snippet, we use the Wgpu backend which is compatible with any operating system and will use the GPU. For other options, see the Burn README. This backend type takes the graphics API, the float type and the int type as generic arguments that will be used during the training. The autodiff backend is simply the same backend, wrapped within the Autodiff struct which imparts differentiability
to any backend.

We call the train function defined earlier with a directory for artifacts, the configuration of the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer configuration which in our case will be the default Adam configuration, and the device which can be obtained from the backend.

You can now train your freshly created model with the command:

cargo run --release

When running your project with the commande above, you should see the training progression through a basic CLI dashboard:

Alt text

Inference

Now that we have trained our model, the next natural step is to use it for inference.

You need two things in order to load weights for a model: the model's record and the model's config. Since parameters in Burn are lazy initialized, no allocation and GPU/CPU kernels are executed by the ModelConfig::init function. The weights are initialized when used for the first time, therefore you can safely use config.init(device).load_record(record) without any meaningful performance cost. Let's create a simple infer method in a new file src/inference.rs which we will use to load our trained model.

use crate::{data::MnistBatcher, training::TrainingConfig};
use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
    record::{CompactRecorder, Recorder},
};

pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
    let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
        .expect("Config should exist for the model");
    let record = CompactRecorder::new()
        .load(format!("{artifact_dir}/model").into(), &device)
        .expect("Trained model should exist");

    let model = config.model.init::<B>(&device).load_record(record);

    let label = item.label;
    let batcher = MnistBatcher::new(device);
    let batch = batcher.batch(vec![item]);
    let output = model.forward(batch.images);
    let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();

    println!("Predicted {} Expected {}", predicted, label);
}

The first step is to load the configuration of the training to fetch the correct model configuration. Then we can fetch the record using the same recorder as we used during training. Finally we can init the model with the configuration and the record. For simplicity we can use the same batcher used during the training to pass from a MnistItem to a tensor.

By running the infer function, you should see the predictions of your model!

Add the call to infer to the main.rs file after the train function call:

mod data;
mod inference;
mod model;
mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
    backend::{Autodiff, Wgpu},
    data::dataset::Dataset,
    optim::AdamConfig,
};

fn main() {
    type MyBackend = Wgpu<f32, i32>;
    type MyAutodiffBackend = Autodiff<MyBackend>;

    let device = burn::backend::wgpu::WgpuDevice::default();
    let artifact_dir = "/tmp/guide";
    crate::training::train::<MyAutodiffBackend>(
        artifact_dir,
        TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
        device.clone(),
    );
    crate::inference::infer::<MyBackend>(
        artifact_dir,
        device,
        burn::data::dataset::vision::MnistDataset::test()
            .get(42)
            .unwrap(),
    );
}

The number 42 is the index of the image in the MNIST dataset. You can explore and verify them using this MNIST viewer.

Conclusion

In this short guide, we've introduced you to the fundamental building blocks for getting started with Burn. While there's still plenty to explore, our goal has been to provide you with the essential knowledge to kickstart your productivity within the framework.

Building Blocks

In this section, we'll guide you through the core elements that make up Burn. We'll walk you through the key components that serve as the building blocks of the framework and your future projects.

As you explore Burn, you might notice that we occasionally draw comparisons to PyTorch. We believe it can provide a smoother learning curve and help you grasp the nuances more effectively.

Backend

Nearly everything in Burn is based on the Backend trait, which enables you to run tensor operations using different implementations without having to modify your code. While a backend may not necessarily have autodiff capabilities, the AutodiffBackend trait specifies when autodiff is needed. This trait not only abstracts operations but also tensor, device, and element types, providing each backend the flexibility they need. It's worth noting that the trait assumes eager mode since burn fully supports dynamic graphs. However, we may create another API to assist with integrating graph-based backends, without requiring any changes to the user's code.

Users are not expected to directly use the backend trait methods, as it is primarily designed with backend developers in mind rather than Burn users. Therefore, most Burn userland APIs are generic across backends. This approach helps users discover the API more organically with proper autocomplete and documentation.

Tensor

As previously explained in the model section, the Tensor struct has 3 generic arguments: the backend B, the dimensionality D, and the data type.

Tensor<B, D>           // Float tensor (default)
Tensor<B, D, Float>    // Explicit float tensor
Tensor<B, D, Int>      // Int tensor
Tensor<B, D, Bool>     // Bool tensor

Note that the specific element types used for Float, Int, and Bool tensors are defined by backend implementations.

Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The actual shape of the tensor is inferred from its initialization. For example, a Tensor of size (5,) is initialized as below:

let floats = [1.0, 2.0, 3.0, 4.0, 5.0];

// Get the default device
let device = Default::default();

// correct: Tensor is 1-Dimensional with 5 elements
let tensor_1 = Tensor::<Backend, 1>::from_floats(floats, &device);

// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats(floats, &device);
// this will lead to an error and is for creating a 5-D tensor

Initialization

Burn Tensors are primarily initialized using the from_data() method which takes the TensorData struct as input. The TensorData struct has two public fields: shape and dtype. The value, now stored as bytes, is private but can be accessed via any of the following methods: as_slice, as_mut_slice, to_vec and iter. To retrieve the data from a tensor, the method .to_data() should be employed when intending to reuse the tensor afterward. Alternatively, .into_data() is recommended for one-time use. Let's look at a couple of examples for initializing a tensor from different inputs.


// Initialization from a given Backend (Wgpu)
let tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0], &device);

// Initialization from a generic Backend
let tensor_2 = Tensor::<Backend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);

// Initialization using from_floats (Recommended for f32 ElementType)
// Will be converted to TensorData internally.
let tensor_3 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0], &device);

// Initialization of Int Tensor from array slices
let arr: [i32; 6] = [1, 2, 3, 4, 5, 6];
let tensor_4 = Tensor::<Backend, 1, Int>::from_data(TensorData::from(&arr[0..3]), &device);

// Initialization from a custom type

struct BodyMetrics {
    age: i8,
    height: i16,
    weight: f32
}

let bmi = BodyMetrics{
        age: 25,
        height: 180,
        weight: 80.0
    };
let data  = TensorData::from([bmi.age as f32, bmi.height as f32, bmi.weight]);
let tensor_5 = Tensor::<Backend, 1>::from_data(data, &device);

Ownership and Cloning

Almost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple times will necessitate cloning it. Let's look at an example to understand the ownership rules and cloning better. Suppose we want to do a simple min-max normalization of an input tensor.

let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let min = input.min();
let max = input.max();
let input = (input - min).div(max - min);

With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules will give an error and prevent using the input tensor after the first .min() operation. The ownership of the input tensor is transferred to the variable min and the input tensor is no longer available for further operations. Burn Tensors like most complex primitives do not implement the Copy trait and therefore have to be cloned explicitly. Now let's rewrite a working example of doing min-max normalization with cloning.

let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let min = input.clone().min();
let max = input.clone().max();
let input = (input.clone() - min.clone()).div(max - min);
println!("{}", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0]

// Notice that max, min have been moved in last operation so
// the below print will give an error.
// If we want to use them for further operations,
// they will need to be cloned in similar fashion.
// println!("{:?}", min.to_data());

We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't copied, and only a reference to it is increased. This makes it possible to determine exactly how many times a tensor is used, which is very convenient for reusing tensor buffers or even fusing operations into a single kernel (burn-fusion). For that reason, we don't provide explicit inplace operations. If a tensor is used only one time, inplace operations will always be used when available.

Tensor Operations

Normally with PyTorch, explicit inplace operations aren't supported during the backward pass, making them useful only for data preprocessing or inference-only model implementations. With Burn, you can focus more on what the model should do, rather than on how to do it. We take the responsibility of making your code run as fast as possible during training as well as inference. The same principles apply to broadcasting; all operations support broadcasting unless specified otherwise.

Here, we provide a list of all supported operations along with their PyTorch equivalents. Note that for the sake of simplicity, we ignore type signatures. For more details, refer to the full documentation.

Basic Operations

Those operations are available for all tensor kinds: Int, Float, and Bool.

BurnPyTorch Equivalent
Tensor::cat(tensors, dim)torch.cat(tensors, dim)
Tensor::empty(shape, device)torch.empty(shape, device=device)
Tensor::from_primitive(primitive)N/A
Tensor::stack(tensors, dim)torch.stack(tensors, dim)
tensor.all()tensor.all()
tensor.all_dim(dim)tensor.all(dim)
tensor.any()tensor.any()
tensor.any_dim(dim)tensor.any(dim)
tensor.chunk(num_chunks, dim)tensor.chunk(num_chunks, dim)
tensor.device()tensor.device
tensor.dims()tensor.size()
tensor.equal(other)x == y
tensor.expand(shape)tensor.expand(shape)
tensor.flatten(start_dim, end_dim)tensor.flatten(start_dim, end_dim)
tensor.flip(axes)tensor.flip(axes)
tensor.into_data()N/A
tensor.into_primitive()N/A
tensor.into_scalar()tensor.item()
tensor.narrow(dim, start, length)tensor.narrow(dim, start, length)
tensor.not_equal(other)x != y
tensor.permute(axes)tensor.permute(axes)
tensor.movedim(src, dst)tensor.movedim(src, dst)
tensor.repeat_dim(dim, times)tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])
tensor.repeat(sizes)tensor.repeat(sizes)
tensor.reshape(shape)tensor.view(shape)
tensor.shape()tensor.shape
tensor.slice(ranges)tensor[(*ranges,)]
tensor.slice_assign(ranges, values)tensor[(*ranges,)] = values
tensor.squeeze(dim)tensor.squeeze(dim)
tensor.to_data()N/A
tensor.to_device(device)tensor.to(device)
tensor.unsqueeze()tensor.unsqueeze(0)
tensor.unsqueeze_dim(dim)tensor.unsqueeze(dim)

Numeric Operations

Those operations are available for numeric tensor kinds: Float and Int.

BurnPyTorch Equivalent
Tensor::eye(size, device)torch.eye(size, device=device)
Tensor::full(shape, fill_value, device)torch.full(shape, fill_value, device=device)
Tensor::ones(shape, device)torch.ones(shape, device=device)
Tensor::zeros(shape)torch.zeros(shape)
Tensor::zeros(shape, device)torch.zeros(shape, device=device)
tensor.abs()torch.abs(tensor)
tensor.add(other) or tensor + othertensor + other
tensor.add_scalar(scalar) or tensor + scalartensor + scalar
tensor.all_close(other, atol, rtol)torch.allclose(tensor, other, atol, rtol)
tensor.argmax(dim)tensor.argmax(dim)
tensor.argmin(dim)tensor.argmin(dim)
tensor.argsort(dim)tensor.argsort(dim)
tensor.argsort_descending(dim)tensor.argsort(dim, descending=True)
tensor.bool()tensor.bool()
tensor.clamp(min, max)torch.clamp(tensor, min=min, max=max)
tensor.clamp_max(max)torch.clamp(tensor, max=max)
tensor.clamp_min(min)torch.clamp(tensor, min=min)
tensor.contains_nan()N/A
tensor.div(other) or tensor / othertensor / other
tensor.div_scalar(scalar) or tensor / scalartensor / scalar
tensor.equal_elem(other)tensor.eq(other)
tensor.full_like(fill_value)`torch.full_like(tensor, fill_value)
tensor.gather(dim, indices)torch.gather(tensor, dim, indices)
tensor.greater(other)tensor.gt(other)
tensor.greater_elem(scalar)tensor.gt(scalar)
tensor.greater_equal(other)tensor.ge(other)
tensor.greater_equal_elem(scalar)tensor.ge(scalar)
tensor.is_close(other, atol, rtol)torch.isclose(tensor, other, atol, rtol)
tensor.is_nan()torch.isnan(tensor)
tensor.lower(other)tensor.lt(other)
tensor.lower_elem(scalar)tensor.lt(scalar)
tensor.lower_equal(other)tensor.le(other)
tensor.lower_equal_elem(scalar)tensor.le(scalar)
tensor.mask_fill(mask, value)tensor.masked_fill(mask, value)
tensor.mask_where(mask, value_tensor)torch.where(mask, value_tensor, tensor)
tensor.max()tensor.max()
tensor.max_dim(dim)tensor.max(dim, keepdim=True)
tensor.max_dim_with_indices(dim)N/A
tensor.max_pair(other)torch.Tensor.max(a,b)
tensor.mean()tensor.mean()
tensor.mean_dim(dim)tensor.mean(dim, keepdim=True)
tensor.min()tensor.min()
tensor.min_dim(dim)tensor.min(dim, keepdim=True)
tensor.min_dim_with_indices(dim)N/A
tensor.min_pair(other)torch.Tensor.min(a,b)
tensor.mul(other) or tensor * othertensor * other
tensor.mul_scalar(scalar) or tensor * scalartensor * scalar
tensor.neg() or -tensor-tensor
tensor.not_equal_elem(scalar)tensor.ne(scalar)
tensor.ones_like()torch.ones_like(tensor)
tensor.pad(pads, value)torch.nn.functional.pad(input, pad, value)
tensor.powf(other) or tensor.powi(intother)tensor.pow(other)
tensor.powf_scalar(scalar) or tensor.powi_scalar(intscalar)tensor.pow(scalar)
tensor.prod()tensor.prod()
tensor.prod_dim(dim)tensor.prod(dim, keepdim=True)
tensor.rem(other) or tensor % othertensor % other
tensor.scatter(dim, indices, values)tensor.scatter_add(dim, indices, values)
tensor.select(dim, indices)tensor.index_select(dim, indices)
tensor.select_assign(dim, indices, values)N/A
tensor.sign()tensor.sign()
tensor.sort(dim)tensor.sort(dim).values
tensor.sort_descending(dim)tensor.sort(dim, descending=True).values
tensor.sort_descending_with_indices(dim)tensor.sort(dim, descending=True)
tensor.sort_with_indices(dim)tensor.sort(dim)
tensor.sub(other) or tensor - othertensor - other
tensor.sub_scalar(scalar) or tensor - scalartensor - scalar
tensor.sum()tensor.sum()
tensor.sum_dim(dim)tensor.sum(dim, keepdim=True)
tensor.topk(k, dim)tensor.topk(k, dim).values
tensor.topk_with_indices(k, dim)tensor.topk(k, dim)
tensor.tril(diagonal)torch.tril(tensor, diagonal)
tensor.triu(diagonal)torch.triu(tensor, diagonal)
tensor.zeros_like()torch.zeros_like(tensor)

Float Operations

Those operations are only available for Float tensors.

Burn APIPyTorch Equivalent
Tensor::one_hot(index, num_classes, device)N/A
tensor.ceil()tensor.ceil()
tensor.cos()tensor.cos()
tensor.erf()tensor.erf()
tensor.exp()tensor.exp()
tensor.floor()tensor.floor()
tensor.from_floats(floats, device)N/A
tensor.from_full_precision(tensor)N/A
tensor.int()Similar to tensor.to(torch.long)
tensor.log()tensor.log()
tensor.log1p()tensor.log1p()
tensor.matmul(other)tensor.matmul(other)
tensor.random(shape, distribution, device)N/A
tensor.random_like(distribution)torch.rand_like() only uniform
tensor.recip()tensor.reciprocal()
tensor.round()tensor.round()
tensor.sin()tensor.sin()
tensor.sqrt()tensor.sqrt()
tensor.swap_dims(dim1, dim2)tensor.transpose(dim1, dim2)
tensor.tanh()tensor.tanh()
tensor.to_full_precision()tensor.to(torch.float)
tensor.transpose()tensor.T
tensor.var(dim)tensor.var(dim)
tensor.var_bias(dim)N/A
tensor.var_mean(dim)N/A
tensor.var_mean_bias(dim)N/A

Int Operations

Those operations are only available for Int tensors.

Burn APIPyTorch Equivalent
tensor.arange(5..10, device)tensor.arange(start=5, end=10, device=device)
tensor.arange_step(5..10, 2, device)tensor.arange(start=5, end=10, step=2, device=device)
tensor.float()tensor.to(torch.float)
tensor.from_ints(ints)N/A
tensor.int_random(shape, distribution, device)N/A
tensor.cartesian_grid(shape, device)N/A
tensor.one_hot(num_classes)N/A

Bool Operations

Those operations are only available for Bool tensors.

Burn APIPyTorch Equivalent
Tensor::diag_mask(shape, diagonal)N/A
Tensor::tril_mask(shape, diagonal)N/A
Tensor::triu_mask(shape, diagonal)N/A
tensor.argwhere()tensor.argwhere()
tensor.float()tensor.to(torch.float)
tensor.int()tensor.to(torch.long)
tensor.nonzero()tensor.nonzero(as_tuple=True)
tensor.not()tensor.logical_not()

Quantization Operations

Those operations are only available for Float tensors on backends that implement quantization strategies.

Burn APIPyTorch Equivalent
tensor.quantize(scheme, qparams)N/A
tensor.dequantize()N/A

Activation Functions

Burn APIPyTorch Equivalent
activation::gelu(tensor)nn.functional.gelu(tensor)
`activation::hard_sigmoid(tensor, alpha, beta)nn.functional.hardsigmoid(tensor)
activation::leaky_relu(tensor, negative_slope)nn.functional.leaky_relu(tensor, negative_slope)
activation::log_sigmoid(tensor)nn.functional.log_sigmoid(tensor)
activation::log_softmax(tensor, dim)nn.functional.log_softmax(tensor, dim)
activation::mish(tensor)nn.functional.mish(tensor)
activation::prelu(tensor,alpha)nn.functional.prelu(tensor,weight)
activation::quiet_softmax(tensor, dim)nn.functional.quiet_softmax(tensor, dim)
activation::relu(tensor)nn.functional.relu(tensor)
activation::sigmoid(tensor)nn.functional.sigmoid(tensor)
activation::silu(tensor)nn.functional.silu(tensor)
activation::softmax(tensor, dim)nn.functional.softmax(tensor, dim)
activation::softmin(tensor, dim)nn.functional.softmin(tensor, dim)
activation::softplus(tensor, beta)nn.functional.softplus(tensor, beta)
activation::tanh(tensor)nn.functional.tanh(tensor)

Displaying Tensor Details

Burn provides flexible options for displaying tensor information, allowing you to control the level of detail and formatting to suit your needs.

Basic Display

To display a detailed view of a tensor, you can simply use Rust's println! or format! macros:

#![allow(unused)]
fn main() {
let tensor = Tensor::<Backend, 2>::full([2, 3], 0.123456789, &Default::default());
println!("{}", tensor);
}

This will output:

Tensor {
  data:
[[0.12345679, 0.12345679, 0.12345679],
 [0.12345679, 0.12345679, 0.12345679]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}

Controlling Precision

You can control the number of decimal places displayed using Rust's formatting syntax:

#![allow(unused)]
fn main() {
println!("{:.2}", tensor);
}

Output:

Tensor {
  data:
[[0.12, 0.12, 0.12],
 [0.12, 0.12, 0.12]],
  shape:  [2, 3],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}

Global Print Options

For more fine-grained control over tensor printing, Burn provides a PrintOptions struct and a set_print_options function:

#![allow(unused)]
fn main() {
use burn::tensor::{set_print_options, PrintOptions};

let print_options = PrintOptions {
    precision: Some(2),
    ..Default::default()
};

set_print_options(print_options);
}

Options:

  • precision: Number of decimal places for floating-point numbers (default: None)

  • threshold: Maximum number of elements to display before summarizing (default: 1000)

  • edge_items: Number of items to show at the beginning and end of each dimension when summarizing (default: 3)

    Checking Tensor Closeness

    Burn provides a utility function check_closeness to compare two tensors and assess their similarity. This function is particularly useful for debugging and validating tensor operations, especially when working with floating-point arithmetic where small numerical differences can accumulate. It's also valuable when comparing model outputs during the process of importing models from other frameworks, helping to ensure that the imported model produces results consistent with the original.

    Here's an example of how to use check_closeness:

    #![allow(unused)]
    fn main() {
    use burn::tensor::{check_closeness, Tensor};
    type B = burn::backend::NdArray;
    
    let device = Default::default();
    let tensor1 = Tensor::<B, 1>::from_floats(
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],
        &device,
    );
    let tensor2 = Tensor::<B, 1>::from_floats(
        [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],
        &device,
    );
    
    check_closeness(&tensor1, &tensor2);
    }

    The check_closeness function compares the two input tensors element-wise, checking their absolute differences against a range of epsilon values. It then prints a detailed report showing the percentage of elements that are within each tolerance level.

    The output provides a breakdown for different epsilon values, allowing you to assess the closeness of the tensors at various precision levels. This is particularly helpful when dealing with operations that may introduce small numerical discrepancies.

    The function uses color-coded output to highlight the results:

    • Green [PASS]: All elements are within the specified tolerance.
    • Yellow [WARN]: Most elements (90% or more) are within tolerance.
    • Red [FAIL]: Significant differences are detected.

    This utility can be invaluable when implementing or debugging tensor operations, especially those involving complex mathematical computations or when porting algorithms from other frameworks. It's also an essential tool when verifying the accuracy of imported models, ensuring that the Burn implementation produces results that closely match those of the original model.

Autodiff

Burn's tensor also supports auto-differentiation, which is an essential part of any deep learning framework. We introduced the Backend trait in the previous section, but Burn also has another trait for autodiff: AutodiffBackend.

However, not all tensors support auto-differentiation; you need a backend that implements both the Backend and AutodiffBackend traits. Fortunately, you can add auto-differentiation capabilities to any backend using a backend decorator: type MyAutodiffBackend = Autodiff<MyBackend>. This decorator implements both the AutodiffBackend and Backend traits by maintaining a dynamic computational graph and utilizing the inner backend to execute tensor operations.

The AutodiffBackend trait adds new operations on float tensors that can't be called otherwise. It also provides a new associated type, B::Gradients, where each calculated gradient resides.

fn calculate_gradients<B: AutodiffBackend>(tensor: Tensor<B, 2>) -> B::Gradients {
    let mut gradients = tensor.clone().backward();

    let tensor_grad = tensor.grad(&gradients);        // get
    let tensor_grad = tensor.grad_remove(&mut gradients); // pop

    gradients
}

Note that some functions will always be available even if the backend doesn't implement the AutodiffBackend trait. In such cases, those functions will do nothing.

Burn APIPyTorch Equivalent
tensor.detach()tensor.detach()
tensor.require_grad()tensor.requires_grad()
tensor.is_require_grad()tensor.requires_grad
tensor.set_require_grad(require_grad)tensor.requires_grad(False)

However, you're unlikely to make any mistakes since you can't call backward on a tensor that is on a backend that doesn't implement AutodiffBackend. Additionally, you can't retrieve the gradient of a tensor without an autodiff backend.

Difference with PyTorch

The way Burn handles gradients is different from PyTorch. First, when calling backward, each parameter doesn't have its grad field updated. Instead, the backward pass returns all the calculated gradients in a container. This approach offers numerous benefits, such as the ability to easily send gradients to other threads.

You can also retrieve the gradient for a specific parameter using the grad method on a tensor. Since this method takes the gradients as input, it's hard to forget to call backward beforehand. Note that sometimes, using grad_remove can improve performance by allowing inplace operations.

In PyTorch, when you don't need gradients for inference or validation, you typically need to scope your code using a block.

# Inference mode
torch.inference():
   # your code
   ...

# Or no grad
torch.no_grad():
   # your code
   ...

With Burn, you don't need to wrap the backend with the Autodiff for inference, and you can call inner() to obtain the inner tensor, which is useful for validation.

/// Use `B: AutodiffBackend`
fn example_validation<B: AutodiffBackend>(tensor: Tensor<B, 2>) {
    let inner_tensor: Tensor<B::InnerBackend, 2> = tensor.inner();
    let _ = inner_tensor + 5;
}

/// Use `B: Backend`
fn example_inference<B: Backend>(tensor: Tensor<B, 2>) {
    let _ = tensor + 5;
    ...
}

Gradients with Optimizers

We've seen how gradients can be used with tensors, but the process is a bit different when working with optimizers from burn-core. To work with the Module trait, a translation step is required to link tensor parameters with their gradients. This step is necessary to easily support gradient accumulation and training on multiple devices, where each module can be forked and run on different devices in parallel. We'll explore deeper into this topic in the Module section.

Module

The Module derive allows you to create your own neural network modules, similar to PyTorch. The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.

use burn::module::Module;
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: Gelu,
}

impl<B: Backend> PositionWiseFeedForward<B> {
    /// Normal method added to a struct.
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);

        self.linear_outer.forward(x)
    }
}

Note that all fields declared in the struct must also implement the Module trait.

Tensor

If you want to create your own module that contains tensors, and not just other modules defined with the Module derive, you need to be careful to achieve the behavior you want.

  • Param<Tensor<B, D>>: If you want the tensor to be included as a parameter of your modules, you need to wrap the tensor in a Param struct. This will create an ID that will be used to identify this parameter. This is essential when performing module optimization and when saving states such as optimizer and module checkpoints. Note that a module's record only contains parameters.

  • Param<Tensor<B, D>>.set_require_grad(false): If you want the tensor to be included as a parameter of your modules, and therefore saved with the module's weights, but you don't want it to be updated by the optimizer.

  • Tensor<B, D>: If you want the tensor to act as a constant that can be recreated when instantiating a module. This can be useful when generating sinusoidal embeddings, for example.

Methods

These methods are available for all modules.

Burn APIPyTorch Equivalent
module.devices()N/A
module.fork(device)Similar to module.to(device).detach()
module.to_device(device)module.to(device)
module.no_grad()module.require_grad_(False)
module.num_params()N/A
module.visit(visitor)N/A
module.map(mapper)N/A
module.into_record()Similar to state_dict
module.load_record(record)Similar to load_state_dict(state_dict)
module.save_file(file_path, recorder)N/A
module.load_file(file_path, recorder)N/A

Similar to the backend trait, there is also the AutodiffModule trait to signify a module with autodiff support.

Burn APIPyTorch Equivalent
module.valid()module.eval()

Visitor & Mapper

As mentioned earlier, modules primarily function as parameter containers. Therefore, we naturally offer several ways to perform functions on each parameter. This is distinct from PyTorch, where extending module functionalities is not as straightforward.

The map and visitor methods are quite similar but serve different purposes. Mapping is used for potentially mutable operations where each parameter of a module can be updated to a new value. In Burn, optimizers are essentially just sophisticated module mappers. Visitors, on the other hand, are used when you don't intend to modify the module but need to retrieve specific information from it, such as the number of parameters or a list of devices in use.

You can implement your own mapper or visitor by implementing these simple traits:

/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
    /// Visit a tensor in the module.
    fn visit<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D>);
}

/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
    /// Map a tensor in the module.
    fn map<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) ->
      Tensor<B, D>;
}

Module Display

Burn provides a simple way to display the structure of a module and its configuration at a glance. You can print the module to see its structure, which is useful for debugging and tracking changes across different versions of a module. (See the print output of the Basic Workflow Model example.)

To customize the display of a module, you can implement the ModuleDisplay trait for your module. This will change the default display settings for the module and its children. Note that ModuleDisplay is automatically implemented for all modules, but you can override it to customize the display by annotating the module with #[module(custom_display)].

#![allow(unused)]
fn main() {
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: Gelu,
}

impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
    /// Custom settings for the display of the module.
    /// If `None` is returned, the default settings will be used.
    fn custom_settings(&self) -> Option<burn::module::DisplaySettings> {
        DisplaySettings::new()
            // Will show all attributes (default is false)
            .with_show_all_attributes(false)
            // Will show each attribute on a new line (default is true)
            .with_new_line_after_attribute(true)
            // Will show the number of parameters (default is true)
            .with_show_num_parameters(true)
            // Will indent by 2 spaces (default is 2)
            .with_indentation_size(2)
            // Will show the parameter ID (default is false)
            .with_show_param_id(false)
            // Convenience method to wrap settings in Some()
            .optional()
    }

    /// Custom content to be displayed.
    /// If `None` is returned, the default content will be used
    /// (all attributes of the module)
    fn custom_content(&self, content: Content) -> Option<Content> {
        content
            .add("linear_inner", &self.linear_inner)
            .add("linear_outer", &self.linear_outer)
            .add("anything", "anything_else")
            .optional()
    }
}
}

Built-in Modules

Burn comes with built-in modules that you can use to build your own modules.

General

Burn APIPyTorch Equivalent
BatchNormnn.BatchNorm1d, nn.BatchNorm2d etc.
Dropoutnn.Dropout
Embeddingnn.Embedding
Gelunn.Gelu
GroupNormnn.GroupNorm
HardSigmoidnn.Hardsigmoid
InstanceNormnn.InstanceNorm1d, nn.InstanceNorm2d etc.
LayerNormnn.LayerNorm
LeakyRelunn.LeakyReLU
Linearnn.Linear
Prelunn.PReLu
Relunn.ReLU
RmsNormNo direct equivalent
SwiGluNo direct equivalent
Interpolate1dNo direct equivalent
Interpolate2dNo direct equivalent

Convolutions

Burn APIPyTorch Equivalent
Conv1dnn.Conv1d
Conv2dnn.Conv2d
Conv3dnn.Conv3d
ConvTranspose1dnn.ConvTranspose1d
ConvTranspose2dnn.ConvTranspose2d
ConvTranspose3dnn.ConvTranspose3d
DeformConv2dtorchvision.ops.DeformConv2d

Pooling

Burn APIPyTorch Equivalent
AdaptiveAvgPool1dnn.AdaptiveAvgPool1d
AdaptiveAvgPool2dnn.AdaptiveAvgPool2d
AvgPool1dnn.AvgPool1d
AvgPool2dnn.AvgPool2d
MaxPool1dnn.MaxPool1d
MaxPool2dnn.MaxPool2d

RNNs

Burn APIPyTorch Equivalent
Grunn.GRU
Lstm/BiLstmnn.LSTM
GateControllerNo direct equivalent

Transformer

Burn APIPyTorch Equivalent
MultiHeadAttentionnn.MultiheadAttention
TransformerDecodernn.TransformerDecoder
TransformerEncodernn.TransformerEncoder
PositionalEncodingNo direct equivalent
RotaryEncodingNo direct equivalent

Loss

Burn APIPyTorch Equivalent
CrossEntropyLossnn.CrossEntropyLoss
MseLossnn.MSELoss
HuberLossnn.HuberLoss

Learner

The burn-train crate encapsulates multiple utilities for training deep learning models. The goal of the crate is to provide users with a well-crafted and flexible training loop, so that projects do not have to write such components from the ground up. Most of the interactions with burn-train will be with the LearnerBuilder struct, briefly presented in the previous training section. This struct enables you to configure the training loop, offering support for registering metrics, enabling logging, checkpointing states, using multiple devices, and so on.

There are still some assumptions in the current provided APIs, which may make them inappropriate for your learning requirements. Indeed, they assume your model will learn from a training dataset and be validated against another dataset. This is the most common paradigm, allowing users to do both supervised and unsupervised learning as well as fine-tuning. However, for more complex requirements, creating a custom training loop might be what you need.

Usage

The learner builder provides numerous options when it comes to configurations.

ConfigurationDescription
Training MetricRegister a training metric
Validation MetricRegister a validation metric
Training Metric PlotRegister a training metric with plotting (requires the metric to be numeric)
Validation Metric PlotRegister a validation metric with plotting (requires the metric to be numeric)
Metric LoggerConfigure the metric loggers (default is saving them to files)
RendererConfigure how to render metrics (default is CLI)
Grad AccumulationConfigure the number of steps before applying gradients
File CheckpointerConfigure how the model, optimizer and scheduler states are saved
Num EpochsSet the number of epochs.
DevicesSet the devices to be used
CheckpointRestart training from a checkpoint
Application loggingConfigure the application logging installer (default is writing to experiment.log)

When the builder is configured at your liking, you can then move forward to build the learner. The build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note that the latter can be a simple float if you want it to be constant during training.

The result will be a newly created Learner struct, which has only one method, the fit function which must be called with the training and validation dataloaders. This will start the training and return the trained model once finished.

Again, please refer to the training section for a relevant code snippet.

Artifacts

When creating a new builder, all the collected data will be saved under the directory provided as the argument to the new method. Here is an example of the data layout for a model recorded using the compressed message pack format, with the accuracy and loss metrics registered:

β”œβ”€β”€ experiment.log
β”œβ”€β”€ checkpoint
β”‚Β Β  β”œβ”€β”€ model-1.mpk.gz
β”‚Β Β  β”œβ”€β”€ optim-1.mpk.gz
β”‚Β Β  └── scheduler-1.mpk.gz
β”‚Β Β  β”œβ”€β”€ model-2.mpk.gz
β”‚Β Β  β”œβ”€β”€ optim-2.mpk.gz
β”‚Β Β  └── scheduler-2.mpk.gz
β”œβ”€β”€ train
β”‚Β Β  β”œβ”€β”€ epoch-1
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ Accuracy.log
β”‚Β Β  β”‚Β Β  └── Loss.log
β”‚Β Β  └── epoch-2
β”‚Β Β      β”œβ”€β”€ Accuracy.log
β”‚Β Β      └── Loss.log
└── valid
    β”œβ”€β”€ epoch-1
    β”‚Β Β  β”œβ”€β”€ Accuracy.log
    β”‚Β Β  └── Loss.log
    └── epoch-2
        β”œβ”€β”€ Accuracy.log
        └── Loss.log

You can choose to save or synchronize that local directory with a remote file system, if desired. The file checkpointer is capable of automatically deleting old checkpoints according to a specified configuration.

Metric

When working with the learner, you have the option to record metrics that will be monitored throughout the training process. We currently offer a restricted range of metrics.

MetricDescription
AccuracyCalculate the accuracy in percentage
TopKAccuracyCalculate the top-k accuracy in percentage
LossOutput the loss used for the backward pass
CPU TemperatureFetch the temperature of CPUs
CPU UsageFetch the CPU utilization
CPU Memory UsageFetch the CPU RAM usage
GPU TemperatureFetch the GPU temperature
Learning RateFetch the current learning rate for each optimizer step
CUDAFetch general CUDA metrics such as utilization

In order to use a metric, the output of your training step has to implement the Adaptor trait from burn-train::metric. Here is an example for the classification output, already provided with the crate.

/// Simple classification output adapted for multiple metrics.
#[derive(new)]
pub struct ClassificationOutput<B: Backend> {
    /// The loss.
    pub loss: Tensor<B, 1>,

    /// The output.
    pub output: Tensor<B, 2>,

    /// The targets.
    pub targets: Tensor<B, 1, Int>,
}

impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
    fn adapt(&self) -> AccuracyInput<B> {
        AccuracyInput::new(self.output.clone(), self.targets.clone())
    }
}

impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
    fn adapt(&self) -> LossInput<B> {
        LossInput::new(self.loss.clone())
    }
}

Custom Metric

Generating your own custom metrics is done by implementing the Metric trait.

/// Metric trait.
///
/// Implementations should define their own input type only used by the metric.
/// This is important since some conflict may happen when the model output is adapted for each
/// metric's input type.
///
/// The only exception is for metrics that don't need any input, setting the associated type
/// to the null type `()`.
pub trait Metric: Send + Sync {
    /// The input type of the metric.
    type Input;

    /// Updates the metric state and returns the current metric entry.
    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
    /// Clear the metric state.
    fn clear(&mut self);
}

As an example, let's see how the loss metric is implemented.

/// The loss metric.
#[derive(Default)]
pub struct LossMetric<B: Backend> {
    state: NumericMetricState,
    _b: B,
}

/// The loss metric input type.
#[derive(new)]
pub struct LossInput<B: Backend> {
    tensor: Tensor<B, 1>,
}

impl<B: Backend> Metric for LossMetric<B> {
    type Input = LossInput<B>;

    fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
        let loss = loss.tensor.clone().mean().into_scalar().elem::<f64>();

        self.state
            .update(loss, 1, FormatOptions::new("Loss").precision(2))
    }

    fn clear(&mut self) {
        self.state.reset()
    }
}

When the metric you are implementing is numeric in nature, you may want to also implement the Numeric trait. This will allow your metric to be plotted.

impl<B: Backend> Numeric for LossMetric<B> {
    fn value(&self) -> f64 {
        self.state.value()
    }
}

Config

When writing scientific code, you normally have a lot of values that are set, and Deep Learning is no exception. Python has the possibility to define default parameters for functions, which helps improve the developer experience. However, this has the downside of potentially breaking your code when upgrading to a new version, as the default values might change without your knowledge, making debugging very challenging.

With that in mind, we came up with the Config system. It's a simple Rust derive that you can apply to your types, allowing you to define default values with ease. Additionally, all configs can be serialized, reducing potential bugs when upgrading versions and improving reproducibility.

use burn::config::Config;

#[derive(Config)]
pub struct MyModuleConfig {
    d_model: usize,
    d_ff: usize,
    #[config(default = 0.1)]
    dropout: f64,
}

The derive also adds useful with_ methods for every attribute of your config, similar to a builder pattern, along with a save method.

fn main() {
    let config = MyModuleConfig::new(512, 2048);
    println!("{}", config.d_model); // 512
    println!("{}", config.d_ff); // 2048
    println!("{}", config.dropout); // 0.1
    let config =  config.with_dropout(0.2);
    println!("{}", config.dropout); // 0.2

    config.save("config.json").unwrap();
}

Good practices

By using the config type it is easy to create new module instances. The initialization method should be implemented on the config type with the device as argument.

impl MyModuleConfig {
    /// Create a module on the given device.
    pub fn init<B: Backend>(&self, device: &B::Device) -> MyModule {
        MyModule {
            linear: LinearConfig::new(self.d_model, self.d_ff).init(device),
            dropout: DropoutConfig::new(self.dropout).init(),
        }
    }
}

Then we could add this line to the above main:

use burn::backend::Wgpu;
let device = Default::default();
let my_module = config.init::<Wgpu>(&device);

Record

Records are how states are saved with Burn. Compared to most other frameworks, Burn has its own advanced saving mechanism that allows interoperability between backends with minimal possible runtime errors. There are multiple reasons why Burn decided to create its own saving formats.

First, Rust has serde, which is an extremely well-developed serialization and deserialization library that also powers the safetensors format developed by Hugging Face. If used properly, all the validations are done when deserializing, which removes the need to write validation code. Since modules in Burn are created with configurations, they can't implement serialization and deserialization. That's why the record system was created: allowing you to save the state of modules independently of the backend in use extremely fast while still giving you all the flexibility possible to include any non-serializable field within your module.

Why not use safetensors?

safetensors uses serde with the JSON file format and only supports serializing and deserializing tensors. The record system in Burn gives you the possibility to serialize any type, which is very useful for optimizers that save their state, but also for any non-standard, cutting-edge modeling needs you may have. Additionally, the record system performs automatic precision conversion by using Rust types, making it more reliable with fewer manual manipulations.

It is important to note that the safetensors format uses the word safe to distinguish itself from Pickle, which is vulnerable to Python code injection. On our end, the simple fact that we use Rust already ensures that no code injection is possible. If your storage mechanism doesn't handle data corruption, you might prefer a recorder that performs checksum validation (i.e., any recorder with Gzip compression).

Recorder

Recorders are independent of the backend and serialize records with precision and a format. Note that the format can also be in-memory, allowing you to save the records directly into bytes.

RecorderFormatCompression
DefaultFileRecorderFile - Named MessagePackNone
NamedMpkFileRecorderFile - Named MessagePackNone
NamedMpkGzFileRecorderFile - Named MessagePackGzip
BinFileRecorderFile - BinaryNone
BinGzFileRecorderFile - BinaryGzip
JsonGzFileRecorderFile - JsonGzip
PrettyJsonFileRecorderFile - Pretty JsonGzip
BinBytesRecorderIn Memory - BinaryNone

Each recorder supports precision settings decoupled from the precision used for training or inference. These settings allow you to define the floating-point and integer types that will be used for serialization and deserialization.

SettingFloat PrecisionInteger Precision
DoublePrecisionSettingsf64i64
FullPrecisionSettingsf32i32
HalfPrecisionSettingsf16i16

Note that when loading a record into a module, the type conversion is automatically handled, so you can't encounter errors. The only crucial aspect is using the same recorder for both serialization and deserialization; otherwise, you will encounter loading errors.

Which recorder should you use?

  • If you want fast serialization and deserialization, choose a recorder without compression. The one with the lowest file size without compression is the binary format; otherwise, the named MessagePack could be used.
  • If you want to save models for storage, you can use compression, but avoid using the binary format, as it may not be backward compatible.
  • If you want to debug your model's weights, you can use the pretty JSON format.
  • If you want to deploy with no-std, use the in-memory binary format and include the bytes with the compiled code.

For examples on saving and loading records, take a look at Saving and Loading Models.

Dataset

At its core, a dataset is a collection of data typically related to a specific analysis or processing task. The data modality can vary depending on the task, but most datasets primarily consist of images, texts, audio or videos.

This data source represents an integral part of machine learning to successfully train a model. Thus, it is essential to provide a convenient and performant API to handle your data. Since this process varies wildly from one problem to another, it is defined as a trait that should be implemented on your type. The dataset trait is quite similar to the dataset abstract class in PyTorch:

pub trait Dataset<I>: Send + Sync {
    fn get(&self, index: usize) -> Option<I>;
    fn len(&self) -> usize;
}

The dataset trait assumes a fixed-length set of items that can be randomly accessed in constant time. This is a major difference from datasets that use Apache Arrow underneath to improve streaming performance. Datasets in Burn don't assume how they are going to be accessed; it's just a collection of items.

However, you can compose multiple dataset transformations to lazily obtain what you want with zero pre-processing, so that your training can start instantly!

Transformation

Transformations in Burn are all lazy and modify one or multiple input datasets. The goal of these transformations is to provide you with the necessary tools so that you can model complex data distributions.

TransformationDescription
SamplerDatasetSamples items from a dataset. This is a convenient way to model a dataset as a probability distribution of a fixed size.
ShuffledDatasetMaps each input index to a random index, similar to a dataset sampled without replacement.
PartialDatasetReturns a view of the input dataset with a specified range.
MapperDatasetComputes a transformation lazily on the input dataset.
ComposedDatasetComposes multiple datasets together to create a larger one without copying any data.
WindowsDatasetDataset designed to work with overlapping windows of data extracted from an input dataset.

Let us look at the basic usages of each dataset transform and how they can be composed together. These transforms are lazy by default except when specified, reducing the need for unnecessary intermediate allocations and improving performance. The full documentation of each transform can be found at the API reference.

  • SamplerDataset: This transform can be used to sample items from a dataset with (default) or without replacement. Transform is initialized with a sampling size which can be bigger or smaller than the input dataset size. This is particularly useful in cases where we want to checkpoint larger datasets more often during training and smaller datasets less often as the size of an epoch is now controlled by the sampling size. Sample usage:
type DbPedia = SqliteDataset<DbPediaItem>;
let dataset: DbPedia = HuggingfaceDatasetLoader::new("dbpedia_14")
        .dataset("train").
        .unwrap();

let dataset = SamplerDataset<DbPedia, DbPediaItem>::new(dataset, 10000);
  • ShuffledDataset: This transform can be used to shuffle the items of a dataset. Particularly useful before splitting the raw dataset into train/test splits. Can be initialized with a seed to ensure reproducibility.
let dataset = ShuffledDataset<DbPedia, DbPediaItem>::with_seed(dataset, 42);
  • PartialDataset: This transform is useful to return a view of the dataset with specified start and end indices. Used to create train/val/test splits. In the example below, we show how to chain ShuffledDataset and PartialDataset to create splits.
// define chained dataset type here for brevity
type PartialData = PartialDataset<ShuffledDataset<DbPedia, DbPediaItem>>;
let len = dataset.len();
let split == "train"; // or "val"/"test"

let data_split = match split {
            "train" => PartialData::new(dataset, 0, len * 8 / 10), // Get first 80% dataset
            "test" => PartialData::new(dataset, len * 8 / 10, len), // Take remaining 20%
            _ => panic!("Invalid split type"),                     // Handle unexpected split types
        };
  • MapperDataset: This transform is useful to apply a transformation on each of the items of a dataset. Particularly useful for normalization of image data when channel means are known.

  • ComposedDataset: This transform is useful to compose multiple datasets downloaded from multiple sources (say different HuggingfaceDatasetLoader sources) into a single bigger dataset which can be sampled from one source.

  • WindowsDataset: This transform is useful to create overlapping windows of a dataset. Particularly useful for sequential Time series Data, for example when working with an LSTM.

Storage

There are multiple dataset storage options available for you to choose from. The choice of the dataset to use should be based on the dataset's size as well as its intended purpose.

StorageDescription
InMemDatasetIn-memory dataset that uses a vector to store items. Well-suited for smaller datasets.
SqliteDatasetDataset that uses SQLite to index items that can be saved in a simple SQL database file. Well-suited for larger datasets.
DataframeDatasetDataset that uses Polars dataframe to store and manage data. Well-suited for efficient data manipulation and analysis.

Sources

For now, there are only a couple of dataset sources available with Burn, but more to come!

Hugging Face

You can easily import any Hugging Face dataset with Burn. We use SQLite as the storage to avoid downloading the model each time or starting a Python process. You need to know the format of each item in the dataset beforehand. Here's an example with the dbpedia dataset.

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DbPediaItem {
    pub title: String,
    pub content: String,
    pub label: usize,
}

fn main() {
    let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new("dbpedia_14")
        .dataset("train") // The training split.
        .unwrap();
}

We see that items must derive serde::Serialize, serde::Deserialize, Clone, and Debug, but those are the only requirements.

Images

ImageFolderDataset is a generic vision dataset used to load images from disk. It is currently available for multi-class and multi-label classification tasks.

// Create an image classification dataset from the root folder,
// where images for each class are stored in their respective folder.
//
// For example:
// root/dog/dog1.png
// root/dog/dog2.png
// ...
// root/cat/cat1.png
let dataset = ImageFolderDataset::new_classification("path/to/dataset/root").unwrap();
// Create a multi-label image classification dataset from a list of items,
// where each item is a tuple `(image path, labels)`, and a list of classes
// in the dataset.
//
// For example:
let items = vec![
    ("root/dog/dog1.png", vec!["animal".to_string(), "dog".to_string()]),
    ("root/cat/cat1.png", vec!["animal".to_string(), "cat".to_string()]),
];
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
    items,
    &["animal", "cat", "dog"],
)
.unwrap();

Comma-Separated Values (CSV)

Loading records from a simple CSV file in-memory is simple with the InMemDataset:

// Build dataset from csv with tab ('\t') delimiter.
// The reader can be configured for your particular file.
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.delimiter(b'\t');

let dataset = InMemDataset::from_csv("path/to/csv", rdr).unwrap();

Note that this requires the csv crate.

What about streaming datasets?

There is no streaming dataset API with Burn, and this is by design! The learner struct will iterate multiple times over the dataset and only checkpoint when done. You can consider the length of the dataset as the number of iterations before performing checkpointing and running the validation. There is nothing stopping you from returning different items even when called with the same index multiple times.

How Is The Dataset Used?

During training, the dataset is used to access the data samples and, for most use cases in supervised learning, their corresponding ground-truth labels. Remember that the Dataset trait implementation is responsible to retrieve the data from its source, usually some sort of data storage. At this point, the dataset could be naively iterated over to provide the model a single sample to process at a time, but this is not very efficient.

Instead, we collect multiple samples that the model can process as a batch to fully leverage modern hardware (e.g., GPUs - which have impressing parallel processing capabilities). Since each data sample in the dataset can be collected independently, the data loading is typically done in parallel to further speed things up. In this case, we parallelize the data loading using a multi-threaded BatchDataLoader to obtain a sequence of items from the Dataset implementation. Finally, the sequence of items is combined into a batched tensor that can be used as input to a model with the Batcher trait implementation. Other tensor operations can be performed during this step to prepare the batch data, as is done in the basic workflow guide. The process is illustrated in the figure below for the MNIST dataset.

Burn Data Loading Pipeline

Although we have conveniently implemented the MnistDataset used in the guide, we'll go over its implementation to demonstrate how the Dataset and Batcher traits are used.

The MNIST dataset of handwritten digits has a training set of 60,000 examples and a test set of 10,000 examples. A single item in the dataset is represented by a \(28 \times 28\) pixels black-and-white image (stored as raw bytes) with its corresponding label (a digit between \(0\) and \(9\)). This is defined by the MnistItemRaw struct.

#[derive(Deserialize, Debug, Clone)]
struct MnistItemRaw {
    pub image_bytes: Vec<u8>,
    pub label: u8,
}

With single-channel images of such low resolution, the entire training and test sets can be loaded in memory at once. Therefore, we leverage the already existing InMemDataset to retrieve the raw images and labels data. At this point, the image data is still just a bunch of bytes, but we want to retrieve the structured image data in its intended form. For that, we can define a MapperDataset that transforms the raw image bytes to a 2D array image (which we convert to float while we're at it).

const WIDTH: usize = 28;
const HEIGHT: usize = 28;

/// MNIST item.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MnistItem {
    /// Image as a 2D array of floats.
    pub image: [[f32; WIDTH]; HEIGHT],

    /// Label of the image.
    pub label: u8,
}

struct BytesToImage;

impl Mapper<MnistItemRaw, MnistItem> for BytesToImage {
    /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
    fn map(&self, item: &MnistItemRaw) -> MnistItem {
        // Ensure the image dimensions are correct.
        debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);

        // Convert the image to a 2D array of floats.
        let mut image_array = [[0f32; WIDTH]; HEIGHT];
        for (i, pixel) in item.image_bytes.iter().enumerate() {
            let x = i % WIDTH;
            let y = i / HEIGHT;
            image_array[y][x] = *pixel as f32;
        }

        MnistItem {
            image: image_array,
            label: item.label,
        }
    }
}

type MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;

/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
/// images per class. There are 60,000 training images and 10,000 test images.
///
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
pub struct MnistDataset {
    dataset: MappedDataset,
}

To construct the MnistDataset, the data source must be parsed into the expected MappedDataset type. Since both the train and test sets use the same file format, we can separate the functionality to load the train() and test() dataset.


impl MnistDataset {
    /// Creates a new train dataset.
    pub fn train() -> Self {
        Self::new("train")
    }

    /// Creates a new test dataset.
    pub fn test() -> Self {
        Self::new("test")
    }

    fn new(split: &str) -> Self {
        // Download dataset
        let root = MnistDataset::download(split);

        // Parse data as vector of images bytes and vector of labels
        let images: Vec<Vec<u8>> = MnistDataset::read_images(&root, split);
        let labels: Vec<u8> = MnistDataset::read_labels(&root, split);

        // Collect as vector of MnistItemRaw
        let items: Vec<_> = images
            .into_iter()
            .zip(labels)
            .map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })
            .collect();

        // Create the MapperDataset for InMemDataset<MnistItemRaw> to transform
        // items (MnistItemRaw -> MnistItem)
        let dataset = InMemDataset::new(items);
        let dataset = MapperDataset::new(dataset, BytesToImage);

        Self { dataset }
    }

   /// Download the MNIST dataset files from the web.
   /// Panics if the download cannot be completed or the content of the file cannot be written to disk.
   fn download(split: &str) -> PathBuf {
       // Dataset files are stored un the burn-dataset cache directory
       let cache_dir = dirs::home_dir()
           .expect("Could not get home directory")
           .join(".cache")
           .join("burn-dataset");
       let split_dir = cache_dir.join("mnist").join(split);

       if !split_dir.exists() {
           create_dir_all(&split_dir).expect("Failed to create base directory");
       }

       // Download split files
       match split {
           "train" => {
               MnistDataset::download_file(TRAIN_IMAGES, &split_dir);
               MnistDataset::download_file(TRAIN_LABELS, &split_dir);
           }
           "test" => {
               MnistDataset::download_file(TEST_IMAGES, &split_dir);
               MnistDataset::download_file(TEST_LABELS, &split_dir);
           }
           _ => panic!("Invalid split specified {}", split),
       };

       split_dir
   }

   /// Download a file from the MNIST dataset URL to the destination directory.
   /// File download progress is reported with the help of a [progress bar](indicatif).
   fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
       // Output file name
       let file_name = dest_dir.as_ref().join(name);

       if !file_name.exists() {
           // Download gzip file
           let bytes = download_file_as_bytes(&format!("{URL}{name}.gz"), name);

           // Create file to write the downloaded content to
           let mut output_file = File::create(&file_name).unwrap();

           // Decode gzip file content and write to disk
           let mut gz_buffer = GzDecoder::new(&bytes[..]);
           std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
       }

       file_name
   }

   /// Read images at the provided path for the specified split.
   /// Each image is a vector of bytes.
   fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
       let file_name = if split == "train" {
           TRAIN_IMAGES
       } else {
           TEST_IMAGES
       };
       let file_name = root.as_ref().join(file_name);

       // Read number of images from 16-byte header metadata
       let mut f = File::open(file_name).unwrap();
       let mut buf = [0u8; 4];
       let _ = f.seek(SeekFrom::Start(4)).unwrap();
       f.read_exact(&mut buf)
           .expect("Should be able to read image file header");
       let size = u32::from_be_bytes(buf);

       let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
       let _ = f.seek(SeekFrom::Start(16)).unwrap();
       f.read_exact(&mut buf_images)
           .expect("Should be able to read image file header");

       buf_images
           .chunks(WIDTH * HEIGHT)
           .map(|chunk| chunk.to_vec())
           .collect()
   }

   /// Read labels at the provided path for the specified split.
   fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
       let file_name = if split == "train" {
           TRAIN_LABELS
       } else {
           TEST_LABELS
       };
       let file_name = root.as_ref().join(file_name);

       // Read number of labels from 8-byte header metadata
       let mut f = File::open(file_name).unwrap();
       let mut buf = [0u8; 4];
       let _ = f.seek(SeekFrom::Start(4)).unwrap();
       f.read_exact(&mut buf)
           .expect("Should be able to read label file header");
       let size = u32::from_be_bytes(buf);

       let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
       let _ = f.seek(SeekFrom::Start(8)).unwrap();
       f.read_exact(&mut buf_labels)
           .expect("Should be able to read labels from file");

       buf_labels
   }
}

Since the MnistDataset simply wraps a MapperDataset instance with InMemDataset, we can easily implement the Dataset trait.

impl Dataset<MnistItem> for MnistDataset {
    fn get(&self, index: usize) -> Option<MnistItem> {
        self.dataset.get(index)
    }

    fn len(&self) -> usize {
        self.dataset.len()
    }
}

The only thing missing now is the Batcher, which we already went over in the basic workflow guide. The Batcher takes a list of MnistItem retrieved by the dataloader as input and returns a batch of images as a 3D tensor along with their targets.

Custom Training Loops

Even though Burn comes with a project dedicated to simplifying training, it doesn't mean that you have to use it. Sometimes you may have special needs for your training, and it might be faster to just reimplement the training loop yourself. Also, you may just prefer implementing your own training loop instead of using a pre-built one in general.

Burn's got you covered!

We will start from the same example shown in the basic workflow section, but without using the Learner struct.

#[derive(Config)]
pub struct MnistTrainingConfig {
    #[config(default = 10)]
    pub num_epochs: usize,
    #[config(default = 64)]
    pub batch_size: usize,
    #[config(default = 4)]
    pub num_workers: usize,
    #[config(default = 42)]
    pub seed: u64,
    #[config(default = 1e-4)]
    pub lr: f64,
    pub model: ModelConfig,
    pub optimizer: AdamConfig,
}

pub fn run<B: AutodiffBackend>(device: &B::Device) {
    // Create the configuration.
    let config_model = ModelConfig::new(10, 1024);
    let config_optimizer = AdamConfig::new();
    let config = MnistTrainingConfig::new(config_model, config_optimizer);

    B::seed(config.seed);

    // Create the model and optimizer.
    let mut model = config.model.init(device);
    let mut optim = config.optimizer.init();

    // Create the batcher.
    let batcher_train = MnistBatcher::<B>::new(device.clone());
    let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());

    // Create the dataloaders.
    let dataloader_train = DataLoaderBuilder::new(batcher_train)
        .batch_size(config.batch_size)
        .shuffle(config.seed)
        .num_workers(config.num_workers)
        .build(MnistDataset::train());

    let dataloader_test = DataLoaderBuilder::new(batcher_valid)
        .batch_size(config.batch_size)
        .shuffle(config.seed)
        .num_workers(config.num_workers)
        .build(MnistDataset::test());

    ...
}

As seen with the previous example, setting up the configurations and the dataloader hasn't changed. Now, let's move forward and write our own training loop:

pub fn run<B: AutodiffBackend>(device: B::Device) {
    ...

    // Iterate over our training and validation loop for X epochs.
    for epoch in 1..config.num_epochs + 1 {
        // Implement our training loop.
        for (iteration, batch) in dataloader_train.iter().enumerate() {
            let output = model.forward(batch.images);
            let loss = CrossEntropyLoss::new(None, &output.device())
                .forward(output.clone(), batch.targets.clone());
            let accuracy = accuracy(output, batch.targets);

            println!(
                "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %",
                epoch,
                iteration,
                loss.clone().into_scalar(),
                accuracy,
            );

            // Gradients for the current backward pass
            let grads = loss.backward();
            // Gradients linked to each parameter of the model.
            let grads = GradientsParams::from_grads(grads, &model);
            // Update the model using the optimizer.
            model = optim.step(config.lr, model, grads);
        }

        // Get the model without autodiff.
        let model_valid = model.valid();

        // Implement our validation loop.
        for (iteration, batch) in dataloader_test.iter().enumerate() {
            let output = model_valid.forward(batch.images);
            let loss = CrossEntropyLoss::new(None, &output.device())
                .forward(output.clone(), batch.targets.clone());
            let accuracy = accuracy(output, batch.targets);

            println!(
                "[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}",
                epoch,
                iteration,
                loss.clone().into_scalar(),
                accuracy,
            );
        }
    }
}

In the previous code snippet, we can observe that the loop starts from epoch 1 and goes up to num_epochs. Within each epoch, we iterate over the training dataloader. During this process, we execute the forward pass, which is necessary for computing both the loss and accuracy. To maintain simplicity, we print the results to stdout.

Upon obtaining the loss, we can invoke the backward() function, which returns the gradients specific to each variable. It's important to note that we need to map these gradients to their corresponding parameters using the GradientsParams type. This step is essential because you might run multiple different autodiff graphs and accumulate gradients for each parameter id.

Finally, we can perform the optimization step using the learning rate, the model, and the computed gradients. It's worth mentioning that, unlike PyTorch, there's no need to register the gradients with the optimizer, nor do you have to call zero_grad. The gradients are automatically consumed during the optimization step. If you're interested in gradient accumulation, you can easily achieve this by using the GradientsAccumulator.

let mut accumulator = GradientsAccumulator::new();
let grads = model.backward();
let grads = GradientsParams::from_grads(grads, &model);
accumulator.accumulate(&model, grads); ...
let grads = accumulator.grads(); // Pop the accumulated gradients.

Note that after each epoch, we include a validation loop to assess our model's performance on previously unseen data. To disable gradient tracking during this validation step, we can invoke model.valid(), which provides a model on the inner backend without autodiff capabilities. It's important to emphasize that we've declared our validation batcher to be on the inner backend, specifically MnistBatcher<B::InnerBackend>; not using model.valid() will result in a compilation error.

You can find the code above available as an example for you to test.

Multiple optimizers

It's common practice to set different learning rates, optimizer parameters, or use different optimizers entirely, for different parts of a model. In Burn, each GradientParams can contain only a subset of gradients to actually apply with an optimizer. This allows you to flexibly mix and match optimizers!

// Start with calculating all gradients
let grads = loss.backward();

// Now split the gradients into various parts.
let grads_conv1 = GradientParams::from_module(&mut grads, &model.conv1);
let grads_conv2 = GradientParams::from_module(&mut grads, &model.conv2);

// You can step the model with these gradients, using different learning
// rates for each param. You could also use an entirely different optimizer here!
model = optim.step(config.lr * 2.0, model, grads_conv1);
model = optim.step(config.lr * 4.0, model, grads_conv2);

// For even more granular control you can split off individual parameter
// eg. a linear bias usually needs a smaller learning rate.
if let Some(bias) == model.linear1.bias {
    let grads_bias = GradientParams::from_params(&mut grads, &model.linear1, &[bias.id]);
    model = optim.step(config.lr * 0.1, model, grads_bias);
}

// Note that above calls remove gradients, so we can just get all "remaining" gradients.
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(config.lr, model, grads);

Custom Type

The explanations above demonstrate how to create a basic training loop. However, you may find it beneficial to organize your program using intermediary types. There are various ways to do this, but it requires getting comfortable with generics.

If you wish to group the optimizer and the model into the same structure, you have several options. It's important to note that the optimizer trait depends on both the AutodiffModule trait and the AutodiffBackend trait, while the module only depends on the AutodiffBackend trait.

Here's a closer look at how you can create your types:

Create a struct that is generic over the backend and the optimizer, with a predefined model.

struct Learner<B, O>
where
    B: AutodiffBackend,
{
    model: Model<B>,
    optim: O,
}

This is quite straightforward. You can be generic over the backend since it's used with the concrete type Model in this case.

Create a struct that is generic over the model and the optimizer.

struct Learner<M, O> {
    model: M,
    optim: O,
}

This option is a quite intuitive way to declare the struct. You don't need to write type constraints with a where statement when defining a struct; you can wait until you implement the actual function. However, with this struct, you may encounter some issues when trying to implement code blocks to your struct.

impl<B, M, O> Learner<M, O>
where
    B: AutodiffBackend,
    M: AutodiffModule<B>,
    O: Optimizer<M, B>,
{
    pub fn step(&mut self, _batch: MnistBatch<B>) {
        //
    }
}

This will result in the following compilation error:

1. the type parameter `B` is not constrained by the impl trait, self type, or predicates
   unconstrained type parameter [E0207]

To resolve this issue, you have two options. The first one is to make your function generic over the backend and add your trait constraint within its definition:

#[allow(dead_code)]
impl<M, O> Learner2<M, O> {
    pub fn step<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)
    where
        B: AutodiffBackend,
        M: AutodiffModule<B>,
        O: Optimizer<M, B>,
    {
        //
    }
}

However, some people may prefer to have the constraints on the implementation block itself. In that case, you can make your struct generic over the backend using PhantomData<B>.

Create a struct that is generic over the backend, the model, and the optimizer.

struct Learner3<B, M, O> {
    model: M,
    optim: O,
    _b: PhantomData<B>,
}

You might wonder why PhantomData is required. Each generic argument must be used as a field when declaring a struct. When you don't need the generic argument, you can use PhantomData to mark it as a zero sized type.

These are just some suggestions on how to define your own types, but you are free to use any pattern that you prefer.

Saving and Loading Models

Saving your trained machine learning model is quite easy, no matter the output format you choose. As mentioned in the Record section, different formats are supported to serialize/deserialize models. By default, we use the NamedMpkFileRecorder which uses the MessagePack binary serialization format with the help of smp_serde.

// Save model in MessagePack format with full precision
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
    .save_file(model_path, &recorder)
    .expect("Should be able to save the model");

Note that the file extension is automatically handled by the recorder depending on the one you choose. Therefore, only the file path and base name should be provided.

Now that you have a trained model saved to your disk, you can easily load it in a similar fashion.

// Load model in full precision from MessagePack file
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
    .load_file(model_path, &recorder, device)
    .expect("Should be able to load the model weights from the provided file");

Note: models can be saved in different output formats, just make sure you are using the correct recorder type when loading the saved model. Type conversion between different precision settings is automatically handled, but formats are not interchangeable. A model can be loaded from one format and saved to another format, just as long as you load it back with the new recorder type afterwards.

Initialization from Recorded Weights

The most straightforward way to load weights for a module is simply by using the generated method load_record. Note that parameter initialization is lazy, therefore no actual tensor allocation and GPU/CPU kernels are executed before the module is used. This means that you can use init(device) followed by load_record(record) without any meaningful performance cost.

// Create a dummy initialized model to save
let device = Default::default();
let model = Model::<MyBackend>::init(&device);

// Save model in MessagePack format with full precision
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
    .save_file(model_path, &recorder)
    .expect("Should be able to save the model");

Afterwards, the model can just as easily be loaded from the record saved on disk.

// Load model record on the backend's default device
let record: ModelRecord<MyBackend> = NamedMpkFileRecorder::<FullPrecisionSettings>::new()
    .load(model_path.into(), &device)
    .expect("Should be able to load the model weights from the provided file");

// Initialize a new model with the loaded record/weights
let model = Model::init(&device).load_record(record);

No Storage, No Problem!

For applications where file storage may not be available (or desired) at runtime, you can use the BinBytesRecorder.

In the previous examples we used a FileRecorder based on the MessagePack format, which could be replaced with another file recorder of your choice. To embed a model as part of your runtime application, first save the model to a binary file with BinFileRecorder.

// Save model in binary format with full precision
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
model
    .save_file(model_path, &recorder)
    .expect("Should be able to save the model");

Then, in your final application, include the model and use the BinBytesRecorder to load it.

Embedding the model as part of your application is especially useful for smaller models but not recommended for very large models as it would significantly increase the binary size as well as consume a lot more memory at runtime.

// Include the model file as a reference to a byte array
static MODEL_BYTES: &[u8] = include_bytes!("path/to/model.bin");

// Load model binary record in full precision
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
    .load(MODEL_BYTES.to_vec(), device)
    .expect("Should be able to load model the model weights from bytes");

// Load that record with the model
model.load_record(record);

This example assumes that the model was already created before loading the model record. If instead you want to skip the random initialization and directly initialize the weights with the provided record, you could adapt this like the previous example.

Importing Models

The Burn project supports the import of models from various frameworks, emphasizing efficiency and compatibility. Currently, it handles two primary model formats:

  1. ONNX: Facilitates direct import, ensuring the model's performance and structure are maintained.

  2. PyTorch: Enables the loading of PyTorch model weights into Burn’s native model architecture, ensuring seamless integration.

Importing ONNX Models in Burn

Table of Contents

  1. Introduction
  2. Why Import Models?
  3. Understanding ONNX
  4. Burn's ONNX Support
  5. Step-by-Step Guide
  6. Advanced Configuration
  7. Loading and Using Models
  8. Troubleshooting
  9. Examples and Resources
  10. Conclusion

Introduction

As the field of deep learning continues to evolve, the need for interoperability between different frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust, recognizes this need and provides robust support for importing models from other popular frameworks. This section focuses on importing ONNX (Open Neural Network Exchange) models into Burn, enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep learning projects.

Why Import Models?

Importing pre-trained models offers several advantages:

  1. Time-saving: Avoid the need to train models from scratch, which can be time-consuming and resource-intensive.
  2. Access to state-of-the-art architectures: Utilize cutting-edge models developed by researchers and industry leaders.
  3. Transfer learning: Fine-tune imported models for your specific tasks, benefiting from knowledge transfer.
  4. Consistency across frameworks: Ensure consistent performance when moving from one framework to another.

Understanding ONNX

ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models. Key features include:

  • Framework agnostic: ONNX provides a common format that works across various deep learning frameworks.
  • Comprehensive representation: It captures both the model architecture and trained weights.
  • Wide support: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX export.

By using ONNX, you can easily move models between different frameworks and deployment environments.

Burn's ONNX Support

Burn takes a unique approach to ONNX import, offering several advantages:

  1. Native Rust code generation: ONNX models are translated into Rust source code, allowing for deep integration with Burn's ecosystem.
  2. Compile-time optimization: The generated Rust code can be optimized by the Rust compiler, potentially improving performance.
  3. No runtime dependency: Unlike some solutions that require an ONNX runtime, Burn's approach eliminates this dependency.
  4. Trainability: Imported models can be further trained or fine-tuned using Burn.
  5. Portability: The generated Rust code can be compiled for various targets, including WebAssembly and embedded devices.
  6. Any Burn Backend: The imported models can be used with any of Burn's backends.

Step-by-Step Guide

Let's walk through the process of importing an ONNX model into a Burn project:

Step 1: Update build.rs

First, add the burn-import crate to your Cargo.toml:

[build-dependencies]
burn-import = "~0.15"

Then, in your build.rs file:

use burn_import::onnx::ModelGen;

fn main() {
    ModelGen::new()
        .input("src/model/my_model.onnx")
        .out_dir("model/")
        .run_from_script();
}

This script uses ModelGen to generate Rust code from your ONNX model during the build process.

Step 2: Modify mod.rs

In your src/model/mod.rs file, include the generated code:

#![allow(unused)]
fn main() {
pub mod my_model {
    include!(concat!(env!("OUT_DIR"), "/model/my_model.rs"));
}
}

This makes the generated model code available in your project.

Step 3: Use the Imported Model

Now you can use the imported model in your Rust code:

use burn::tensor;
use burn_ndarray::{NdArray, NdArrayDevice};
use model::my_model::Model;

fn main() {
    let device = NdArrayDevice::default();

    // Create model instance and load weights from target dir default device.
    // (see more load options below in "Loading and Using Models" section)
    let model: Model<NdArray<f32>> = Model::default();

    // Create input tensor (replace with your actual input)
    let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 3, 224, 224], &device);

    // Perform inference
    let output = model.forward(input);

    println!("Model output: {:?}", output);
}

Advanced Configuration

The ModelGen struct offers several configuration options:

#![allow(unused)]
fn main() {
ModelGen::new()
    .input("path/to/model.onnx")
    .out_dir("model/")
    .record_type(RecordType::NamedMpk)
    .half_precision(false)
    .embed_states(false)
    .run_from_script();
}
  • record_type: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or PrettyJson).
  • half_precision: Use half-precision (f16) for weights to reduce model size.
  • embed_states: Embed model weights directly in the generated Rust code. Note: This requires record type Bincode.

Loading and Using Models

Depending on your configuration, you can load models in different ways:

#![allow(unused)]
fn main() {
// Create a new model instance with device. Initializes weights randomly and lazily.
// You can load weights via `load_record` afterwards.
let model = Model::<Backend>::new(&device);

// Load from a file (must specify weights file in the target output directory or copy it from there).
// File type should match the record type specified in `ModelGen`.
let model = Model::<Backend>::from_file("path/to/weights", &device);

// Load from embedded weights (if embed_states was true)
let model = Model::<Backend>::from_embedded();

// Load from the out director location and load to default device (useful for testing)
let model = Model::<Backend>::default();
}

Troubleshooting

Here are some common issues and their solutions:

  1. Unsupported ONNX operator: If you encounter an error about an unsupported operator, check the list of supported ONNX operators. You may need to simplify your model or wait for support to be added.

  2. Build errors: Ensure that your burn-import version matches your Burn version. Also, check that the ONNX file path in build.rs is correct.

  3. Runtime errors: If you get errors when running your model, double-check that your input tensors match the expected shape and data type of your model.

  4. Performance issues: If your imported model is slower than expected, try using the half_precision option to reduce memory usage, or experiment with different record_type options.

  5. Artifact Files: You can view the generated Rust code and weights files in the OUT_DIR directory specified in build.rs (usually target/debug/build/<project>/out).

Examples and Resources

For more detailed examples, check out:

  1. MNIST Inference Example
  2. SqueezeNet Image Classification

These examples demonstrate real-world usage of ONNX import in Burn projects.

Conclusion

Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's safety features. By following this guide, you should be able to seamlessly integrate ONNX models into your Burn projects, whether for inference, fine-tuning, or as a starting point for further development.

Remember that the burn-import crate is actively developed, with ongoing work to support more ONNX operators and improve performance. Stay tuned to the Burn repository for updates and new features!


🚨Note: The burn-import crate is in active development. For the most up-to-date information on supported ONNX operators, please refer to the official documentation.

PyTorch Model

Introduction

Whether you've trained your model in PyTorch or you want to use a pre-trained model from PyTorch, you can import them into Burn. Burn supports importing PyTorch model weights with .pt file extension. Compared to ONNX models, .pt files only contain the weights of the model, so you will need to reconstruct the model architecture in Burn.

Here in this document we will show the full workflow of exporting a PyTorch model and importing it. Also you can refer to this Transitioning From PyTorch to Burn tutorial on importing a more complex model.

How to export a PyTorch model

If you have a PyTorch model that you want to import into Burn, you will need to export it first, unless you are using a pre-trained published model. To export a PyTorch model, you can use the torch.save function.

The following is an example of how to export a PyTorch model:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(2, 2, (2,2))
        self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

if __name__ == "__main__":
    torch.manual_seed(42)  # To make it reproducible
    model = Net().to(torch.device("cpu"))
    model_weights = model.state_dict()
    torch.save(model_weights, "conv2d.pt")

Use Netron to view the exported model. You should see something like this:

image alt>

How to import a PyTorch model

  1. Define the model in Burn:

    #![allow(unused)]
    fn main() {
    use burn::{
        nn::conv::{Conv2d, Conv2dConfig},
        prelude::*,
    };
    
    #[derive(Module, Debug)]
    pub struct Net<B: Backend> {
        conv1: Conv2d<B>,
        conv2: Conv2d<B>,
    }
    
    impl<B: Backend> Net<B> {
        /// Create a new model.
        pub fn init(device: &B::Device) -> Self {
            let conv1 = Conv2dConfig::new([2, 2], [2, 2])
                .init(device);
            let conv2 = Conv2dConfig::new([2, 2], [2, 2])
                .with_bias(false)
                .init(device);
            Self { conv1, conv2 }
        }
    
        /// Forward pass of the model.
        pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
            let x = self.conv1.forward(x);
            self.conv2.forward(x)
        }
    }
    }
  2. Load the model weights from the exported PyTorch model (2 options):

    a) Dynamically, but this requires burn-import runtime dependency:

    use crate::model;
    
    use burn::record::{FullPrecisionSettings, Recorder};
    use burn_import::pytorch::PyTorchFileRecorder;
    
    type Backend = burn_ndarray::NdArray<f32>;
    
    fn main() {
        let device = Default::default();
        let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
            .load("./conv2d.pt".into(), &device)
            .expect("Should decode state successfully");
    
        let model = model::Net::<Backend>::init(&device).load_record(record);
    }

    b) Pre-converted to Burn's binary format:

    // Convert the PyTorch model to Burn's binary format in
    // build.rs or in a separate executable. Then, include the generated file
    // in your project. See `examples/pytorch-import` for an example.
    
    use crate::model;
    
    use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
    use burn_import::pytorch::PyTorchFileRecorder;
    
    type Backend = burn_ndarray::NdArray<f32>;
    
    fn main() {
        let device = Default::default();
        let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        let record: model::NetRecord<B> = recorder
            .load("./conv2d.pt".into(), &device)
            .expect("Should decode state successfully");
    
        // Save the model record to a file.
        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
        recorder
            .record(record, "MY_FILE_OUTPUT_PATH".into())
            .expect("Failed to save model record");
    }
    
    /// Load the model from the file in your source code (not in build.rs or script).
    fn load_model() -> Net::<Backend> {
        let device = Default::default();
        let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
            .load("./MY_FILE_OUTPUT_PATH".into(), &device)
            .expect("Should decode state successfully");
    
        Net::<Backend>::init(&device).load_record(record)
    }

Extract Configuration

In some cases, models may require additional configuration settings, which are often included in a .pt file during export. The config_from_file function from the burn-import cargo package allows for the extraction of these configurations directly from the .pt file. The extracted configuration can then be used to initialize the model in Burn. Here is an example of how to extract the configuration from a .pt file:

use std::collections::HashMap;

use burn::config::Config;
use burn_import::pytorch::config_from_file;

#[derive(Debug, Config)]
struct NetConfig {
    n_head: usize,
    n_layer: usize,
    d_model: usize,
    // Candle's pickle has a bug with float serialization
    // https://github.com/huggingface/candle/issues/1729
    // some_float: f64,
    some_int: i32,
    some_bool: bool,
    some_str: String,
    some_list_int: Vec<i32>,
    some_list_str: Vec<String>,
    // Candle's pickle has a bug with float serialization
    // https://github.com/huggingface/candle/issues/1729
    // some_list_float: Vec<f64>,
    some_dict: HashMap<String, String>,
}

fn main() {
    let path = "weights_with_config.pt";
    let top_level_key = Some("my_config");
    let config: NetConfig = config_from_file(path, top_level_key).unwrap();
    println!("{:#?}", config);

    // After extracting, it's recommended you save it as a json file.
    config.save("my_config.json").unwrap();
}

Troubleshooting

Adjusting the source model architecture

If your target model differs structurally from the model you exported, PyTorchFileRecorder allows you to change the attribute names and the order of the attributes. For example, if you exported a model with the following structure:

class ConvModule(nn.Module):
    def __init__(self):
        super(ConvModule, self).__init__()
        self.conv1 = nn.Conv2d(2, 2, (2,2))
        self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = ConvModule()

    def forward(self, x):
        x = self.conv(x)
        return x

But you need to import it into a model with the following structure:

#![allow(unused)]
fn main() {
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
    conv1: Conv2d<B>,
    conv2: Conv2d<B>,
}
}

Which produces the following weights structure (viewed in Netron):

image alt>

You can use the PyTorchFileRecorder to change the attribute names and the order of the attributes by specifying a regular expression (See regex::Regex::replace and try it online) to match the attribute name and a replacement string in LoadArgs:

#![allow(unused)]
fn main() {
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
    // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
    .with_key_remap("conv\\.(.*)", "$1");

let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
    .load(load_args, &device)
    .expect("Should decode state successfully");

let model = Net::<Backend>::init(&device).load_record(record);
}

Printing the source model keys and tensor information

If you are unsure about the keys in the source model, you can print them using the following code:

#![allow(unused)]
fn main() {
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
    // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
    .with_key_remap("conv\\.(.*)", "$1")
    .with_debug_print(); // Print the keys and remapped keys

let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
    .load(load_args, &device)
    .expect("Should decode state successfully");

let model = Net::<Backend>::init(&device).load_record(record);
}

Here is an example of the output:

Debug information of keys and tensor shapes:
---
Original Key: conv.conv1.bias
Remapped Key: conv1.bias
Shape: [2]
Dtype: F32
---
Original Key: conv.conv1.weight
Remapped Key: conv1.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
Original Key: conv.conv2.weight
Remapped Key: conv2.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---

Non-contiguous indices in the source model

Sometimes the indices of the source model are non-contiguous. For example, the source model has:

"model.ax.0.weight",
"model.ax.0.bias",
"model.ax.2.weight",
"model.ax.2.bias",
"model.ax.4.weight",
"model.ax.4.bias",
"model.ax.6.weight",
"model.ax.6.bias",
"model.ax.8.weight",
"model.ax.8.bias",

This may occur when model.ax attribute (in the above example) uses Sequential to define the layers and the skipped items do not have weight tensors, such as a ReLU layer. PyTorch simply skips the layers without weight tensors, resulting in non-contiguous indices. In this case, PyTorchFileRecorder automatically corrects the indices to be contiguous preserving the order of the weights resulting in:

"model.ax.0.weight",
"model.ax.0.bias",
"model.ax.1.weight",
"model.ax.1.bias",
"model.ax.2.weight",
"model.ax.2.bias",
"model.ax.3.weight",
"model.ax.3.bias",
"model.ax.4.weight",
"model.ax.4.bias",

Loading the model weights to a partial model

PyTorchFileRecorder enables selective weight loading into partial models. For instance, in a model with both an encoder and a decoder, it's possible to load only the encoder weights. This is done by defining the encoder in Burn, allowing the loading of its weights while excluding the decoder's.

Specifying the top-level key for state_dict

Sometimes the state_dict is nested under a top-level key along with other metadata as in a general checkpoint. For example, the state_dict of the whisper model is nested under model_state_dict key. In this case, you can specify the top-level key in LoadArgs:

#![allow(unused)]
fn main() {
let device = Default::default();
let load_args = LoadArgs::new("tiny.en.pt".into())
    .with_top_level_key("my_state_dict");

let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
    .load(load_args, &device)
    .expect("Should decode state successfully")
}

Models containing enum modules

Burn supports models containing enum modules with new-type variants (tuple with one item). Importing weights for such models is automatically supported by the PyTorchFileRecorder. However, it should be noted that since the source weights file does not contain the enum variant information, the enum variant is picked based on the enum variant type. Let's consider the following example:

#![allow(unused)]
fn main() {
#[derive(Module, Debug)]
pub enum Conv<B: Backend> {
    DwsConv(DwsConv<B>),
    Conv(Conv2d<B>),
}

#[derive(Module, Debug)]
pub struct DwsConv<B: Backend> {
    dconv: Conv2d<B>,
    pconv: Conv2d<B>,
}

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
    conv: Conv<B>,
}
}

If the source weights file contains weights for DwsConv, such as the following keys:

---
Key: conv.dconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.dconv.weight
Shape: [2, 1, 3, 3]
Dtype: F32
---
Key: conv.pconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.pconv.weight
Shape: [2, 2, 1, 1]
Dtype: F32

The weights will be imported into the DwsConv variant of the Conv enum module.

If the variant types are identical, then the first variant is picked. Generally, it won't be a problem since the variant types are usually different.

Current known issues

  1. Candle's pickle does not currently unpack boolean tensors.

Models and Pre-Trained Weights

The models repository contains definitions of different deep learning models with examples for different domains like computer vision and natural language processing.

This includes image classification models such as MobileNetV2, SqueezeNet and ResNet, object detection models such as YOLOX and language models like BERT and RoBERTa.

Be sure to check out the up-to-date collection of models to get you started. Pre-trained weights are available for every supported architecture in this collection. You will also find a spotlight of community contributed models.

Quantization (Beta)

Quantization techniques perform computations and store tensors in lower precision data types like 8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep learning model categorized as:

  • Post-training quantization (PTQ)
  • Quantization aware training (QAT)

In post-training quantization, the model is trained in floating point precision and later converted to the lower precision data type.

There are two types of post-training quantization:

  1. Static quantization: quantizes the weights and activations of the model. Quantizing the activations statically requires data to be calibrated (i.e., recording the activation values to compute the optimal quantization parameters with representative data).
  2. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the activations are dynamically at runtime.

Sometimes post-training quantization is not able to achieve acceptable task accuracy. This is where quantization aware training comes into play, as it models the effects of quantization during training. Quantization errors are thus modeled in the forward and backward passes using fake quantization modules, which helps the model learn representations that are more robust to the reduction in precision.

Quantization support in Burn is currently in active development.

It supports the following modes on some backends:

  • Static per-tensor quantization to signed 8-bit integer (i8)

No integer operations are currently supported, which means tensors are dequantized to perform the operations in floating point precision.

Module Quantization

Quantizing the weights of your model after training is quite simple. We have access to the weight tensors and can collect their statistics, such as the min and max value when using MinMaxCalibration, to compute the quantization parameters.

use burn::module::Quantizer;
use burn::tensor::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType};

// Quantization config
let mut quantizer = Quantizer {
    calibration: MinMaxCalibration {},
    scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8),
};

// Quantize the weights
let model = model.quantize_weights(&mut quantizer);

Given that all operations are currently performed in floating point precision, it might be wise to dequantize the module parameters before inference. This allows us to save disk space by storing the model in reduced precision while preserving the inference speed.

This can easily be implemented with a ModuleMapper.

use burn::module::{ModuleMapper, ParamId};
use burn::tensor::{backend::Backend, Tensor};

/// Module mapper used to dequantize the model params being loaded.
pub struct Dequantize {}

impl<B: Backend> ModuleMapper<B> for Dequantize {
    fn map_float<const D: usize>(
        &mut self,
        _id: ParamId,
        tensor: Tensor<B, D>,
    ) -> Tensor<B, D> {
        tensor.dequantize()
    }
}

// Load saved quantized model in floating point precision
model = model
    .load_file(file_path, recorder, &device)
    .expect("Should be able to load the quantized model weights")
    .map(&mut Dequantize {});

Calibration

Calibration is the step during quantization where the range of all floating-point tensors is computed. This is pretty straightforward for weights since the actual range is known at quantization-time (weights are static), but activations require more attention.

To compute the quantization parameters, Burn supports the following Calibration methods.

MethodDescription
MinMaxCalibrationComputes the quantization range mapping based on the running min and max values.

Quantization Scheme

A quantization scheme defines the quantized type, quantization granularity and range mapping technique.

Burn currently supports the following QuantizationType variants.

TypeDescription
QInt88-bit signed integer quantization.

Quantization parameters are defined based on the range of values to represent and can typically be calculated for the layer's entire weight tensor with per-tensor quantization or separately for each channel with per-channel quantization (commonly used with CNNs).

Burn currently supports the following QuantizationScheme variants.

VariantDescription
PerTensorAffineComputes the quantization parameters for the whole tensor and applies an affine range mapping with zero point.
PerTensorSymmetricComputes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0.

Advanced

In this section, we will go into advanced topics that extend beyond basic usage. Given Burn's exceptional flexibility, a lot of advanced use cases become possible.

Before going through this section, we strongly recommend exploring the basic workflow section and the building blocks section. Establishing a solid understanding of how the framework operates is crucial to comprehending the advanced concepts presented here. While you have the freedom to explore the advanced sections in any order you prefer, it's important to note that this section is not intended to be linear, contrary to preceding sections. Instead, it serves as a repository of use cases that you can refer to for guidance as needed.

Backend Extension

Burn aims to be the most flexible deep learning framework. While it's crucial to maintain compatibility with a wide variety of backends, Burn provides the ability to extend the functionality of a backend implementation to suit your modeling requirements. This versatility is advantageous in numerous ways, such as supporting custom operations like flash attention or manually fusing operations for enhanced performance.

In this section, we will go into the process of extending a backend, providing multiple examples. But before we proceed, let's establish the fundamental principles that will empower you to craft your own backend extensions.

As you can observe, most types in Burn are generic over the Backend trait. This might give the impression that Burn operates at a high level over the backend layer. However, making the trait explicit instead of being chosen via a compilation flag was a thoughtful design decision. This explicitness does not imply that all backends must be identical; rather, it offers a great deal of flexibility when composing backends. The autodifferentiation backend trait (see autodiff section) is an example of how the backend trait has been extended to enable gradient computation with backpropagation. Furthermore, this design allows you to create your own backend extension. To achieve this, you need to design your own backend trait specifying which functions should be supported.

pub trait Backend: burn::tensor::backend::Backend {
    fn my_new_function(tensor: B::TensorPrimitive<2>) -> B::TensorPrimitive<2> {
        // You can define a basic implementation reusing the Burn Backend API.
        // This can be useful since all backends will now automatically support
        // your model. But performance can be improved for this new
        // operation by implementing this block in specific backends.
    }
}

You can then implement your new custom backend trait for any backend that you want to support:

impl<E: TchElement> Backend for burn_tch::LibTorch<E> {
   fn my_new_function(tensor: TchTensor<E, 2>) -> TchTensor<E, 2> {
      // My Tch implementation
   }
}

impl<E: NdArrayElement> Backend for burn_ndarray::NdArray<E> {
    // No specific implementation, but the backend can still be used.
}

You can support the backward pass using the same pattern.

impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {
    // No specific implementation; autodiff will work with the default
    // implementation. Useful if you still want to train your model, but
    // observe performance gains mostly during inference.
}

impl<B: Backend> Backend for burn_autodiff::Autodiff<B> {
   fn my_new_function(tensor: AutodiffTensor<E, 2>) -> AutodiffTensor<E, 2> {
      // My own backward implementation, generic over my custom Backend trait.
      //
      // You can add a new method `my_new_function_backward` to your custom backend
      // trait if you want to invoke a custom kernel during the backward pass.
   }
}

impl<E: TchElement> Backend for burn_autodiff::Autodiff<burn_tch::LibTorch<E>> {
   fn my_new_function(tensor: AutodiffTensor<E, 2>) -> AutodiffTensor<E, 2> {
      // My own backward implementation, generic over a backend implementation.
      //
      // This is another way to call a custom kernel for the backward pass that
      // doesn't require the addition of a new `backward` function in the custom backend.
      // This is useful if you don't want all backends to support training, reducing
      // the need for extra code when you know your model will only be trained on one
      // specific backend.
   }
}

The specifics of each implementation will be covered by the examples provided in this section. The cubecl compiler frontend is the recommended method of implementing custom kernels, since it supports multiple backends, including wgpu and CUDA, and is the way first-party burn kernels are written.

Custom cubecl Kernel

In this section, you will learn how to create your own custom operation by writing your own kernel with the cubecl compiler frontend. We will take the example of a common workflow in the deep learning field, where we create a kernel to fuse multiple operations together. Note that burn does this automatically, but a manual implementation might be more efficient in some cases. We will fuse a matmul kernel followed by an addition and the ReLU activation function, which is commonly found in various models. All the code can be found under the examples directory.

Custom Backend Trait

First, we need to determine the type signature of our newly created operation by defining our custom backend traits. As we will use the associated type TensorPrimitive of the Backend trait, which encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid the ugly disambiguation with associated types.

/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self>;
}

/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}

In our project, we can use these traits instead of the burn::tensor::backend::{Backend, AutodiffBackend} traits provided by Burn. Burn's user APIs typically make use of the Tensor struct rather than dealing directly with primitive tensor types. Therefore, we can encapsulate our newly defined backend traits with functions that expose new operations while maintaining a consistent API.

/// We define our custom implementation using the added function on our custom backend.
pub fn matmul_add_relu_custom<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let output = B::fused_matmul_add_relu(
        lhs.into_primitive().tensor(),
        rhs.into_primitive().tensor(),
        bias.into_primitive().tensor(),
    );

    Tensor::from_primitive(TensorPrimitive::Float(output))
}

/// We define a reference implementation using basic tensor operations.
pub fn matmul_add_relu_reference<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let x = lhs.matmul(rhs) + bias;

    activation::relu(x)
}

Note that we also provide a reference implementation for testing purposes, which allows us to easily validate our new implementation. While not mandatory, having a reference implementation can be valuable, especially in projects where creating a reference implementation solely using basic tensor operations is feasible.

Forward Kernel

Now, let's proceed to write the fused kernel using the cubecl compiler frontend. To keep things simple, we'll create a straightforward matmul kernel without employing any intricate techniques. We won't delve into the details of the cube macro, but if you're interested to learn more, please see cubecl Book. the The actual matmul, add and relu computations are found at the end, after an extensive prelude that serves to correctly map each compute unit to the data it is responsible for, with support for batches.

use cubecl::{cube, prelude::*};

#[cube(launch)]
pub fn fused_matmul_add_relu_kernel<F: Float>(
    lhs: &Tensor<F>,
    rhs: &Tensor<F>,
    bias: &Tensor<F>,
    output: &mut Tensor<F>,
) {
    let row = ABSOLUTE_POS_X;
    let col = ABSOLUTE_POS_Y;
    let batch = ABSOLUTE_POS_Z;

    let n_rows = output.shape(output.rank() - 2);
    let n_cols = output.shape(output.rank() - 1);
    let dim_k = rhs.shape(rhs.rank() - 1);

    if row >= n_rows || col >= n_cols {
        return;
    }

    let offset_output = batch * n_rows * n_cols;
    let mut offset_lhs = 0;
    let mut offset_rhs = 0;

    let batch_dims = output.rank() - 2;
    for dim in 0..batch_dims {
        offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);
        offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);
    }

    let mut sum = F::new(0.0);
    for k in 0..dim_k {
        let lhs_index = row * dim_k + k;
        let rhs_index = k * n_cols + col;

        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
    }

    let out_index = row * n_cols + col;
    let index = offset_output + out_index;

    output[index] = F::max(sum + bias[index], F::new(0.0));
}

Now, let's move on to the next step, which involves implementing the remaining code to launch the kernel. We'll go into implementing our custom backend trait for the generic JIT backend. This automatically implements the trait for burn-cuda, burn-wgpu as well as fusion.

/// Implement our custom backend trait for the generic `JitBackend`.
impl<R: JitRuntime, F: FloatElement, I: IntElement> Backend for JitBackend<R, F, I> {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Define cube dim, hardcoded for simplicity.
        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };

        lhs.assert_is_on_same_device(&rhs);
        lhs.assert_is_on_same_device(&bias);

        // For simplicity, make sure each tensor is continuous.
        let lhs = into_contiguous(lhs);
        let rhs = into_contiguous(rhs);
        let bias = into_contiguous(bias);

        // Get the matmul relevant shapes.
        let ndims = lhs.shape.num_dims();
        let num_rows = lhs.shape.dims[ndims - 2];
        let num_cols = rhs.shape.dims[ndims - 1];

        // Compute shape of output, while tracking number of batches.
        let mut num_batches = 1;
        let mut shape_out = vec![0; ndims];
        for i in shape_out.clone().into_iter().take(ndims - 2) {
            shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]);
            num_batches *= shape_out[i];
        }
        shape_out[ndims - 2] = num_rows;
        shape_out[ndims - 1] = num_cols;
        let shape_out = Shape::from(shape_out);

        // Create a buffer for the output tensor.
        let buffer = lhs
            .client
            .empty(shape_out.num_elements() * core::mem::size_of::<F>());

        // Create the output tensor primitive.
        let output =
            JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);

        // Declare the wgsl workgroup with the number of cubes in x, y and z.
        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
        let cube_count =
            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);

        // Execute lazily the kernel with the launch information and the given buffers. For
        // simplicity, no vectorization is performed
        fused_matmul_add_relu_kernel::launch::<F, R>(
            &lhs.client,
            cube_count,
            cube_dim,
            lhs.as_tensor_arg(1),
            rhs.as_tensor_arg(1),
            bias.as_tensor_arg(1),
            output.as_tensor_arg(1),
        );

        // Return the output tensor.
        output
    }
}

In the preceding code block, we demonstrated how to launch the kernel that modifies the correct buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the capability to execute any mutable operation on any buffer. While this isn't a problem in the previous scenario where we only modify the newly created output buffer, it is wise to keep this in mind.

Backward

Now that the custom backend trait is implemented for the JIT backend, you can use it to invoke the matmul_add_relu_custom function. However, calculating gradients is not yet possible at this stage. If your use case does not extend beyond inference, there is no need to implement any of the following code.

For the backward pass, we will leverage the backend implementation from burn-autodiff, which is actually generic over the backend. Instead of crafting our own cubecl kernel for the backward pass, we will use our fused kernel only for the forward pass, and compute the gradient using basic operations.

// Implement our custom backend trait for any backend that also implements our custom backend trait.
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Create our zero-sized type that will implement the Backward trait.
        #[derive(Debug)]
        struct FusedMatmulAddReluBackward;

        // Implement the backward trait for the given backend B, the node gradient
        // with three other gradients to calculate (lhs, rhs, and bias).
        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
            // Our state that we must build during the forward pass to compute the backward pass.
            //
            // Note that we could improve the performance further by only keeping the state of
            // tensors that are tracked, improving memory management, but for simplicity, we avoid
            // that part.
            type State = (NodeID, NodeID, FloatTensor<B>, Shape);

            fn backward(
                self,
                ops: Ops<Self::State, 3>,
                grads: &mut Gradients,
                checkpointer: &mut Checkpointer,
            ) {
                // Get the nodes of each variable.
                let [node_lhs, node_rhs, node_bias] = ops.parents;
                // Fetch the gradient for the current node.
                let grad = grads.consume::<B>(&ops.node);

                // Set our state.
                let (lhs_state, rhs_state, output, shape_bias) = ops.state;
                let lhs = checkpointer.retrieve_node_output(lhs_state);
                let rhs = checkpointer.retrieve_node_output(rhs_state);

                // Fetch shapes of our tensor to support broadcasting.
                let shape_lhs = B::float_shape(&lhs);
                let shape_rhs = B::float_shape(&rhs);

                // Compute the gradient of the output using the already existing `relu_backward`
                // function in the basic Burn backend trait.
                let grad_output = B::relu_backward(output, grad);

                // Compute the lhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_lhs = broadcast_shape::<B>(
                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
                    &shape_lhs,
                );
                // Compute the rhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_rhs = broadcast_shape::<B>(
                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
                    &shape_rhs,
                );
                // The add derivative is only 1, so we just need to support broadcasting to
                // compute the bias gradient.
                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);

                // Register the gradient for each variable based on whether they are marked as
                // `tracked`.
                if let Some(node) = node_bias {
                    grads.register::<B>(node.id, grad_bias);
                }
                if let Some(node) = node_lhs {
                    grads.register::<B>(node.id, grad_lhs);
                }
                if let Some(node) = node_rhs {
                    grads.register::<B>(node.id, grad_rhs);
                }
            }
        }

        // Prepare a stateful operation with each variable node and corresponding graph.
        //
        // Each node can be fetched with `ops.parents` in the same order as defined here.
        match FusedMatmulAddReluBackward
            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
            // Marks the operation as compute bound, meaning it will save its
            // state instead of recomputing itself during checkpointing
            .compute_bound()
            .stateful()
        {
            OpsKind::Tracked(mut prep) => {
                // When at least one node is tracked, we should register our backward step.

                // The state consists of what will be needed for this operation's backward pass.
                // Since we need the parents' outputs, we must checkpoint their ids to retrieve
                // their node output at the beginning of the backward pass. We can also save
                // utilitary data such as the bias shape. If we also need this operation's output,
                // we can either save it in the state or recompute it.
                // during the backward pass. Here we choose to save it in the state because it's a
                // compute bound operation.
                let lhs_state = prep.checkpoint(&lhs);
                let rhs_state = prep.checkpoint(&rhs);
                let bias_shape = B::float_shape(&bias.primitive);

                let output = B::fused_matmul_add_relu(
                    lhs.primitive.clone(),
                    rhs.primitive.clone(),
                    bias.primitive,
                );

                let state = (lhs_state, rhs_state, output.clone(), bias_shape);

                prep.finish(state, output)
            }
            OpsKind::UnTracked(prep) => {
                // When no node is tracked, we can just compute the original operation without
                // keeping any state.
                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
                prep.finish(output)
            }
        }
    }
}

The previous code is self-documented to make it clearer, but here is what it does in summary:

We define fused_matmul_add_relu within Autodiff<B>, allowing any autodiff-decorated backend to benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be implemented. This is achieved using a comprehensive match statement block where computation is delegated to the inner backend, while keeping track of a state. The state comprises any information relevant to the backward pass, such as input and output tensors, along with the bias shape. When an operation isn't tracked (meaning there won't be a backward pass for this specific operation in the graph), storing a state becomes unnecessary, and we simply perform the forward computation.

The backward pass uses the gradient obtained from the preceding node in the computation graph. It calculates the derivatives for relu (relu_backward), add (no operation is required here, as the derivative is one), and matmul (another matmul with transposed inputs). This results in gradients for both input tensors and the bias, which are registered for consumption by subsequent operation nodes.

The only remaining part is to implement our autodiff-decorated backend trait for our JIT Backend.

impl<R: JitRuntime, F: FloatElement, I: IntElement> AutodiffBackend
    for Autodiff<JitBackend<R, F, I>>
{
}

Conclusion

In this guide, we've implemented a fused kernel using the cubecl compiler frontend, enabling execution on any GPU and any cubecl backend. By delving into the inner workings of both the JIT backend and the autodiff backend, we've gained a deeper understanding of these systems.

While extending a backend may be harder than working with straightforward tensors, the benefits can be worth it. This approach enables the crafting of custom models with greater control over execution, which can potentially greatly enhance the performance of your models.

As we conclude this guide, we hope that you have gained insights into Burn's world of backend extensions, and that it will help you to unleash the full potential of your projects.

Custom WGPU Kernel

In this section, you will learn how to create your own custom operation by writing your own kernel with the WGPU backend. We will take the example of a common workflow in the deep learning field, where we create a kernel to fuse multiple operations together. Note that burn does this automatically, but a manual implementation might be more efficient in some cases. We will fuse a matmul kernel followed by an addition and the ReLU activation function, which is commonly found in various models. All the code can be found under the examples directory.

Custom Backend Trait

First, we need to determine the type signature of our newly created operation by defining our custom backend traits. As we will use the associated type TensorPrimitive of the Backend trait, which encapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid the ugly disambiguation with associated types.

/// We create our own Backend trait that extends the Burn backend trait.
pub trait Backend: burn::tensor::backend::Backend {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self>;
}

/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.
pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}

In our project, we can use these traits instead of the burn::tensor::backend::{Backend, AutodiffBackend} traits provided by Burn. Burn's user APIs typically make use of the Tensor struct rather than dealing directly with primitive tensor types. Therefore, we can encapsulate our newly defined backend traits with functions that expose new operations while maintaining a consistent API.

/// We define our custom implementation using the added function on our custom backend.
pub fn matmul_add_relu_custom<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let output = B::fused_matmul_add_relu(
        lhs.into_primitive().tensor(),
        rhs.into_primitive().tensor(),
        bias.into_primitive().tensor(),
    );

    Tensor::from_primitive(TensorPrimitive::Float(output))
}

/// We define a reference implementation using basic tensor operations.
pub fn matmul_add_relu_reference<B: Backend>(
    lhs: Tensor<B, 3>,
    rhs: Tensor<B, 3>,
    bias: Tensor<B, 3>,
) -> Tensor<B, 3> {
    let x = lhs.matmul(rhs) + bias;

    activation::relu(x)
}

Note that we also provide a reference implementation for testing purposes, which allows us to easily validate our new implementation. While not mandatory, having a reference implementation can be valuable, especially in projects where creating a reference implementation solely using basic tensor operations is feasible.

Forward Kernel

Now, let's proceed to write the fused kernel using the WGSL shading language. To keep things simple, we'll create a straightforward matmul kernel without employing any intricate techniques. Although we won't delve into the details of the WGSL syntax, as it falls beyond the scope of this guide, we still provide the implementation below for readers who are curious. The actual matmul, add and relu computations are found at the end, after an extensive overhead whose use is to correctly map each compute unit to the data it is responsible of, with support for batches.

@group(0)
@binding(0)
var<storage, read_write> lhs: array<{{ elem }}>;

@group(0)
@binding(1)
var<storage, read_write> rhs: array<{{ elem }}>;

@group(0)
@binding(2)
var<storage, read_write> bias: array<{{ elem }}>;

@group(0)
@binding(3)
var<storage, read_write> output: array<{{ elem }}>;

@group(0)
@binding(4)
var<storage, read_write> info: array<u32>;

const BLOCK_SIZE = {{ workgroup_size_x }}u;

@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
    @builtin(global_invocation_id) global_id: vec3<u32>,
    @builtin(local_invocation_index) local_idx: u32,
    @builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
    // Indices
    let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);
    let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);
    let batch = global_id.z;

    // Basic information
    let dim = info[0];
    let n_rows = info[6u * dim - 1u];
    let n_cols = info[6u * dim];
    let K = info[5u * dim - 1u];

    // Returns if outside the output dimension
    if row >= n_rows || col >= n_cols {
        return;
    }

    // Calculate the corresponding offsets with support for broadcasting.
    let offset_output = batch * n_rows * n_cols;
    var offset_lhs: u32 = 0u;
    var offset_rhs: u32 = 0u;

    let batch_dims = dim - 2u;
    for (var b: u32 = 1u; b <= batch_dims; b++) {
        let stride_lhs = info[b];
        let stride_rhs = info[b + dim];
        let stride_output = info[b + 2u * dim];
        let shape_lhs = info[b + 3u * dim];
        let shape_rhs = info[b + 4u * dim];

        offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
        offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
    }

    // Basic matmul implementation
    var sum = 0.0;
    for (var k: u32 = 0u; k < K; k++) {
        let lhs_index = row * K + k;
        let rhs_index = k * n_cols + col;

        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
    }

    let output_index = row * n_cols + col;
    let index = offset_output + output_index;

    // Add and ReLU
    output[index] = max(sum + bias[index], 0.0);
}

Now, let's move on to the next step, which involves implementing the remaining code to launch the kernel. The initial part entails loading the template and populating it with the appropriate variables. The register(name, value) method simply replaces occurrences of {{ name }} in the above WGSL code with some other string before it is compilated. In order to use templating utilities, you will have to activate the template feature of Burn in your cargo.toml.

// Source the kernel written in WGSL.
kernel_wgsl!(FusedMatmulAddReluRaw, "./kernel.wgsl");

// Define our kernel type with cube information.
#[derive(new, Debug)]
struct FusedMatmulAddRelu<E: FloatElement> {
    cube_dim: CubeDim,
    _elem: PhantomData<E>,
}

// Implement the dynamic kernel trait for our kernel type.
impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
    fn source(&self) -> SourceTemplate {
        // Extend our raw kernel with cube size information using the
        // `SourceTemplate` trait.
        FusedMatmulAddReluRaw::new()
            .source()
            .register("workgroup_size_x", self.cube_dim.x.to_string())
            .register("workgroup_size_y", self.cube_dim.y.to_string())
            .register("elem", E::type_name())
            .register("int", "i32")
    }

    fn id(&self) -> cubecl::KernelId {
        cubecl::KernelId::new::<Self>().info(self.cube_dim)
    }
}

Subsequently, we'll go into implementing our custom backend trait for the WGPU backend. Note that we won't go into supporting the fusion feature flag in this tutorial, so we implement the trait for the raw WgpuBackend type.

/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Define cube dim, hardcoded for simplicity.
        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };

        lhs.assert_is_on_same_device(&rhs);
        lhs.assert_is_on_same_device(&bias);

        // For simplicity, make sure each tensor is continuous.
        let lhs = into_contiguous(lhs);
        let rhs = into_contiguous(rhs);
        let bias = into_contiguous(bias);

        // Get the matmul relevant shapes.
        let ndims = lhs.shape.num_dims();
        let num_rows = lhs.shape.dims[ndims - 2];
        let num_cols = rhs.shape.dims[ndims - 1];

        // Compute shape of output, while tracking number of batches.
        let mut num_batches = 1;
        let mut shape_out = vec![0; ndims];
        for i in shape_out.clone().into_iter().take(ndims - 2) {
            shape_out[i] = usize::max(lhs.shape.dims[i], rhs.shape.dims[i]);
            num_batches *= shape_out[i];
        }
        shape_out[ndims - 2] = num_rows;
        shape_out[ndims - 1] = num_cols;
        let shape_out = Shape::from(shape_out);

        // Create a buffer for the output tensor.
        let buffer = lhs
            .client
            .empty(shape_out.num_elements() * core::mem::size_of::<F>());

        // Create the output tensor primitive.
        let output =
            JitTensor::new_contiguous(lhs.client.clone(), lhs.device.clone(), shape_out, buffer);

        // Create the kernel.
        let kernel = FusedMatmulAddRelu::<F>::new(cube_dim);

        // Build info buffer with tensor information needed by the kernel, such as shapes and strides.
        let info = build_info(&[&lhs, &rhs, &output]);
        let info_handle = lhs.client.create(bytemuck::cast_slice(&info));

        // Declare the wgsl workgroup with the number of cubes in x, y and z.
        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;
        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;
        let cube_count =
            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);

        // Execute lazily the kernel with the launch information and the given buffers.
        lhs.client.execute(
            Box::new(SourceKernel::new(kernel, cube_dim)),
            cube_count,
            vec![
                lhs.handle.binding(),
                rhs.handle.binding(),
                bias.handle.binding(),
                output.handle.clone().binding(),
                info_handle.binding(),
            ],
        );

        // Return the output tensor.
        output
    }
}

In the preceding code block, we demonstrated how to launch the kernel that modifies the correct buffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the capability to execute any mutable operation on any buffer. While this isn't a problem in the previous scenario where we only modify the newly created output buffer, it is wise to keep this in mind.

Backward

Now that the custom backend trait is implemented for the WGPU backend, you can use it to invoke the matmul_add_relu_custom function. However, calculating gradients is not yet possible at this stage. If your use case does not extend beyond inference, there is no need to implement any of the following code.

For the backward pass, we will leverage the backend implementation from burn-autodiff, which is actually generic over the backend. Instead of crafting our own WGSL kernel for the backward pass, we will use our fused kernel only for the forward pass, and compute the gradient using basic operations.

// Implement our custom backend trait for any backend that also implements our custom backend trait.
//
// Note that we could implement the backend trait only for the Wgpu backend instead of any backend that
// also implements our own API. This would allow us to call any function only implemented for Wgpu
// and potentially call a custom kernel crafted only for this task.
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
    fn fused_matmul_add_relu(
        lhs: FloatTensor<Self>,
        rhs: FloatTensor<Self>,
        bias: FloatTensor<Self>,
    ) -> FloatTensor<Self> {
        // Create our zero-sized type that will implement the Backward trait.
        #[derive(Debug)]
        struct FusedMatmulAddReluBackward;

        // Implement the backward trait for the given backend B, the node gradient
        // with three other gradients to calculate (lhs, rhs, and bias).
        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {
            // Our state that we must build during the forward pass to compute the backward pass.
            //
            // Note that we could improve the performance further by only keeping the state of
            // tensors that are tracked, improving memory management, but for simplicity, we avoid
            // that part.
            type State = (NodeID, NodeID, FloatTensor<B>, Shape);

            fn backward(
                self,
                ops: Ops<Self::State, 3>,
                grads: &mut Gradients,
                checkpointer: &mut Checkpointer,
            ) {
                // Get the nodes of each variable.
                let [node_lhs, node_rhs, node_bias] = ops.parents;
                // Fetch the gradient for the current node.
                let grad = grads.consume::<B>(&ops.node);

                // Set our state.
                let (lhs_state, rhs_state, output, shape_bias) = ops.state;
                let lhs = checkpointer.retrieve_node_output(lhs_state);
                let rhs = checkpointer.retrieve_node_output(rhs_state);

                // Fetch shapes of our tensor to support broadcasting.
                let shape_lhs = B::float_shape(&lhs);
                let shape_rhs = B::float_shape(&rhs);

                // Compute the gradient of the output using the already existing `relu_backward`
                // function in the basic Burn backend trait.
                let grad_output = B::relu_backward(output, grad);

                // Compute the lhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_lhs = broadcast_shape::<B>(
                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),
                    &shape_lhs,
                );
                // Compute the rhs gradient, which is the derivative of matmul with support for
                // broadcasting.
                let grad_rhs = broadcast_shape::<B>(
                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),
                    &shape_rhs,
                );
                // The add derivative is only 1, so we just need to support broadcasting to
                // compute the bias gradient.
                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);

                // Register the gradient for each variable based on whether they are marked as
                // `tracked`.
                if let Some(node) = node_bias {
                    grads.register::<B>(node.id, grad_bias);
                }
                if let Some(node) = node_lhs {
                    grads.register::<B>(node.id, grad_lhs);
                }
                if let Some(node) = node_rhs {
                    grads.register::<B>(node.id, grad_rhs);
                }
            }
        }

        // Prepare a stateful operation with each variable node and corresponding graph.
        //
        // Each node can be fetched with `ops.parents` in the same order as defined here.
        match FusedMatmulAddReluBackward
            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])
            // Marks the operation as compute bound, meaning it will save its
            // state instead of recomputing itself during checkpointing
            .compute_bound()
            .stateful()
        {
            OpsKind::Tracked(mut prep) => {
                // When at least one node is tracked, we should register our backward step.

                // The state consists of what will be needed for this operation's backward pass.
                // Since we need the parents' outputs, we must checkpoint their ids to retrieve their node
                // output at the beginning of the backward. We can also save utilitary data such as the bias shape
                // If we also need this operation's output, we can either save it in the state or recompute it
                // during the backward pass. Here we choose to save it in the state because it's a compute bound operation.
                let lhs_state = prep.checkpoint(&lhs);
                let rhs_state = prep.checkpoint(&rhs);
                let bias_shape = B::float_shape(&bias.primitive);

                let output = B::fused_matmul_add_relu(
                    lhs.primitive.clone(),
                    rhs.primitive.clone(),
                    bias.primitive,
                );

                let state = (lhs_state, rhs_state, output.clone(), bias_shape);

                prep.finish(state, output)
            }
            OpsKind::UnTracked(prep) => {
                // When no node is tracked, we can just compute the original operation without
                // keeping any state.
                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);
                prep.finish(output)
            }
        }
    }
}

The previous code is self-documented to make it clearer, but here is what it does in summary.

We define fused_matmul_add_relu within Autodiff<B>, allowing any autodiff-decorated backend to benefit from our implementation. In an autodiff-decorated backend, the forward pass must still be implemented. This is achieved using a comprehensive match statement block where computation is delegated to the inner backend, while keeping track of a state. The state comprises any information relevant to the backward pass, such as input and output tensors, along with the bias shape. When an operation isn't tracked (meaning there won't be a backward pass for this specific operation in the graph), storing a state becomes unnecessary, and we simply perform the forward computation.

The backward pass uses the gradient obtained from the preceding node in the computation graph. It calculates the derivatives for relu (relu_backward), add (no operation is required here, as the derivative is one), and matmul (another matmul with transposed inputs). This results in gradients for both input tensors and the bias, which are registered for consumption by subsequent operation nodes.

The only remaining part is to implement our autodiff-decorated backend trait for our WGPU Backend.

impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend for Autodiff<WgpuBackend<G, F, I>>
{
}

Conclusion

In this guide, we've implemented a fused kernel using the WGPU backend, enabling execution on any GPU. By delving into the inner workings of both the WGPU backend and the autodiff backend, we've gained a deeper understanding of these systems.

While extending a backend may be harder than working with straightforward tensors, the benefits can be worth it. This approach enables the crafting of custom models with greater control over execution, which can potentially greatly enhance the performance of your models.

As we conclude this guide, we hope that you have gained insights into Burn's world of backend extensions, and that it will help you to unleash the full potential of your projects.

No Standard Library

In this section, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico. This should be universally applicable to other platforms. All the code can be found under the examples directory.

Step-by-Step Guide

Let's walk through the process of running an embedded ONNX model:

Setup

Follow the embassy guide for your specific environment. Once setup, you should have something similar to the following.

./inference
β”œβ”€β”€ Cargo.lock
β”œβ”€β”€ Cargo.toml
β”œβ”€β”€ build.rs
β”œβ”€β”€ memory.x
└── src
    └── main.rs

Some other dependencies have to be added

[dependencies]
embedded-alloc = "0.5.1" # Only if there is no default allocator for your chip
burn = { version = "0.15", default-features = false, features = ["ndarray"] } # Backend must be ndarray

[build-dependencies]
burn-import = { version = "0.14" } # Used to auto generate the rust code to import the model

Import the Model

Follow the directions to import models.

Use the following ModelGen config

ModelGen::new()
    .input(my_model)
    .out_dir("model/")
    .record_type(RecordType::Bincode)
    .embed_states(true)
    .run_from_script();

Global Allocator

First define a global allocator (if you are on a no_std system without alloc).

use embedded_alloc::Heap;

#[global_allocator]
static HEAP: Heap = Heap::empty();

#[embassy_executor::main]
async fn main(_spawner: Spawner) {
	{
        use core::mem::MaybeUninit;
        const HEAP_SIZE: usize = 100 * 1024; // This is dependent on the model size in memory.
        static mut HEAP_MEM: [MaybeUninit<u8>; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE];
        unsafe { HEAP.init(HEAP_MEM.as_ptr() as usize, HEAP_SIZE) }
    }
}

Define Backend

We are using ndarray, so we just need to define the NdArray backend as usual

use burn::{backend::NdArray, tensor::Tensor};

type Backend = NdArray<f32>;
type BackendDeice = <Backend as burn::tensor::backend::Backend>::Device;

Then inside the main function add

use your_model::Model;

// Get a default device for the backend
let device = BackendDeice::default();

// Create a new model and load the state
let model: Model<Backend> = Model::default();

Running the Model

To run the model, just call it as you would normally

// Define the tensor
let input = Tensor::<Backend, 2>::from_floats([[input]], &device);

// Run the model on the input
let output = model.forward(input);

Conclusion

Running a model in a no_std environment is pretty much identical to a normal environment. All that is needed is a global allocator.