AutodiffBackend

Trait AutodiffBackend 

pub trait AutodiffBackend: Backend {
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
    type Gradients: Send;

    // Required methods
    fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients;
    fn grad(
        tensor: &Self::FloatTensorPrimitive,
        grads: &Self::Gradients,
    ) -> Option<<Self::InnerBackend as BackendTypes>::FloatTensorPrimitive>;
    fn grad_remove(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
    ) -> Option<<Self::InnerBackend as BackendTypes>::FloatTensorPrimitive>;
    fn grad_replace(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
        grad: <Self::InnerBackend as BackendTypes>::FloatTensorPrimitive,
    );
    fn inner(
        tensor: Self::FloatTensorPrimitive,
    ) -> <Self::InnerBackend as BackendTypes>::FloatTensorPrimitive;
    fn int_inner(
        tensor: Self::IntTensorPrimitive,
    ) -> <Self::InnerBackend as BackendTypes>::IntTensorPrimitive;
    fn bool_inner(
        tensor: Self::BoolTensorPrimitive,
    ) -> <Self::InnerBackend as BackendTypes>::BoolTensorPrimitive;
    fn q_inner(
        tensor: Self::QuantizedTensorPrimitive,
    ) -> <Self::InnerBackend as BackendTypes>::QuantizedTensorPrimitive;
    fn from_inner(
        tensor: <Self::InnerBackend as BackendTypes>::FloatTensorPrimitive,
    ) -> Self::FloatTensorPrimitive;
    fn int_from_inner(
        tensor: <Self::InnerBackend as BackendTypes>::IntTensorPrimitive,
    ) -> Self::IntTensorPrimitive;
    fn bool_from_inner(
        tensor: <Self::InnerBackend as BackendTypes>::BoolTensorPrimitive,
    ) -> Self::BoolTensorPrimitive;
    fn q_from_inner(
        tensor: <Self::InnerBackend as BackendTypes>::QuantizedTensorPrimitive,
    ) -> Self::QuantizedTensorPrimitive;
}
Expand description

Trait that allows a backend to support autodiff.

Required Associated Types§

type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>

The inner backend type.

type Gradients: Send

Gradients type.

Required Methods§

fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients

Backward pass.

§Arguments
  • tensor - The tensor is the last node of computational graph where the gradients are computed.
§Returns

The gradients.

fn grad( tensor: &Self::FloatTensorPrimitive, grads: &Self::Gradients, ) -> Option<<Self::InnerBackend as BackendTypes>::FloatTensorPrimitive>

Returns the gradients of a tensor.

§Arguments
  • tensor - The tensor to extract the gradients from.
§Returns

An optional tensor containing the gradient.

fn grad_remove( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, ) -> Option<<Self::InnerBackend as BackendTypes>::FloatTensorPrimitive>

Pops the gradients of a tensor and returns them.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
§Returns

An optional tensor containing the given gradients.

fn grad_replace( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, grad: <Self::InnerBackend as BackendTypes>::FloatTensorPrimitive, )

Replace the gradients of a tensor with the one provided.

If no gradient existed for the provided tensor, register it.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
  • grad - The updated grad tensor.

fn inner( tensor: Self::FloatTensorPrimitive, ) -> <Self::InnerBackend as BackendTypes>::FloatTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn int_inner( tensor: Self::IntTensorPrimitive, ) -> <Self::InnerBackend as BackendTypes>::IntTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn bool_inner( tensor: Self::BoolTensorPrimitive, ) -> <Self::InnerBackend as BackendTypes>::BoolTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn q_inner( tensor: Self::QuantizedTensorPrimitive, ) -> <Self::InnerBackend as BackendTypes>::QuantizedTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

fn from_inner( tensor: <Self::InnerBackend as BackendTypes>::FloatTensorPrimitive, ) -> Self::FloatTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn int_from_inner( tensor: <Self::InnerBackend as BackendTypes>::IntTensorPrimitive, ) -> Self::IntTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn bool_from_inner( tensor: <Self::InnerBackend as BackendTypes>::BoolTensorPrimitive, ) -> Self::BoolTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

fn q_from_inner( tensor: <Self::InnerBackend as BackendTypes>::QuantizedTensorPrimitive, ) -> Self::QuantizedTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

§

impl<B, C> AutodiffBackend for Autodiff<B, C>
where B: Backend, C: CheckpointStrategy,

Available on non-crate feature distributed only.
§

type InnerBackend = B

§

type Gradients = Gradients