Shichen Peng
Shichen Peng
发布于 2023-07-13 / 41 阅读
0
0

Quantization Realization in PyTorch

Quantization Realization in PyTorch

Quantization API in PyTorch is still in active development and the APIs may not be so stable. This tutorial may be outdated in the future.

Background

Two Modes to Quantize Network in PyTorch

There are two modes in PyTorch that can do quantization of models. Eager Mode and FX Graph Mode.

  • Eager Mode - It was brought into PyTorch earlier than FX Graph Mode. Users need to do fusion and specify where quantization and dequantization happen manually, also it only supports modules and not support functionals. Users may need to modify their model structure to adapt to this mode.

  • FX Graph Mode - It is a newer feature that can work automatically which can trace the modules inside networks to find the optimum quantization parameters. It also allows users to adjust the behavior of how it quantizes given models. This mode is recommended by PyTorch officials.

In this tutorial, quantization will be based on FX Graph Mode since it is recommended and there is no need to modify the model which means we can load the pre-trained parameters easier.

API Overview

In PyTorch, there are two packages for quantization, one is torch.quantization and the other one is torch.ao.quantization. According to a post written by Jerry Zhang who belongs to the PyTorch quantization developer team, torch.quantization is going to be deprecated and users are encouraged to use torch.ao.quantization. Here, ao means Architecture Optimization. PyTorch offers methods to configure quantization settings, trace and find suitable quantization parameters, and to accomplish quantization.

How PyTorch Quantize Model

To quantize a model in FX Graph Mode in PyTorch is easy, you only need to do the following steps:

  1. Create an instance of the model you want to quantize.
  2. Load the pre-trained parameter in the original float32 data format.
  3. Make a deep copy of this model if necessary. (Later modifications will happen in place)
  4. Configure the quantization settings.
  5. Insert observers into the model.
  6. Feed sample inputs with statistically referenced values into the model to calibrate optimized quantization parameters.
  7. Conduct quantization.

Quantization Demo

Inference on quantized models only supports x86 and ARM CPU in the current stable PyTorch release. GPU quantized is based on TensorRT which is still an experimental feature. In this tutorial, we will only use the CPU.

It is hard to describe the specific steps in words. To make it more comprehensive, an example will be demonstrated here. Only statements newly appear in each step will have comments in order to make the demo more readable.

Post-Training Dynamic Quantization(PTDQ)

Not Necessary Now

Post-Training Static Quantization(PTSQ)

Start Point

Assume we have already defined a model in PyTorch like this:

class DemoModel(nn.Module):
    # Some definitions here
    pass

The common usage of how to conduct inference is like this:

import torch

# Prepare input data for inference
input_x = some_function_to_fetch_input()

# Create model instance and load pre-trained parameters
model_fp32 = DemoModel()
model_fp32.load_state_dict(torch.load(FILE_PATH, map_location='cpu'))
# Enable evaluation mode
model_fp32.eval()

# Conduct inference
output_y = model_fp32(input_x)

Now, let us begin quantizing this network.

Deep Copy the Model

Since the modification will be in place, we should better copy an instance.

You can skip this step if the original model will no longer be used.

Import the deepcopy method and copy duplicate the model.

# Import package for deep copy
from copy import deepcopy

import torch

input_x = some_function_to_fetch_input()

model_fp32 = DemoModel()
model_fp32.load_state_dict(torch.load(FILE_PATH, map_location='cpu'))
model_fp32.eval()

# Copy to another instance
model_to_quantize = deepcopy(model_fp32)
model_to_quantize.eval()

output_y = model_fp32(input_x)

Configure Quantization Settings

PyTorch allows developers to customize some behaviors of how it quantizes a model through an API called QconfigMapping which you can map different QConfig to different modules. QConfig describes the detailed methods used to quantize weights and activations such as data format, per-channel or per-tensor, and so on.

For more information on how to customize your quantization, you can refer to the source code or the official documents.

If you do not have specific requirements, you can use the default configuration:

from copy import deepcopy

import torch
# Import method to get default quantization configuration
from torch.ao.quantization import get_default_qconfig_mapping

input_x = some_function_to_fetch_input()

model_fp32 = DemoModel()
model_fp32.load_state_dict(torch.load(FILE_PATH, map_location='cpu'))
model_fp32.eval()

model_to_quantize = deepcopy(model_fp32)
model_to_quantize.eval()

# Generate default quantization configuration
qconfig_mapping = get_default_qconfig_mapping("x86")

output_y = model_fp32(input_x)

Since quantization is an architecture-related optimization method, get_default_qconfig_mapping() needs a string to tell PyTorch which back-end architecture your model will be used. There are several options:

  • "x86" - For x86 CPUs.
  • "fbgemm" - For x86 CPUs. (will be deprecated)
  • "qnnpack" - For ARM CPUs.

Insert Observers into Model

Based on the algorithm of quantization, PyTorch needs to find appropriate parameters such as Zero and Scale to quantize tensors with maximum precision which need to observe the internal value during reference. The following code shows how to insert them:

from copy import deepcopy

import torch
from torch.ao.quantization import get_default_qconfig_mapping
# Import quantize_fx package for inserting observers and later conversion
from torch.ao.quantization import quantize_fx

input_x = some_function_to_fetch_input()

# Prepare dataset with statistically referenced value and its dataloader
calibration_set = CalibrationDataset(DATASET_PATH)
calibration_loader = CalibrationDataLoader(dataset=calibration_set, batch_size=1)

model_fp32 = DemoModel()
model_fp32.load_state_dict(torch.load(FILE_PATH, map_location='cpu'))
model_fp32.eval()

model_to_quantize = deepcopy(model_fp32)
model_to_quantize.eval()

qconfig_mapping = get_default_qconfig_mapping("x86")

# Generate calibration data list in tuple. (Assume data[0] is input and data[1] is label)
example_inputs = tuple(data[0] for data in calibration_loader)

# Insert the observers with configuration
model_prepared = quantize_fx.prepare_fx(model=model_to_quantize, 
										qconfig_mapping=qconfig_mapping, 
										example_inputs=example_inputs)

output_y = model_fp32(input_x)

Until PyTorch 2.0.1 which this tutorial is based on, the source code of method prepare_fx does not take any use of its input parameter: example_inputs. It is proved by developer comments in the PyTorch source code. The calibration step still needs to be done manually. However, to ensure the afterward capability of your code, please add example inputs anyway.

How the observers decide the range can also be configured by MinMaxObserver, MovingAverageMinMaxObserver, and HistogramObserver.

Calibrate Model using Sample Inputs

This step will be automatically done by PyTorch. What you need to do is only feed inputs into the network. It will analyze the distribution of internal data and find the optimum parameters.

from copy import deepcopy

import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization import quantize_fx

input_x = some_function_to_fetch_input()

calibration_set = CalibrationDataset(DATASET_PATH)
calibration_loader = CalibrationDataLoader(dataset=calibration_set, batch_size=1)

model_fp32 = DemoModel()
model_fp32.load_state_dict(torch.load(FILE_PATH, map_location='cpu'))
model_fp32.eval()

model_to_quantize = deepcopy(model_fp32)
model_to_quantize.eval()

qconfig_mapping = get_default_qconfig_mapping("x86")

example_inputs = tuple(data[0] for data in calibration_loader)

model_prepared = quantize_fx.prepare_fx(model=model_to_quantize, 
										qconfig_mapping=qconfig_mapping, 
										example_inputs=example_inputs)

# Feed the calibration data into observation model to calibrate quantization parameters
for data in calibration_loader:
    calibration_data, _ = data
    model_prepared(calibration_data)

output_y = model_fp32(input_x)

Convert Model to Quantized Version

Now it is time to get the quantized model. It can be done by a single statement:

from copy import deepcopy

import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization import quantize_fx

input_x = some_function_to_fetch_input()

calibration_set = CalibrationDataset(DATASET_PATH)
calibration_loader = CalibrationDataLoader(dataset=calibration_set, batch_size=1)

model_fp32 = DemoModel()
model_fp32.load_state_dict(torch.load(FILE_PATH, map_location='cpu'))
model_fp32.eval()

model_to_quantize = deepcopy(model_fp32)
model_to_quantize.eval()

qconfig_mapping = get_default_qconfig_mapping("x86")

example_inputs = tuple(data[0] for data in calibration_loader)

model_prepared = quantize_fx.prepare_fx(model=model_to_quantize, 
										qconfig_mapping=qconfig_mapping, 
										example_inputs=example_inputs)

for data in calibration_loader:
    calibration_data, _ = data
    model_prepared(calibration_data)

# Convert model to quantized version
model_quantized = quantize_fx.convert_fx(model_prepared)

output_y = model_fp32(input_x)
# Try to make an inference
output_y_quantized = model_quantized(input_x)

You finished all the steps.

Quantization-Aware Training

Not Necessarry Now

评论