Trait burn::optim::SimpleOptimizer
pub trait SimpleOptimizer<B>:
Send
+ Sync
+ Clonewhere
B: Backend,{
type State<const D: usize>: Record<B> + Clone + 'static;
// Required methods
fn step<const D: usize>(
&self,
lr: f64,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>);
fn to_device<const D: usize>(
state: Self::State<D>,
device: &<B as Backend>::Device,
) -> Self::State<D>;
}
Expand description
Simple optimizer is an opinionated trait to simplify the process of implementing an optimizer.
Implementations don’t have to handle missing gradients, loading and exporting records, navigate the module parameter structure, handle tracked and untracked tensors, and the likes.
Required Associated Types§
Required Methods§
fn step<const D: usize>(
&self,
lr: f64,
tensor: Tensor<B, D>,
grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>)
fn step<const D: usize>( &self, lr: f64, tensor: Tensor<B, D>, grad: Tensor<B, D>, state: Option<Self::State<D>>, ) -> (Tensor<B, D>, Option<Self::State<D>>)
The optimizer step is performed for one tensor at a time with its gradient and state.
Note that the state is passed as parameter, so implementations don’t have to handle the saving and loading of recorded states.
Object Safety§
This trait is not object safe.