Overview
This release brings major upgrades in performance and platform
compatibility (most notably, a new Metal
backend via WGPU
passthrough). CubeCL now powers backends for CUDA
,
Metal
, ROCm
, Vulkan
and WebGPU
. Tensor operation fusion support has been greatly expanded to
optimize element-wise, reductions and matmul operations.
A new compilation cache and improved autotune cache speed up repeated runs by reusing precompiled binaries and tuned kernel configurations. Data parallel training now scales better across multiple GPUs with automatic batch assignment to each worker. A new tensor slice API offers a simpler, more intuitive way to index tensors.
This version also comes with broad performance gains across tensor operations, especially for reductions, matmul, and convolutions. An initial implementation of quantized matmul is now available, with further quantization improvements planned in the future.
As with previous releases, this includes various bug fixes, further optimizations and enhanced documentation.
Be sure to check out the new burn-bench to compare performance across different versions, hardware and backends.
CubeCL Backends
Burn supports CUDA
,
ROCm
, Vulkan
and WebGPU
, and the
newly added Metal
backend.
Each backend can be used through their respective type aliases, provided that the appropriate backend feature flag is also enabled.
Metal
burn = { version = "0.17.0", features = ["metal"] }
use burn::prelude::*;
use burn::backend::wgpu::{Metal, WgpuDevice};
let tensor = Tensor::<Metal, 2>::zeros([2, 4], &WgpuDevice::default());
Cuda
burn = { version = "0.17.0", features = ["cuda"] }
use burn::prelude::*;
use burn::backend::cuda::{Cuda, CudaDevice};
let tensor = Tensor::<Cuda, 2>::zeros([2, 4], &CudaDevice::default());
Rocm
burn = { version = "0.17.0", features = ["rocm"] }
use burn::prelude::*;
use burn::backend::rocm::{Rocm, HipDevice};
let tensor = Tensor::<Rocm, 2>::zeros([2, 4], &HipDevice::default());
Vulkan
burn = { version = "0.17.0", features = ["vulkan"] }
use burn::prelude::*;
use burn::backend::wgpu::{Vulkan, WgpuDevice};
let tensor = Tensor::<Vulkan, 2>::zeros([2, 4], &WgpuDevice::default());
WebGpu
burn = { version = "0.17.0", features = ["webgpu"] }
use burn::prelude::*;
use burn::backend::wgpu::{WebGpu, WgpuDevice};
let tensor = Tensor::<WebGpu, 2>::zeros([2, 4], &WgpuDevice::default());
When using one of the wgpu
backends, you may encounter compilation
errors related to recursive type evaluation. This is due to complex type
nesting within the wgpu
dependency chain. To resolve this
issue, add the following line at the top of your main.rs
or
lib.rs
file:
#![recursion_limit = "256"]
The default recursion limit (128) is often just below the required depth (typically 130–150) due to deeply nested associated types and trait bounds.
Data Loader and Batcher
The Batcher
trait has been updated to improve multi-device support.
Previously, batcher implementations stored a device internally, which could
lead to all data being loaded on the same device. The latest changes have
the DataLoader
generic over the backend, while the device is
passed explicitly:
-impl<B: Backend> Batcher<MyItem, MyBatch<B>> for MyBatcher<B> {
+impl<B: Backend> Batcher<B, MyItem, MyBatch<B>> for MyBatcher {
- fn batch(&self, items: Vec<MyItem>) -> MyBatch<B> {
+ fn batch(&self, items: Vec<MyItem>, device: &B::Device) -> MyBatch<B> {
// The correct `device` is already provided for the batching logic to use
}
}
The device can now be set when building a data loader:
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(batch_size)
.shuffle(seed)
.num_workers(num_workers)
+ .set_device(device)
.build(dataset);
This step is not required for the Learner
, which handles
the device configuration automatically.
Better Tensor Slicing & Indexing
Tensor slicing now fully adopts idiomatic Rust range syntax, replacing
the older (i64, i64)
and Option tuple forms.
For example:
let tensor = Tensor::<B, 2>::zeros([m, n], &device);
-let slice = tensor.slice([(0, -1), (0, -2)]);
+let slice = tensor.slice([0..-1, 0..-2]);
For more complex or mixed range types, use the s![]
macro:
let tensor = Tensor::<B, 3>::zeros([b, s, d], &device);
-let slice = tensor.slice([None, Some((t as i64, t as i64 + 1)), None]);
+let slice = tensor.slice(s![.., t..t + 1, ..]);
The macro is inspired by ndarray's s![] (at least, by name) and helps build flexible slice patterns.
use burn::prelude::*;
let tensor = Tensor::<B, 4>::zeros([8, 4, 2, 3], &device);
let slice = tensor.slice(s![..=4, 0..=3, .., -1]);
assert_eq!(slice.dims(), [5, 4, 2, 1]);