det

Function det 

pub fn det<B, const D: usize, const D1: usize, const D2: usize>(
    tensor: Tensor<B, D>,
) -> Tensor<B, D2>
where B: Backend,
Expand description

Computes the determinant on the last two dimensions of the input tensor.

§Arguments

  • tensor - The input tensor of shape [..., N, N].

§Returns

  • The determinant tensor of shape [...] where its rank is less than the input tensor’s rank by two.

§Generic Parameters

  • D: The rank of the input tensor.
  • D1: Must be set to D - 1.
  • D2: Must be set to D - 2.

§Panics

This function will panic if:

  • The generic parameters do not satisfy D - 1 == D1.
  • The generic parameters do not satisfy D - 2 == D2.
  • The input tensor rank D is less than 3.
  • The last two dimensions of the input tensor are not equal.
  • The input is a quantized tensor with dtype DType::QFloat.

§Performance Note

The determinant for 1 by 1, 2 by 2, and 3 by 3 matrices are computed using closed-form expressions. For larger matrices (4 by 4 or larger), the determinant function relies on the LU decomposition function under the hood,which is not fully optimized. It will not be as fast as highly tuned specialized libraries, especially for very large matrices or large batch sizes.

§Numerical Behavior

  • If the input tensors have types F16 or BF16, then they are internally upcast to F32 to perform the computations and cast back to the original data type (F16 or BF16) right before the function returns.
  • In this case, if the determinant values fall outside of the original data type’s range, then the cast-back will underflow to zero.

§Example

use burn::tensor::Tensor;
use burn::tensor::linalg;

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 3>::from_data([[[4.0, 3.0], [6.0, 3.0]]], &device);

    // Compute determinant
    let result = linalg::det::<B, 3, 2, 1>(tensor);

    // Expected Output:
    // result: [-6.0]
}

fn example2<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 3>::from_data(
        [
            [[1.0, 2.0], [3.0, 4.0]],   // det = -2
            [[2.0, 0.0], [0.0, 3.0]],   // det = 6
            [[5.0, 6.0], [7.0, 8.0]],   // det = -2
        ],
        &device,
    );

    // Compute determinant
    let result = linalg::det::<B, 3, 2, 1>(tensor);

    // Expected Output:
    // result: [-2.0, 6.0, -2.0]
}