Module

The Module derive allows you to create your own neural network modules, similar to PyTorch. The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.

use burn::module::Module;
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: Gelu,
}

impl<B: Backend> PositionWiseFeedForward<B> {
    /// Normal method added to a struct.
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);

        self.linear_outer.forward(x)
    }
}

Note that all fields declared in the struct must also implement the Module trait.

Tensor

If you want to create your own module that contains tensors, and not just other modules defined with the Module derive, you need to be careful to achieve the behavior you want.

  • Param<Tensor<B, D>>: If you want the tensor to be included as a parameter of your modules, you need to wrap the tensor in a Param struct. This will create an ID that will be used to identify this parameter. This is essential when performing module optimization and when saving states such as optimizer and module checkpoints. Note that a module's record only contains parameters.

  • Param<Tensor<B, D>>.set_require_grad(false): If you want the tensor to be included as a parameter of your modules, and therefore saved with the module's weights, but you don't want it to be updated by the optimizer.

  • Tensor<B, D>: If you want the tensor to act as a constant that can be recreated when instantiating a module. This can be useful when generating sinusoidal embeddings, for example.

Methods

These methods are available for all modules.

Burn APIPyTorch Equivalent
module.devices()N/A
module.fork(device)Similar to module.to(device).detach()
module.to_device(device)module.to(device)
module.no_grad()module.require_grad_(False)
module.num_params()N/A
module.visit(visitor)N/A
module.map(mapper)N/A
module.into_record()Similar to state_dict
module.load_record(record)Similar to load_state_dict(state_dict)
module.save_file(file_path, recorder)N/A
module.load_file(file_path, recorder)N/A

Similar to the backend trait, there is also the AutodiffModule trait to signify a module with autodiff support.

Burn APIPyTorch Equivalent
module.valid()module.eval()

Visitor & Mapper

As mentioned earlier, modules primarily function as parameter containers. Therefore, we naturally offer several ways to perform functions on each parameter. This is distinct from PyTorch, where extending module functionalities is not as straightforward.

The map and visitor methods are quite similar but serve different purposes. Mapping is used for potentially mutable operations where each parameter of a module can be updated to a new value. In Burn, optimizers are essentially just sophisticated module mappers. Visitors, on the other hand, are used when you don't intend to modify the module but need to retrieve specific information from it, such as the number of parameters or a list of devices in use.

You can implement your own mapper or visitor by implementing these simple traits:

/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
    /// Visit a tensor in the module.
    fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
}

/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
    /// Map a tensor in the module.
    fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) ->
      Tensor<B, D>;
}

Module Display

Burn provides a simple way to display the structure of a module and its configuration at a glance. You can print the module to see its structure, which is useful for debugging and tracking changes across different versions of a module. (See the print output of the Basic Workflow Model example.)

To customize the display of a module, you can implement the ModuleDisplay trait for your module. This will change the default display settings for the module and its children. Note that ModuleDisplay is automatically implemented for all modules, but you can override it to customize the display by annotating the module with #[module(custom_display)].

#![allow(unused)]
fn main() {
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: Gelu,
}

impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
    /// Custom settings for the display of the module.
    /// If `None` is returned, the default settings will be used.
    fn custom_settings(&self) -> Option<burn::module::DisplaySettings> {
        DisplaySettings::new()
            // Will show all attributes (default is false)
            .with_show_all_attributes(false)
            // Will show each attribute on a new line (default is true)
            .with_new_line_after_attribute(true)
            // Will show the number of parameters (default is true)
            .with_show_num_parameters(true)
            // Will indent by 2 spaces (default is 2)
            .with_indentation_size(2)
            // Will show the parameter ID (default is false)
            .with_show_param_id(false)
            // Convenience method to wrap settings in Some()
            .optional()
    }

    /// Custom content to be displayed.
    /// If `None` is returned, the default content will be used
    /// (all attributes of the module)
    fn custom_content(&self, content: Content) -> Option<Content> {
        content
            .add("linear_inner", &self.linear_inner)
            .add("linear_outer", &self.linear_outer)
            .add("anything", "anything_else")
            .optional()
    }
}
}

Built-in Modules

Burn comes with built-in modules that you can use to build your own modules.

General

Burn APIPyTorch Equivalent
BatchNormnn.BatchNorm1d, nn.BatchNorm2d etc.
Dropoutnn.Dropout
Embeddingnn.Embedding
Gelunn.Gelu
GroupNormnn.GroupNorm
HardSigmoidnn.Hardsigmoid
InstanceNormnn.InstanceNorm1d, nn.InstanceNorm2d etc.
LayerNormnn.LayerNorm
LeakyRelunn.LeakyReLU
Linearnn.Linear
Prelunn.PReLu
Relunn.ReLU
RmsNormNo direct equivalent
SwiGluNo direct equivalent
Interpolate1dNo direct equivalent
Interpolate2dNo direct equivalent

Convolutions

Burn APIPyTorch Equivalent
Conv1dnn.Conv1d
Conv2dnn.Conv2d
Conv3dnn.Conv3d
ConvTranspose1dnn.ConvTranspose1d
ConvTranspose2dnn.ConvTranspose2d
ConvTranspose3dnn.ConvTranspose3d

Pooling

Burn APIPyTorch Equivalent
AdaptiveAvgPool1dnn.AdaptiveAvgPool1d
AdaptiveAvgPool2dnn.AdaptiveAvgPool2d
AvgPool1dnn.AvgPool1d
AvgPool2dnn.AvgPool2d
MaxPool1dnn.MaxPool1d
MaxPool2dnn.MaxPool2d

RNNs

Burn APIPyTorch Equivalent
Grunn.GRU
Lstm/BiLstmnn.LSTM
GateControllerNo direct equivalent

Transformer

Burn APIPyTorch Equivalent
MultiHeadAttentionnn.MultiheadAttention
TransformerDecodernn.TransformerDecoder
TransformerEncodernn.TransformerEncoder
PositionalEncodingNo direct equivalent
RotaryEncodingNo direct equivalent

Loss

Burn APIPyTorch Equivalent
CrossEntropyLossnn.CrossEntropyLoss
MseLossnn.MSELoss
HuberLossnn.HuberLoss