pub trait Optimizer<M, B>: Sendwhere
M: AutodiffModule<B>,
B: AutodiffBackend,{
type Record: Record<B>;
// Required methods
fn step(&mut self, lr: f64, module: M, grads: GradientsParams) -> M;
fn to_record(&self) -> Self::Record;
fn load_record(self, record: Self::Record) -> Self;
}
Expand description
General trait to optimize module.
Required Associated Types§
Required Methods§
fn step(&mut self, lr: f64, module: M, grads: GradientsParams) -> M
fn step(&mut self, lr: f64, module: M, grads: GradientsParams) -> M
Perform the optimizer step using the given learning rate and gradients. The updated module is returned.
fn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Load the state of the optimizer as a record.
Object Safety§
This trait is not object safe.