Overview
This release brings major upgrades in performance and platform compatibility (most
notably, a new Metal backend via WGPU passthrough). CubeCL now powers backends
for CUDA,
Metal, ROCm, Vulkan and WebGPU. Tensor
operation fusion support has been greatly expanded to optimize element-wise, reductions
and matmul operations.
A new compilation cache and improved autotune cache speed up repeated runs by reusing precompiled binaries and tuned kernel configurations. Data parallel training now scales better across multiple GPUs with automatic batch assignment to each worker. A new tensor slice API offers a simpler, more intuitive way to index tensors.
This version also comes with broad performance gains across tensor operations, especially for reductions, matmul, and convolutions. An initial implementation of quantized matmul is now available, with further quantization improvements planned in the future.
As with previous releases, this includes various bug fixes, further optimizations and enhanced documentation.
Be sure to check out the new burn-bench to compare performance across different versions, hardware and backends.
CubeCL Backends
Burn supports CUDA,
ROCm, Vulkan and WebGPU, and the newly added Metal backend.
Each backend can be used through their respective type aliases, provided that the appropriate backend feature flag is also enabled.
Metal
burn = { version = "0.17.0", features = ["metal"] } use burn::prelude::*;
use burn::backend::wgpu::{Metal, WgpuDevice};
let tensor = Tensor::<Metal, 2>::zeros([2, 4], &WgpuDevice::default()); Cuda
burn = { version = "0.17.0", features = ["cuda"] } use burn::prelude::*;
use burn::backend::cuda::{Cuda, CudaDevice};
let tensor = Tensor::<Cuda, 2>::zeros([2, 4], &CudaDevice::default()); Rocm
burn = { version = "0.17.0", features = ["rocm"] } use burn::prelude::*;
use burn::backend::rocm::{Rocm, HipDevice};
let tensor = Tensor::<Rocm, 2>::zeros([2, 4], &HipDevice::default()); Vulkan
burn = { version = "0.17.0", features = ["vulkan"] } use burn::prelude::*;
use burn::backend::wgpu::{Vulkan, WgpuDevice};
let tensor = Tensor::<Vulkan, 2>::zeros([2, 4], &WgpuDevice::default()); WebGpu
burn = { version = "0.17.0", features = ["webgpu"] } use burn::prelude::*;
use burn::backend::wgpu::{WebGpu, WgpuDevice};
let tensor = Tensor::<WebGpu, 2>::zeros([2, 4], &WgpuDevice::default());
When using one of the wgpu backends, you may encounter compilation errors related
to recursive type evaluation. This is due to complex type nesting within the wgpu dependency chain. To resolve this issue, add the following line at the top of your main.rs or
lib.rs file:
recursion_limit = "256"] The default recursion limit (128) is often just below the required depth (typically 130–150) due to deeply nested associated types and trait bounds.
Data Loader and Batcher
The Batcher trait has been updated to improve multi-device support. Previously, batcher
implementations stored a device internally, which could lead to all data being loaded on the same
device. The latest changes have the DataLoader generic over the backend, while the
device is passed explicitly:
-impl<B: Backend> Batcher<MyItem, MyBatch<B>> for MyBatcher<B> {
+impl<B: Backend> Batcher<B, MyItem, MyBatch<B>> for MyBatcher {
- fn batch(&self, items: Vec<MyItem>) -> MyBatch<B> {
+ fn batch(&self, items: Vec<MyItem>, device: &B::Device) -> MyBatch<B> {
// The correct `device` is already provided for the batching logic to use
}
} The device can now be set when building a data loader:
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(batch_size)
.shuffle(seed)
.num_workers(num_workers)
+ .set_device(device)
.build(dataset);
This step is not required for the Learner, which handles the device
configuration automatically.
Better Tensor Slicing & Indexing
Tensor slicing now fully adopts idiomatic Rust range syntax, replacing the older (i64, i64) and Option tuple forms.
For example:
let tensor = Tensor::<B, 2>::zeros([m, n], &device);
-let slice = tensor.slice([(0, -1), (0, -2)]);
+let slice = tensor.slice([0..-1, 0..-2]);
For more complex or mixed range types, use the s![] macro:
let tensor = Tensor::<B, 3>::zeros([b, s, d], &device);
-let slice = tensor.slice([None, Some((t as i64, t as i64 + 1)), None]);
+let slice = tensor.slice(s![.., t..t + 1, ..]); The macro is inspired by ndarray's s![] (at least, by name) and helps build flexible slice patterns.
use burn::prelude::*;
let tensor = Tensor::<B, 4>::zeros([8, 4, 2, 3], &device);
let slice = tensor.slice(s![..=4, 0..=3, .., -1]);
assert_eq!(slice.dims(), [5, 4, 2, 1]); 