pub trait Module<B>:
Clone
+ Send
+ Debugwhere
B: Backend,{
type Record: Record<B>;
Show 13 methods
// Required methods
fn collect_devices(
&self,
devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device>;
fn fork(self, device: &<B as Backend>::Device) -> Self;
fn to_device(self, device: &<B as Backend>::Device) -> Self;
fn visit<Visitor>(&self, visitor: &mut Visitor)
where Visitor: ModuleVisitor<B>;
fn map<Mapper>(self, mapper: &mut Mapper) -> Self
where Mapper: ModuleMapper<B>;
fn load_record(self, record: Self::Record) -> Self;
fn into_record(self) -> Self::Record;
// Provided methods
fn devices(&self) -> Vec<<B as Backend>::Device> { ... }
fn no_grad(self) -> Self { ... }
fn num_params(&self) -> usize { ... }
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), RecorderError>
where FR: FileRecorder<B>,
PB: Into<PathBuf> { ... }
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device,
) -> Result<Self, RecorderError>
where FR: FileRecorder<B>,
PB: Into<PathBuf> { ... }
fn quantize_weights<C>(self, quantizer: &mut Quantizer<C>) -> Self
where C: Calibration { ... }
}
Expand description
Trait for all neural network modules.
Modules should be created using the derive attribute.
This will make your module trainable, savable and loadable via
state
and load
.
§Example
A module should have a backend defined as a generic parameter B. This will be used by the derive attribute to generate the code necessary to optimize and train the module on any backend.
// Not necessary when using the burn crate directly.
use burn_core as burn;
use burn::{
nn,
module::Module,
tensor::Tensor,
tensor::backend::Backend,
};
#[derive(Module, Debug)]
struct MyModule<B: Backend> {
my_param: nn::Linear<B>,
my_other_field: usize,
}
Required Associated Types§
Required Methods§
fn collect_devices(
&self,
devices: Vec<<B as Backend>::Device>,
) -> Vec<<B as Backend>::Device>
fn collect_devices( &self, devices: Vec<<B as Backend>::Device>, ) -> Vec<<B as Backend>::Device>
Return all the devices found in the underneath module tree added to the given vector without duplicates.
fn to_device(self, device: &<B as Backend>::Device) -> Self
fn to_device(self, device: &<B as Backend>::Device) -> Self
Move the module and all of its sub-modules to the given device.
§Warnings
The operation supports autodiff and it will be registered when activated. However, this may not be what you want. The output model will be an intermediary model, meaning that you can’t optimize it with gradient descent. If you want to optimize the output network on the target device, use fork instead.
fn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
fn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
Visit each tensor parameter in the module with a visitor.
fn map<Mapper>(self, mapper: &mut Mapper) -> Selfwhere
Mapper: ModuleMapper<B>,
fn map<Mapper>(self, mapper: &mut Mapper) -> Selfwhere
Mapper: ModuleMapper<B>,
Map each tensor parameter in the module with a mapper.
fn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Load the module state from a record.
fn into_record(self) -> Self::Record
fn into_record(self) -> Self::Record
Convert the module into a record containing the state.
Provided Methods§
fn devices(&self) -> Vec<<B as Backend>::Device>
fn devices(&self) -> Vec<<B as Backend>::Device>
Return all the devices found in the underneath module tree without duplicates.
fn no_grad(self) -> Self
fn no_grad(self) -> Self
fn num_params(&self) -> usize
fn num_params(&self) -> usize
Get the number of parameters the module has, including all of its sub-modules.
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), RecorderError>
fn save_file<FR, PB>( self, file_path: PB, recorder: &FR, ) -> Result<(), RecorderError>
Save the module to a file using the provided file recorder.
List of supported file recorders:
- default
- bincode
- bincode compressed with gzip
- json pretty
- json compressed with gzip
- named mpk
- named mpk compressed with gzip
§Notes
The file extension is automatically added depending on the file recorder provided, you don’t have to specify it.
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device,
) -> Result<Self, RecorderError>
fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &<B as Backend>::Device, ) -> Result<Self, RecorderError>
Load the module from a file using the provided file recorder.
The recorder should be the same as the one used to save the module, see save_file.
§Notes
The file extension is automatically added depending on the file recorder provided, you don’t have to specify it.
fn quantize_weights<C>(self, quantizer: &mut Quantizer<C>) -> Selfwhere
C: Calibration,
fn quantize_weights<C>(self, quantizer: &mut Quantizer<C>) -> Selfwhere
C: Calibration,
Quantize the weights of the module.