TrainStep

Trait TrainStep 

pub trait TrainStep {
    type Input: Send + 'static;
    type Output: ItemLazy + 'static;

    // Required method
    fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;

    // Provided methods
    fn optimize<B, O>(
        self,
        optim: &mut O,
        lr: f64,
        grads: GradientsParams,
    ) -> Self
       where B: AutodiffBackend,
             O: Optimizer<Self, B>,
             Self: AutodiffModule<B> { ... }
    fn optimize_multi<B, O>(
        self,
        optim: &mut O,
        lr: f64,
        grads: MultiGradientsParams,
    ) -> Self
       where B: AutodiffBackend,
             O: Optimizer<Self, B>,
             Self: AutodiffModule<B> { ... }
}
Expand description

Trait to be implemented for models to be able to be trained.

The step method needs to be manually implemented for all structs.

The optimize method can be overridden if you want to control how the optimizer is used to update the model. This can be useful if you want to call custom mutable functions on your model (e.g., clipping the weights) before or after the optimizer is used.

§Notes

To be used with the Learner struct, the struct which implements this trait must also implement the AutodiffModule trait, which is done automatically with the Module derive.

Required Associated Types§

type Input: Send + 'static

Type of input for a step of the training stage.

type Output: ItemLazy + 'static

Type of output for a step of the training stage.

Required Methods§

fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>

Runs a step for training, which executes the forward and backward passes.

§Arguments
  • item - The input for the model.
§Returns

The output containing the model output and the gradients.

Provided Methods§

fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
where B: AutodiffBackend, O: Optimizer<Self, B>, Self: AutodiffModule<B>,

Optimize the current module with the provided gradients and learning rate.

§Arguments
  • optim: Optimizer used for learning.
  • lr: The learning rate used for this step.
  • grads: The gradients of each parameter in the current model.
§Returns

The updated model.

fn optimize_multi<B, O>( self, optim: &mut O, lr: f64, grads: MultiGradientsParams, ) -> Self
where B: AutodiffBackend, O: Optimizer<Self, B>, Self: AutodiffModule<B>,

Optimize the current module with the provided gradients and learning rate.

§Arguments
  • optim: Optimizer used for learning.
  • lr: The learning rate used for this step.
  • grads: Multiple gradients associated to each parameter in the current model.
§Returns

The updated model.

Implementors§