Burn 0.21.0 Release: Up to 8× Lower Framework Overhead, Differentiable Collectives and Improved Kernels

Flame digital art generated by stable diffusion.

Burn 0.21.0 brings 4 months of improvements that make the framework significantly faster and more reliable across the board. The gains span distributed workflows for training large models all the way down to small-model inference, where the reduced framework overhead becomes especially noticeable.

We rethought our distributed computing stack around differentiable collective operations. Kernel selection is now more reliable thanks to better autotuning and a new validation layer, and a project-level burn.toml file lets you tweak those internals (and many others) without recompiling. A reworked device handle reduces framework overhead, and a new burn-dispatch crate simplifies backend selection while paving the way for faster compile times. The release also ships burn-flex, a lightweight eager CPU backend for WebAssembly and embedded targets that replaces burn-ndarray. Finally, we added early off-policy reinforcement learning support and a fresh round of kernel work on GEMV, top-k, and FFT. The post highlights the headline changes, the release includes many bug fixes and other changes, all listed in the release notes[1].

Distributed Training

To improve distributed performance in Burn when training on multiple GPUs, we had to rethink our earlier strategies. Burn should be flexible enough to allow many kinds of parallelization strategies, and to do that we need strong foundations.

We redesigned our distributed story around differentiable collective operations, which serve as the base for improved multi-device communication and a better DDP trainer. The work isn't complete yet, but the foundation is in place.

Two of the most common building blocks in distributed training are device transfer (to_device) and all_reduce, and both got a substantial overhaul this release. As the chart below shows, on 4 CUDA GPUs we measured roughly 16-21× faster device transfers and around 6× faster all-reduce across the two shapes we tested.

Distributed Operations on 4 GPUs (Lower is Better)CubeCL on Cuda(0..3) | non-linear scale per groupBurn 0.20Burn 0.21to_device · shape (32, 512, 1024) · 16.4× fasterBurn 0.2021.13 msBurn 0.211.29 msto_device · shape (128, 512, 2048) · 21.2× fasterBurn 0.20170.44 msBurn 0.218.03 msall_reduce · shape (32, 512, 1024) · 6.3× fasterBurn 0.2098.37 msBurn 0.2115.64 msall_reduce · shape (128, 512, 2048) · 6.2× fasterBurn 0.20780.91 msBurn 0.21126.18 ms

Kernel Reliability Improvements

In this release we spent a lot of time improving the reliability of kernel selection: better autotuning with smarter kernel priority, kernel grouping, and more reliable micro-benchmarking with proper scoring instead of simply selecting the median.

There's still more to do here. One direction we're exploring is computing theoretical values for benchmarks (memory throughput, compute throughput) and setting a threshold at which we stop autotuning.

We also added a validation layer on top of CubeCL[2] kernels. With this, we were able to identify kernels that generated out-of-bounds memory accesses and could potentially cause memory errors. This validation is opt-in through the new burn.toml configuration file described in the next section.

The burn.toml Configuration File

Each release ships more internal components worth tweaking: fusion's beam search, autotune aggressiveness, compilation cache and validation modes, streaming concurrency, memory pool persistence, and that list continues to grow. Instead of spreading these knobs across environment variables, build features, and ad-hoc Rust APIs, we introduced a single project-level burn.toml file. Drop it at the root of your project and parameterize every internal subsystem without touching code or recompiling.

The same file is also where we expose logging levels across CubeCL and Burn services. You can dial visibility up when you're profiling or chasing a regression, and dial it back down for production runs. No rebuild required. Here's a representative burn.toml:

[fusion.beam_search]
max_blocks = 5      # usize. Higher = more fusion opportunities, more cache misses.

[cubecl.autotune]
level = "balanced"  # "minimal" | "balanced" | "extensive" | "full"
cache = "target"    # "local" | "target" | "global" | { file = "<path>" }

[cubecl.compilation]
check_mode = "auto" # "enforce" | "validate" | "auto"
cache = "target"    # "local" | "target" | "global" | { file = "<path>" }

[cubecl.streaming]
max_streams = 8     # u8

[cubecl.memory]
persistent_memory = "enabled"  # "enabled" | "disabled" | "enforced"

Logging is configured per component, and a single logger can fan out to several destinations at once: a file, stdout, stderr, or routing through the standard log crate so an existing logging stack picks the messages up. Levels exist per component and, for now, aren't standardized across subsystems: autotune speaks in terms of disabled/minimal/full, fusion's beam search uses disabled/basic/verbose, and other subsystems have their own vocabulary. Unifying these is on the to-do list, but the per-component flexibility has been useful while each subsystem is still evolving.

[cubecl.autotune.logger]
level = "full"                  # "disabled" | "minimal" | "full"
targets = ["stdout", "log", { file = "autotune.log" }]

[fusion.beam_search.logger]
level = "basic"                 # "disabled" | "basic" | "verbose"
targets = ["stderr"]

[cubecl.compilation.logger]
level = "minimal"               # vocabulary varies per subsystem
targets = [
    "stdout",
    "stderr",
    { file = "build.log" },
]

Up to 8.2× Lower Framework Overhead

One of the biggest challenges in a deep learning framework is efficiently executing computations while minimizing overhead. If the framework spends more time figuring out how to run an algorithm than actually running it, slowdowns appear. That's why with fusion enabled we sometimes saw no meaningful speedup, and occasionally even regressions.

To tackle this, we substantially refined our device handle implementation. A device handle is a primitive inner-mutability data structure that lets different device services execute in a context where mutability is allowed. It was previously implemented with a recursive mutex, so different services could be accessed recursively on the same device, but in that setup, fusion was slowing down the CubeCL runtime.

So we changed the underlying primitive to leverage a custom communication channel tailored for lazy, fire-and-forget task execution. Now worker threads execute tasks, and multiple services can live on the same thread at once. On top of that, multiple worker threads are also available for a single device through a new device service stage. This pipelines the work done in fusion with the work done in the CubeCL runtime, and improves overall speed.

This is a fairly sophisticated piece of the framework's internals: by treating a feature like fusion as a device service, we can cleanly optimize communication between different execution contexts. For those interested in the channel implementation itself, we wrote a dedicated technical post [3].

Launch Overhead: Median Time (Lower is Better)Shape (1, 8, 8, 8) | non-linear scale per groupAvg eager speedup: 3.4×Avg fusion speedup: 5.4×Best: 8.2×Regressions: 00.20.1 eager0.20.1 fusion0.21.0 eager0.21.0 fusion512 reps, 1 thread0.20.1 eager13.94 ms0.20.1 fusion7.44 ms0.21.0 eager9.82 ms0.21.0 fusion2.56 ms1024 reps, 1 thread0.20.1 eager26.95 ms0.20.1 fusion14.45 ms0.21.0 eager19.22 ms0.21.0 fusion4.61 ms512 reps, 4 threads0.20.1 eager62.14 ms0.20.1 fusion68.67 ms0.21.0 eager46.39 ms0.21.0 fusion9.22 ms1024 reps, 4 threads0.20.1 eager110.27 ms0.20.1 fusion126.90 ms0.21.0 eager90.27 ms0.21.0 fusion17.95 ms512 reps, 8 threads0.20.1 eager166.47 ms0.20.1 fusion140.38 ms0.21.0 eager121.42 ms0.21.0 fusion17.92 ms1024 reps, 8 threads0.20.1 eager330.09 ms0.20.1 fusion261.20 ms0.21.0 eager240.51 ms0.21.0 fusion35.36 ms512 reps, 16 threads0.20.1 eager466.27 ms0.20.1 fusion367.01 ms0.21.0 eager264.87 ms0.21.0 fusion53.56 ms1024 reps, 16 threads0.20.1 eager926.66 ms0.20.1 fusion733.06 ms0.21.0 eager547.62 ms0.21.0 fusion89.19 ms

The impact is most visible when fusion is enabled. On a small (1, 8, 8, 8) shape across thread counts from 1 to 16, the 0.21.0 fusion bars sit dramatically below their 0.20.1 counterparts, overhead drops by an average of 5.4×, peaking at 8.2× on the heaviest 16-thread workload (89 ms vs 733 ms). The base vulkan runtime also sees a consistent 3.4× average improvement. There are no regressions across the configurations we measured.

Burn Dispatch

A new burn-dispatch crate introduces a global backend dispatch mechanism. This simplifies backend selection and is a major step toward removing the Backend generic from user-facing APIs. Instead of having backends as a generic on the user side, we expose them through the device:

let device = DispatchDevice::Cuda(CudaDevice::Cuda(0));
let device = DispatchDevice::Wgpu(WgpuDevice::DiscreteGpu(0));
let device = DispatchDevice::LibTorch(LibTorchDevice::Cpu);

Feature flags remain, but from the user's point of view they only enable new variants on the device and allow for easy runtime device (i.e., backend) selection.

Early benchmarks suggest no performance regression from this mechanism, since only a static enum dispatch is performed for tensor operations. Combined with the new device handle mechanism, those static dispatches are executed on the user's thread, so they don't slow down fusion or the CubeCL runtime.

The biggest advantage of this dispatch approach over the current generic-backend setup is future compilation improvements. With careful dependency handling and type obfuscation, we can drive incremental compile times for Burn projects way down. An early experiment reduced incremental release compilation time from 11 seconds to under 1 second. This isn't a promise, but it's very likely that Burn will soon be compiling fast enough for fast experimentation with minimal friction.

Burn Flex: A Lightweight Eager CPU Backend

This release introduces burn-flex, a new CPU backend implemented entirely in Rust. It replaces burn-ndarray, which is now on a deprecation path.

Unlike the CubeCL-based backends, burn-flex is intentionally simple: eager execution only, with no fusion, no autotune, and no burn.toml integration. The trade-off is deliberate. When you're targeting WebAssembly, an embedded device, or just running a small model, you usually don't want a JIT, kernel cache, or compute scheduler in the loop, you want a small, predictable, dependency-light CPU backend that runs the math and gets out of the way. That's the niche burn-flex fills.

For higher-performance CPU and GPU workloads, the CubeCL-based backends remain the right choice and continue to be where most of the kernel work goes. burn-ndarray still works in 0.21 to give existing projects a transition window, but new development should target burn-flex. We'll keep burn-ndarray around for one or two more releases and then remove it.

Early Reinforcement Learning Support

Since the last release, we've been working on generalizing the training utilities in burn-train while still allowing customization. As part of that, we added a new off-policy training loop that follows the same pattern as the supervised training loop in burn-train. Here's a simple Deep Q-Network example in Burn:

pub fn run<B: AutodiffBackend>(device: B::Device) {
    let dqn_config = DqnAgentConfig {
        gamma: 0.99,
        learning_rate: 3e-4,
        tau: 0.005,
        epsilon_start: 0.99,
        epsilon_end: 0.05,
        epsilon_decay: 6000.0,
    };
    let model_config = MlpNetConfig {
        num_layers: 3,
        dropout: 0.0,
        d_input: 4,
        d_output: 2,
        d_hidden: 64,
    };
    let learning_config = OffPolicyConfig {
        num_envs: 8,
        autobatch_size: 8,
        replay_buffer_size: 50_000,
        train_interval: 8,
        eval_interval: 4_000,
        eval_episodes: 5,
        train_batch_size: 128,
        train_steps: 4,
        warmup_steps: 0,
    };

    let policy_model = MlpNet::<B>::new(&model_config, &device);
    let optimizer = AdamWConfig::new()
        .with_grad_clipping(Some(GradientClippingConfig::Value(100.0)))
        .init();
    let agent = DqnLearningAgent::new(policy_model, optimizer, dqn_config);
    let learner = RLTraining::new(ARTIFACT_DIR, CartPoleWrapper::new)
        .metrics_train((LossMetric::new(),))
        .metrics_agent((ExplorationRateMetric::new(),))
        .metrics_episode((EpisodeLengthMetric::new(), CumulativeRewardMetric::new()))
        .with_file_checkpointer(CompactRecorder::new())
        .num_steps(40_000)
        .with_learning_strategy(burn::train::RLStrategies::OffPolicyStrategy(
            learning_config,
        ))
        .summary();

    let _result = learner.launch(agent);
}

Users only have to implement a few traits, all the orchestration is handled by the training loop.

Kernel Optimizations

As always, we're constantly working on the performance of our kernels[4], and this release is no exception. This time we focused on GEMV, top-k and FFT.

GEMV performance was significantly lacking; many cases are now fixed and reach state-of-the-art performance across all CubeCL runtimes. As the benchmarks below show, on a (4096, 4096) column-major problem we match LibTorch on CUDA and beat it on CPU. The exception is row-major GEMV, which isn't yet optimal in our case, we focused mostly on column-major, since that's the most optimal layout and the one that should be prioritized.

GEMV: Execution Time (Lower is Better)Shape: (1, 1, 4096) × (4096, 4096) | non-linear scale per groupCubeCLLibTorchColumn-major · CUDACubeCL174µsLibTorch171µsColumn-major · CPUCubeCL1.451msLibTorch1.626msRow-major · CUDACubeCL231µsLibTorch177µsRow-major · CPUCubeCL2.749msLibTorch1.692ms

Top-k was also a focus this release. As usual with our kernels, we implemented a general kernel scheme rather than a fixed set of kernels for a few common configurations. As the benchmarks below show, we are routinely an order of magnitude, and sometimes more than 40× faster than LibTorch on CPU across a range of k, axis, and shape combinations.

Top-k on CPU: Execution Time (Lower is Better)argtopk(k) on a given axis | non-linear scale per group for visibilityCubeCL (CPU)LibTorch (CPU)k=1 · axis=0 · shape (16, 1024, 1024) · 41.8× fasterCubeCL645µsLibTorch26.94msk=1 · axis=0 · shape (2, 6144, 6144) · 15.5× fasterCubeCL7.44msLibTorch115.36msk=3 · axis=0 · shape (16, 1024, 1024) · 17.5× fasterCubeCL1.72msLibTorch30.05msk=5 · axis=0 · shape (256, 256, 256) · 15.7× fasterCubeCL1.87msLibTorch29.21msk=5 · axis=1 · shape (16, 1024, 1024) · 8.9× fasterCubeCL1.99msLibTorch17.60msk=10 · axis=0 · shape (256, 256, 256) · 8.4× fasterCubeCL3.67msLibTorch30.82msk=1 · axis=2 · shape (16, 1024, 1024) · 3.1× fasterCubeCL957µsLibTorch2.96msWhere LibTorch's specialized kernel wins (axis=2, large k)k=10 · axis=2 · shape (256, 256, 256) · LibTorch 1.36× fasterCubeCL27.72msLibTorch20.45msk=10 · axis=2 · shape (2, 6144, 6144) · LibTorch 1.38× fasterCubeCL24.27msLibTorch17.55ms

That said, the picture isn't uniformly green on axis=2. For small k we still come out ahead (~3× at k=1), but at k=10 LibTorch wins by around 1.36-1.42×. This isn't a kernel we're missing, LibTorch ships a very fast specialized kernel for exactly this configuration, while ours is a single general scheme that covers every layout. The flip side is the broader story you can read off the rest of the chart: LibTorch's performance is highly fragmented between settings, while ours stays predictable and consistent across the whole space. We expect to claw back that ~30% by tuning the last-axis path of our existing kernel, no new specialized kernel required.

Attention Progress

Work on flash attention continued as well, but we ran into challenges handling the complexity of compute-bound kernels. That's why we started a fairly ambitious refactoring around tile abstractions inside CubeK. The new abstraction will be more fine-grained than other tile DSLs like TileLang and Triton, while bringing significant composability and ease of use for writing complex compute-bound, matrix-oriented programs.

Early Complex Number Support

We also introduce the highly requested fast Fourier transform and inverse fast Fourier transform in this release. This is the first step toward supporting complex tensors in Burn.

What's Next

Several items in this release are explicit setup for what lands next. The Backend trait will disappear from the user API. burn-dispatch is the foundation, and the next release puts it in front of users so models, optimizers, and training code stop carrying a B: Backend generic everywhere. We also want tighter integration between CubeCL, CubeK, and Burn so that extending the framework with your own custom kernels, reusing the CubeK building blocks, feels like a first-class workflow instead of duct taping multiple crates together.

As always, we'd love to hear your feedback. Try the release, file issues, and let us know what you'd like to see next.

Migration Guide

burn-dataset cache directory

To respect platform conventions, we switched from using a hardcoded ~/.cache directory root for downloaded artifacts. For Linux users without $XDG_CACHE_HOME configured, this change has no effect. The cache directory remains ~/.cache.

  • Linux: $XDG_CACHE_HOME or $HOME/.cache
  • macOS: $HOME/Library/Caches
  • Windows: {FOLDERID_LocalAppData}

Interface Changes

TensorData

The shape field now stores a Shape instead of a Vec<usize>. Existing binary records using BinFileRecorder or BinBytesRecorder are not forward-compatible and must be converted before upgrading.

static STATE_ENCODED: &[u8] = include_bytes!("model.bin");

let model: Model<B> = Model::new(&Default::default());

// Old format can still be loaded before upgrade, but must be re-saved
// in a forward-compatible format.
let record = BinBytesRecorder::<FullPrecisionSettings, &'static [u8]>::default()
    .load(STATE_ENCODED, &Default::default())
    .expect("Failed to decode state");

let model = model.load_record(record);

model
    .save_file(
        "model.mpk",
        &NamedMpkFileRecorder::<FullPrecisionSettings>::new(),
    )
    .unwrap();
Module

The module derive macro has been improved, and the Ignored<T> wrapper is now deprecated. For fields that should not considered modules, use #[module(skip)] instead.

#[derive(Module, Debug)]
pub struct Conv1d<B: Backend> {
-    pub padding: Ignored<PaddingConfig1d>,
+    #[module(skip)]
+    pub padding: PaddingConfig1d,
}
PaddingConfig1d/2d

We added support for explicit asymmetric padding. If you were using explicit padding, you must now specify the same value for all pairs. Note that PaddingConfig3d does not support asymmetric padding yet.

// Symmetric (left, right)
- PaddingConfig1d::Explicit(1)
+ PaddingConfig1d::Explicit(1, 1)

// Symmetric (top, left, bottom, right)
- PaddingConfig2d::Explicit(1, 1)
+ PaddingConfig2d::Explicit(1, 1, 1, 1)
Gelu

The Gelu activation module can now be configured with tanh approximation. This only affects code that instantiated Gelu directly.

- let activation = Gelu;
+ let activation = Gelu::new(); // or Gelu::default()

The position-wise feed-forward module now has a configurable activation function. To keep it backwards compatible with previously saved records, the field is marked as #[module(skip)].

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    // ...
-   /// GELU activation function.
-   pub gelu: Gelu,
+   /// Activation function.
+   #[module(skip)]
+   pub activation: Activation<B>,
}
Shape

The Shape fields are now private and some methods have been renamed. ShapeError has been renamed to MetadataError.

- let b = tensor.shape().dims[0];
+ let b = tensor.shape()[0];

- if let Err(ShapeError::RankMismatch{...}) = lhs.broadcast(&rhs) {
+ if let Err(MetadataError::RankMismatch{...}) = lhs.broadcast(&rhs) {

- let shape = shape.swap(1, 2).unwrap();
+ let shape = shape.swapped(1, 2).unwrap();

- let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
+ let shape = shape.permuted(&[0, 2, 1, 3]).unwrap();
DType

The boolean data type was expanded to include its storage type.

match bool_tensor.dtype() {
-   DType::Bool => todo!(),
+   DType::Bool(BoolStore::Native) => todo!(),
+   DType::Bool(BoolStore::U8) => todo!(),
+   DType::Bool(BoolStore::U32) => todo!(),
    _ => unreachable!(),
}
powf

powf is no longer supported for Int tensors, as it previously relied on incorrect implicit truncation. These operations are now only available for Float tensors.

- let tensor_i = tensor_int.powf(tensor_float);
+ let tensor_f = tensor_int.float().powf(tensor_float);

- let tensor_i = tensor_int.powf_scalar(scalar_float);
+ let tensor_f = tensor_int.float().powf_scalar(scalar_float);
Backend

Backend tensor creation and conversion ops now take an explicit output dtype. This removes backend-specific dtype inference and ensures consistent behavior across backends. (Backend implementors only.)

impl BoolTensorOps<Self> for MyBackend {
-    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
+    fn bool_empty(
+        shape: Shape,
+        device: &Device<Self>,
+        dtype: BoolDType,
+    ) -> BoolTensor<Self> {
        // use `dtype` instead of inferring internally
    }

-    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
+    fn bool_into_int(
+        tensor: BoolTensor<Self>,
+        out_dtype: IntDType,
+    ) -> IntTensor<Self> {
        // use `out_dtype` instead of inferring internally
    }
}

Associated types were moved from Backend to BackendTypes. Prefer the type aliases (Device<B>, FloatTensor<B>, etc.) to avoid type resolution issues.

impl BoolTensorOps<Self> for MyBackend {
-    fn bool_empty(
-        shape: Shape,
-        device: &<Self as Backend>::Device,
-        dtype: BoolDType,
-    ) -> <Self as Backend>::BoolTensorPrimitive {
+    fn bool_empty(
+        shape: Shape,
+        device: &Device<Self>,
+        dtype: BoolDType,
+    ) -> BoolTensor<Self> {
    }
}

References

[1]Github Release Page
[2]CubeCL: Multi-platform high-performance compute language extension for Rust.
[3]A Faster Channel for Lazy Task Execution
[4]CubeK: high-performance multi-platform kernels in CubeCL

Join the mailing list

Join our community! We'd love to keep you in the loop with our newsletter.

unsubscribed

Copyright 2025 © Burn | Tracel Inc. All rights reserved. Design by Perdomo Logo