Burn 0.20.0 Release: Unifying CPU & GPU kernels with CubeCL

It's been an intense few months of development, and we're ready to release Burn 0.20[1]. This version marks a major turning point for the ecosystem with the introduction of CubeK[2]. Our goal was to solve a classic challenge in deep learning: achieving peak performance on diverse hardware without maintaining fragmented codebases. By unifying CPU and GPU kernels through CubeCL[3], we've managed to squeeze maximum efficiency out of everything from NVIDIA Blackwell GPUs to standard consumer CPUs. Beyond performance, this release makes the library more robust, flexible, and significantly easier to debug.
Table Of Contents
Performance
The core of 0.20 is about achieving better performance across all hardware. Most of these gains come from key optimizations within CubeCL and how we handle kernel execution.
CubeK: High-performance multi-platform kernels
The biggest internal change is the introduction of CubeK. This project establishes strict kernel architecture guidelines aimed at minimizing CPU launch latency. Previously, hardware specialization and error handling occurred at every launch; now, we've shifted that work to the just-in-time (JIT) compilation phase. By caching these specializations the first time a kernel is compiled, we significantly reduce the overhead required for the CPU to launch complex kernels.
This shift was primarily motivated by our new Flash Attention implementation and a reworked reduction engine. While we are still tuning the low-level optimizations for full tensor core saturation, the engine already generates functionally correct kernels for both GPUs and CPUs. This infrastructure allows us to generate optimal Flash Attention kernels by automatically selecting the best instructions and tiling dimensions for whichever backend is being used.
We also refactored the CubeCL backend to support dynamic data types with compile-time information, eliminating the heavy macro-based patterns previously used for generic numerical types. For users, this means a cleaner codebase, smaller binaries, and faster compilation times.
CubeCL CPU Overhaul
The CubeCL CPU backend received a major update. It now features proper
lazy execution and the same multi-stream support as our WGPU runtime.
By focusing on cache line alignment and memory coalescing, our kernels
are now outperforming established libraries like
libtorch in several benchmarks. We've also added support for
kernel fusion, which was a missing piece in our previous CPU backends.
As seen in the benchmarks, CubeCL achieves up to a 4x speedup over LibTorch,
with even larger margins compared to SIMD-enabled ndarray. Importantly, these
gains did not require changes to the max_pool2d logic itself;
they stem from a launch strategy that better utilizes the CPU's architecture.
CubeCL kernels are designed to adapt their computation based on the line size provided at launch. This is the real win here. By selecting the optimal line size, cube dimensions, and cube counts specifically for the CPU, we can control exactly how threads map to data without touching the kernel code. We increased the line size to ensure better SIMD vectorization and tuned the cube settings so that data ranges respect physical cache line boundaries. This coordination automatically eliminates cache contention, preventing multiple cores from fighting over the same memory segments, and keeps the underlying logic fully portable and optimal across both GPU and CPU.
The reduction benchmarks show the impact of the CubeK architecture
more clearly. In operations like reduce-argmin on Axis 0,
CubeCL is significantly faster than both ndarray and LibTorch. These results
come from our reworked reduction engine, which aligns memory access with
hardware capabilities automatically. While libraries like LibTorch might
have highly optimized kernels for common ops like
sum, less common operations like argmin
often don't get the same attention. Our engine ensures high performance
across all reduction types by design.
While these results are promising, the CPU backend is still evolving. It isn't fully optimized across every operator yet, specifically, convolution and matrix multiplication need more work before we recommend this as a primary target for CPU production environments.
Next-Gen GPU Support
For high-end GPUs, this release adds support for the Tensor Memory Accelerator (TMA) and inlined PTX for manual Matrix-Multiply Accumulate (MMA) instructions. These changes bring CubeCL closer to the theoretical peak performance of modern silicon. We've adapted our matmul engine to combine TMA with warp specialization, specifically targeting Blackwell-based hardware like the RTX 5090. These foundational improvements also benefit NVIDIA's Ada and Hopper architectures.
On top of these improvements, we optimized batched vector matrix product, often present in LLM inference, to leverage tensor cores with a clever problem transformation, which benefits from our matrix multiplication optimizations.
The resulting benchmarks show our kernels reaching near state-of-the-art performance, matching the industry-standard CUTLASS and cuBLAS libraries found in LibTorch.
Burn Train: Generalizing the Learning Paradigm
Training isn't just about supervised learning anymore. We've redesigned our core abstractions to support a wider range of learning paradigms by decoupling the learner from the feedback providers. This makes the training loop much more extensible for custom research or production needs without making the initial setup more difficult.
let training = SupervisedTraining::new(ARTIFACT_DIR, dataloader_train, dataloader_valid)
.metrics((AccuracyMetric::new(), LossMetric::new()))
.num_epochs(config.num_epochs)
.with_learning_strategy(TrainingStrategy::SingleDevice(device));
let result = training.launch(Learner::new(
model,
config.optimizer.init(),
lr_scheduler.init().unwrap(),
)); This new infrastructure will allow us to add official support for reinforcement learning in the next release, followed potentially by evolutionary optimization, all under the same unified architecture.
General Improvements
Enhanced Error Handling
Lazy and deferred execution are great for performance, but they
historically make debugging a pain. In 0.20, we've introduced
Result-based error propagation. Synchronizing a device now returns a Result<(), Error>, allowing you to catch issues like out-of-memory (OOM) or
compilation errors gracefully. We've extended this model to new
methods like
tensor.try_into_data(), tensor.try_into_scalar(), and transaction.try_execute().
ONNX and Burn-Import
The ONNX importer has been overhauled with a node-centric architecture and full support for the new burnpack format. This enables zero-copy loading using memory-mapped tensor references, making the process of loading large models significantly faster and more memory-efficient.
Migration Guide
Training
We refactored burn-train to better support different abstractions
and custom training strategies. As part of this, the LearnerBuilder has been replaced by the LearningParadigm flow:
- let learner = LearnerBuilder::new(ARTIFACT_DIR)
+ let training = SupervisedTraining::new(ARTIFACT_DIR, dataloader_train, dataloader_valid)
.metrics((AccuracyMetric::new(), LossMetric::new()))
.num_epochs(config.num_epochs)
- .learning_strategy(burn::train::LearningStrategy::SingleDevice(device))
- .build(model, config.optimizer.init(), lr_scheduler.init().unwrap());
+ .summary();
- let result = learner.fit(dataloader_train, dataloader_valid);
+ let result = training.launch(Learner::new(
+ model,
+ config.optimizer.init(),
+ lr_scheduler.init().unwrap(),
+ )); Interface Changes
The scatter and select_assign operations now
require an IndexingUpdateOp to specify the update behavior.
- let output = tensor.scatter(0, indices, values);
+ let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);
API calls for slice, slice_assign, and
slice_fill no longer require const generics for dimensions,
which cleans up the syntax quite a bit:
- let prev_slice = tensor.slice::<[Range<usize>; D]>(slices.try_into().unwrap());
+ let prev_slice = tensor.slice(slices.as_slice());
The grid_sample_2d operation now supports different options.
To preserve the previous behavior, make sure to specify the matching options:
- let output = tensor.grid_sample_2d(grid, InterpolateMode::Bilinear);
+ let options = GridSampleOptions::new(InterpolateMode::Bilinear)
+ .with_padding_mode(GridSamplePaddingMode::Border)
+ .with_align_corners(true);
+ let output = tensor.grid_sample_2d(grid, options);
The QuantStore variants used in QuantScheme have
been updated to support a packing dimension.
pub enum QuantStore {
/// Native quantization doesn't require packing and unpacking.
Native,
+ /// Store packed quantized values in a natively supported packing format (i.e. e2m1x2).
+ PackedNative(usize),
/// Store packed quantized values in a 4-byte unsigned integer.
- U32,
+ PackedU32(usize),
}
Finally, Shape no longer implements IntoIterator. If you need to iterate by-value over dimensions, access the dims
field directly.
- for s in shape {
+ for s in shape.dims {