Backend

We have effectively written most of the necessary code to train our model. However, we have not explicitly designated the backend to be used at any point. This will be defined in the main entrypoint of our program, namely the main function defined in src/main.rs.

mod data;
mod model;
mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
    backend::{Autodiff, Wgpu},
    data::dataset::Dataset,
    optim::AdamConfig,
};

fn main() {
    type MyBackend = Wgpu<f32, i32>;
    type MyAutodiffBackend = Autodiff<MyBackend>;

    let device = burn::backend::wgpu::WgpuDevice::default();
    let artifact_dir = "/tmp/guide";
    crate::training::train::<MyAutodiffBackend>(
        artifact_dir,
        TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
        device.clone(),
    );
}

In this code snippet, we use the Wgpu backend which is compatible with any operating system and will use the GPU. For other options, see the Burn README. This backend type takes the graphics API, the float type and the int type as generic arguments that will be used during the training. The autodiff backend is simply the same backend, wrapped within the Autodiff struct which imparts differentiability
to any backend.

We call the train function defined earlier with a directory for artifacts, the configuration of the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer configuration which in our case will be the default Adam configuration, and the device which can be obtained from the backend.

You can now train your freshly created model with the command:

cargo run --release

When running your project with the commande above, you should see the training progression through a basic CLI dashboard:

Alt text