pub trait Module<B>: Clone + Send + Debugwhere
B: Backend,{
type Record: Record<B>;
// Required methods
fn collect_devices(
&self,
devices: Vec<<B as Backend>::Device>
) -> Vec<<B as Backend>::Device>;
fn fork(self, device: &<B as Backend>::Device) -> Self;
fn to_device(self, device: &<B as Backend>::Device) -> Self;
fn visit<Visitor>(&self, visitor: &mut Visitor)
where Visitor: ModuleVisitor<B>;
fn map<Mapper>(self, mapper: &mut Mapper) -> Self
where Mapper: ModuleMapper<B>;
fn load_record(self, record: Self::Record) -> Self;
fn into_record(self) -> Self::Record;
// Provided methods
fn devices(&self) -> Vec<<B as Backend>::Device> { ... }
fn no_grad(self) -> Self { ... }
fn num_params(&self) -> usize { ... }
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR
) -> Result<(), RecorderError>
where FR: FileRecorder<B>,
PB: Into<PathBuf> { ... }
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device
) -> Result<Self, RecorderError>
where FR: FileRecorder<B>,
PB: Into<PathBuf> { ... }
}
Expand description
Trait for all neural network modules.
Modules should be created using the derive attribute.
This will make your module trainable, savable and loadable via
state
and load
.
§Example
A module should have a backend defined as a generic parameter B. This will be used by the derive attribute to generate the code necessary to optimize and train the module on any backend.
// Not necessary when using the burn crate directly.
use burn_core as burn;
use burn::{
nn,
module::Module,
tensor::Tensor,
tensor::backend::Backend,
};
#[derive(Module, Debug)]
struct MyModule<B: Backend> {
my_param: nn::Linear<B>,
my_other_field: usize,
}
Required Associated Types§
Required Methods§
fn collect_devices(
&self,
devices: Vec<<B as Backend>::Device>
) -> Vec<<B as Backend>::Device>
fn collect_devices( &self, devices: Vec<<B as Backend>::Device> ) -> Vec<<B as Backend>::Device>
Return all the devices found in the underneath module tree added to the given vector without duplicates.
fn to_device(self, device: &<B as Backend>::Device) -> Self
fn to_device(self, device: &<B as Backend>::Device) -> Self
Move the module and all of its sub-modules to the given device.
§Warnings
The device operations will be registered in the autodiff graph. Therefore, be sure to call backward only one time even if you have the same module on multiple devices. If you want to call backward multiple times, look into using fork instead.
fn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
fn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
Visit each tensor parameter in the module with a visitor.
fn map<Mapper>(self, mapper: &mut Mapper) -> Selfwhere
Mapper: ModuleMapper<B>,
fn map<Mapper>(self, mapper: &mut Mapper) -> Selfwhere
Mapper: ModuleMapper<B>,
Map each tensor parameter in the module with a mapper.
fn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Load the module state from a record.
fn into_record(self) -> Self::Record
fn into_record(self) -> Self::Record
Convert the module into a record containing the state.
Provided Methods§
fn devices(&self) -> Vec<<B as Backend>::Device>
fn devices(&self) -> Vec<<B as Backend>::Device>
Return all the devices found in the underneath module tree without duplicates.
fn no_grad(self) -> Self
fn no_grad(self) -> Self
fn num_params(&self) -> usize
fn num_params(&self) -> usize
Get the number of parameters the module has, including all of its sub-modules.
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR
) -> Result<(), RecorderError>
fn save_file<FR, PB>( self, file_path: PB, recorder: &FR ) -> Result<(), RecorderError>
Save the module to a file using the provided file recorder.
List of supported file recorders:
- default
- bincode
- bincode compressed with gzip
- json pretty
- json compressed with gzip
- named mpk
- named mpk compressed with gzip
§Notes
The file extension is automatically added depending on the file recorder provided, you don’t have to specify it.
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device
) -> Result<Self, RecorderError>
fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &<B as Backend>::Device ) -> Result<Self, RecorderError>
Load the module from a file using the provided file recorder.
The recorder should be the same as the one used to save the module, see save_file.
§Notes
The file extension is automatically added depending on the file recorder provided, you don’t have to specify it.