Dataset
At its core, a dataset is a collection of data typically related to a specific analysis or processing task. The data modality can vary depending on the task, but most datasets primarily consist of images, texts, audio or videos.
This data source represents an integral part of machine learning to successfully train a model. Thus, it is essential to provide a convenient and performant API to handle your data. Since this process varies wildly from one problem to another, it is defined as a trait that should be implemented on your type. The dataset trait is quite similar to the dataset abstract class in PyTorch:
pub trait Dataset<I>: Send + Sync {
fn get(&self, index: usize) -> Option<I>;
fn len(&self) -> usize;
}
The dataset trait assumes a fixed-length set of items that can be randomly accessed in constant time. This is a major difference from datasets that use Apache Arrow underneath to improve streaming performance. Datasets in Burn don't assume how they are going to be accessed; it's just a collection of items.
However, you can compose multiple dataset transformations to lazily obtain what you want with zero pre-processing, so that your training can start instantly!
Transformation
Transformations in Burn are all lazy and modify one or multiple input datasets. The goal of these transformations is to provide you with the necessary tools so that you can model complex data distributions.
Transformation | Description |
---|---|
SamplerDataset | Samples items from a dataset. This is a convenient way to model a dataset as a probability distribution of a fixed size. |
ShuffledDataset | Maps each input index to a random index, similar to a dataset sampled without replacement. |
PartialDataset | Returns a view of the input dataset with a specified range. |
MapperDataset | Computes a transformation lazily on the input dataset. |
ComposedDataset | Composes multiple datasets together to create a larger one without copying any data. |
WindowsDataset | Dataset designed to work with overlapping windows of data extracted from an input dataset. |
Let us look at the basic usages of each dataset transform and how they can be composed together. These transforms are lazy by default except when specified, reducing the need for unnecessary intermediate allocations and improving performance. The full documentation of each transform can be found at the API reference.
- SamplerDataset: This transform can be used to sample items from a dataset with (default) or without replacement. Transform is initialized with a sampling size which can be bigger or smaller than the input dataset size. This is particularly useful in cases where we want to checkpoint larger datasets more often during training and smaller datasets less often as the size of an epoch is now controlled by the sampling size. Sample usage:
type DbPedia = SqliteDataset<DbPediaItem>;
let dataset: DbPedia = HuggingfaceDatasetLoader::new("dbpedia_14")
.dataset("train").
.unwrap();
let dataset = SamplerDataset<DbPedia, DbPediaItem>::new(dataset, 10000);
- ShuffledDataset: This transform can be used to shuffle the items of a dataset. Particularly useful before splitting the raw dataset into train/test splits. Can be initialized with a seed to ensure reproducibility.
let dataset = ShuffledDataset<DbPedia, DbPediaItem>::with_seed(dataset, 42);
- PartialDataset: This transform is useful to return a view of the dataset with specified start and end indices. Used to create train/val/test splits. In the example below, we show how to chain ShuffledDataset and PartialDataset to create splits.
// define chained dataset type here for brevity
type PartialData = PartialDataset<ShuffledDataset<DbPedia, DbPediaItem>>;
let len = dataset.len();
let split == "train"; // or "val"/"test"
let data_split = match split {
"train" => PartialData::new(dataset, 0, len * 8 / 10), // Get first 80% dataset
"test" => PartialData::new(dataset, len * 8 / 10, len), // Take remaining 20%
_ => panic!("Invalid split type"), // Handle unexpected split types
};
-
MapperDataset: This transform is useful to apply a transformation on each of the items of a dataset. Particularly useful for normalization of image data when channel means are known.
-
ComposedDataset: This transform is useful to compose multiple datasets downloaded from multiple sources (say different HuggingfaceDatasetLoader sources) into a single bigger dataset which can be sampled from one source.
-
WindowsDataset: This transform is useful to create overlapping windows of a dataset. Particularly useful for sequential Time series Data, for example when working with an LSTM.
Storage
There are multiple dataset storage options available for you to choose from. The choice of the dataset to use should be based on the dataset's size as well as its intended purpose.
Storage | Description |
---|---|
InMemDataset | In-memory dataset that uses a vector to store items. Well-suited for smaller datasets. |
SqliteDataset | Dataset that uses SQLite to index items that can be saved in a simple SQL database file. Well-suited for larger datasets. |
DataframeDataset | Dataset that uses Polars dataframe to store and manage data. Well-suited for efficient data manipulation and analysis. |
Sources
For now, there are only a couple of dataset sources available with Burn, but more to come!
Hugging Face
You can easily import any Hugging Face dataset with Burn. We use SQLite as the storage to avoid downloading the model each time or starting a Python process. You need to know the format of each item in the dataset beforehand. Here's an example with the dbpedia dataset.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DbPediaItem {
pub title: String,
pub content: String,
pub label: usize,
}
fn main() {
let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new("dbpedia_14")
.dataset("train") // The training split.
.unwrap();
}
We see that items must derive serde::Serialize
, serde::Deserialize
, Clone
, and Debug
, but
those are the only requirements.
Images
ImageFolderDataset
is a generic vision dataset used to load images from disk. It is currently
available for multi-class and multi-label classification tasks as well as semantic segmentation and object detection tasks.
// Create an image classification dataset from the root folder,
// where images for each class are stored in their respective folder.
//
// For example:
// root/dog/dog1.png
// root/dog/dog2.png
// ...
// root/cat/cat1.png
let dataset = ImageFolderDataset::new_classification("path/to/dataset/root").unwrap();
// Create a multi-label image classification dataset from a list of items,
// where each item is a tuple `(image path, labels)`, and a list of classes
// in the dataset.
//
// For example:
let items = vec![
("root/dog/dog1.png", vec!["animal".to_string(), "dog".to_string()]),
("root/cat/cat1.png", vec!["animal".to_string(), "cat".to_string()]),
];
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
items,
&["animal", "cat", "dog"],
)
.unwrap();
// Create a segmentation mask dataset from a list of items, where each
// item is a tuple `(image path, mask path)` and a list of classes
// corresponding to the integer values in the mask.
let items = vec![
(
"path/to/images/image0.png",
"path/to/annotations/mask0.png",
),
(
"path/to/images/image1.png",
"path/to/annotations/mask1.png",
),
(
"path/to/images/image2.png",
"path/to/annotations/mask2.png",
),
];
let dataset = ImageFolderDataset::new_segmentation_with_items(
items,
&[
"cat", // 0
"dog", // 1
"background", // 2
],
)
.unwrap();
// Create an object detection dataset from a COCO dataset. Currently only
// the import of object detection data (bounding boxes) is supported.
//
// COCO offers separate annotation and image archives for training and
// validation, paths to the unpacked files need to be passed as parameters:
let dataset = ImageFolderDataset::new_coco_detection(
"/path/to/coco/instances_train2017.json",
"/path/to/coco/images/train2017"
)
.unwrap();
Comma-Separated Values (CSV)
Loading records from a simple CSV file in-memory is simple with the InMemDataset
:
// Build dataset from csv with tab ('\t') delimiter.
// The reader can be configured for your particular file.
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.delimiter(b'\t');
let dataset = InMemDataset::from_csv("path/to/csv", rdr).unwrap();
Note that this requires the csv
crate.
What about streaming datasets?
There is no streaming dataset API with Burn, and this is by design! The learner struct will iterate
multiple times over the dataset and only checkpoint when done. You can consider the length of the
dataset as the number of iterations before performing checkpointing and running the validation.
There is nothing stopping you from returning different items even when called with the same index
multiple times.
How Is The Dataset Used?
During training, the dataset is used to access the data samples and, for most use cases in
supervised learning, their corresponding ground-truth labels. Remember that the Dataset
trait
implementation is responsible to retrieve the data from its source, usually some sort of data
storage. At this point, the dataset could be naively iterated over to provide the model a single
sample to process at a time, but this is not very efficient.
Instead, we collect multiple samples that the model can process as a batch to fully leverage
modern hardware (e.g., GPUs - which have impressing parallel processing capabilities). Since each
data sample in the dataset can be collected independently, the data loading is typically done in
parallel to further speed things up. In this case, we parallelize the data loading using a
multi-threaded BatchDataLoader
to obtain a sequence of items from the Dataset
implementation.
Finally, the sequence of items is combined into a batched tensor that can be used as input to a
model with the Batcher
trait implementation. Other tensor operations can be performed during this
step to prepare the batch data, as is done in the basic workflow guide.
The process is illustrated in the figure below for the MNIST dataset.

Although we have conveniently implemented the
MnistDataset
used in the guide, we'll go over its implementation to demonstrate how the Dataset
and Batcher
traits are used.
The MNIST dataset of handwritten digits has a training set of
60,000 examples and a test set of 10,000 examples. A single item in the dataset is represented by a
28×28 pixels black-and-white image (stored as raw bytes) with its corresponding label
(a digit between 0 and 9). This is defined by the MnistItemRaw
struct.
With single-channel images of such low resolution, the entire training and test sets can be loaded
in memory at once. Therefore, we leverage the already existing InMemDataset
to retrieve the raw
images and labels data. At this point, the image data is still just a bunch of bytes, but we want to
retrieve the structured image data in its intended form. For that, we can define a MapperDataset
that transforms the raw image bytes to a 2D array image (which we convert to float while we're at
it).
To construct the MnistDataset
, the data source must be parsed into the expected MappedDataset
type. Since both the train and test sets use the same file format, we can separate the functionality
to load the train()
and test()
dataset.
Since the MnistDataset
simply wraps a MapperDataset
instance with InMemDataset
, we can easily
implement the Dataset
trait.
impl Dataset<MnistItem> for MnistDataset {
fn get(&self, index: usize) -> Option<MnistItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
The only thing missing now is the Batcher
, which we already went over
in the basic workflow guide. The Batcher
takes a list of MnistItem
retrieved by the dataloader as input and returns a batch of images as a 3D tensor along with their
targets.