Reduced Memory Usage: Burn's Rusty Approach to Tensor Handling
Introduction
The latest release of Burn [1] includes significant changes to its memory management strategy. One of the most notable changes is that tensor-allocated memory can now be reused way more often. This is a big improvement, as every operation was previously implemented using read-only tensor references, which often resulted in unnecessary memory allocations. Overall, these changes significantly reduced memory usage, especially on the CPU compared to PyTorch.
The new approach motivated a complete rewrite of the auto differentiation engine of Burn, which allows any backend to support backpropagation. The previous implementation relied on object-oriented programming patterns, resulting in excessive indirections and memory allocations. Moreover, the engine was read-only, limiting its ability to free up unused tensors or reuse them when possible during the backward pass . The new implementation addresses these limitations by adopting more efficient rusty patterns, leveraging the updated backend API, and enabling in-place operations.
This is mostly a technical blog. If you're not that interested in memory management, you can just skip to the Benchmarks section.
In-place Operations
How does Burn enable the reuse of tensor-allocated memory? One way is through a pattern similar to copy-on-write [2], where data can only be safely written when there is a single reference pointing to it. If there are multiple references, the original data is copied before being written, hence the name copy-on-write. The distinction in how Burn reuses tensor-allocated memory is as follows: instead of copying data when there are multiple references using a tensor, the normal (non in-place) operation is used instead. Here's an example of how the log function is implemented using the tch backend [3] (bindings to LibTorch).
fn log<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(
// When the tensor is safe to mutate in-place.
|mut tensor| tensor.log_(),
// When the tensor is not safe to mutate in-place.
|tensor| tensor.log(),
)
}
In this example, we call the in-place log function when possible, and the normal function otherwise. Note that all functions ending with an underscore are in-place operations with PyTorch/LibTorch [4]. The unary_ops operation abstracts how the number of references is counted, using Atomic Reference Counting (Arc) provided by the standard library [5]. This behavior is also available with the alloc crate when working with no_std, which is useful for environments where there is no operating system available [6,7]. The non-thread-safe variant, Reference Counting (Rc), also provides the same functionality [8]. Normally, this pattern should be used with try_unwrap and get_mut functions provided by Arc and Rc, which safely return a mutable or an owned value of the inner type [9,10]. However, the PyTorch bindings are not really safe. You can call mutable operations on any tensor handle, even read-only references, by just doing a shallow copy. Hence, we track the memory location of each tensor's storage manually instead.
pub fn unary_ops<FOwn, FRef, EOut: tch::kind::Element, const D_OUT: usize>(
self,
fown: FOwn,
fref: FRef,
) -> TchTensor<EOut, D_OUT>
where
FOwn: Fn(tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor) -> tch::Tensor,
{
// We check if there are multiple tensors pointing the the same storage
if Arc::strong_count(&self.storage) > 1 {
// If this is the case, the non-in-place function is called
return TchTensor::from_existing(fref(&self.tensor), self.storage);
}
// Only the current tensor is pointing to the provided storage space
// Since the tensor will never be reused, we can safely call the owned
// function, which may dispatch to an in-place operation.
TchTensor::from_existing(fown(self.tensor), self.storage)
}
This strategy can also be used for binary operations, enabling other kinds of optimizations. In some cases, you may call another LibTorch function depending on the number of references pointing to the input tensors, in order to reuse as much tensor-allocated memory as possible. Here's an example with the lower implementation.
pub fn lower<const D: usize>(
lhs: TchTensor<E, D>,
rhs: TchTensor<E, D>
) -> TchTensor<bool, D> {
TchTensor::binary_ops_tensor(
lhs,
rhs,
// When the lhs tensor is safe to mutate.
|lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool),
// When the rhs tensor is safe to mutate, but not lhs.
|lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool),
// When both tensor are not safe to mutate
|lhs, rhs| lhs.less_tensor(rhs),
)
}
In this case, when it is safe to mutate the left-hand side (LHS) tensor, we call the in-place operation to reuse its data. However, when there is at least one other reference to that tensor, it is not safe to mutate. In this case, we could still reuse tensor-allocated memory by calling the in-place greater operation on the right-hand side (RHS) tensor instead. This produces the same output but is more efficient. Note that this assumes that boolean tensors and float tensors can reuse the same memory space, which may depend on the float data type.
Tensor API
Unfortunately, this pattern was previously impossible to integrate with Burn because all operations received references to tensors as arguments and not owned tensors, which doesn't increase the strong reference count.
To allow backends to use this pattern, the API has been updated to receive owned tensors as parameters, with some nice quality-of-life improvements as well. The consequence is that tensors are cloned exactly the number of times they are reused. This makes it easy for users to optimize their code by removing unnecessary cloning when it's not required. Note that clippy [11] normally checks for unnecessary cloning, but is not perfect. You might change the order in which you do your operations to reduce the amount of cloning, something that clippy can't check.
There is no API in Burn to call in-place operations directly on tensors. If the backend supports it, every time it is possible, in-place operations will be used. This is a major quality-of-life improvement, since the places where in-place operations can be used differ between training and inference, but Burn aims to provide the most optimized code for both use cases. Let's reuse the log function as an example.
During inference, a temporary tensor, which is not a model parameter, is used one time with the log operation, which will reuse the tensor-allocated memory since the tensor was never cloned. However, during training, the backward step of the log function needs the input tensor to calculate its derivative, so the input tensor is cloned during the forward pass. Therefore, the input tensor will be left unchanged during the forward pass, but may be reused during the backward pass.
Autodiff
The method for calculating gradients with Burn was highly inefficient but offered significant flexibility. After learning from the first implementation, I decided to rewrite it from scratch. Although it seemed like a daunting task at first, I was able to complete it in less than a week without any breaking changes to the API. Burn is designed to allow for this kind of refactoring and continuous improvement of performance and architecture. The primary goal was to reduce unnecessary cloning and memory allocations while simplifying the complex and difficult-to-understand code patterns. Additionally, it was important to make it easy to support new operations of any kind, which presented challenges in terms of flexibility, simplicity, and minimizing repetitive code. To demonstrate this, consider the implementation of the cosine function.
fn cos<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
// Define a struct for static dispatch
#[derive(Debug)]
struct Cos;
impl<B: Backend, const D: usize> Backward<B, D, 1> for Cos {
// Define the state to capture during the foward pass
type State = B::TensorPrimitive<D>;
// Code that is executed during the backward pass
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let input = ops.state;
// Calculate the derivative with respect to its parent
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let value = B::neg(B::sin(input));
B::mul(grad, value)
});
}
}
// Prepare a statefull operation
match Cos.prepare([tensor.node], [tensor.graph]).statefull() {
// Executes when the tensor is tracked
OpsKind::Tracked(prep) => {
// Finish the preparation capturing the state
// The input tensor is cloned for the backward pass
prep.finish(tensor.primitive.clone(), B::cos(tensor.primitive))
}
// Executes when the tensor is not part of the autodiff graph
// The cos operation is called without any cloning
OpsKind::UnTracked(prep) => prep.finish(B::cos(tensor.primitive)),
}
}
The cosine function definition is the same as any backend since gradients are calculated using a backend decorator. The implementation uses static dispatch via a zero-sized struct named Cos, which implements the Backward trait. During the backward pass, the derivative with respect to the parent node is calculated using the chain rule of differentiation. The function supports both tracked and untracked operations, with the former requiring cloning of the input tensor for use in the backward pass. However, sometimes operations don't require any state during the backward pass. Let's see how it's done with the scalar addition function.
fn add_scalar<const D: usize>(
lhs: ADTensor<B, D>,
rhs: FloatElem<B>,
) -> ADTensor<B, D> {
// Define a struct for static dispatch
#[derive(Debug)]
struct AddScalar;
impl<B: Backend, const D: usize> Backward<B, D, 1> for AddScalar {
type State = ();
// Code that is executed during the backward pass
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| grad);
}
}
// Simpler definition where no match are required.
AddScalar
.prepare([lhs.node], [lhs.graph])
.stateless(B::add_scalar(lhs.primitive, rhs))
}
Similar to the cosine function, the scalar addition function also supports tracked and untracked operations. However, since scalar addition doesn't require any state during the backward pass, the implementation is simpler, and no match statements are required.
Benchmarks
Even though the last release was focused on structural refactors to allow for more optimizations and control from backend implementations, it's still interesting to see how it compares to other frameworks. So let's compare it to PyTorch for simple use cases.
Disclamer
It's important to note that Burn doesn't support fused operations, even for popular activation functions like softmax, gelu, and sigmoid. Furthermore, all the derivatives of each primitive operation that are calculated during the backward pass also use primitive operations and are not fused. This lack of operation fusion has a significant impact on real-world performance. As a result, PyTorch is likely to be faster for common models, at least on the GPU, where the impact of operation fusion is more pronounced
Burn's development direction is likely to differ from other frameworks. It's unfortunate that writing mathematical operations in a more declarative way can be less performant than using a high-level function with a highly optimized kernel implementation. Ideally, Burn should be able to detect when such a kernel exists and use it automatically, without requiring any changes to the code. This is the kind of developer experience that Burn aims to provide: enabling users to write mathematical operations using primitives while allowing backend developers to declare graphs of operations that can be fused for optimal performance. Additionally, Burn should allow users to profile their models, identify which functions take the most time, and write optimized kernels for those functions using their preferred backend without the need to fork a framework, rewrite the model, or change programming language. All of this should be achievable while still supporting fully dynamic graphs and custom control flow with an eager-like programming model. However, these are significant constraints, and it will require a lot of thinking and hard work to make this a reality. If you have any comments, suggestions, or recommendations regarding fused operations, we invite you to join our Discord and share your thoughts.
Now let get into the benchmarks.
Softmax
The first benchmark is a custom implementation of the softmax activation function. For numerical stability, we will use an implementation that uses log softmax.
burn::tensor::{backend::Backend, Tensor};
fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
log_softmax(tensor, dim).exp()
}
fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
tensor.clone() - tensor.exp().sum_dim(dim).log()
}
Now, let’s compare it to the equivalent code in PyTorch.
import torch
from torch import Tensor
def softmax(tensor: Tensor, dim: int) -> Tensor:
return log_softmax(tensor, dim).exp()
def log_softmax(tensor: Tensor, dim: int) -> Tensor:
return tensor - tensor.exp().sum(dim=dim, keepdim=True).log()
The main difference in the code is the extra typing in Burn, which specifies the number of dimensions a tensor has and the backend it runs on. However, we can soften the notation by moving the generic argument declaration into a zero-sized struct, which groups functions that operate on tensors of the same type. The other difference is that we need an explicit clone in the Burn version because the tensor is reused twice. During inference, we expect `tensor.clone() - tensor.exp()` to not be executed in-place since they all use the same tensor. However, we expect all other operations to reuse the same memory to avoid unnecessary allocations.
The tests were performed on my laptop, so they may not be fully reliable, but they are still informative in terms of what kind of performance we can expect.
Inference Memory | Inference Speed | Autodiff Memory | Autodiff Speed | |
---|---|---|---|---|
CPU - PyTorch | 586 M | 47.39 ms | 980 M | 111.8 ms |
CPU - Burn | 353 M | 34.25 ms | 1047 M | 146.93 ms |
GPU - PyTorch | 852 M | 1.474 ms | 980 M | 4.103 ms |
GPU - Burn | 756 M | 1.479 ms | 1076 M | 5.365 ms |
An interesting takeaway here is that Burn seems to be much faster during inference on the CPU, while pretty comparable on the GPU. In all cases, it seems to take less memory, with a bigger difference on the CPU. This may be because PyTorch has taken great care of their GPU implementation and may have something similar to a memory pool or other tricks to handle GPU memory more efficiently.
We also see PyTorch appears to be faster and requires less memory to compute gradients. This could be due to their backward implementation of each tensor operation executing fewer kernels. In the case of Burn, the logarithm backward implementation uses two kernels, and the sum backward allocates a new tensor while using two other kernels. To test that hypothesis, we will have to do another set of benchmarks where we reduce the difference in the number of kernels executed by both frameworks.
Multi layer perceptron (MLP)
The second set of benchmarks is pretty simple. We are going to compare a simple Multi-layer perceptron implementation using a simple linear layer with the ReLU activation function. Even if the ReLU backward implementation of Burn uses two kernels, the difference is smaller than in the previous experiment, and we should see a smaller difference in execution time when calculating gradients.
Inference Memory | Inference Speed | Autodiff Memory | Autodiff Speed | |
---|---|---|---|---|
CPU - PyTorch | 433 M | 22.765 ms | 708 M | 76.85 ms |
CPU - Burn | 385 M | 22.695 ms | 576 M | 80.429 ms |
GPU - PyTorch | 1190 M | 0.8474 ms | 1204 M | 3.2708 ms |
GPU - Burn | 1096 M | 0.8042 ms | 1222 M | 2.4874 ms |
The results are pretty much what I expected, except for the fact that Burn is considerably faster on the GPU when computing gradients. This could be explained by how the ReLU backward is implemented in Burn, which would be faster in that specific use case. However, I would not conclude that Burn is faster in training MLP, and I would be really careful in coming up with conclusions; it might just be an outlier data point.
For the other data points, we see that Burn and PyTorch are generally similar in execution time, but Burn uses less memory, especially on the CPU. I didn't include the MLP implementation of both frameworks, but you can have a look at it on the repository .
Conclusion
So this is the end of this blog post. I presented how Burn leverages Rust’s ownership tracking to safely reuse tensor-allocated memory when possible. I presented the necessary changes to the tensor API and the autodiff backend to leverage owned tensor to reduce memory usage. Some small benchmarks were made to validate the effectiveness of that strategy, which showed consistent reduced memory usage, especially on the CPU. However, we also saw the necessity of operation fusion to really speed up computation, and it may explain why it’s a major focus of PyTorch 2.0 [12] with their new graph compilation capabilities. The next phase of Burn will be on stabilizing the API, improving the docs, and making the project easier to use overall. After that, it will be crucial to focus on operation fusion and come up with a strategy that respects all the previously mentioned wishes.
Note that this is not a full report of what has been accomplished since the last release. A lot of work has been done by contributors, and Burn can now be compiled to Web Assembly for inference, which runs natively on browsers on the client side. You can test it yourself with the online demo . I want to thank everybody that got involved with the project, I received so much constructive feedback that has or will definitively improve Burn. It's also always interesting to see what kind of project, research, or product Burn can help you with, so don't hesitate to reach out if you find value in what we are building.