Adding a New Operation to burn

Let's discuss how one might go about adding new operators to Burn, using the example of the pow operator added in this PR.

Adding the Op to burn-tensor

burn-tensor is the crate that defines all tensor operations that need to be implemented by the various backends. The core of this lies in crates/burn-tensor/src/tensor/api/numeric.rs, which is home to the numeric trait and its implementation for the different tensor types. The numeric trait is the home of all tensor operations that are numeric in nature and that are shared by Int and Float Tensor types. More information on the relationship between Tensor modules can be found under the section for Tensor Architecture.

Here is where pow was added to crates/burn-tensor/src/tensor/api/numeric.rs:

  1. for the Tensor<Backend, Dimension, Kind> struct
  2. for the numeric trait
  3. for the implementation of numeric for float and int

Tensor is a struct that has a single member: primitive (defined here), that is defined by its Kind: one of Bool, Float, or Int (those linked in 3). These call the ops for that data type defined in the Backend supertrait1. This is the trait that is then implemented by the different burn- backends (such as burn-ndarray and burn-wgpu) which must implement the functions if no default is provided.

In this case, we don't need to worry about Bool Tensors. Ops for Float is implemented under crates/burn-tensor/src/tensor/ops/tensor.rs, and for Int under crates/burn-tensor/src/tensor/ops/int_tensor.rs. The current convention is ops of each type, if not unique to that type, are prefixed with the type. So powf and sundry would be defined as int_powf for IntTensorOps and float_powf for FloatTensorOps. If an op is unique to a type, then it should be implemented under burn-tensor/src/api/{type}.rs. For example, here is an implementation for sin under crates/burn-tensor/src/api/float.rs which obviously doesn't make sense for Int or Bool tensors.

The Int Tensor function uses the ones defined for Float with 2 extra casts (LHS to a Float tensor, Output to an Int). Given that the rest of the code will only look at the float implementations.

Adding Tests

Additional Tests should be added to burn-tensor under crates/burn-tensor/src/tests/ops/{op_name}.rs, inserting the module name into crates/burn-tensor/src/tests/ops/mod.rs. Then add it to the testgen_all macro under crates/burn-tensor/src/tests/mod.rs. This macro is called from the lib.rs file in each backend, which autogenerates the tests for that specific backend. It isn't necessary to define tests in the backends directly, save for those that require specific testing such as burn-autodiff.

Adding the Op to burn-autodiff

Since this is probably the hardest and the least straightforward, we'll cover this backend separately. burn-autodiff enables other backends to use autodifferentiation2. Ops for float types are implemented in crates/burn-autodiff/src/ops/tensor.rs and need to:

  1. Define a unit struct 3 that implements a backward (pass) function
  2. Within the backward function, as this is an elementwise binary operation it implements the binary function (from backward.rs under the same directory), the last 2 arguments are two closures that define the left and right partial derivatives.
  3. Then define what happens when a specific operation is tracked or untracked, where untracked just calls the function in the normal way, and tracked sets the execution the backward function defined above.
  4. When tracked, operations are part of the autodiff graph and must save the needed information to efficiently perform their backward pass later. If the information is light (such as a shape), it should be directly saved in the state. If the operation's inputs are needed to compute the backward pass, it should be checkpointed rather than saved. This will allow the input to be provided lazily at the backward pass depending on the checkpointing strategy.
  5. An operation must also be identified as compute-bound (.computeBound()) or memory-bound (.memoryBound()) for gradient checkpointing. Compute-bound operation are heavy to compute (for instance matmul or convolution), which means that even with checkpointing they will save their output for the backward pass and not recompute it. Memory-bound operations are more trivial (like powf which only performs one small operation per tensor entry), so it can be beneficial to recompute them during the backward pass instead of saving their whole forward output to memory. Operations registered as memory-bound need to know their parents (.parents() method) and how to recompute their forward pass during the backward pass (with a struct that implements RetroForward), using their parents' outputs.

The above steps are mostly boilerplate, so you can often just copy the contents of another similar op, change the name of the structs, and ensure that either both sides have the data they need (if they need to have a copy of the opposite sided tensor, clone its contents).

Computing derivatives

For those that need it, here is a quick refresher on the necessary calculus. If you are familiar with how to calculate partial derivatives, you can skip this section.

Since pow is a binary operation, the left and right functions are the partial derivatives with respect to the left and right sided tensors.

Let's define the operator as a function \(f(x,y)=x^{y}\) , where \(x\) is the left hand tensor and \(y\) is the right handed tensor. The two closures are defining the partial derivatives of \(f\) with respect to \(x\),\(y\). Treat the other variables as a constant

$$\frac{\delta }{\delta x} (x^{y})= y \cdot x^{y-1}$$ is the left handed closure, and

$$\frac{\delta }{\delta y} (x^{y}) = x^{y} \cdot ln(x)$$

is the right. If you aren't sure how to calculate these by hand, it is recommended to use symbolab, plug in your operator in terms of \(x\) and \(y\), and just swap out the variable \(x\)|\(y\) in the partial derivative to get the other side.

Testing autodiff

For testing the autodiff operations, please refer to this section.

Adding the Op to other backends

Most of these are fairly straightforward implementations. For reference here's pow's float implementation for torch, ndarray and candle backends:

  1. Torch implementation in crates/burn-tch/src/ops/tensor.rs and the Op used in crates/burn-tch/src/ops/base.rs
  2. NdArray in crates/burn-ndarray/src/ops/tensor.rs
  3. Candle in crates/burn-candle/src/ops/tensor.rs

This is where any calculation happens currently. Playing a guessing game with method names and seeing what completions are suggested will take you far. If you are having trouble figuring out how to do it from the docs for that backend, try searching github for relevant function calls.

Adding the Op to fusion, JIT and cubecl backends

Adding an operator to these backends can be fairly straightforward, though due to what these backends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target backends as much as backends that enable certain functionality for other backends, in this case kernel fusion or just-in-time compilation. Adding the operator won't involve doing any calculation, you'll just be describing how the generated code should look. Most of this can be copy/pasted/adjusted from other functions.

Here's how powf was added to burn-fusion:

  1. Added powf to the float ops under crates/burn-fusion/src/ops/float.rs
  2. Added powf to the NumericOperationDescription enum under crates/burn-fusion/src/stream/operation.rs
  3. Added powf to the implementations of NumericOperationDescription enum under crates/burn-fusion/src/stream/context.rs

The way cubecl handles tensor-scalar operations is by transforming both into a sequence of vectorized scalar operations. Since powf already existed in cubecl, it was pretty easy to reuse the existing implementation for the situation where both sides of the operation were tensors. The cubecl crate is primarily concerned with how the operation is compiled and executed by the gpu. The actual implementation is defined in burn-jit.

Here is where code was added for powf in burn-jit and cubecl:

  1. to the implementation of FloatTensorOps under crates/burn-jit/src/ops/float_ops.rs
  2. the function being called was added to crates/burn-jit/src/ops/numeric.rs
  3. the operator was defined in cubecl-core/src/ir/operation.rs
  4. how the operation looks to the gpu was added to crates/burn-jit/src/fusion/on_write/ir.rs
  5. the mappings between the gpu operation and the CPP, WGSL and SPIR-V instructions were added to cubecl-cpp/src/shared/base.rs, cubecl-wgpu/src/compiler/wgsl/compiler.rs and cubecl-spirv/src/instruction.rs
  6. the instructions themselves were added for WGSL to instruction op enum in cubecl-wgpu/src/compiler/wgsl/instructions.rs, and the actual instruction in wgsl here, for CPP in the enum here cubecl-cpp/src/shared/instruction.rs and the actual instruction here cubecl-cpp/src/shared/binary.rs

We needed to generate some custom WGSL code for powf in WGSL, primarily due to issues with proper case handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an even power being positive. We reused as much as the existing logic as possible, and then branched at the last point based off the var type of the rhs. See here. For most operations, you shouldn't need to add to cubecl-wgpu/src/compiler/wgsl/extension.rs unless the operation isn't native to WGSL.

For functions that need a complex kernel without a direct mapping to a base instruction, simply use the cube macro (see the cubecl book).

Adding the Op to burn-import

Generating the ONNX test files or tests is already covered in the ONNX to burn guide; this is more about the specific changes you need to make when adding new operators after you have generated the tests.

Changes will need to be made to both onnx-ir and burn-import. The code within onnx-ir defines how to parse the nodes in an onnx file and produces the intermediate representation. The code within burn-import is divided into two sections: src/onnx and src/burn. The code under the former maps that intermediate representation to one used for code generation and the latter defines how to generate code for the operator you've implemented earlier in this guide.

So when you are loading a model, the operator is first parsed to an intermediate representation defined by burn-import and then mapped to a Burn operation defined under src/burn/node; the mapping from onnx to burn is aptly defined in src/onnx/to_burn

Let's review the changes made for powf starting from src/burn and moving to src/onnx:

  1. Determine the type of operator and add your operator to the appropriate node (operation) type, in this case BinaryNode under crates/burn-import/src/burn/node/binary.rs along with its to_str definition
  2. Add an arm to the match statement inside the into_burn function in crates/burn-import/src/onnx/to_burn.rs for the ONNX NodeType (which corresponds to an op in the ONNX spec), and make an {op}_conversion function that maps the ONNX node to the binary type
  3. Specify how dimensions for the output should be derived in crates/onnx-ir/src/dim_inference.rs

And you're done! Congrats, you just fully added a new operation to burn, and we are all one step closer to the answer to Are we learning yet? being "Yes, and it's freaking fast!". Buy yourself a coffee.

1

for more on supertraits see the advanced trait section of the rust book

2

wiki link for automatic differentiation

3

for more information on unit structs see the defining and instantiating structs section of the rust book