Data

Typically, one trains a model on some dataset. Burn provides a library of very useful dataset sources and transformations, such as Hugging Face dataset utilities that allow to download and store data into an SQLite database for extremely efficient data streaming and storage. For this guide though, we will use the MNIST dataset from burn::data::dataset::vision which requires no external dependency.

To iterate over a dataset efficiently, we will define a struct which will implement the Batcher trait. The goal of a batcher is to map individual dataset items into a batched tensor that can be used as input to our previously defined model.

Let us start by defining our dataset functionalities in a file src/data.rs. We shall omit some of the imports for brevity, but the full code for following this guide can be found at examples/guide/ directory.

use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
};

#[derive(Clone)]
pub struct MnistBatcher<B: Backend> {
    device: B::Device,
}

impl<B: Backend> MnistBatcher<B> {
    pub fn new(device: B::Device) -> Self {
        Self { device }
    }
}

This codeblock defines a batcher struct with the device in which the tensor should be sent before being passed to the model. Note that the device is an associative type of the Backend trait since not all backends expose the same devices. As an example, the Libtorch-based backend exposes Cuda(gpu_index), Cpu, Vulkan and Metal devices, while the ndarray backend only exposes the Cpu device.

Next, we need to actually implement the batching logic.

use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
};

#[derive(Clone)]
pub struct MnistBatcher<B: Backend> {
    device: B::Device,
}

impl<B: Backend> MnistBatcher<B> {
    pub fn new(device: B::Device) -> Self {
        Self { device }
    }
}

#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
    pub images: Tensor<B, 3>,
    pub targets: Tensor<B, 1, Int>,
}

impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
    fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
        let images = items
            .iter()
            .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
            .map(|data| Tensor::<B, 2>::from_data(data, &self.device))
            .map(|tensor| tensor.reshape([1, 28, 28]))
            // Normalize: make between [0,1] and make the mean=0 and std=1
            // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
            // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
            .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
            .collect();

        let targets = items
            .iter()
            .map(|item| {
                Tensor::<B, 1, Int>::from_data(
                    [(item.label as i64).elem::<B::IntElem>()],
                    &self.device,
                )
            })
            .collect();

        let images = Tensor::cat(images, 0).to_device(&self.device);
        let targets = Tensor::cat(targets, 0).to_device(&self.device);

        MnistBatch { images, targets }
    }
}
🦀 Iterators and Closures

The iterator pattern allows you to perform some tasks on a sequence of items in turn.

In this example, an iterator is created over the MnistItems in the vector items by calling the iter method.

Iterator adaptors are methods defined on the Iterator trait that produce different iterators by changing some aspect of the original iterator. Here, the map method is called in a chain to transform the original data before consuming the final iterator with collect to obtain the images and targets vectors. Both vectors are then concatenated into a single tensor for the current batch.

You probably noticed that each call to map is different, as it defines a function to execute on the iterator items at each step. These anonymous functions are called closures in Rust. They're easy to recognize due to their syntax which uses vertical bars ||. The vertical bars capture the input variables (if applicable) while the rest of the expression defines the function to execute.

If we go back to the example, we can break down and comment the expression used to process the images.

let images = items                                                       // take items Vec<MnistItem>
    .iter()                                                              // create an iterator over it
    .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())  // for each item, convert the image to float data struct
    .map(|data| Tensor::<B, 2>::from_data(data, &self.device))           // for each data struct, create a tensor on the device
    .map(|tensor| tensor.reshape([1, 28, 28]))                           // for each tensor, reshape to the image dimensions [C, H, W]
    .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)                    // for each image tensor, apply normalization
    .collect();                                                          // consume the resulting iterator & collect the values into a new vector

For more information on iterators and closures, be sure to check out the corresponding chapter in the Rust Book.


In the previous example, we implement the Batcher trait with a list of MnistItem as input and a single MnistBatch as output. The batch contains the images in the form of a 3D tensor, along with a targets tensor that contains the indexes of the correct digit class. The first step is to parse the image array into a TensorData struct. Burn provides the TensorData struct to encapsulate tensor storage information without being specific for a backend. When creating a tensor from data, we often need to convert the data precision to the current backend in use. This can be done with the .convert() method (in this example, the data is converted backend's float element type B::FloatElem). While importing the burn::tensor::ElementConversion trait, you can call .elem() on a specific number to convert it to the current backend element type in use.