Autodiff

Burn's tensor also supports auto-differentiation, which is an essential part of any deep learning framework. We introduced the Backend trait in the previous section, but Burn also has another trait for autodiff: AutodiffBackend.

However, not all tensors support auto-differentiation; you need a backend that implements both the Backend and AutodiffBackend traits. Fortunately, you can add auto-differentiation capabilities to any backend using a backend decorator: type MyAutodiffBackend = Autodiff<MyBackend>. This decorator implements both the AutodiffBackend and Backend traits by maintaining a dynamic computational graph and utilizing the inner backend to execute tensor operations.

The AutodiffBackend trait adds new operations on float tensors that can't be called otherwise. It also provides a new associated type, B::Gradients, where each calculated gradient resides.

fn calculate_gradients<B: AutodiffBackend>(tensor: Tensor<B, 2>) -> B::Gradients {
    let mut gradients = tensor.clone().backward();

    let tensor_grad = tensor.grad(&gradients);        // get
    let tensor_grad = tensor.grad_remove(&mut gradients); // pop

    gradients
}

Note that some functions will always be available even if the backend doesn't implement the AutodiffBackend trait. In such cases, those functions will do nothing.

Burn APIPyTorch Equivalent
tensor.detach()tensor.detach()
tensor.require_grad()tensor.requires_grad()
tensor.is_require_grad()tensor.requires_grad
tensor.set_require_grad(require_grad)tensor.requires_grad(False)

However, you're unlikely to make any mistakes since you can't call backward on a tensor that is on a backend that doesn't implement AutodiffBackend. Additionally, you can't retrieve the gradient of a tensor without an autodiff backend.

Difference with PyTorch

The way Burn handles gradients is different from PyTorch. First, when calling backward, each parameter doesn't have its grad field updated. Instead, the backward pass returns all the calculated gradients in a container. This approach offers numerous benefits, such as the ability to easily send gradients to other threads.

You can also retrieve the gradient for a specific parameter using the grad method on a tensor. Since this method takes the gradients as input, it's hard to forget to call backward beforehand. Note that sometimes, using grad_remove can improve performance by allowing inplace operations.

In PyTorch, when you don't need gradients for inference or validation, you typically need to scope your code using a block.

# Inference mode
torch.inference():
   # your code
   ...

# Or no grad
torch.no_grad():
   # your code
   ...

With Burn, you don't need to wrap the backend with the Autodiff for inference, and you can call inner() to obtain the inner tensor, which is useful for validation.

/// Use `B: AutodiffBackend`
fn example_validation<B: AutodiffBackend>(tensor: Tensor<B, 2>) {
    let inner_tensor: Tensor<B::InnerBackend, 2> = tensor.inner();
    let _ = inner_tensor + 5;
}

/// Use `B: Backend`
fn example_inference<B: Backend>(tensor: Tensor<B, 2>) {
    let _ = tensor + 5;
    ...
}

Gradients with Optimizers

We've seen how gradients can be used with tensors, but the process is a bit different when working with optimizers from burn-core. To work with the Module trait, a translation step is required to link tensor parameters with their gradients. This step is necessary to easily support gradient accumulation and training on multiple devices, where each module can be forked and run on different devices in parallel. We'll explore deeper into this topic in the Module section.