This repository contains the PyTorch codebase for Sparsity INT8 training and TensorRT inference.
Please refer to the blogpost titled "Sparsity in INT8: Training Workflow and Best Practices for NVIDIA TensorRT Acceleration".
For TensorRT inference: Ampere GPU due to Sparse Tensor Core support.
Download the ImageNet 2012 dataset and format it according to the instructions in data/
- PyTorch 1.11.0 (tested, may work with other versions)
- PyTorch Quantization toolkit: pytorch-quantization
- PyTorch Sparsity toolkit: APEX
- (Manual installation) TensorRT engine deployment: TensorRT
See docker.
- Create a Python virtual environment and install dependencies:
virtualenv -p /usr/bin/python3.8 venv38
source venv38/bin/activate
chmod +x && ./
- Download TensorRT and install
python wheel:
pip install $TRT_PATH/python/tensorrt-8.6.1-cp38-none-linux_x86_64.whl
See each python script for all supported flags.
Loads the pre-trained dense weights, sparsifies the model, and fine-tunes it.
python --model_name=resnet34 --data_dir=$DATA_DIR --batch_size=128 --eval_baseline --eval_sparse
This saves the sparse checkpoints (best and final), and their respective ONNX files. The best checkpoint will be used for the QAT workflow, and the best ONNX file will be used for the PTQ workflow.
There are two ways of quantizing the network: Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT).
Calibrates an ONNX model via the entropy
or minmax
python --onnx_path=model_sparse.onnx --onnx_input_name=input.1 --data_dir=$DATA_DIR \
--calibrator_type=entropy --calib_data_size=512
To generate the dense-PTQ version, for comparison, use the flag --is_dense_calibration
This will disable sparse weights when calibrating the dense model.
Loads the fine-tuned sparsified weights, adds QDQ nodes to relevant layers, calibrates it, and fine-tunes it.
python --model_name=resnet34 --data_dir=$DATA_DIR --batch_size=128 --eval_qat
To generate the dense-QAT version, for comparison, use the flag --is_dense_training
For results, please see refer to our blogpost.