burn::optim

Trait SimpleOptimizer

pub trait SimpleOptimizer<B>:
    Send
    + Sync
    + Clone
where B: Backend,
{ type State<const D: usize>: Record<B> + Clone + 'static; // Required methods fn step<const D: usize>( &self, lr: f64, tensor: Tensor<B, D>, grad: Tensor<B, D>, state: Option<Self::State<D>>, ) -> (Tensor<B, D>, Option<Self::State<D>>); fn to_device<const D: usize>( state: Self::State<D>, device: &<B as Backend>::Device, ) -> Self::State<D>; }
Expand description

Simple optimizer is an opinionated trait to simplify the process of implementing an optimizer.

Implementations don’t have to handle missing gradients, loading and exporting records, navigate the module parameter structure, handle tracked and untracked tensors, and the likes.

Required Associated Types§

type State<const D: usize>: Record<B> + Clone + 'static

The state of the optimizer. It also implements record, so that it can be saved.

Required Methods§

fn step<const D: usize>( &self, lr: f64, tensor: Tensor<B, D>, grad: Tensor<B, D>, state: Option<Self::State<D>>, ) -> (Tensor<B, D>, Option<Self::State<D>>)

The optimizer step is performed for one tensor at a time with its gradient and state.

Note that the state is passed as parameter, so implementations don’t have to handle the saving and loading of recorded states.

fn to_device<const D: usize>( state: Self::State<D>, device: &<B as Backend>::Device, ) -> Self::State<D>

Change the device of the state.

This function will be called accordindly to have the state on the same device as the gradient and the tensor when the step function is called.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

§

impl<B> SimpleOptimizer<B> for AdaGrad
where B: Backend,

§

type State<const D: usize> = AdaGradState<B, D>

§

impl<B> SimpleOptimizer<B> for Adam
where B: Backend,

§

type State<const D: usize> = AdamState<B, D>

§

impl<B> SimpleOptimizer<B> for AdamW
where B: Backend,

§

type State<const D: usize> = AdamWState<B, D>

§

impl<B> SimpleOptimizer<B> for RmsProp
where B: Backend,

§

type State<const D: usize> = RmsPropState<B, D>

§

impl<B> SimpleOptimizer<B> for Sgd<B>
where B: Backend,

§

type State<const D: usize> = SgdState<B, D>