Trait TrainStep
pub trait TrainStep {
type Input: Send + 'static;
type Output: ItemLazy + 'static;
// Required method
fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;
// Provided methods
fn optimize<B, O>(
self,
optim: &mut O,
lr: f64,
grads: GradientsParams,
) -> Self
where B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B> { ... }
fn optimize_multi<B, O>(
self,
optim: &mut O,
lr: f64,
grads: MultiGradientsParams,
) -> Self
where B: AutodiffBackend,
O: Optimizer<Self, B>,
Self: AutodiffModule<B> { ... }
}Expand description
Trait to be implemented for models to be able to be trained.
The step method needs to be manually implemented for all structs.
The optimize method can be overridden if you want to control how the optimizer is used to update the model. This can be useful if you want to call custom mutable functions on your model (e.g., clipping the weights) before or after the optimizer is used.
§Notes
To be used with the Learner struct, the struct which implements this trait must also implement the AutodiffModule trait, which is done automatically with the Module derive.