pub trait TrainStep<TI, TO> {
// Required method
fn step(&self, item: TI) -> TrainOutput<TO>;
// Provided method
fn optimize<B, O>(
self,
optim: &mut O,
lr: f64,
grads: GradientsParams,
) -> Self
where B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B> { ... }
}
Expand description
Trait to be implemented for training models.
The step method needs to be manually implemented for all structs.
The optimize method can be overridden if you want to control how the optimizer is used to update the model. This can be useful if you want to call custom mutable functions on your model (e.g., clipping the weights) before or after the optimizer is used.
§Notes
To be used with the Learner struct, the struct which implements this trait must also implement the AutodiffModule trait, which is done automatically with the Module derive.