Trait burn::tensor::backend::AutodiffBackend

pub trait AutodiffBackend: Backend {
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
    type Gradients: Send;

    // Required methods
    fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients;
    fn grad(
        tensor: &Self::FloatTensorPrimitive,
        grads: &Self::Gradients,
    ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
    fn grad_remove(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
    ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
    fn grad_replace(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
        grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
    );
    fn inner(
        tensor: Self::FloatTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive;
    fn int_inner(
        tensor: Self::IntTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive;
    fn bool_inner(
        tensor: Self::BoolTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive;
    fn q_inner(
        tensor: Self::QuantizedTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive;
    fn from_inner(
        tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
    ) -> Self::FloatTensorPrimitive;
    fn int_from_inner(
        tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive,
    ) -> Self::IntTensorPrimitive;
    fn bool_from_inner(
        tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive,
    ) -> Self::BoolTensorPrimitive;
    fn q_from_inner(
        tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive,
    ) -> Self::QuantizedTensorPrimitive;
}
Expand description

Trait that allows a backend to support autodiff.

Required Associated Types§

type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>

The inner backend type.

type Gradients: Send

Gradients type.

Required Methods§

fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients

Backward pass.

§Arguments
  • tensor - The tensor is the last node of computational graph where the gradients are computed.
§Returns

The gradients.

fn grad( tensor: &Self::FloatTensorPrimitive, grads: &Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>

Returns the gradients of a tensor.

§Arguments
  • tensor - The tensor to extract the gradients from.
§Returns

An optional tensor containing the gradient.

fn grad_remove( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>

Pops the gradients of a tensor and returns them.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
§Returns

An optional tensor containing the given gradients.

fn grad_replace( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive, )

Replace the gradients of a tensor with the one provided.

If no gradient existed for the provided tensor, register it.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
  • grad - The updated grad tensor.

fn inner( tensor: Self::FloatTensorPrimitive, ) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn int_inner( tensor: Self::IntTensorPrimitive, ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn bool_inner( tensor: Self::BoolTensorPrimitive, ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn q_inner( tensor: Self::QuantizedTensorPrimitive, ) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn from_inner( tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive, ) -> Self::FloatTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn int_from_inner( tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive, ) -> Self::IntTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn bool_from_inner( tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive, ) -> Self::BoolTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn q_from_inner( tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive, ) -> Self::QuantizedTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Object Safety§

This trait is not object safe.

Implementors§