Trait burn::tensor::backend::AutodiffBackend
pub trait AutodiffBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
type Gradients: Send;
// Required methods
fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients;
fn grad(
tensor: &Self::FloatTensorPrimitive,
grads: &Self::Gradients,
) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
fn grad_remove(
tensor: &Self::FloatTensorPrimitive,
grads: &mut Self::Gradients,
) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
fn grad_replace(
tensor: &Self::FloatTensorPrimitive,
grads: &mut Self::Gradients,
grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
);
fn inner(
tensor: Self::FloatTensorPrimitive,
) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive;
fn int_inner(
tensor: Self::IntTensorPrimitive,
) -> <Self::InnerBackend as Backend>::IntTensorPrimitive;
fn bool_inner(
tensor: Self::BoolTensorPrimitive,
) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive;
fn q_inner(
tensor: Self::QuantizedTensorPrimitive,
) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive;
fn from_inner(
tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
) -> Self::FloatTensorPrimitive;
fn int_from_inner(
tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive,
) -> Self::IntTensorPrimitive;
fn bool_from_inner(
tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive,
) -> Self::BoolTensorPrimitive;
fn q_from_inner(
tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive,
) -> Self::QuantizedTensorPrimitive;
}
Expand description
Trait that allows a backend to support autodiff.
Required Associated Types§
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>
The inner backend type.
Required Methods§
fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients
fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients
fn grad(
tensor: &Self::FloatTensorPrimitive,
grads: &Self::Gradients,
) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>
fn grad( tensor: &Self::FloatTensorPrimitive, grads: &Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>
fn grad_remove(
tensor: &Self::FloatTensorPrimitive,
grads: &mut Self::Gradients,
) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>
fn grad_remove( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>
fn grad_replace(
tensor: &Self::FloatTensorPrimitive,
grads: &mut Self::Gradients,
grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
)
fn grad_replace( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive, )
Replace the gradients of a tensor with the one provided.
If no gradient existed for the provided tensor, register it.
§Arguments
tensor
- The tensor to pop the gradients from.grads
- The gradients.grad
- The updated grad tensor.
fn inner(
tensor: Self::FloatTensorPrimitive,
) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive
fn inner( tensor: Self::FloatTensorPrimitive, ) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive
fn int_inner(
tensor: Self::IntTensorPrimitive,
) -> <Self::InnerBackend as Backend>::IntTensorPrimitive
fn int_inner( tensor: Self::IntTensorPrimitive, ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive
fn bool_inner(
tensor: Self::BoolTensorPrimitive,
) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive
fn bool_inner( tensor: Self::BoolTensorPrimitive, ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive
fn q_inner(
tensor: Self::QuantizedTensorPrimitive,
) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive
fn q_inner( tensor: Self::QuantizedTensorPrimitive, ) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive
fn from_inner(
tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
) -> Self::FloatTensorPrimitive
fn from_inner( tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive, ) -> Self::FloatTensorPrimitive
fn int_from_inner(
tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive,
) -> Self::IntTensorPrimitive
fn int_from_inner( tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive, ) -> Self::IntTensorPrimitive
fn bool_from_inner(
tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive,
) -> Self::BoolTensorPrimitive
fn bool_from_inner( tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive, ) -> Self::BoolTensorPrimitive
fn q_from_inner(
tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive,
) -> Self::QuantizedTensorPrimitive
fn q_from_inner( tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive, ) -> Self::QuantizedTensorPrimitive
Object Safety§
This trait is not object safe.