lu

Function lu 

pub fn lu<B, const D: usize, const D1: usize>(
    tensor: Tensor<B, D>,
) -> (Tensor<B, D>, Tensor<B, D>, Tensor<B, D>)
where B: Backend,
Expand description

Computes the LU decomposition of a square or rectangular matrix with partial pivoting.

This function decomposes the input tensor A into three tensors P, L, and U such that A = PLU.

§Arguments

  • tensor - The input tensor of shape [..., n_rows, n_cols].

§Returns

A tuple of three tensors (P, L, U):

  • P - The permutation tensor of shape [..., n_rows, n_rows].
  • L - The lower triangular tensor of shape [..., n_rows, min(n_rows, n_cols)] with unit diagonal elements.
  • U - The upper triangular tensor of shape [..., min(n_rows, n_cols), n_cols].

§Generic Parameters

  • D: The number of dimensions of the input tensor.
  • D1: The number of dimensions of the 1D pivot tensor. Must be exactly D - 1.

§Panics

This function will panic if the tensor checks fail:

  • The input tensor has less than 2 dimensions (D < 2).
  • The generic parameters do not satisfy D - 1 == D1.
  • The input is a quantized tensor with dtype DType::QFloat.

§Performance Note

The current implementation of LU decomposition is not fully optimized. It will not be as fast as highly tuned specialized libraries, especially for very large matrices or large batch sizes.

§Numerical Behavior

  • If the input tensor has dtype F16 or BF16, it is internally upcast to F32
    for the computation and cast back to the original dtype before returning.
  • In this case, values in L and U that fall outside the original dtype’s
    representable range will saturate or underflow on cast-back.

§Example

use burn::tensor::Tensor;
use burn::backend::Flex;
use burn::tensor::linalg;

fn example<B: Backend>() {
    let device = Default::default();
    let tensor = Tensor::<B, 2>::from_data([[4.0, 3.0], [6.0, 3.0]], &device);

    // Compute P, L, U
    let (p, l, u) = linalg::lu::<B, 2, 1>(tensor);

    // Expected Output:
    // p: [[0.0, 1.0],
    //     [1.0, 0.0]]
    //
    // l: [[1.0,       0.0],
    //     [0.6666667, 1.0]]
    //
    // u: [[6.0, 3.0],
    //     [0.0, 1.0]]
}