Burn 0.17.0 Release Notes

Flame digital art generated by stable diffusion.
Thu Apr 24 2025
Guillaume Lagrange

Overview

This release brings major upgrades in performance and platform compatibility (most notably, a new Metal backend via WGPU passthrough). CubeCL now powers backends for CUDA, Metal, ROCm, Vulkan and WebGPU. Tensor operation fusion support has been greatly expanded to optimize element-wise, reductions and matmul operations.

A new compilation cache and improved autotune cache speed up repeated runs by reusing precompiled binaries and tuned kernel configurations. Data parallel training now scales better across multiple GPUs with automatic batch assignment to each worker. A new tensor slice API offers a simpler, more intuitive way to index tensors.

This version also comes with broad performance gains across tensor operations, especially for reductions, matmul, and convolutions. An initial implementation of quantized matmul is now available, with further quantization improvements planned in the future.

As with previous releases, this includes various bug fixes, further optimizations and enhanced documentation.

Be sure to check out the new burn-bench to compare performance across different versions, hardware and backends.

CubeCL Backends

Burn supports CUDA, ROCm, Vulkan and WebGPU, and the newly added Metal backend.

Each backend can be used through their respective type aliases, provided that the appropriate backend feature flag is also enabled.

Metal
burn = { version = "0.17.0", features = ["metal"] }

use burn::prelude::*;
use burn::backend::wgpu::{Metal, WgpuDevice};
    
let tensor = Tensor::<Metal, 2>::zeros([2, 4], &WgpuDevice::default());
Cuda
burn = { version = "0.17.0", features = ["cuda"] }

use burn::prelude::*;
use burn::backend::cuda::{Cuda, CudaDevice};

let tensor = Tensor::<Cuda, 2>::zeros([2, 4], &CudaDevice::default());
Rocm
burn = { version = "0.17.0", features = ["rocm"] }

use burn::prelude::*;
use burn::backend::rocm::{Rocm, HipDevice};

let tensor = Tensor::<Rocm, 2>::zeros([2, 4], &HipDevice::default());
Vulkan
burn = { version = "0.17.0", features = ["vulkan"] }

use burn::prelude::*;
use burn::backend::wgpu::{Vulkan, WgpuDevice};

let tensor = Tensor::<Vulkan, 2>::zeros([2, 4], &WgpuDevice::default());
WebGpu
burn = { version = "0.17.0", features = ["webgpu"] }

use burn::prelude::*;
use burn::backend::wgpu::{WebGpu, WgpuDevice};

let tensor = Tensor::<WebGpu, 2>::zeros([2, 4], &WgpuDevice::default());
⚠️ WARNING

When using one of the wgpu backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the wgpu dependency chain. To resolve this issue, add the following line at the top of your main.rs or lib.rs file:

#![recursion_limit = "256"]

The default recursion limit (128) is often just below the required depth (typically 130–150) due to deeply nested associated types and trait bounds.

Data Loader and Batcher

The Batcher trait has been updated to improve multi-device support. Previously, batcher implementations stored a device internally, which could lead to all data being loaded on the same device. The latest changes have the DataLoader generic over the backend, while the device is passed explicitly:

-impl<B: Backend> Batcher<MyItem, MyBatch<B>> for MyBatcher<B> {
+impl<B: Backend> Batcher<B, MyItem, MyBatch<B>> for MyBatcher {
-   fn batch(&self, items: Vec<MyItem>) -> MyBatch<B> {
+   fn batch(&self, items: Vec<MyItem>, device: &B::Device) -> MyBatch<B> {
        // The correct `device` is already provided for the batching logic to use
    }
}

The device can now be set when building a data loader:

let dataloader = DataLoaderBuilder::new(batcher)
        .batch_size(batch_size)
        .shuffle(seed)
        .num_workers(num_workers)
+       .set_device(device)
        .build(dataset);

This step is not required for the Learner, which handles the device configuration automatically.

Better Tensor Slicing & Indexing

Tensor slicing now fully adopts idiomatic Rust range syntax, replacing the older (i64, i64) and Option tuple forms.

For example:

let tensor = Tensor::<B, 2>::zeros([m, n], &device);
-let slice = tensor.slice([(0, -1), (0, -2)]);
+let slice = tensor.slice([0..-1, 0..-2]);

For more complex or mixed range types, use the s![] macro:

let tensor = Tensor::<B, 3>::zeros([b, s, d], &device);
-let slice = tensor.slice([None, Some((t as i64, t as i64 + 1)), None]);
+let slice = tensor.slice(s![.., t..t + 1, ..]);

The macro is inspired by ndarray's s![] (at least, by name) and helps build flexible slice patterns.

use burn::prelude::*;

let tensor = Tensor::<B, 4>::zeros([8, 4, 2, 3], &device);
let slice = tensor.slice(s![..=4, 0..=3, .., -1]);
assert_eq!(slice.dims(), [5, 4, 2, 1]);

Changelog

Module & Tensor

Feature add new one hot function meeting multi-dimensions (ranks) #2613@tiruka
Expand GRU support #2704@nwhitehead
feat: bitwise-ops-for-tensors #2498@quinton11
Feat: Add PoissonNLL loss #2765@salvomcl
Add metric parametrized name #2808@laggui
Add boolean and/or to bool tensors #2802@wingertge
Add ATOL/RTOL defaults #2824@crutcher
Feat: Add tan trig function #2854@Msa360
Refactor quantization schemes #2849#3036@laggui@maxtremblay
Vectorize pooling for optimization #2905@wingertge
Feat: Add Cosh and Sinh #2959@Msa360
Refactor in-memory recorder load args #2892@BjornTheProgrammer
Improve gradient checkpointing #2997@nathanielsimard
Optimize minmax #3009@nathanielsimard
Improve tensor.slice(...) to support multiple range types #3061@laggui

Bug Fixes

Fix bce loss log #2741@laggui
Fix repeat_dim backward w/ dim size > 1 #2777@laggui
[Fix] tch upgrade #2834@wingertge
Check channels_in matches in convolution layers #2944@chlobes
Fixed GroupNorm implementation #2945@computer-whisperer

Backends

Migrate to type magic autotune #2710@wingertge
Feat/fused matmul tune #2726@nathanielsimard
Feat/shared sum #2737@maxtremblay
Improve fusion for broadcasting, mix vectorization and reshape operation #2773#2833@nathanielsimard
Remove from_data conversions in backends #2783@laggui
Feat fuse swap dims #2801#2877@nathanielsimard
[Feature] reduce fuse on read #2870@nathanielsimard
[Feat] SIMD acceleration for ndarray backend #2851@wingertge
Perf/reduce fuse on write #2937@nathanielsimard
[metal] Add CubeCL metal compiler support #2993@syl20bnr
Compilation Cache #3020@nathanielsimard
Cubecl quantize matmul #3022#3030@maxtremblay

Bug Fixes

Fix from data fusion #2735#2778@laggui@nathanielsimard
Fix constant creation in fusion to cast at compile time, not runtime #2782@wingertge
Fix two autotune issues on wasm #2899@ArthurBrussee
Fix/reduce out of bounds #2906@nathanielsimard
Fix fusion bug #3031@nathanielsimard
Fix metal backend name #3040@nathanielsimard
Fix matmul dynamic line size support #3056@nathanielsimard
Fix: matmul lower precision / flex32 #3059@nathanielsimard
Fix/autotune cache conflicts #3070@nathanielsimard

Documentation & Examples

Wasserstein Generative Adversarial Network #2660@wangjiawen2013
Add modern lstm #2752@wangjiawen2013
Improve tensor docs #2951@PtiLuky

Fixes

chore: fix some comments #2717@sunxunle
Add hardsigmoid formula and fix WGAN doc + default lr #2706@laggui
Fix db-pedia-infer backend #2736@laggui
Fixed typo in the burn book chapter advanced unit no-std. #2731@xmy314
typo - correct smp_serde to rmp_serde as per crate's name in url #2744@cameronbraid
typo - missing tick which was breaking formatting #2745@cameronbraid
Remove autodiff from generate #2759@laggui
Remove empty format precision specifier #2785@hkBst
Update tch instructions #2844#2976@laggui
Fix from_embedded and bool ops docs #2848@laggui
Fix tiny typo in mathematical expression #2867@janhohenheim
Fix typos #2927@crutcher
Fix/web example #2954#2978@laggui
Fix: burn-book getting-started Use Declarations #2966@jerryshell
chore: fix comment #3008@tsinghuacoder

ONNX Support

Code generation bug fix for ONNX import #2708@antimora
One hot ONNX #2784@akshitgaur2005
Onnx op topk #2305@oojo12
Fix output elem type for unsqueeze and reshape #2807@christeefy
Feat/Split ONNX Import #2568@agelas
Refactor GatherNode to support scalar outputs. #2828@loloxwg
Rename dim to rank for ONNX import #2831@antimora
Add rank inference for tan #2868@Msa360
Fix RandomNormalLike ONNX node output rank #2936@Knight-Ops
Support multiple outputs being tracked in BurnGraph during ONNX conversion #2938@Knight-Ops
Ignore ONNX optional node inputs/outputs #2935@Knight-Ops
Fix ONNX flatten to match spec #2940@catch-twenty-two
burn-import: add some tests for ConstantNode #2623@jameshiew@laggui
Update SUPPORTED-ONNX-OPS.md with the latest info #3064@antimora

Enhancements

Add new burn-vision crate #2753#2810#2842@wingertge
Improve Burn compilation times #2815#2994@nathanielsimard
Support training in no-std #2830@ivila
Perf: Speed up element and TensorData conversion #2913@wingertge
Feat/cubecl caching #2902@nathanielsimard
Improve multi-device data loading strategy #2890#3035@laggui
Autotune level matmul double buffering #2988@nathanielsimard@louisfd

Refactoring

Remove deprecated Data and DataSerialize #2703@laggui
Clean up train system metrics #2707@laggui
Move IR to its own crate #2796#2798@laggui
Refactor burn jit => burn-cubecl #2809@nathanielsimard
Cleanup Tensor Registry in fusion #2826@nathanielsimard
Migrate conv2d to cubecl #2908#3018@wingertge
Update to edition 2024 #2931@laggui
Update runtime names #2909@nathanielsimard
Migrate backend comparison #2961@laggui
Improve test tolerance assertions #3024@maxtremblay@laggui
[hip] Move burn-hip to burn-rocm and rename backend to ROCm #3062@syl20bnr

Miscellaneous

Fix no default features flags + update cubecl #2725@laggui
Replace return with terminate #2742@maxtremblay
Clean up -jit suffix in feature flags and modules #2705@laggui
Fix types under autotune flag #2750@laggui
Fix BackendValues in backend-comparison after removal of jit suffix #2756@syl20bnr
Update cubecl #2764@wingertge
Fix optional burn-import dep + impl module types for isize #2774@laggui
Update cubecl with fix to shared_sum #2779@maxtremblay
feat: using rustls instead of native-tls #2799@ShoofLLC
bump cubecl version with dummy implementations #2814@maxtremblay
Add data_dir optional argument to Huggingface DataLoader to enable some manual download use cases #2817@Pablo1785
Bump xtask to 1.1.9 #2896@syl20bnr
Fix test checks for macos #2952@PtiLuky
Update cargo deps #2962@Brooooooklyn
Add train end event #2967@laggui
Update cubecl bitcast -> reinterpret #2985@maxtremblay
Update wgpu to v25 #3007@syl20bnr
update cubecl: sync full cyclic checked #3025@louisfd
Fix autotune measurement #3043@nathanielsimard

References

[1]Github Release Page

Stay connected

Join our community! We'd love to keep you in the loop with our newsletter.

unsubscribed