PyTorch Model
Introduction
Whether you've trained your model in PyTorch or you want to use a pre-trained model from PyTorch,
you can import them into Burn. Burn supports importing PyTorch model weights with .pt
file
extension. Compared to ONNX models, .pt
files only contain the weights of the model, so you will
need to reconstruct the model architecture in Burn.
Here in this document we will show the full workflow of exporting a PyTorch model and importing it. Also you can refer to this Transitioning From PyTorch to Burn tutorial on importing a more complex model.
How to export a PyTorch model
If you have a PyTorch model that you want to import into Burn, you will need to export it first,
unless you are using a pre-trained published model. To export a PyTorch model, you can use the
torch.save
function.
The following is an example of how to export a PyTorch model:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
if __name__ == "__main__":
torch.manual_seed(42) # To make it reproducible
model = Net().to(torch.device("cpu"))
model_weights = model.state_dict()
torch.save(model_weights, "conv2d.pt")
Use Netron to view the exported model. You should see something like this:
How to import a PyTorch model
-
Define the model in Burn:
#![allow(unused)] fn main() { use burn::{ nn::conv::{Conv2d, Conv2dConfig}, prelude::*, }; #[derive(Module, Debug)] pub struct Net<B: Backend> { conv1: Conv2d<B>, conv2: Conv2d<B>, } impl<B: Backend> Net<B> { /// Create a new model. pub fn init(device: &B::Device) -> Self { let conv1 = Conv2dConfig::new([2, 2], [2, 2]) .init(device); let conv2 = Conv2dConfig::new([2, 2], [2, 2]) .with_bias(false) .init(device); Self { conv1, conv2 } } /// Forward pass of the model. pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { let x = self.conv1.forward(x); self.conv2.forward(x) } } }
-
Load the model weights from the exported PyTorch model (2 options):
a) Dynamically, but this requires burn-import runtime dependency:
use crate::model; use burn::record::{FullPrecisionSettings, Recorder}; use burn_import::pytorch::PyTorchFileRecorder; type Backend = burn_ndarray::NdArray<f32>; fn main() { let device = Default::default(); let record = PyTorchFileRecorder::<FullPrecisionSettings>::default() .load("./conv2d.pt".into(), &device) .expect("Should decode state successfully"); let model = model::Net::<Backend>::init(&device).load_record(record); }
b) Pre-converted to Burn's binary format:
// Convert the PyTorch model to Burn's binary format in // build.rs or in a separate executable. Then, include the generated file // in your project. See `examples/pytorch-import` for an example. use crate::model; use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}; use burn_import::pytorch::PyTorchFileRecorder; type Backend = burn_ndarray::NdArray<f32>; fn main() { let device = Default::default(); let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default() let record: model::NetRecord<B> = recorder .load("./conv2d.pt".into(), &device) .expect("Should decode state successfully"); // Save the model record to a file. let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default(); recorder .record(record, "MY_FILE_OUTPUT_PATH".into()) .expect("Failed to save model record"); } /// Load the model from the file in your source code (not in build.rs or script). fn load_model() -> Net::<Backend> { let device = Default::default(); let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default() .load("./MY_FILE_OUTPUT_PATH".into(), &device) .expect("Should decode state successfully"); Net::<Backend>::init(&device).load_record(record) }
Extract Configuration
In some cases, models may require additional configuration settings, which are often included in a
.pt
file during export. The config_from_file
function from the burn-import
cargo package
allows for the extraction of these configurations directly from the .pt
file. The extracted
configuration can then be used to initialize the model in Burn. Here is an example of how to extract
the configuration from a .pt
file:
use std::collections::HashMap; use burn::config::Config; use burn_import::pytorch::config_from_file; #[derive(Debug, Config)] struct NetConfig { n_head: usize, n_layer: usize, d_model: usize, // Candle's pickle has a bug with float serialization // https://github.com/huggingface/candle/issues/1729 // some_float: f64, some_int: i32, some_bool: bool, some_str: String, some_list_int: Vec<i32>, some_list_str: Vec<String>, // Candle's pickle has a bug with float serialization // https://github.com/huggingface/candle/issues/1729 // some_list_float: Vec<f64>, some_dict: HashMap<String, String>, } fn main() { let path = "weights_with_config.pt"; let top_level_key = Some("my_config"); let config: NetConfig = config_from_file(path, top_level_key).unwrap(); println!("{:#?}", config); // After extracting, it's recommended you save it as a json file. config.save("my_config.json").unwrap(); }
Troubleshooting
Adjusting the source model architecture
If your target model differs structurally from the model you exported, PyTorchFileRecorder
allows
you to change the attribute names and the order of the attributes. For example, if you exported a
model with the following structure:
class ConvModule(nn.Module):
def __init__(self):
super(ConvModule, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = ConvModule()
def forward(self, x):
x = self.conv(x)
return x
But you need to import it into a model with the following structure:
#![allow(unused)] fn main() { #[derive(Module, Debug)] pub struct Net<B: Backend> { conv1: Conv2d<B>, conv2: Conv2d<B>, } }
Which produces the following weights structure (viewed in Netron):
You can use the PyTorchFileRecorder
to change the attribute names and the order of the attributes
by specifying a regular expression (See
regex::Regex::replace and
try it online) to match the attribute name and a
replacement string in LoadArgs
:
#![allow(unused)] fn main() { let device = Default::default(); let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into()) // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" .with_key_remap("conv\\.(.*)", "$1"); let record = PyTorchFileRecorder::<FullPrecisionSettings>::default() .load(load_args, &device) .expect("Should decode state successfully"); let model = Net::<Backend>::init(&device).load_record(record); }
Printing the source model keys and tensor information
If you are unsure about the keys in the source model, you can print them using the following code:
#![allow(unused)] fn main() { let device = Default::default(); let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into()) // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" .with_key_remap("conv\\.(.*)", "$1") .with_debug_print(); // Print the keys and remapped keys let record = PyTorchFileRecorder::<FullPrecisionSettings>::default() .load(load_args, &device) .expect("Should decode state successfully"); let model = Net::<Backend>::init(&device).load_record(record); }
Here is an example of the output:
Debug information of keys and tensor shapes:
---
Original Key: conv.conv1.bias
Remapped Key: conv1.bias
Shape: [2]
Dtype: F32
---
Original Key: conv.conv1.weight
Remapped Key: conv1.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
Original Key: conv.conv2.weight
Remapped Key: conv2.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
Non-contiguous indices in the source model
Sometimes the indices of the source model are non-contiguous. For example, the source model has:
"model.ax.0.weight",
"model.ax.0.bias",
"model.ax.2.weight",
"model.ax.2.bias",
"model.ax.4.weight",
"model.ax.4.bias",
"model.ax.6.weight",
"model.ax.6.bias",
"model.ax.8.weight",
"model.ax.8.bias",
This may occur when model.ax
attribute (in the above example) uses Sequential
to define the
layers and the skipped items do not have weight tensors, such as a ReLU
layer. PyTorch simply
skips the layers without weight tensors, resulting in non-contiguous indices. In this case,
PyTorchFileRecorder automatically corrects the indices to be contiguous preserving the order of the
weights resulting in:
"model.ax.0.weight",
"model.ax.0.bias",
"model.ax.1.weight",
"model.ax.1.bias",
"model.ax.2.weight",
"model.ax.2.bias",
"model.ax.3.weight",
"model.ax.3.bias",
"model.ax.4.weight",
"model.ax.4.bias",
Loading the model weights to a partial model
PyTorchFileRecorder
enables selective weight loading into partial models. For instance, in a model
with both an encoder and a decoder, it's possible to load only the encoder weights. This is done by
defining the encoder in Burn, allowing the loading of its weights while excluding the decoder's.
Specifying the top-level key for state_dict
Sometimes the
state_dict
is nested under a top-level key along with other metadata as in a
general checkpoint.
For example, the state_dict
of the whisper model is nested under model_state_dict
key. In this
case, you can specify the top-level key in LoadArgs
:
#![allow(unused)] fn main() { let device = Default::default(); let load_args = LoadArgs::new("tiny.en.pt".into()) .with_top_level_key("my_state_dict"); let record = PyTorchFileRecorder::<FullPrecisionSettings>::default() .load(load_args, &device) .expect("Should decode state successfully") }
Models containing enum modules
Burn supports models containing enum modules with new-type variants (tuple with one item). Importing weights for such models is automatically supported by the PyTorchFileRecorder. However, it should be noted that since the source weights file does not contain the enum variant information, the enum variant is picked based on the enum variant type. Let's consider the following example:
#![allow(unused)] fn main() { #[derive(Module, Debug)] pub enum Conv<B: Backend> { DwsConv(DwsConv<B>), Conv(Conv2d<B>), } #[derive(Module, Debug)] pub struct DwsConv<B: Backend> { dconv: Conv2d<B>, pconv: Conv2d<B>, } #[derive(Module, Debug)] pub struct Net<B: Backend> { conv: Conv<B>, } }
If the source weights file contains weights for DwsConv
, such as the following keys:
---
Key: conv.dconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.dconv.weight
Shape: [2, 1, 3, 3]
Dtype: F32
---
Key: conv.pconv.bias
Shape: [2]
Dtype: F32
---
Key: conv.pconv.weight
Shape: [2, 2, 1, 1]
Dtype: F32
The weights will be imported into the DwsConv
variant of the Conv
enum module.
If the variant types are identical, then the first variant is picked. Generally, it won't be a problem since the variant types are usually different.