Burn 0.19.0 Release: Quantization, Distributed Training, and LLVM Backend

Flame digital art generated by stable diffusion.

This release [1] post covers the major improvements we made to enable efficient distributed training, quantization, and CPU support in Burn. To achieve true multi-GPU parallelism, we had to rethink several core systems: we implemented multi-stream execution to keep all GPUs busy, optimized device transfers to avoid unnecessary synchronization, and redesigned our locking strategies to eliminate bottlenecks in autotuning, fusion, and autodiff. We also introduced burn-collective [2] for gradient synchronization and refactored our training loop to support different distributed training strategies. Additionally, we added comprehensive quantization support, allowing models to use significantly less memory while maintaining performance through fused dequantization and optimized quantized operations. Finally, we introduced a new CPU backend powered by MLIR [3] and LLVM [4], bringing the same JIT compilation, autotuning, and fusion capabilities from our GPU backends to CPU execution.

Distributed Training

To enable distributed training, we reworked several components to better maximize GPU utilization across all devices.

Multi Stream

Let's start by defining what a stream is exactly. In CUDA, a stream represents a GPU execution queue where kernels are guaranteed to be executed in the order they are submitted [5]. Multiple streams can be created on a single GPU to have concurrent kernel execution. These streams abstract the underlying GPU queue and the kernel scheduling logic. We don't have such an abstraction with the WGPU runtime, so we implemented our own scheduler, interleaving kernels from different streams to mimic the streaming behavior of CUDA and ROCm.

CubeCL Stream

The first step was to define our desired API for declaring new streams. We wanted to expose concurrent behavior at the GPU level in a way that feels natural in Rust: using threads.

CubeCL streams are created when tensor operations are executed on a new thread. Each stream maintains its own memory pools, because operation order is only guaranteed within a single stream, and reusing memory with deferred execution is only possible when ordering is guaranteed.

CubeCL streams can use tensors created from other streams without any issue. When that happens, we ensure proper ordering using events on CUDA/HIP and by manually flushing the deferred kernels on WGPU, such that users don't have to worry about data availability and manual synchronization between different streams. Here's an example of how it can be used in practice:

fn multi_streams<B: Backend>(tensor_1: Tensor<B, 2>, tensor_2: Tensor<B, 2>) -> Tensor<B, 2> {
      // Create stream_1: spawning a new thread automatically creates a new CubeCL stream
      let stream_1 = std::thread::spawn(|| tensor_1.clone().matmul(tensor_1.log()));
      
      // Create stream_2: this runs concurrently with stream_1 on the same GPU.
      let stream_2 = std::thread::spawn(|| tensor_2.clone().matmul(tensor_2.exp()));
      
      // Synchronize and retrieve tensors from streams
      // Note: due to lazy execution, the thread joins immediately after queuing the operations,
      // not after the GPU completes them
      let tensor_3 = stream_1.join().unwrap();
      let tensor_4 = stream_2.join().unwrap();
  
      // Cross-stream operation: CubeCL detects tensors from different streams
      // and automatically ensures proper ordering with events (CUDA/HIP) or manual flushing (WGPU)
      tensor_3.matmul(tensor_4)
  }

Note 1: Right now, streams are coupled to native threads, but we might at some point create an explicit way to declare new streams with configurable runtimes, such as Tokio [6], where each Tokio task spawns a new stream.

Note 2: Spawning threads doesn't always create a new stream. Since each stream has its own memory pool and is never deleted, creating streams is expensive. To limit the number of concurrent streams per device, we provide a configurable maximum.

# Configuration file 'cubecl.toml' at the root of a project
[streaming]
max_streams = 4
logger = { level = "disabled" }
  

Lazy Device Transfer

An important optimization was to avoid synchronizing the GPU queue when transferring data between devices. In other words, we needed to avoid calling into_data, which synchronizes the GPU queue. To make this possible, we pushed the device transfer logic into the CubeCL runtime, where each runtime implementation can perform data transfers as efficiently as possible using deferred synchronization instead of immediate synchronization.


Old Data Transfer Using into_data
fn data_transfer_sync<B: Backend>(tensor: Tensor<B, 2>, device: &B::Device) -> Tensor<B, 2> {
    // This sync the GPU queue and may reduce GPU utilization.
    let data = tensor.into_data();
    Tensor::from_data(data, device)
}
  
New Data Transfer Using Cubecl Compute Client
pub fn data_transfer_async(&self, client: ComputeClient<R::Server>, device: R::Device) -> Self {
  // This only changes the compute client for the current device with a copy description
  // in CubeCL without syncing the GPU queue.
  let desc = self
      .handle
      .copy_descriptor(&self.shape.dims, &self.strides, self.elem_size());
  let alloc = self.client.to_client_tensor(desc, &client);
  
  Self {
      client,
      handle: alloc.handle,
      shape: self.shape.clone(),
      device,
      strides: alloc.strides,
      dtype: self.dtype,
      qparams: self.qparams.clone(),
  }
}
  
Note that you don't need to change your code to benefit from these optimizations. The function Tensor::to_device dispatches to the right procedure based on the backend.

Pinned Memory

Many GPUs don't have direct connections between them and require transferring data through the CPU. This process can be sped up significantly using pinned memory. For those unfamiliar with pinned memory, it's an allocation strategy that returns memory allocated contiguously with direct physical memory access [7]. These allocations are normally very expensive, since the OS sometimes has to perform defragmentation to find enough contiguous space for the allocation to succeed. Therefore, we leverage our existing memory pools— originally used to manage GPU memory—to also manage CPU pinned memory.

State Management: Locking Strategies

As we improved GPU utilization, we discovered that our locking mechanisms were becoming bottlenecks, preventing true parallel execution across devices. We had to carefully redesign our locking strategies to eliminate contention and enable concurrent operations.

Concurrent Autotuning

One of the earliest problems we encountered was that only one GPU was active at a time. One reason was that a single autotune state lock was shared between all devices. So we had to create fine-grained locks to enable autotuning on multiple devices simultaneously.

Burn Fusion Lock

Another problem we encountered was deadlock when fusion was activated. The issue stemmed from recursive locking during autotune when profiling was active. The main problem came from our locking strategy, and to resolve most of these issues we introduced a global device lock. The global device lock is shared between the fusion server as well as the CubeCL runtime server. We use interior mutability to ensure only a single mutable reference can exist to a data structure at any given time, but we support recursive locking with a reentrant mutex. We use a popular library for this [8], and have an alternative implementation with fine-grained locking for no-std environments.

Autodiff Graph Lock

At that point, we still had problems keeping all GPUs busy, and the reason was the autodiff server lock. The tricky thing with the autodiff graph—which contains the state of all intermediary tensors necessary for the backward pass—is that it can't share the same locking strategy as the fusion server lock. The reason is that a single autodiff graph can span multiple devices. The solution was to implement fine-grained graph locks with a graph locator client. Previously, all nodes within all graphs were kept in the same autodiff server, but now we spawn different servers for different graphs, which enables multiple graphs to be computed in parallel. Note that a single autodiff server still supports multiple graphs, where nodes are shared between them. However, those kinds of graphs will share the same lock, since their state is shared.

Burn Collective

Now we were able to keep all GPUs busy almost all the time with the naive data-parallel training strategy. That strategy involves having a single device perform gradient updates on the model and synchronizing the model with the other devices at each iteration. But we also wanted to support the classic data-parallel training strategy where the synchronization is on the gradients, and each device performs the optimization of its own version of the model. So we created burn-collective to expose an all-reduce function that can perform gradient synchronization. Note that the crate burn-collective is young, and the performance is still not satisfactory, but we will improve it over time.

Burn Train

We refactored the provided training loop in Burn to integrate with the newer distributed training strategies. Instead of registering a list of device(s), you now specify a training strategy. That training strategy receives the device as an argument. In the future we want to extend this functionality to allow custom training strategies, which would enable custom training loops with minimal boilerplate while still keeping all of the metric logging features, CLI dashboard, checkpointing, and so on. Here's how the learning strategy is registered:

Learning Strategy: Single Device
let learner = LearnerBuilder::new(artifact_dir)
    .learning_strategy(LearningStrategy::SingleDevice(device))
    ...
    .build(model, optim, lr_scheduler);
  

Learning Strategy: Multi Device
let learner = LearnerBuilder::new(artifact_dir)
    .learning_strategy(LearningStrategy::MultiDeviceNaive(devices))
    ...
    .build(model, optim, lr_scheduler);
  

Learning Strategy: Collective DDP [9]
let collective_config = CollectiveConfig::default();
let learner = LearnerBuilder::new(artifact_dir)
    .learning_strategy(burn::train::ddp(devices, collective_config))
    ...
    .build(model, optim, lr_scheduler);
  

Note 1: The collective config also support distributed computing over the network using websockets, but the performance is still unknown.

Quantization

The goal of quantization is to reduce global memory usage by compressing floating-point tensors. There are many different ways to perform quantization, and we wanted to support most of them by design.

Quantization Scheme

The key concept behind quantization is transforming a list of floats into a list of integers, where we can multiply the integer values by a scaling factor to retrieve the original floating-point values. How we handle that logic is defined by the quantization scheme.

/// Describes a quantization scheme/configuration.
pub struct QuantScheme {
    /// The logical data type of quantized input values (e.g., QInt8).
    ///
    /// This defines how values are interpreted during computation, independent of how they're stored.
    pub value: QuantValue,
    /// Precision used for quantization parameters (e.g., scale).
    pub param: QuantParam,
    /// Data type used for storing quantized values.
    pub store: QuantStore,
    /// Granularity level of quantization (e.g., per-tensor).
    pub level: QuantLevel,
    /// Quantization mode (e.g., symmetric).
    pub mode: QuantMode,
}

Quantized Module

Now this is great, but how can we use it to compress a whole Burn module? Well, we created an API in Burn to do exactly that. To quantize the weights of a module, we need to first define the calibration strategy. This is how we calculate the scaling factors described earlier. It is common practice to use the min-max strategy.

fn quantize_q4_block_32<B: Backend, M: Module<B>>(module: M) -> M {
    let calibration = Calibration::MinMax;
    let scheme = QuantScheme {
        level: QuantLevel::Block(BlockSize::new([32])),
        value: QuantValue::Q4F,
        ..Default::default(),
    };
    let mut quantizer = Quantizer {
        calibration,
        scheme,
    };
    module.quantize_weights(&mut quantizer)
}

Fused Dequantization

When a module is quantized, you can use it as if it were in full precision, and by leveraging the quantization propagation setting, we can use a full-precision signal with the quantized weights. The tensor type for float tensors is actually an enum:

#[derive(Debug, Clone)]
/// A primitive tensor representation.
pub enum TensorPrimitive<B: Backend> {
    /// Float tensor primitive.
    Float(B::FloatTensorPrimitive),
    /// Quantized float tensor primitive.
    QFloat(B::QuantizedTensorPrimitive),
}
  

Most tensor operations are not defined on the quantized tensor, and therefore require a call to dequantize:

impl<B: Backend> TensorPrimitive<B> {
    /// Returns the full tensor representation.
    pub fn tensor(self) -> B::FloatTensorPrimitive {
        match self {
            Self::QFloat(tensor) => B::dequantize(tensor),
            Self::Float(tensor) => tensor,
        }
    }
} 

However, using the fusion backend, we can fuse the dequantization kernel with pretty much anything, never actually representing the floating-point tensor in global memory and reducing global memory reads and writes.

Quantized Matmul

One of the only exceptions to dequantization fusing is for complex, compute-bound operations such as matmul. We adapted our kernel engine to handle quantized inputs for maximum performance. We need to perform the same optimizations for multiple compute-bound kernels such as convolution and pooling operations, but this will be done in following releases.

LLVM Backend

We developed a new CPU backend that reuses the same compute infrastructure as our GPU backends based on CubeCL, enabling true multi-platform execution [10]. The CPU backend leverages portable SIMD instructions, autotuning, JIT compilation, and dynamic graph fusion—all the same optimizations that make our GPU backends efficient.

We achieved this by adding a new compiler and runtime to CubeCL. We integrated an MLIR compiler using the vector dialect, which automatically targets SIMD instructions when lowering the representation to LLVM and compiling to binary.

However, CPU architecture differs fundamentally from GPUs in how work is divided and executed. The divide-and-conquer logic that works well on GPUs doesn't translate directly to CPUs. We addressed this by setting the plane size to 1, since CPUs don't have the concept of warps like GPUs do. As a reminder, a plane in CubeCL corresponds to a warp in CUDA, a wavefront in HIP, a simdgroup in Metal, and a subgroup in Vulkan. We could have chosen to set the SIMD lane width as the plane size, but that would introduce significant complexity in handling synchronization and divergent computation within a single plane. Instead, we chose simplicity and ease of understanding by using the Line type to represent a SIMD lane rather than a plane. The trade-off is that we will need to update our existing kernels and create new ones specifically optimized for CPU execution patterns.

The CPU backend is available in this release, but don't expect strong performance yet—we still need to adapt our kernels for CPU-specific optimizations.

Other Improvements

Persistent Memory

In our effort to reduce the global memory footprint of models, we introduced the concept of persistent memory. All CubeCL backends use a memory pool to minimize costly allocations and deallocations on GPUs. Each memory pool features multiple page sizes, determined by the maximum memory available on a device. However, during inference and training, it's common to see allocated memory that remains unused by a model, often due to padding. Padding can account for up to 50% of total allocated memory, and in some cases, even 70–80%. One potential optimization to reduce memory usage is virtual memory with smaller blocks, but the API for creating a virtual memory table is unavailable on WGPU, limiting portability. Additionally, virtual memory introduces data fragmentation issues. A simpler improvement involves defining memory pools tailored to the current model. This is where persistent memory comes in. The concept is straightforward: some tensors maintain a static shape during computation, allowing them to be allocated in a memory page that matches their exact size, eliminating padding. In most models, parameters follow this logic, so Burn now allocates model parameters with persistent memory by default. However, you can easily disable persistent memory if it doesn't optimize memory usage for your workflow.

# Configuration file 'cubecl.toml' at the root of a project
[streaming]
persistent_memory = "disabled"
You can also enforce persistent memory for all allocations. This approach carries risks, as it may cause memory leaks, but if your workflow avoids dynamic-shape tensors, it could potentially eliminate padding entirely.
# Configuration file 'cubecl.toml' at the root of a project
[streaming]
persistent_memory = "enforced"

You can also enable persistent memory using the new backend function, which allows fine-grained control over allocations.
fn persistent_memory<B: Backend>(distribution: Distribution, device: &B::Device) -> Tensor<B, 2> {
    let (x, y) = B::memory_persistent_allocations(device, (), |_| {
        // Only x,y will be allocated using persistent memory.
        let x = Tensor::random([32, 32], distribution, device);
        let y = Tensor::random([32, 32], distribution, device);
        (x, y)
    });
    x.matmul(y)
}

Evaluator

The burn-train training loop previously lacked a proper evaluation loop. We aimed to reuse as much of the logging code from the Learner struct as possible, but tailored for evaluation. To achieve this, we modified the Learner's return type to include all potential states the training loop might hold. This allows us to transfer the renderer to the evaluator, which can display metrics collected during evaluation alongside training metrics in the CLI dashboard. The API is simple to use and follows the same conventions as the Learner types. Here's a small example:
// You can create multiple evaluation loops using the same renderer.
// Super useful if you want to test on multiple datasets.
fn evaluate<B: Backend>(
    name: &str,
    model: Model<B>,
    renderer: Box<dyn MetricsRenderer>,
    dataset_test: impl Dataset<MnistItem> + 'static,
    batch_size: usize,
) -> Box<dyn MetricsRenderer> {
    let batcher = MnistBatcher::default();
    let dataloader_test = DataLoaderBuilder::new(batcher)
        .batch_size(batch_size)
        .num_workers(2)
        .build(dataset_test);

    let evaluator = EvaluatorBuilder::new(ARTIFACT_DIR)
        .renderer(renderer)
        .metrics((AccuracyMetric::new(), LossMetric::new()))
        .build(model);

    evaluator.eval(name, dataloader_test)
}

fn main() {
  ...
  let learner = LearnerBuilder::new(ARTIFACT_DIR)
    .metrics((AccuracyMetric::new(), LossMetric::new()))
    .with_file_checkpointer(CompactRecorder::new())
    .num_epochs(config.num_epochs)
    .learning_strategy(burn::train::LearningStrategy::SingleDevice(device))
    .build(model, config.optimizer.init(), lr_scheduler.init().unwrap());
    
  let result = learner.fit(dataloader_train, dataloader_valid);
  let (renderer, model) = (result.renderer, result.model);

  let renderer = evaluate(
    "MNIST Official Test Set",
    model.clone(),
    renderer,
    MnistDataset::test(),
    batch_size,
  );
  let renderer = evaluate(
    "MNIST Custom Test Set",
    model,
    renderer,
    MnistDataset::custom("custom/annotations"),
    batch_size,
  );

  // Wait for user input before closing the renderer, allowing the user to analyze the metrics.
  renderer.manual_close();
}
  

Burn Storage

This release introduces burn-store, a new crate providing advanced model storage and serialization capabilities for Burn. It implements comprehensive safetensors support with cross-framework interoperability, enabling seamless PyTorch and Burn model conversion, efficient large model handling through lazy loading, and flexible tensor filtering and remapping capabilities. Here's what the code looks like:

// Load PyTorch model with automatic conversions
let mut store = SafetensorsStore::from_file("pytorch_model.safetensors")
    .with_from_adapter(PyTorchToBurnAdapter)      // Auto-transpose weights
    .with_regex(r"^transformer..*")              // Only load transformer
    .with_key_pattern(r".attn.", ".attention.")  // Rename layers
    .allow_partial(true);                        // Skip unknown tensors

model.load_from(&mut store)?;

This release also introduces a new file format for optimized weight loading and saving, but it isn't yet integrated with the recorder system in Burn. More to come in following releases.

Burn Import

This release adds comprehensive shape type support for ONNX operations. Burn's ONNX import system uses code generation based on the provided ONNX files, which means imported models produce high-level Rust code that you can modify for your specific needs [11].

However, ONNX and Burn handle shapes differently. In ONNX, shapes are represented as tensors. In Burn, they're simple vectors of usize values. This difference required us to rethink our shape handling approach.

Burn's ONNX implementation now uses a dedicated Shape type instead of Tensor for shape operations. This lets us manipulate shapes efficiently on the CPU, since shape data is just a small array of usize values. By avoiding tensor representations for shapes, we eliminate expensive CPU-GPU data transfers, which improves both compatibility and performance.

Migration Guide

Learning Strategy

We refactored the Learner to support better distributed training strategies. Instead of registering a list of device(s), you now specify a training strategy.

let learner = LearnerBuilder::new(artifact_dir)
  .metric_train_numeric(AccuracyMetric::new())
  .metric_valid_numeric(AccuracyMetric::new())
  .metric_train_numeric(LossMetric::new())
  .metric_valid_numeric(LossMetric::new())
  .with_file_checkpointer(CompactRecorder::new())
-   .devices(vec![device.clone()])
+   .learning_strategy(LearningStrategy::SingleDevice(device.clone()))
  .num_epochs(config.num_epochs)
  .summary()
  .build(
      config.model.init::<B>(&device),
      config.optimizer.init(),
      config.learning_rate,
  );

Learner Training Result

The Learner previously lacked an evaluation loop. We extended its return type to include all training states in a TrainingResult, which includes the trained model and a metrics renderer.

- let model_trained = learner.fit(dataloader_train, dataloader_valid);
+ let result = learner.fit(dataloader_train, dataloader_valid);

- model_trained
+ result
+    .model
     .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
     .expect("Trained model should be saved successfully");

This enables the renderer to be reused by the new evaluator so that training and evaluation metrics appear together in the TUI dashboard:

let mut renderer = result.renderer;
let evaluator = EvaluatorBuilder::new(artifact_dir)
    .renderer(renderer)
    .metrics((AccuracyMetric::new(), LossMetric::new()))
    .build(result.model.clone());

evaluator.eval(name, dataloader_test);

Interface Changes

Config

The Config trait now requires Debug:

- #[derive(Config)]
+ #[derive(Config, Debug)]
  pub struct TrainingConfig {
      // ...
  }
BatchNorm

BatchNorm no longer requires the spatial dimension generic:

#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
    conv: nn::conv::Conv2d<B>,
-   norm: BatchNorm<B, 2>,
+   norm: BatchNorm<B>,
    pool: Option<MaxPool2d>,
    activation: nn::Relu,
}
Backend Seed

Seeding is now device-specific:

- B::seed(seed);
+ B::seed(&device, seed);
Tensor

For consistency with other methods like unsqueeze() / unsqueeze_dim(dim), squeeze(dim) was renamed:

- tensor.squeeze(dim)
+ tensor.squeeze_dim(dim)

We've also added a tensor.squeeze() method which squeezes all singleton dimensions.

Finally, we removed tensor ^ T syntax, which was clunky.

- use burn::tensor::T;
- tensor ^ T
+ tensor.t()

tensor.t() is also a simple alias for tensor.transpose().

References

[1]Github Release Page
[2]Crate burn-collective
[3]Multi-Level Intermediate Representation (MLIR)
[4]The LLVM Compiler Infrastructure
[5]CUDA: Stream management
[6]Tokio - An asynchronous Rust runtime
[7]Pinned Host Memory
[8]Crate parking_lot
[9]A comprehensive guide of Distributed Data Parallel (DDP)
[10]Crate cubecl-cpu
[11]Importing ONNX Models in Burn

Join the mailing list

Join our community! We'd love to keep you in the loop with our newsletter.

unsubscribed

Copyright 2025 © Burn | Tracel Inc. All rights reserved. Design by