Trait burn::module::Parameter

pub trait Parameter:
    Clone
    + Debug
    + Send {
    type Device: Clone;

    // Required methods
    fn device(&self) -> Self::Device;
    fn is_require_grad(&self) -> bool;
    fn set_require_grad(self, require_grad: bool) -> Self;
}
Expand description

Trait that defines what is necessary for a type to be a parameter.

Required Associated Types§

type Device: Clone

The device type to be used.

Required Methods§

fn device(&self) -> Self::Device

Fetch the device.

fn is_require_grad(&self) -> bool

Fetch the gradient requirement.

fn set_require_grad(self, require_grad: bool) -> Self

Set the gradient requirement.

Object Safety§

This trait is not object safe.

Implementors§

§

impl<B, const D: usize> Parameter for Tensor<B, D>
where B: Backend,

§

type Device = <B as Backend>::Device

§

impl<B, const D: usize> Parameter for Tensor<B, D, Bool>
where B: Backend,

§

type Device = <B as Backend>::Device

§

impl<B, const D: usize> Parameter for Tensor<B, D, Int>
where B: Backend,

§

type Device = <B as Backend>::Device