Skip to content

Commit

Permalink
🎉 support for end2end
Browse files Browse the repository at this point in the history
  • Loading branch information
Linaom1214 committed Aug 11, 2022
1 parent 08bb095 commit 245e339
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 18 deletions.
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
## [简体中文](README_CN.md)

## Support
YOLOv7、YOLOv6、 YOLOX、 YOLOV5、
YOLOv7、YOLOv6、 YOLOX、 YOLOV5

The C++ code for YOLOv7/YOLOv6 also can be used for YOLOx or YOLOv5

## Update
- 2022.8.11 nms plugin support ==> more simple
- 2022.7.8 support YOLOV7
- 2022.7.3 support TRT int8 post-training quantization

Expand Down Expand Up @@ -48,15 +51,23 @@ python models/export.py --weights ../yolov7.pt --grid

```
python export.py -o onnx-name -e trt-name -p fp32/16/int8
--end2end export the model include nms plugin
```
### Test

```
cd yolov7
python trt.py
```
tips!

### C++
if you use the end2end model please modift the code as such

`origin_img = pred.inference(img_path, conf=0.5, end2end=True)`

### C++ [Now don't support end2end model]

C++ [Demo](yolov7/cpp/README.md)

Expand Down Expand Up @@ -84,7 +95,7 @@ python deploy/ONNX/export_onnx.py --weights yolov6s.pt --img 640 --batch 1
### Convert to TensorRT Engine

```
python export.py -o onnx-name -e trt-name -p fp32/16/int8
python export.py -o onnx-name -e trt-name -p fp32/16/int8 --end2end
```
### Test

Expand All @@ -93,7 +104,7 @@ cd yolov6
python trt.py
```

### C++
### C++ [Now don't support end2end model]

C++ [Demo](yolov6/cpp/README.md)

Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
YOLOv7、YOLOv6、 YOLOX、 YOLOV5、

## 更新
- 2022.8.11 端到端导出支持, 更简洁的端到端导出方法
- 2022.7.8 支持YOLOV7
- 2022.7.3 支持 TRT int8 post-training quantization

Expand Down
73 changes: 70 additions & 3 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, verbose=False, workspace=8):
self.network = None
self.parser = None

def create_network(self, onnx_path):
def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det):
"""
Parse the ONNX graph and create the corresponding TensorRT network definition.
:param onnx_path: The path to the ONNX graph to load.
Expand Down Expand Up @@ -142,6 +142,61 @@ def create_network(self, onnx_path):
assert self.batch_size > 0
self.builder.max_batch_size = self.batch_size

if end2end:
previous_output = self.network.get_output(0)
self.network.unmark_output(previous_output)
# output [1, 8400, 85]
# slice boxes, obj_score, class_scores
strides = trt.Dims([1,1,1])
starts = trt.Dims([0,0,0])
bs, num_boxes, temp = previous_output.shape
shapes = trt.Dims([bs, num_boxes, 4])
# [0, 0, 0] [1, 8400, 4] [1, 1, 1]
boxes = self.network.add_slice(previous_output, starts, shapes, strides)
num_classes = temp -5
starts[2] = 4
shapes[2] = 1
# [0, 0, 4] [1, 8400, 1] [1, 1, 1]
obj_score = self.network.add_slice(previous_output, starts, shapes, strides)
starts[2] = 5
shapes[2] = num_classes
# [0, 0, 5] [1, 8400, 80] [1, 1, 1]
scores = self.network.add_slice(previous_output, starts, shapes, strides)
# scores = obj_score * class_scores => [bs, num_boxes, nc]
updated_scores = self.network.add_elementwise(obj_score.get_output(0), scores.get_output(0), trt.ElementWiseOperation.PROD)

'''
"plugin_version": "1",
"background_class": -1, # no background class
"max_output_boxes": detections_per_img,
"score_threshold": score_thresh,
"iou_threshold": nms_thresh,
"score_activation": False,
"box_coding": 1,
'''
registry = trt.get_plugin_registry()
assert(registry)
creator = registry.get_plugin_creator("EfficientNMS_TRT", "1")
assert(creator)
fc = []
fc.append(trt.PluginField("background_class", np.array([-1], dtype=np.int32), trt.PluginFieldType.INT32))
fc.append(trt.PluginField("max_output_boxes", np.array([max_det], dtype=np.int32), trt.PluginFieldType.INT32))
fc.append(trt.PluginField("score_threshold", np.array([conf_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32))
fc.append(trt.PluginField("iou_threshold", np.array([iou_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32))
fc.append(trt.PluginField("box_coding", np.array([1], dtype=np.int32), trt.PluginFieldType.INT32))

fc = trt.PluginFieldCollection(fc)
nms_layer = creator.create_plugin("nms_layer", fc)

layer = self.network.add_plugin_v2([boxes.get_output(0), updated_scores.get_output(0)], nms_layer)
layer.get_output(0).name = "num"
layer.get_output(1).name = "boxes"
layer.get_output(2).name = "scores"
layer.get_output(3).name = "classes"
for i in range(4):
self.network.mark_output(layer.get_output(i))


def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=5000,
calib_batch_size=8):
"""
Expand Down Expand Up @@ -176,7 +231,8 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
# Also enable fp16, as some layers may be even more efficient in fp16 than int8
self.config.set_flag(trt.BuilderFlag.FP16)
self.config.set_flag(trt.BuilderFlag.INT8)
self.config.int8_calibrator = EngineCalibrator(calib_cache)
# self.config.int8_calibrator = EngineCalibrator(calib_cache)
self.config.int8_calibrator = SwinCalibrator(calib_cache)
if not os.path.exists(calib_cache):
calib_shape = [calib_batch_size] + list(inputs[0].shape[1:])
calib_dtype = trt.nptype(inputs[0].dtype)
Expand All @@ -190,7 +246,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No

def main(args):
builder = EngineBuilder(args.verbose, args.workspace)
builder.create_network(args.onnx)
builder.create_network(args.onnx, args.end2end, args.conf_thres, args.iou_thres, args.max_det)
builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images,
args.calib_batch_size)

Expand All @@ -210,7 +266,17 @@ def main(args):
help="The maximum number of images to use for calibration, default: 5000")
parser.add_argument("--calib_batch_size", default=8, type=int,
help="The batch size for the calibration process, default: 8")
parser.add_argument("--end2end", default=False, action="store_true",
help="export the engine include nms plugin, default: False")
parser.add_argument("--conf_thres", default=0.4, type=float,
help="The conf threshold for the nms, default: 0.4")
parser.add_argument("--iou_thres", default=0.5, type=float,
help="The iou threshold for the nms, default: 0.5")
parser.add_argument("--max_det", default=100, type=int,
help="The total num for results, default: 100")

args = parser.parse_args()
print(args)
if not all([args.onnx, args.engine]):
parser.print_help()
log.error("These arguments are required: --onnx and --engine")
Expand All @@ -219,6 +285,7 @@ def main(args):
parser.print_help()
log.error("When building in int8 precision, --calib_input or an existing --calib_cache file is required")
sys.exit(1)

main(args)


29 changes: 21 additions & 8 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, engine_path, imgsz=(640,640)):

logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
trt.init_libnvinfer_plugins(logger,'') # initialize TensorRT plugins
with open(engine_path, "rb") as f:
serialized_engine = f.read()
engine = runtime.deserialize_cuda_engine(serialized_engine)
Expand Down Expand Up @@ -59,33 +60,45 @@ def infer(self, img):
data = [out['host'] for out in self.outputs]
return data

def detect_video(self, video_path):
def detect_video(self, video_path, conf=0.5, end2end=False):
cap = cv2.VideoCapture(video_path)
while True:
ret, frame = cap.read()
if not ret:
break
blob, ratio = preproc(frame, self.imgsz, self.mean, self.std)
data = self.infer(blob)
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
dets = self.postprocess(predictions,ratio)
if end2end:
num, final_boxes, final_scores, final_cls_inds = data
final_boxes = np.reshape(final_boxes/ratio, (-1, 4))
dets = np.concatenate([final_boxes[:num[0]], np.array(final_scores)[:num[0]].reshape(-1, 1), np.array(final_cls_inds)[:num[0]].reshape(-1, 1)], axis=-1)
else:
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
dets = self.postprocess(predictions,ratio)

if dets is not None:
final_boxes, final_scores, final_cls_inds = dets[:,
:4], dets[:, 4], dets[:, 5]
frame = vis(frame, final_boxes, final_scores, final_cls_inds,
conf=0.5, class_names=self.class_names)
cv2.imshow('frame', frame)
conf=conf, class_names=self.class_names)
cv2.imshow('frame', frame)
if cv2.waitKey(25) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()

def inference(self, img_path, conf=0.5):
def inference(self, img_path, conf=0.5, end2end=False):
origin_img = cv2.imread(img_path)
img, ratio = preproc(origin_img, self.imgsz, self.mean, self.std)
data = self.infer(img)
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
dets = self.postprocess(predictions,ratio)
if end2end:
num, final_boxes, final_scores, final_cls_inds = data
final_boxes = np.reshape(final_boxes/ratio, (-1, 4))
dets = np.concatenate([final_boxes[:num[0]], np.array(final_scores)[:num[0]].reshape(-1, 1), np.array(final_cls_inds)[:num[0]].reshape(-1, 1)], axis=-1)
else:
predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0]
dets = self.postprocess(predictions,ratio)

if dets is not None:
final_boxes, final_scores, final_cls_inds = dets[:,
:4], dets[:, 4], dets[:, 5]
Expand Down
6 changes: 3 additions & 3 deletions yolov6/trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def __init__(self, engine_path , imgsz=(640,640)):


if __name__ == '__main__':
pred = Predictor(engine_path='yolov6.trt')
pred = Predictor(engine_path='yolov6-new.trt')
img_path = '../src/3.jpg'
origin_img = pred.inference(img_path)
origin_img = pred.inference(img_path, conf=0.5, end2end=True)
cv2.imwrite("%s_yolov6.jpg" % os.path.splitext(
os.path.split(img_path)[-1])[0], origin_img)
pred.detect_video('../src/video1.mp4') # set 0 use a webcam
pred.detect_video('../src/video1.mp4', conf=0.5, end2end=False) # set 0 use a webcam
pred.get_fps()

0 comments on commit 245e339

Please sign in to comment.