burn::tensor::backend

Trait AutodiffBackend

pub trait AutodiffBackend: Backend {
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
    type Gradients: Send;

    // Required methods
    fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients;
    fn grad(
        tensor: &Self::FloatTensorPrimitive,
        grads: &Self::Gradients,
    ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
    fn grad_remove(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
    ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
    fn grad_replace(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
        grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
    );
    fn inner(
        tensor: Self::FloatTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive;
    fn int_inner(
        tensor: Self::IntTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive;
    fn bool_inner(
        tensor: Self::BoolTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive;
    fn q_inner(
        tensor: Self::QuantizedTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive;
    fn from_inner(
        tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
    ) -> Self::FloatTensorPrimitive;
    fn int_from_inner(
        tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive,
    ) -> Self::IntTensorPrimitive;
    fn bool_from_inner(
        tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive,
    ) -> Self::BoolTensorPrimitive;
    fn q_from_inner(
        tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive,
    ) -> Self::QuantizedTensorPrimitive;
}
Expand description

Trait that allows a backend to support autodiff.

Required Associated Types§

type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>

The inner backend type.

type Gradients: Send

Gradients type.

Required Methods§

fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients

Backward pass.

§Arguments
  • tensor - The tensor is the last node of computational graph where the gradients are computed.
§Returns

The gradients.

fn grad( tensor: &Self::FloatTensorPrimitive, grads: &Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>

Returns the gradients of a tensor.

§Arguments
  • tensor - The tensor to extract the gradients from.
§Returns

An optional tensor containing the gradient.

fn grad_remove( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>

Pops the gradients of a tensor and returns them.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
§Returns

An optional tensor containing the given gradients.

fn grad_replace( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive, )

Replace the gradients of a tensor with the one provided.

If no gradient existed for the provided tensor, register it.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
  • grad - The updated grad tensor.

fn inner( tensor: Self::FloatTensorPrimitive, ) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn int_inner( tensor: Self::IntTensorPrimitive, ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn bool_inner( tensor: Self::BoolTensorPrimitive, ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn q_inner( tensor: Self::QuantizedTensorPrimitive, ) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn from_inner( tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive, ) -> Self::FloatTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn int_from_inner( tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive, ) -> Self::IntTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn bool_from_inner( tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive, ) -> Self::BoolTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn q_from_inner( tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive, ) -> Self::QuantizedTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§