Trait burn::train::TrainStep

pub trait TrainStep<TI, TO> {
    // Required method
    fn step(&self, item: TI) -> TrainOutput<TO>;

    // Provided method
    fn optimize<B, O>(
        self,
        optim: &mut O,
        lr: f64,
        grads: GradientsParams,
    ) -> Self
       where B: AutodiffBackend,
             O: Optimizer<Self, B>,
             Self: AutodiffModule<B> { ... }
}
Expand description

Trait to be implemented for training models.

The step method needs to be manually implemented for all structs.

The optimize method can be overridden if you want to control how the optimizer is used to update the model. This can be useful if you want to call custom mutable functions on your model (e.g., clipping the weights) before or after the optimizer is used.

§Notes

To be used with the Learner struct, the struct which implements this trait must also implement the AutodiffModule trait, which is done automatically with the Module derive.

Required Methods§

fn step(&self, item: TI) -> TrainOutput<TO>

Runs the training step, which executes the forward and backward passes.

§Arguments
  • item - The training input for the model.
§Returns

The training output containing the model output and the gradients.

Provided Methods§

fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
where B: AutodiffBackend, O: Optimizer<Self, B>, Self: AutodiffModule<B>,

Optimize the current module with the provided gradients and learning rate.

§Arguments
  • optim: Optimizer used for training this model.
  • lr: The learning rate used for this step.
  • grads: The gradients of each parameter in the current model.
§Returns

The updated model.

Implementors§