Tensor

As previously explained in the model section, the Tensor struct has 3 generic arguments: the backend B, the dimensionality D, and the data type.

Tensor<B, D>           // Float tensor (default)
Tensor<B, D, Float>    // Explicit float tensor
Tensor<B, D, Int>      // Int tensor
Tensor<B, D, Bool>     // Bool tensor

Note that the specific element types used for Float, Int, and Bool tensors are defined by backend implementations.

Burn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape. The actual shape of the tensor is inferred from its initialization. For example, a Tensor of size (5,) is initialized as below:

let floats = [1.0, 2.0, 3.0, 4.0, 5.0];

// Get the default device
let device = Default::default();

// correct: Tensor is 1-Dimensional with 5 elements
let tensor_1 = Tensor::<Backend, 1>::from_floats(floats, &device);

// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats(floats, &device);
// this will lead to an error and is for creating a 5-D tensor

Initialization

Burn Tensors are primarily initialized using the from_data() method which takes the Data struct as input. The Data struct has two fields: value & shape. To retrieve the data from a tensor, the method .to_data() should be employed when intending to reuse the tensor afterward. Alternatively, .into_data() is recommended for one-time use. Let's look at a couple of examples for initializing a tensor from different inputs.


// Initialization from a given Backend (Wgpu)
let tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0], &device);

// Initialization from a generic Backend
let tensor_2 = Tensor::<Backend, 1>::from_data(Data::from([1.0, 2.0, 3.0]).convert(), &device);

// Initialization using from_floats (Recommended for f32 ElementType)
// Will be converted to Data internally.
// `.convert()` not needed as from_floats() defined for fixed ElementType
let tensor_3 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0], &device);

// Initialization of Int Tensor from array slices
let arr: [i32; 6] = [1, 2, 3, 4, 5, 6];
let tensor_4 = Tensor::<Backend, 1, Int>::from_data(Data::from(&arr[0..3]).convert(), &device);

// Initialization from a custom type

struct BodyMetrics {
    age: i8,
    height: i16,
    weight: f32
}

let bmi = BodyMetrics{
        age: 25,
        height: 180,
        weight: 80.0
    };
let data  = Data::from([bmi.age as f32, bmi.height as f32, bmi.weight]).convert();
let tensor_5 = Tensor::<Backend, 1>::from_data(data, &device);

The .convert() method for Data struct is called to ensure that the data's primitive type is consistent across all backends. With .from_floats() method the ElementType is fixed as f32 and therefore no convert operation is required across backends. This operation can also be done at element wise level as: let tensor_6 = Tensor::<B, 1, Int>::from_data(Data::from([(item.age as i64).elem()]). The ElementConversion trait however needs to be imported for the element wise operation.

Ownership and Cloning

Almost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple times will necessitate cloning it. Let's look at an example to understand the ownership rules and cloning better. Suppose we want to do a simple min-max normalization of an input tensor.

let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let min = input.min();
let max = input.max();
let input = (input - min).div(max - min);

With PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules will give an error and prevent using the input tensor after the first .min() operation. The ownership of the input tensor is transferred to the variable min and the input tensor is no longer available for further operations. Burn Tensors like most complex primitives do not implement the Copy trait and therefore have to be cloned explicitly. Now let's rewrite a working example of doing min-max normalization with cloning.

let input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let min = input.clone().min();
let max = input.clone().max();
let input = (input.clone() - min.clone()).div(max - min);
println!("{:?}", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0]

// Notice that max, min have been moved in last operation so
// the below print will give an error.
// If we want to use them for further operations,
// they will need to be cloned in similar fashion.
// println!("{:?}", min.to_data());

We don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't copied, and only a reference to it is increased. This makes it possible to determine exactly how many times a tensor is used, which is very convenient for reusing tensor buffers or even fusing operations into a single kernel (burn-fusion). For that reason, we don't provide explicit inplace operations. If a tensor is used only one time, inplace operations will always be used when available.

Tensor Operations

Normally with PyTorch, explicit inplace operations aren't supported during the backward pass, making them useful only for data preprocessing or inference-only model implementations. With Burn, you can focus more on what the model should do, rather than on how to do it. We take the responsibility of making your code run as fast as possible during training as well as inference. The same principles apply to broadcasting; all operations support broadcasting unless specified otherwise.

Here, we provide a list of all supported operations along with their PyTorch equivalents. Note that for the sake of simplicity, we ignore type signatures. For more details, refer to the full documentation.

Basic Operations

Those operations are available for all tensor kinds: Int, Float, and Bool.

BurnPyTorch Equivalent
Tensor::cat(tensors, dim)torch.cat(tensors, dim)
Tensor::empty(shape, device)torch.empty(shape, device=device)
Tensor::from_primitive(primitive)N/A
Tensor::stack(tensors, dim)torch.stack(tensors, dim)
tensor.all()tensor.all()
tensor.all_dim(dim)tensor.all(dim)
tensor.any()tensor.any()
tensor.any_dim(dim)tensor.any(dim)
tensor.chunk(num_chunks, dim)tensor.chunk(num_chunks, dim)
tensor.device()tensor.device
tensor.dims()tensor.size()
tensor.equal(other)x == y
tensor.expand(shape)tensor.expand(shape)
tensor.flatten(start_dim, end_dim)tensor.flatten(start_dim, end_dim)
tensor.flip(axes)tensor.flip(axes)
tensor.into_data()N/A
tensor.into_primitive()N/A
tensor.into_scalar()tensor.item()
tensor.narrow(dim, start, length)tensor.narrow(dim, start, length)
tensor.not_equal(other)x != y
tensor.permute(axes)tensor.permute(axes)
tensor.repeat(2, 4)tensor.repeat([1, 1, 4])
tensor.reshape(shape)tensor.view(shape)
tensor.shape()tensor.shape
tensor.slice(ranges)tensor[(*ranges,)]
tensor.slice_assign(ranges, values)tensor[(*ranges,)] = values
tensor.squeeze(dim)tensor.squeeze(dim)
tensor.to_data()N/A
tensor.to_device(device)tensor.to(device)
tensor.unsqueeze()tensor.unsqueeze(0)
tensor.unsqueeze_dim(dim)tensor.unsqueeze(dim)

Numeric Operations

Those operations are available for numeric tensor kinds: Float and Int.

BurnPyTorch Equivalent
Tensor::eye(size, device)torch.eye(size, device=device)
Tensor::full(shape, fill_value, device)torch.full(shape, fill_value, device=device)
Tensor::ones(shape, device)torch.ones(shape, device=device)
Tensor::zeros(shape)torch.zeros(shape)
Tensor::zeros(shape, device)torch.zeros(shape, device=device)
tensor.abs()torch.abs(tensor)
tensor.add(other) or tensor + othertensor + other
tensor.add_scalar(scalar) or tensor + scalartensor + scalar
tensor.all_close(other, atol, rtol)torch.allclose(tensor, other, atol, rtol)
tensor.argmax(dim)tensor.argmax(dim)
tensor.argmin(dim)tensor.argmin(dim)
tensor.argsort(dim)tensor.argsort(dim)
tensor.argsort_descending(dim)tensor.argsort(dim, descending=True)
tensor.bool()tensor.bool()
tensor.clamp(min, max)torch.clamp(tensor, min=min, max=max)
tensor.clamp_max(max)torch.clamp(tensor, max=max)
tensor.clamp_min(min)torch.clamp(tensor, min=min)
tensor.div(other) or tensor / othertensor / other
tensor.div_scalar(scalar) or tensor / scalartensor / scalar
tensor.equal_elem(other)tensor.eq(other)
tensor.gather(dim, indices)torch.gather(tensor, dim, indices)
tensor.greater(other)tensor.gt(other)
tensor.greater_elem(scalar)tensor.gt(scalar)
tensor.greater_equal(other)tensor.ge(other)
tensor.greater_equal_elem(scalar)tensor.ge(scalar)
tensor.is_close(other, atol, rtol)torch.isclose(tensor, other, atol, rtol)
tensor.lower(other)tensor.lt(other)
tensor.lower_elem(scalar)tensor.lt(scalar)
tensor.lower_equal(other)tensor.le(other)
tensor.lower_equal_elem(scalar)tensor.le(scalar)
tensor.mask_fill(mask, value)tensor.masked_fill(mask, value)
tensor.mask_where(mask, value_tensor)torch.where(mask, value_tensor, tensor)
tensor.max()tensor.max()
tensor.max_dim(dim)tensor.max(dim, keepdim=True)
tensor.max_dim_with_indices(dim)N/A
tensor.max_pair(other)torch.Tensor.max(a,b)
tensor.mean()tensor.mean()
tensor.mean_dim(dim)tensor.mean(dim, keepdim=True)
tensor.min()tensor.min()
tensor.min_dim(dim)tensor.min(dim, keepdim=True)
tensor.min_dim_with_indices(dim)N/A
tensor.min_pair(other)torch.Tensor.min(a,b)
tensor.mul(other) or tensor * othertensor * other
tensor.mul_scalar(scalar) or tensor * scalartensor * scalar
tensor.neg() or -tensor-tensor
tensor.not_equal_elem(scalar)tensor.ne(scalar)
tensor.pad(pads, value)torch.nn.functional.pad(input, pad, value)
tensor.powf(other) or tensor.powi(intother)tensor.pow(other)
tensor.powf_scalar(scalar) or tensor.powi_scalar(intscalar)tensor.pow(scalar)
tensor.prod()tensor.prod()
tensor.prod_dim(dim)tensor.prod(dim, keepdim=True)
tensor.scatter(dim, indices, values)tensor.scatter_add(dim, indices, values)
tensor.select(dim, indices)tensor.index_select(dim, indices)
tensor.select_assign(dim, indices, values)N/A
tensor.sign()tensor.sign()
tensor.sort(dim)tensor.sort(dim).values
tensor.sort_descending(dim)tensor.sort(dim, descending=True).values
tensor.sort_descending_with_indices(dim)tensor.sort(dim, descending=True)
tensor.sort_with_indices(dim)tensor.sort(dim)
tensor.sub(other) or tensor - othertensor - other
tensor.sub_scalar(scalar) or tensor - scalartensor - scalar
tensor.sum()tensor.sum()
tensor.sum_dim(dim)tensor.sum(dim, keepdim=True)
tensor.topk(k, dim)tensor.topk(k, dim).values
tensor.topk_with_indices(k, dim)tensor.topk(k, dim)
tensor.tril(diagonal)torch.tril(tensor, diagonal)
tensor.triu(diagonal)torch.triu(tensor, diagonal)

Float Operations

Those operations are only available for Float tensors.

Burn APIPyTorch Equivalent
tensor.cos()tensor.cos()
tensor.erf()tensor.erf()
tensor.exp()tensor.exp()
tensor.from_floats(floats, device)N/A
tensor.from_full_precision(tensor)N/A
tensor.int()Similar to tensor.to(torch.long)
tensor.log()tensor.log()
tensor.log1p()tensor.log1p()
tensor.matmul(other)tensor.matmul(other)
tensor.one_hot(index, num_classes, device)N/A
tensor.ones_like()torch.ones_like(tensor)
tensor.random(shape, distribution, device)N/A
tensor.random_like(distribution)torch.rand_like() only uniform
tensor.recip()tensor.reciprocal()
tensor.sin()tensor.sin()
tensor.sqrt()tensor.sqrt()
tensor.swap_dims(dim1, dim2)tensor.transpose(dim1, dim2)
tensor.tanh()tensor.tanh()
tensor.to_full_precision()tensor.to(torch.float)
tensor.transpose()tensor.T
tensor.var(dim)tensor.var(dim)
tensor.var_bias(dim)N/A
tensor.var_mean(dim)N/A
tensor.var_mean_bias(dim)N/A
tensor.zeros_like()torch.zeros_like(tensor)

Int Operations

Those operations are only available for Int tensors.

Burn APIPyTorch Equivalent
tensor.arange(5..10, device) tensor.arange(start=5, end=10, device=device)
tensor.arange_step(5..10, 2, device)tensor.arange(start=5, end=10, step=2, device=device)
tensor.float()tensor.to(torch.float)
tensor.from_ints(ints)N/A
tensor.int_random(shape, distribution, device)N/A

Bool Operations

Those operations are only available for Bool tensors.

Burn APIPyTorch Equivalent
Tensor.diag_mask(shape, diagonal)N/A
Tensor.tril_mask(shape, diagonal)N/A
Tensor.triu_mask(shape, diagonal)N/A
tensor.argwhere()tensor.argwhere()
tensor.float()tensor.to(torch.float)
tensor.int()tensor.to(torch.long)
tensor.nonzero()tensor.nonzero(as_tuple=True)
tensor.not()tensor.logical_not()

Activation Functions

Burn APIPyTorch Equivalent
activation::gelu(tensor)nn.functional.gelu(tensor)
activation::leaky_relu(tensor, negative_slope)nn.functional.leaky_relu(tensor, negative_slope)
activation::log_sigmoid(tensor)nn.functional.log_sigmoid(tensor)
activation::log_softmax(tensor, dim)nn.functional.log_softmax(tensor, dim)
activation::mish(tensor)nn.functional.mish(tensor)
activation::prelu(tensor,alpha)nn.functional.prelu(tensor,weight)
activation::quiet_softmax(tensor, dim)nn.functional.quiet_softmax(tensor, dim)
activation::relu(tensor)nn.functional.relu(tensor)
activation::sigmoid(tensor)nn.functional.sigmoid(tensor)
activation::silu(tensor)nn.functional.silu(tensor)
activation::softmax(tensor, dim)nn.functional.softmax(tensor, dim)
activation::softplus(tensor, beta)nn.functional.softplus(tensor, beta)
activation::tanh(tensor)nn.functional.tanh(tensor)