Trait burn::optim::SimpleOptimizer

pub trait SimpleOptimizer<B>:
    Send
    + Sync
    + Clone
where B: Backend,
{ type State<const D: usize>: Record<B> + Clone + 'static; // Required methods fn step<const D: usize>( &self, lr: f64, tensor: Tensor<B, D>, grad: Tensor<B, D>, state: Option<Self::State<D>>, ) -> (Tensor<B, D>, Option<Self::State<D>>); fn to_device<const D: usize>( state: Self::State<D>, device: &<B as Backend>::Device, ) -> Self::State<D>; }
Expand description

Simple optimizer is an opinionated trait to simplify the process of implementing an optimizer.

Implementations don’t have to handle missing gradients, loading and exporting records, navigate the module parameter structure, handle tracked and untracked tensors, and the likes.

Required Associated Types§

type State<const D: usize>: Record<B> + Clone + 'static

The state of the optimizer. It also implements record, so that it can be saved.

Required Methods§

fn step<const D: usize>( &self, lr: f64, tensor: Tensor<B, D>, grad: Tensor<B, D>, state: Option<Self::State<D>>, ) -> (Tensor<B, D>, Option<Self::State<D>>)

The optimizer step is performed for one tensor at a time with its gradient and state.

Note that the state is passed as parameter, so implementations don’t have to handle the saving and loading of recorded states.

fn to_device<const D: usize>( state: Self::State<D>, device: &<B as Backend>::Device, ) -> Self::State<D>

Change the device of the state.

This function will be called accordindly to have the state on the same device as the gradient and the tensor when the step function is called.

Object Safety§

This trait is not object safe.

Implementors§

§

impl<B> SimpleOptimizer<B> for AdaGrad<B>
where B: Backend,

§

type State<const D: usize> = AdaGradState<B, D>

§

impl<B> SimpleOptimizer<B> for Adam<B>
where B: Backend,

§

type State<const D: usize> = AdamState<B, D>

§

impl<B> SimpleOptimizer<B> for AdamW<B>
where B: Backend,

§

type State<const D: usize> = AdamWState<B, D>

§

impl<B> SimpleOptimizer<B> for RmsProp<B>
where B: Backend,

§

type State<const D: usize> = RmsPropState<B, D>

§

impl<B> SimpleOptimizer<B> for Sgd<B>
where B: Backend,

§

type State<const D: usize> = SgdState<B, D>