Struct burn::nn::GateController
pub struct GateController<B>where
B: Backend,{
pub input_transform: Linear<B>,
pub hidden_transform: Linear<B>,
}
Expand description
A GateController represents a gate in an LSTM cell. An LSTM cell generally contains three gates: an input gate, forget gate, and output gate. Additionally, cell gate is just used to compute the cell state.
An Lstm gate is modeled as two linear transformations. The results of these transformations are used to calculate the gate’s output.
Fields§
§input_transform: Linear<B>
Represents the affine transformation applied to input vector
Represents the affine transformation applied to the hidden state
Implementations§
§impl<B> GateController<B>where
B: Backend,
impl<B> GateController<B>where
B: Backend,
pub fn new(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &<B as Backend>::Device,
) -> GateController<B>
pub fn new( d_input: usize, d_output: usize, bias: bool, initializer: Initializer, device: &<B as Backend>::Device, ) -> GateController<B>
Initialize a new gate_controller module.
pub fn gate_product(
&self,
input: Tensor<B, 2>,
hidden: Tensor<B, 2>,
) -> Tensor<B, 2>
pub fn gate_product( &self, input: Tensor<B, 2>, hidden: Tensor<B, 2>, ) -> Tensor<B, 2>
Helper function for performing weighted matrix product for a gate and adds bias, if any.
Mathematically, performs Wx*X + Wh*H + b
, where:
Wx = weight matrix for the connection to input vector X
Wh = weight matrix for the connection to hidden state H
X = input vector
H = hidden state
b = bias terms
Trait Implementations§
§impl<B> AutodiffModule<B> for GateController<B>
impl<B> AutodiffModule<B> for GateController<B>
§type InnerModule = GateController<<B as AutodiffBackend>::InnerBackend>
type InnerModule = GateController<<B as AutodiffBackend>::InnerBackend>
§fn valid(&self) -> <GateController<B> as AutodiffModule<B>>::InnerModule
fn valid(&self) -> <GateController<B> as AutodiffModule<B>>::InnerModule
§impl<B> Clone for GateController<B>where
B: Backend,
impl<B> Clone for GateController<B>where
B: Backend,
§fn clone(&self) -> GateController<B>
fn clone(&self) -> GateController<B>
1.0.0 · source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source
. Read more§impl<B> Debug for GateController<B>
impl<B> Debug for GateController<B>
§impl<B> Display for GateController<B>where
B: Backend,
impl<B> Display for GateController<B>where
B: Backend,
§impl<B> Module<B> for GateController<B>where
B: Backend,
impl<B> Module<B> for GateController<B>where
B: Backend,
§type Record = GateControllerRecord<B>
type Record = GateControllerRecord<B>
§fn load_record(
self,
record: <GateController<B> as Module<B>>::Record,
) -> GateController<B>
fn load_record( self, record: <GateController<B> as Module<B>>::Record, ) -> GateController<B>
§fn into_record(self) -> <GateController<B> as Module<B>>::Record
fn into_record(self) -> <GateController<B> as Module<B>>::Record
§fn num_params(&self) -> usize
fn num_params(&self) -> usize
§fn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
fn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
§fn map<Mapper>(self, mapper: &mut Mapper) -> GateController<B>where
Mapper: ModuleMapper<B>,
fn map<Mapper>(self, mapper: &mut Mapper) -> GateController<B>where
Mapper: ModuleMapper<B>,
§fn collect_devices(
&self,
devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device>
fn collect_devices( &self, devices: Vec<<B as Backend>::Device>, ) -> Vec<<B as Backend>::Device>
§fn to_device(self, device: &<B as Backend>::Device) -> GateController<B>
fn to_device(self, device: &<B as Backend>::Device) -> GateController<B>
§fn fork(self, device: &<B as Backend>::Device) -> GateController<B>
fn fork(self, device: &<B as Backend>::Device) -> GateController<B>
§fn devices(&self) -> Vec<<B as Backend>::Device>
fn devices(&self) -> Vec<<B as Backend>::Device>
§fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), RecorderError>
fn save_file<FR, PB>( self, file_path: PB, recorder: &FR, ) -> Result<(), RecorderError>
§fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device,
) -> Result<Self, RecorderError>
fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &<B as Backend>::Device, ) -> Result<Self, RecorderError>
§fn quantize_weights<C>(self, quantizer: &mut Quantizer<C>) -> Selfwhere
C: Calibration,
fn quantize_weights<C>(self, quantizer: &mut Quantizer<C>) -> Selfwhere
C: Calibration,
§impl<B> ModuleDisplay for GateController<B>where
B: Backend,
impl<B> ModuleDisplay for GateController<B>where
B: Backend,
§fn format(&self, passed_settings: DisplaySettings) -> String
fn format(&self, passed_settings: DisplaySettings) -> String
§fn custom_settings(&self) -> Option<DisplaySettings>
fn custom_settings(&self) -> Option<DisplaySettings>
§impl<B> ModuleDisplayDefault for GateController<B>where
B: Backend,
impl<B> ModuleDisplayDefault for GateController<B>where
B: Backend,
Auto Trait Implementations§
impl<B> !Freeze for GateController<B>
impl<B> !RefUnwindSafe for GateController<B>
impl<B> Send for GateController<B>
impl<B> !Sync for GateController<B>
impl<B> Unpin for GateController<B>where
<B as Backend>::FloatTensorPrimitive: Unpin,
<B as Backend>::QuantizedTensorPrimitive: Unpin,
<B as Backend>::Device: Unpin,
impl<B> UnwindSafe for GateController<B>where
<B as Backend>::FloatTensorPrimitive: UnwindSafe,
<B as Backend>::QuantizedTensorPrimitive: UnwindSafe,
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
source§unsafe fn clone_to_uninit(&self, dst: *mut T)
unsafe fn clone_to_uninit(&self, dst: *mut T)
clone_to_uninit
)§impl<T> Instrument for T
impl<T> Instrument for T
§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
source§impl<T> IntoEither for T
impl<T> IntoEither for T
source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moresource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more§impl<T> Pointable for T
impl<T> Pointable for T
§impl<T> ToCompactString for Twhere
T: Display,
impl<T> ToCompactString for Twhere
T: Display,
§fn try_to_compact_string(&self) -> Result<CompactString, ToCompactStringError>
fn try_to_compact_string(&self) -> Result<CompactString, ToCompactStringError>
ToCompactString::to_compact_string()
] Read more§fn to_compact_string(&self) -> CompactString
fn to_compact_string(&self) -> CompactString
CompactString
]. Read more