Skip to content

Commit

Permalink
✨ update
Browse files Browse the repository at this point in the history
  • Loading branch information
Linaom1214 committed Jul 3, 2022
1 parent b9294ae commit 9525cbc
Show file tree
Hide file tree
Showing 12 changed files with 477 additions and 111 deletions.
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# YOLOv6、 YOLOX、 YOLOV5、 TensorRT Python/C++ API
## Update 2022.7.3 support TRT int8 post-training quantization


## Prepare TRT Python

```
pip install --upgrade setuptools pip --user
pip install nvidia-pyindex
pip install --upgrade nvidia-tensorrt
pip install pycuda
```


Here is a Python Demo mybe help you quickly understand this repo [Link](https://aistudio.baidu.com/aistudio/projectdetail/4263301?contributionType=1&shared=1)
## YOLOv6 [C++, Python Support]
Expand All @@ -24,7 +36,7 @@ python deploy/ONNX/export_onnx.py --weights yolov6s.pt --img 640 --batch 1
### 转化为TensorRT Engine

```
python export_trt.py -m onnx-name -o trt-name
python export.py -o onnx-name -e trt-name -p fp32/16/int8
```
### 测试

Expand Down Expand Up @@ -73,7 +85,7 @@ python3 tools/export_onnx.py --output-name yolox_s.onnx -n yolox-s -c yolox_s.pt
```
### 转化为TensorRT Engine
```
python export_trt.py -m onnx-name -o trt-name
python export.py -o onnx-name -e trt-name -p fp32/16/int8
```
### 测试

Expand All @@ -98,7 +110,7 @@ python path/to/export.py --weights yolov5s.pt --include onnx
### 转化为TensorRT Engine

```
python export_trt.py -m onnx-name -o trt-name
python export.py -o onnx-name -e trt-name -p fp32/16/int8
```
### 测试

Expand Down
224 changes: 224 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import os
import sys
import logging
import argparse

import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

from image_batch import ImageBatcher

logging.basicConfig(level=logging.INFO)
logging.getLogger("EngineBuilder").setLevel(logging.INFO)
log = logging.getLogger("EngineBuilder")

class EngineCalibrator(trt.IInt8EntropyCalibrator2):
"""
Implements the INT8 Entropy Calibrator 2.
"""

def __init__(self, cache_file):
"""
:param cache_file: The location of the cache file.
"""
super().__init__()
self.cache_file = cache_file
self.image_batcher = None
self.batch_allocation = None
self.batch_generator = None

def set_image_batcher(self, image_batcher: ImageBatcher):
"""
Define the image batcher to use, if any. If using only the cache file, an image batcher doesn't need
to be defined.
:param image_batcher: The ImageBatcher object
"""
self.image_batcher = image_batcher
size = int(np.dtype(self.image_batcher.dtype).itemsize * np.prod(self.image_batcher.shape))
self.batch_allocation = cuda.mem_alloc(size)
self.batch_generator = self.image_batcher.get_batch()

def get_batch_size(self):
"""
Overrides from trt.IInt8EntropyCalibrator2.
Get the batch size to use for calibration.
:return: Batch size.
"""
if self.image_batcher:
return self.image_batcher.batch_size
return 1

def get_batch(self, names):
"""
Overrides from trt.IInt8EntropyCalibrator2.
Get the next batch to use for calibration, as a list of device memory pointers.
:param names: The names of the inputs, if useful to define the order of inputs.
:return: A list of int-casted memory pointers.
"""
if not self.image_batcher:
return None
try:
batch, _, _ = next(self.batch_generator)
log.info("Calibrating image {} / {}".format(self.image_batcher.image_index, self.image_batcher.num_images))
cuda.memcpy_htod(self.batch_allocation, np.ascontiguousarray(batch))
return [int(self.batch_allocation)]
except StopIteration:
log.info("Finished calibration batches")
return None

def read_calibration_cache(self):
"""
Overrides from trt.IInt8EntropyCalibrator2.
Read the calibration cache file stored on disk, if it exists.
:return: The contents of the cache file, if any.
"""
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
log.info("Using calibration cache file: {}".format(self.cache_file))
return f.read()

def write_calibration_cache(self, cache):
"""
Overrides from trt.IInt8EntropyCalibrator2.
Store the calibration cache to a file on disk.
:param cache: The contents of the calibration cache to store.
"""
with open(self.cache_file, "wb") as f:
log.info("Writing calibration cache data to: {}".format(self.cache_file))
f.write(cache)

class EngineBuilder:
"""
Parses an ONNX graph and builds a TensorRT engine from it.
"""
def __init__(self, verbose=False, workspace=8):
"""
:param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger.
:param workspace: Max memory workspace to allow, in Gb.
"""
self.trt_logger = trt.Logger(trt.Logger.INFO)
if verbose:
self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE

trt.init_libnvinfer_plugins(self.trt_logger, namespace="")

self.builder = trt.Builder(self.trt_logger)
self.config = self.builder.create_builder_config()
self.config.max_workspace_size = workspace * (2 ** 30)

self.batch_size = None
self.network = None
self.parser = None

def create_network(self, onnx_path):
"""
Parse the ONNX graph and create the corresponding TensorRT network definition.
:param onnx_path: The path to the ONNX graph to load.
"""
network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

self.network = self.builder.create_network(network_flags)
self.parser = trt.OnnxParser(self.network, self.trt_logger)

onnx_path = os.path.realpath(onnx_path)
with open(onnx_path, "rb") as f:
if not self.parser.parse(f.read()):
print("Failed to load ONNX file: {}".format(onnx_path))
for error in range(self.parser.num_errors):
print(self.parser.get_error(error))
sys.exit(1)

inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]

print("Network Description")
for input in inputs:
self.batch_size = input.shape[0]
print("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
for output in outputs:
print("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
assert self.batch_size > 0
self.builder.max_batch_size = self.batch_size

def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=5000,
calib_batch_size=8):
"""
Build the TensorRT engine and serialize it to disk.
:param engine_path: The path where to serialize the engine to.
:param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
:param calib_input: The path to a directory holding the calibration images.
:param calib_cache: The path where to write the calibration cache to, or if it already exists, load it from.
:param calib_num_images: The maximum number of images to use for calibration.
:param calib_batch_size: The batch size to use for the calibration process.
"""
engine_path = os.path.realpath(engine_path)
engine_dir = os.path.dirname(engine_path)
os.makedirs(engine_dir, exist_ok=True)
print("Building {} Engine in {}".format(precision, engine_path))
inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]

# TODO: Strict type is only needed If the per-layer precision overrides are used
# If a better method is found to deal with that issue, this flag can be removed.
self.config.set_flag(trt.BuilderFlag.STRICT_TYPES)

if precision == "fp16":
if not self.builder.platform_has_fast_fp16:
print("FP16 is not supported natively on this platform/device")
else:
self.config.set_flag(trt.BuilderFlag.FP16)
elif precision == "int8":
if not self.builder.platform_has_fast_int8:
print("INT8 is not supported natively on this platform/device")
else:
if self.builder.platform_has_fast_fp16:
# Also enable fp16, as some layers may be even more efficient in fp16 than int8
self.config.set_flag(trt.BuilderFlag.FP16)
self.config.set_flag(trt.BuilderFlag.INT8)
self.config.int8_calibrator = EngineCalibrator(calib_cache)
if not os.path.exists(calib_cache):
calib_shape = [calib_batch_size] + list(inputs[0].shape[1:])
calib_dtype = trt.nptype(inputs[0].dtype)
self.config.int8_calibrator.set_image_batcher(
ImageBatcher(calib_input, calib_shape, calib_dtype, max_num_images=calib_num_images,
exact_batches=True))

with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
print("Serializing engine to file: {:}".format(engine_path))
f.write(engine.serialize())

def main(args):
builder = EngineBuilder(args.verbose, args.workspace)
builder.create_network(args.onnx)
builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images,
args.calib_batch_size)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--onnx", help="The input ONNX model file to load")
parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8"],
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
parser.add_argument("-w", "--workspace", default=1, type=int, help="The max memory workspace size to allow in Gb, "
"default: 1")
parser.add_argument("--calib_input", help="The directory holding images to use for calibration")
parser.add_argument("--calib_cache", default="./calibration.cache",
help="The file path for INT8 calibration cache to use, default: ./calibration.cache")
parser.add_argument("--calib_num_images", default=5000, type=int,
help="The maximum number of images to use for calibration, default: 5000")
parser.add_argument("--calib_batch_size", default=8, type=int,
help="The batch size for the calibration process, default: 8")
args = parser.parse_args()
if not all([args.onnx, args.engine]):
parser.print_help()
log.error("These arguments are required: --onnx and --engine")
sys.exit(1)
if args.precision == "int8" and not (args.calib_input or os.path.exists(args.calib_cache)):
parser.print_help()
log.error("When building in int8 precision, --calib_input or an existing --calib_cache file is required")
sys.exit(1)
main(args)


56 changes: 0 additions & 56 deletions export_trt.py

This file was deleted.

Loading

0 comments on commit 9525cbc

Please sign in to comment.