Trait burn::backend::autodiff::ops::Backward

pub trait Backward<B, const D: usize, const N: usize>: Send + Debug + Sized + 'static
where B: Backend,
{ type State: Clone + Send + Debug + 'static; // Required method fn backward( self, ops: Ops<Self::State, N>, grads: &mut Gradients, checkpointer: &mut Checkpointer, ); // Provided method fn prepare<C>( self, nodes: [Arc<Node>; N], ) -> OpsPrep<Self, B, Self::State, C, D, N> where C: CheckpointStrategy { ... } }
Expand description

Trait for all operations.

§Notes

Concrete types implementing this trait should not have any state. If a state is necessary during the backward pass, they should be declared with the associated type ‘State’.

Required Associated Types§

type State: Clone + Send + Debug + 'static

Associated type to compute the backward pass.

Required Methods§

fn backward( self, ops: Ops<Self::State, N>, grads: &mut Gradients, checkpointer: &mut Checkpointer, )

The backward pass.

Provided Methods§

fn prepare<C>( self, nodes: [Arc<Node>; N], ) -> OpsPrep<Self, B, Self::State, C, D, N>

Prepare the backward ops.

Object Safety§

This trait is not object safe.

Implementors§