Trait burn::optim::Optimizer

pub trait Optimizer<M, B>: Send{
    type Record: Record<B>;

    // Required methods
    fn step(&mut self, lr: f64, module: M, grads: GradientsParams) -> M;
    fn to_record(&self) -> Self::Record;
    fn load_record(self, record: Self::Record) -> Self;
}
Expand description

General trait to optimize module.

Required Associated Types§

type Record: Record<B>

Optimizer associative type to be used when saving and loading the state.

Required Methods§

fn step(&mut self, lr: f64, module: M, grads: GradientsParams) -> M

Perform the optimizer step using the given learning rate and gradients. The updated module is returned.

fn to_record(&self) -> Self::Record

Get the current state of the optimizer as a record.

fn load_record(self, record: Self::Record) -> Self

Load the state of the optimizer as a record.

Object Safety§

This trait is not object safe.

Implementors§

§

impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>

§

type Record = HashMap<ParamId, AdaptorRecord<O, B>>