Adding a New Operation to burn
Let's discuss how one might go about adding new operators to Burn, using the example of the pow operator added in this PR.
Adding the Op to burn-tensor
burn-tensor
is the crate that defines all tensor operations that need to be implemented by the
various backends. The core of this lies in
crates/burn-tensor/src/tensor/api/numeric.rs,
which is home to the numeric trait and its implementation for the different tensor types. The
numeric trait is the home of all tensor operations that are numeric in nature and that are shared by
Int
and Float
Tensor types. More information on the relationship between Tensor modules can be
found under the section for Tensor Architecture.
Here is where pow was added to crates/burn-tensor/src/tensor/api/numeric.rs
:
- for the
Tensor<Backend, Dimension, Kind>
struct - for the numeric trait
- for the implementation of numeric for float and int
Tensor is a struct that has a single member: primitive
(defined
here),
that is defined by its
Kind
:
one of Bool
, Float
, or Int
(those linked in 3). These call the ops for that data type defined
in the
Backend
supertrait1. This is the trait that is then implemented by the different burn-
backends (such as burn-ndarray
and burn-wgpu
) which must implement the functions if no default
is provided.
In this case, we don't need to worry about Bool
Tensors. Ops for Float
is implemented under
crates/burn-tensor/src/tensor/ops/tensor.rs
,
and for Int
under
crates/burn-tensor/src/tensor/ops/int_tensor.rs
.
The current convention is ops of each type, if not unique to that type, are prefixed with the type.
So powf
and sundry would be defined as int_powf
for IntTensorOps
and float_powf
for
FloatTensorOps
. If an op is unique to a type, then it should be implemented under
burn-tensor/src/api/{type}.rs
. For example, here is an implementation for
sin
under crates/burn-tensor/src/api/float.rs
which obviously doesn't make sense for Int
or Bool
tensors.
The Int
Tensor function uses the ones defined for Float with 2 extra casts (LHS to a Float
tensor, Output to an Int
). Given that the rest of the code will only look at the float
implementations.
Adding Tests
Additional Tests should be added to burn-tensor
under
crates/burn-tensor/src/tests/ops/{op_name}.rs
,
inserting the module name into crates/burn-tensor/src/tests/ops/mod.rs
. Then add it to the
testgen_all
macro under crates/burn-tensor/src/tests/mod.rs
. This macro is called from the
lib.rs
file in each backend, which autogenerates the tests for that specific backend. It isn't
necessary to define tests in the backends directly, save for those that require specific testing
such as burn-autodiff
.
Adding the Op to burn-autodiff
Since this is probably the hardest and the least straightforward, we'll cover this backend
separately. burn-autodiff
enables other backends to use autodifferentiation2. Ops for
float types are implemented in
crates/burn-autodiff/src/ops/tensor.rs
and need to:
- Define a unit struct 3 that implements a backward (pass) function
- Within the backward function, as this is an elementwise binary operation it implements the binary
function (from
backward.rs
under the same directory), the last 2 arguments are two closures that define the left and right partial derivatives. - Then define what happens when a specific operation is tracked or untracked, where untracked just calls the function in the normal way, and tracked sets the execution the backward function defined above.
- When tracked, operations are part of the autodiff graph and must save the needed information to efficiently perform their backward pass later. If the information is light (such as a shape), it should be directly saved in the state. If the operation's inputs are needed to compute the backward pass, it should be checkpointed rather than saved. This will allow the input to be provided lazily at the backward pass depending on the checkpointing strategy.
- An operation must also be identified as compute-bound (
.computeBound()
) or memory-bound (.memoryBound()
) for gradient checkpointing. Compute-bound operation are heavy to compute (for instance matmul or convolution), which means that even with checkpointing they will save their output for the backward pass and not recompute it. Memory-bound operations are more trivial (likepowf
which only performs one small operation per tensor entry), so it can be beneficial to recompute them during the backward pass instead of saving their whole forward output to memory. Operations registered as memory-bound need to know their parents (.parents()
method) and how to recompute their forward pass during the backward pass (with a struct that implementsRetroForward
), using their parents' outputs.
The above steps are mostly boilerplate, so you can often just copy the contents of another similar op, change the name of the structs, and ensure that either both sides have the data they need (if they need to have a copy of the opposite sided tensor, clone its contents).
Computing derivatives
For those that need it, here is a quick refresher on the necessary calculus. If you are familiar with how to calculate partial derivatives, you can skip this section.
Since pow
is a binary operation, the left and right functions are the partial derivatives with
respect to the left and right sided tensors.
Let's define the operator as a function \(f(x,y)=x^{y}\) , where \(x\) is the left hand tensor and \(y\) is the right handed tensor. The two closures are defining the partial derivatives of \(f\) with respect to \(x\),\(y\). Treat the other variables as a constant
$$\frac{\delta }{\delta x} (x^{y})= y \cdot x^{y-1}$$ is the left handed closure, and
$$\frac{\delta }{\delta y} (x^{y}) = x^{y} \cdot ln(x)$$
is the right. If you aren't sure how to calculate these by hand, it is recommended to use symbolab, plug in your operator in terms of \(x\) and \(y\), and just swap out the variable \(x\)|\(y\) in the partial derivative to get the other side.
Testing autodiff
For testing the autodiff
operations, please refer to
this section.
Adding the Op to other backends
Most of these are fairly straightforward implementations. For reference here's pow's float implementation for torch, ndarray and candle backends:
- Torch implementation in crates/burn-tch/src/ops/tensor.rs and the Op used in crates/burn-tch/src/ops/base.rs
- NdArray in crates/burn-ndarray/src/ops/tensor.rs
- Candle in crates/burn-candle/src/ops/tensor.rs
This is where any calculation happens currently. Playing a guessing game with method names and seeing what completions are suggested will take you far. If you are having trouble figuring out how to do it from the docs for that backend, try searching github for relevant function calls.
Adding the Op to fusion, JIT and cubecl backends
Adding an operator to these backends can be fairly straightforward, though due to what these backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target backends as much as backends that enable certain functionality for other backends, in this case kernel fusion or just-in-time compilation. Adding the operator won't involve doing any calculation, you'll just be describing how the generated code should look. Most of this can be copy/pasted/adjusted from other functions.
Here's how powf was added to burn-fusion
:
- Added powf to the float ops under
crates/burn-fusion/src/ops/float.rs
- Added powf to the
NumericOperationDescription
enum under crates/burn-fusion/src/stream/operation.rs - Added powf to the implementations of
NumericOperationDescription
enum under crates/burn-fusion/src/stream/context.rs
The way cubecl
handles tensor-scalar operations is by transforming both into a sequence of
vectorized scalar operations. Since powf already existed in cubecl
, it was pretty easy to reuse
the existing implementation for the situation where both sides of the operation were tensors. The
cubecl
crate is primarily concerned with how the operation is compiled and executed by the gpu.
The actual implementation is defined in burn-jit
.
Here is where code was added for powf in burn-jit
and cubecl
:
- to the implementation of
FloatTensorOps
undercrates/burn-jit/src/ops/float_ops.rs
- the function being called was added to crates/burn-jit/src/ops/numeric.rs
- the operator was defined in
cubecl-core/src/ir/operation.rs
- how the operation looks to the gpu was added to
crates/burn-jit/src/fusion/on_write/ir.rs
- the mappings between the gpu operation and the CPP, WGSL and SPIR-V instructions were added to
cubecl-cpp/src/shared/base.rs
,cubecl-wgpu/src/compiler/wgsl/compiler.rs
andcubecl-spirv/src/instruction.rs
- the instructions themselves were added for WGSL to
instruction op enum in
cubecl-wgpu/src/compiler/wgsl/instructions.rs
, and the actual instruction in wgsl here, for CPP in the enum herecubecl-cpp/src/shared/instruction.rs
and the actual instruction herecubecl-cpp/src/shared/binary.rs
We needed to generate some custom WGSL code for powf in WGSL, primarily due to issues with proper
case handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an
even power being positive. We reused as much as the existing logic as possible, and then branched at
the last point based off the var type of the rhs.
See here.
For most operations, you shouldn't need to add to cubecl-wgpu/src/compiler/wgsl/extension.rs
unless the operation isn't native to WGSL.
For functions that need a complex kernel without a direct mapping to a base instruction, simply use
the cube
macro (see
the cubecl
book).
Adding the Op to burn-import
Generating the ONNX test files or tests is already covered in the ONNX to burn guide; this is more about the specific changes you need to make when adding new operators after you have generated the tests.
Changes will need to be made to both onnx-ir
and burn-import
. The code within onnx-ir
defines
how to parse the nodes in an onnx file and produces the intermediate representation. The code within
burn-import
is divided into two sections: src/onnx
and src/burn
. The code under the former
maps that intermediate representation to one used for code generation and the latter defines how to
generate code for the operator you've implemented earlier in this guide.
So when you are loading a model, the operator is first parsed to an intermediate representation
defined by burn-import
and then mapped to a Burn operation defined under src/burn/node
; the
mapping from onnx to burn is aptly defined in src/onnx/to_burn
Let's review the changes made for powf starting from src/burn
and moving to src/onnx
:
- Determine the type of operator and add your operator to the appropriate node (operation) type, in
this case
BinaryNode under
crates/burn-import/src/burn/node/binary.rs
along with itsto_str
definition - Add an arm to the match statement inside the
into_burn
function in crates/burn-import/src/onnx/to_burn.rs for the ONNXNodeType
(which corresponds to an op in the ONNX spec), and make an{op}_conversion
function that maps the ONNX node to the binary type - Specify how dimensions for the output should be derived in crates/onnx-ir/src/dim_inference.rs
And you're done! Congrats, you just fully added a new operation to burn, and we are all one step closer to the answer to Are we learning yet? being "Yes, and it's freaking fast!". Buy yourself a coffee.
for more on supertraits see the advanced trait section of the rust book
wiki link for automatic differentiation
for more information on unit structs see the defining and instantiating structs section of the rust book