Optimal Performance without Static Graphs by Fusing Tensor Operation Streams
Introduction
There are three things that are crucial for performance when implementing a Deep Learning framework: reducing to a minimum the movement of data that isn’t strictly required, using the hardware to the fullest with specific instructions for calculations, and reducing the overhead created by the framework itself that is not related to executing tensor operations.
Minimizing the movement of data is actually the most sensitive part of those three since it is tightly linked to the framework architecture, as well as model architectures and other strategies like gradient checkpointing. Using the most efficient instructions of the GPU or CPU isn’t particularly hard, but requires a lot of human effort to implement and optimize most kernels. Finally, reducing the overhead of the framework is probably the easiest part, since lazy evaluation of kernels means that the extra computation done by the framework isn’t blocking the flow of tensor operations and is supposed to be minimal. While it might become the bottleneck for very small networks, it is quite unlikely with the current Deep Learning approach. This explains why current frameworks written in Python such as PyTorch[1] and TensorFlow[2] are still very fast even if Python can be hundreds of times slower than C++ or Rust. If you are interested into going deeper on the subject, there is an excellent blog post[3] that goes over the process of optimizing deep learning models.
Current approaches to minimizing memory movement normally require a static graph, where all information about a model is gathered at compile time, making it possible for a framework to optimize the graph by creating custom kernels and removing the need to create intermediary tensors. The disadvantage is that it's impossible to modify the graph at runtime and execute something else than the specified instructions during the forward pass. In other words, it forces users to configure a graph instead of coding their model. This is why PyTorch has been the most popular framework until now: a programmable Deep Learning framework will always have a better developer experience than a configurable graph.
However, the drawback of an Eager-first framework is its sub-optimal performance. Is it something that can be fixed? Well, yes, and this is what this blog is all about!
Memory Movement
First, we have to understand what the actual problem is! What do I mean by memory movements and why is it important? Let’s start with a simple example: multiple element-wise operations done on tensors. So let’s say you do the following:
let x = Tensor::random([32, 32]);
let y = Tensor::random([32, 32]);
let tmp1 = x.clone() + y.clone();
let tmp2 = tmp1 * x;
let out = tmp2 * y;
println!(“{out}”);
In this small code snippet, we created two tensors x
and y
, added them together to create a new temporary tensor tmp1
, then multiplied the new temporary tensor with x
followed
by y
. In a fully eager mode framework, we would normally
launch one kernel for each operation, while allocating data for the
new tensor created by each operation. In this simple case, we would
allocate x
, y
, tmp1
, tmp2
, and out
. In addition, for each of those new temporary
tensors, a GPU kernel will need to write to the global GPU memory.
Then, for all operations, a kernel must read both the left hand side
and right hand side tensors from the global memory. To summarize:
Tensor Allocations (Bytes) | Global Memory Reads (Bytes) | Global Memory Writes (Bytes) |
---|---|---|
20,480 | 24,576 | 20,480 |
Now, let’s say we want to optimize this operation. We could write our own GPU kernel to fuse the 3 operations together. In this case, we would reduce the number of reads to only 2 and the number of writes to 3 since we remove all temporary tensors. So we would have:
Tensor Allocations (Bytes) | Global Memory Reads (Bytes) | Global Memory Writes (Bytes) |
---|---|---|
12,288 | 8,192 | 12,288 |
This is what we mean when we say a framework should minimize memory movement: fewer reads, fewer writes, and fewer allocations! When a tensor contains millions of elements, most of the time is spent on memory movements instead of actually performing computations, so it’s crucial for a framework to optimize those scenarios.
Tensor Operation Streams
Burn optimizes memory in a very unique way that allows for extremely dynamic models, while keeping optimal performance. Even with edge case scenarios like calling a database inside a model, Burn will still optimize it. But how do we do it?
Burn is built around what we call “Tensor Operation Streams”, where a stream is an endless sequence of tensor operations. This is really different from standard frameworks where the computation is actually stored into a finite graph. In order to optimize a stream, we have to gather all information related to the tensor dynamic lifetime; in other words, we need to know for sure when a tensor is symbolically read-only, or modifiable because it won’t be reused later. This is where Rust is crucial, since it does not rely on the end of the scope to release memory. Instead, it relies on an ownership system, which makes it really easy to capture the dynamic lifetime, using reference counting.
The next step is to capture the stream into an internal representation, detecting patterns that we can optimize, and forwarding that representation into a just-in-time compiler to create a highly optimized kernel.
You might wonder if this introduces any framework overhead, and you are right to ask! Yes, there are some overheads associated with capturing operation streams, ranking optimizations, compiling them, and autotuning [4] the generated code. However, we have a clever multi-level caching system where we exploit the fact that tensor streams have a lot of regularity. Therefore, an optimization on the same stream of operations that were already optimized can be reused, making the overhead extremely light. In the following sections, we will dig deeper into the solution and see how it all works!
Burn Fusion
Although Burn handles multiple streams concurrently based on the device and thread IDs, for simplification we will assume only one stream. Each stream is composed of an operation queue as well as a stream segment processor. A segment is a finite list of operations that can be executed given an execution plan ID and a store, which we will explain in more detail later.
Operation Queue
The operation queue is responsible for keeping an updated list of growing tensor operations. However, since a stream is an endless sequence of operations, the tensor ID and the shapes might be different even if the operations are semantically the same. This is why the queue is actually keeping two versions of the stream: one with relative tensor ID and shapes, and another with the real tensor ID and shapes. The fusion backend is responsible for creating executable operations that work with any shapes, but can still use the runtime shape information to create specialized kernels; it’s just a different level of optimization. It is important to have a shape-agnostic stream representation since it doesn’t impact which operations can be fused.
Segment Processor
The segment processor is where the most interesting stuff happens. Three actions are possible when processing a segment: exploring new potential optimizations, deferring any calculation, or executing an existing optimization. When a model is hot, meaning that all potential optimizations have already been found, the policy will only generate actions to defer the calculation and run an execution plan. This minimizes the overhead significantly, as the only additional tasks involve keeping the operation queue up to date, along with the minor calculations performed by the policy. On the other hand, when the model isn’t hot, we will have to explore new optimizations.
The exploration part is quite simple: the goal is to create an execution plan that we can store a reuse later to avoid doing the exploration again. The execution plan is found by using optimization builders that receive one operation after another until they can’t optimize the segment further. Since we have multiple optimization builders, we simply explore until all of them are finished and choose the best one following a simple scoring algorithm.
Now, the beautiful part about the execution plans is that they are stored in memory for the policy but can also be serialized to disk to reduce cold start. Yes, we thought about this too, knowing that cold start is always a drawback with Just-In-Time compilers. But you might still wonder: how efficient is our policy algorithm for detecting the right execution plan based on a stream segment?
The most straightforward approach would be to keep a key-value map where the key is the hash value of a segment and the value is the execution plan; however, it would require us to keep a version of the hash for each new operation added, which scales really badly with long segments. Instead, we reformulated the problem as invalidating multiple potential execution plans based on the next operation. When starting a new segment, the policy searches in the store for all possible executions where the segment starts with the current operation. Then, as new operations are added to the segment, we invalidate the ones that diverge. Instead of scaling with the number of operations per execution plan, it scales with the number of potential candidate execution plans for the starting point of a segment, which is never supposed to be large at any time.
In the figure above, we see an overview of the most important components of the tensor stream fusion system. We can observe the procedure where the segment processor discovers new optimizations while also applying previously discovered ones through the policy and the explorer. The resulting execution plans are stored and then reused by the operation queue, modifying the execution context which stores the tensor handles. Note that the relative segment of the operation queue will always match the relative segment of an execution plan. Thus, we can also summarize that the policy's job is to find the execution plan which matches the relative segment of the operation queue in the most efficient way.
Benchmark
As always, the real impact of optimizations should be measured empirically. At the time of writing, Burn hasn't implemented all possible optimizations to reduce memory movements. We focused our effort into creating a robust implementation for element-wise operations. This includes most math operators, but excludes matrix multiplications, reductions, convolutions and pooling. We plan to work on those after we stabilize our multi-target Just-In-Time compiler to bring our optimizations to other graphics APIs than WebGPU.
Now let's create a benchmark that uses basic math operators to implement the GELU activation function[5]. To test our fused tensor stream, we will compare GELU implementations using our WebGPU backend[6], both without fusion and with fusion, and the LibTorch CUDA backend[7].
We will compare three GELU implementations: the reference implementation, one using math operators and the error function, and one using a custom implementation of the error function[8] using the high-level tensor API. Below is the code used for the experiment followed by the results.
enum GeluKind {
Reference,
WithReferenceErf,
WithCustomErf,
}
fn gelu<B: Backend, const D: usize>(tensor: Tensor<B, D>, kind: GeluKind) {
match kind {
GeluKind::Reference => burn::tensor::activation::gelu(tensor),
GeluKind::WithReferenceErf => gelu_custom(tensor, Tensor::erf),
GeluKind::WithCustomErf => gelu_custom(tensor, erf_custom),
};
}
fn gelu_custom<B: Backend, const D: usize, Erf>(x: Tensor<B, D>, erf: Erf) -> Tensor<B, D>
where
Erf: Fn(Tensor<B, D>) -> Tensor<B, D>,
{
let x = x.clone() * (erf(x / SQRT_2) + 1);
x / 2
}
fn erf_custom<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let x1 = -erf_positive(-x.clone());
let x2 = erf_positive(x.clone());
let mask = x.greater_elem(0);
x1.mask_where(mask, x2)
}
fn erf_positive<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let p = 0.3275911;
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let x1 = x.clone().abs() * p + 1;
let t = x1.recip();
let tmp = (((((t.clone() * a5) + a4) * t.clone()) + a3) * t.clone() + a2) * t.clone() + a1;
-(tmp * t * (-x.clone() * x).exp()) + 1.0
}
Backend | Reference (ms) | Custom with Erf Reference (ms) | Custom with Custom Erf (ms) |
---|---|---|---|
WGPU Vulkan | 6.96 | 6.91 | 67.23 |
WGPU Vulkan with Fusion | 0.838 | 0.819 | 0.858 |
LibTorch CUDA | 0.653 | 3.86 | 37.89 |
As we can see in the benchmark, the fusion backend improved the execution of the fully custom gelu implementation from 67.23 ms down to 0.858 ms, which is almost as fast as the highly optimized reference implementation used in LibTorch and over 78 times faster than the fully eager execution using the same graphics API (Vulkan). Of course, this is a synthetic benchmark and nobody in their right mind would use a gelu implementation with a custom error function approximation, but it just goes to show that you wouldn't need to write your own custom GPU kernel if you wanted to research a new activation function. It also shows that our tensor stream approach effectively creates custom GPU kernels without adding too much overhead that would affect the speed of execution of operations.
Conclusion
In this blog post, we summarized our tensor operation stream strategy to create highly optimized kernels based on a fully eager API. However, we didn’t go through how the compiler and runtime are built in Burn. This will be covered in a following blog post to detail how we can leverage runtime and compile-time information in our just-in-time compiler to create highly optimized kernels specialized for the current hardware in use. We believe that creating a programmable and flexible framework is crucial for research and applied AI. This is important not only to unlock new model architectures but also to achieve optimal performance. Some optimizations are very dynamic in nature, like the one explored in "BranchyNet: Fast Inference via Early Exiting from Deep Neural Networks"[9] and others [10,11,12] that gives the control back to the model to modify its computation graph dynamically, reducing the total number of computations required. That kind of optimization is pretty hard to achieve, if even possible, with a static graph-focused framework, while totally intuitive with an eager API. The pursuit of dynamic optimization techniques underscores the importance of flexible frameworks in advancing the capabilities of AI models, a vision we are firmly dedicated to achieve with Burn.