Trait burn::backend::autodiff::checkpoint::retro_forward::RetroForward
pub trait RetroForward:
Debug
+ Send
+ 'static {
// Required method
fn forward(&self, states: &mut BackwardStates, out_node: NodeID);
}
Expand description
Definition of the forward function of a node, called during retropropagation only. This is different from the normal forward function because it reads and writes from the BackwardStates map instead of having a clear function signature.
Required Methods§
fn forward(&self, states: &mut BackwardStates, out_node: NodeID)
fn forward(&self, states: &mut BackwardStates, out_node: NodeID)
Applies the forward pass for retropropagation.