Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Quantization

Quantization techniques perform computations and store tensors in lower precision data types like 8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep learning model categorized as:

  • Post-training quantization (PTQ)
  • Quantization aware training (QAT)

In post-training quantization, the model is trained in floating point precision and later converted to the lower precision data type. There are two types of post-training quantization:

  1. Static quantization: quantizes the weights and activations of the model. Quantizing the activations statically requires data to be calibrated (i.e., recording the activation values to compute the optimal quantization parameters with representative data).
  2. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the activations are dynamically at runtime.

Sometimes post-training quantization is not able to achieve acceptable task accuracy. This is where quantization aware training comes into play, as it models the effects of quantization during training. Quantization errors are thus modeled in the forward and backward passes using fake quantization modules, which helps the model learn representations that are more robust to the reduction in precision.

Quantization support in Burn is currently in active development.

It supports the following modes on some backends:

  • Per-tensor and per-block (linear) quantization to 8-bit, 4-bit and 2-bit representations

No integer operations are currently supported, which means tensors are dequantized to perform the operations in floating point precision.

Module Quantization

Quantizing the weights of your model after training is quite simple. We have access to the weight tensors and can collect their statistics, such as the min and max value when using MinMaxCalibration, to compute the quantization parameters.

use burn::module::Quantizer;
use burn::tensor::quantization::{Calibration, QuantLevel, QuantParam, QuantScheme, QuantValue};

// Quantization config
let scheme = QuantScheme::default()
    .with_level(QuantLevel::Block(32))
    .with_value(QuantValue::Q4F)
    .with_param(QuantParam::F16);
let mut quantizer = Quantizer {
    calibration: Calibration::MinMax,
    scheme,
};

// Quantize the weights
let model = model.quantize_weights(&mut quantizer);

Calibration

Calibration is the step during quantization where the range of all floating-point tensors is computed. This is pretty straightforward for weights since the actual range is known at quantization-time (weights are static), but activations require more attention.

To compute the quantization parameters, Burn supports the following Calibration methods.

MethodDescription
MinMaxComputes the quantization range mapping based on the running min and max values.

Quantization Scheme

A quantization scheme defines how an input is quantized, including the representation of quantized values, storage format, granularity, and how the values are scaled.

#![allow(unused)]
fn main() {
let scheme = QuantScheme::default()
    .with_mode(QuantMode::Symmetric)         // Quantization mode
    .with_level(QuantLevel::block([2, 16]))  // Granularity (per-tensor or per-block)
    .with_value(QuantValue::Q8S)             // Data type of quantized values, independent of how they're stored
    .with_store(QuantStore::Native)          // Storage format for quantized values
    .with_param(QuantParam::F16);            // Precision for quantization parameters
}

Quantization Mode

ModeDescription
SymmetricValues are scaled symmetrically around zero.

Quantization Level

LevelDescription
TensorA single quantization parameter set for the entire tensor.
Block(block_size: BlockSize)Tensor divided into blocks (1D, 2D, or higher) defined by block_size, each with its own quantization params.

Quantization Value

ValueBitsDescription
Q8F88-bit full-range quantization
Q4F44-bit full-range quantization
Q2F22-bit full-range quantization
Q8S88-bit symmetric quantization
Q4S44-bit symmetric quantization
Q2S22-bit symmetric quantization
E5M288-bit floating-point (5 exponent, 2 mantissa)
E4M388-bit floating-point (4 exponent, 3 mantissa)
E2M144-bit floating-point (2 exponent, 1 mantissa)

Quantization Store

StoreDescription
NativeEach quantized value stored directly in memory.
U32Multiple quantized values packed into a 32-bit integer.

Native storage is not supported for sub-byte quantization values.

Quantization Parameters Precision

ParamDescription
F32Full floating-point precision.
F16Half-precision floating point.
BF16Brain float 16-bit precision.