diff --git a/README.md b/README.md index 254411d..848fafa 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,10 @@ # YOLO Series TensorRT Python/C++ ## Support -[YOLOv8](https://v8docs.ultralytics.com/)、[YOLOv7](https://github.com/WongKinYiu/yolov7)、[YOLOv6](https://github.com/meituan/YOLOv6)、 [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)、 [YOLOV5](https://github.com/ultralytics/yolov5)、[YOLOv3](https://github.com/ultralytics/yolov3) +[YOLOv10](https://github.com/THU-MIG/yolov10)、[YOLOv9](https://github.com/WongKinYiu/yolov9)、[YOLOv8](https://v8docs.ultralytics.com/)、[YOLOv7](https://github.com/WongKinYiu/yolov7)、[YOLOv6](https://github.com/meituan/YOLOv6)、 [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX)、 [YOLOV5](https://github.com/ultralytics/yolov5)、[YOLOv3](https://github.com/ultralytics/yolov3) +- [x] YOLOv10 +- [x] YOLOv9 - [x] YOLOv8 - [x] YOLOv7 - [x] YOLOv6 @@ -10,7 +12,8 @@ - [x] YOLOv5 - [x] YOLOv3 -## Update +## Update +- 2024.6.16 Support YOLOv9, YOLOv10, changing the TensorRT version to 10.0 - 2023.8.15 Support cuda-python - 2023.5.12 Update - 2023.1.7 support YOLOv8 @@ -30,7 +33,37 @@ pip install cuda-python [By Docker](https://github.com/NVIDIA/TensorRT/blob/main/docker/ubuntu-20.04.Dockerfile) -## Try YOLOv8 +## YOLOv10 +### Generate TRT File +```shell +python export.py -o yolov10n.onnx -e yolov10.trt --end2end --v10 -p fp32 +``` +### Inference +```shell +python trt.py -e yolov10.trt -i src/1.jpg -o yolov10-1.jpg --end2end +``` + +## YOLOv9 +### Generate TRT File +```shell +python export.py -o yolov9-c.onnx -e yolov9.trt --end2end --v8 -p fp32 +``` +### Inference +```shell +python trt.py -e yolov9.trt -i src/1.jpg -o yolov9-1.jpg --end2end +``` + +## Python Demo +
Expand + +1. [YOLOv5](##YOLOv5) +2. [YOLOx](##YOLOX) +3. [YOLOv6](##YOLOV6) +4. [YOLOv7](##YOLOv7) +5. [YOLOv8](##YOLOv8) + +## YOLOv8 + ### Install && Download [Weights](https://github.com/ultralytics/assets/) ```shell pip install ultralytics @@ -53,14 +86,6 @@ python export.py -o yolov8n.onnx -e yolov8n.trt --end2end --v8 --fp32 python trt.py -e yolov8n.trt -i src/1.jpg -o yolov8n-1.jpg --end2end ``` -## Python Demo -
Expand - -1. [YOLOv5](##YOLOv5) -2. [YOLOx](##YOLOX) -3. [YOLOv6](##YOLOV6) -4. [YOLOv7](##YOLOv7) - ## YOLOv5 diff --git a/export.py b/export.py index 83d398b..fb61206 100644 --- a/export.py +++ b/export.py @@ -121,6 +121,7 @@ def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det, **k :param onnx_path: The path to the ONNX graph to load. """ v8 = kwargs['v8'] + v10 = kwargs['v10'] network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) self.network = self.builder.create_network(network_flags) @@ -133,10 +134,8 @@ def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det, **k for error in range(self.parser.num_errors): print(self.parser.get_error(error)) sys.exit(1) - inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] - print("Network Description") for input in inputs: self.batch_size = input.shape[0] @@ -146,10 +145,78 @@ def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det, **k assert self.batch_size > 0 # self.builder.max_batch_size = self.batch_size # This no effect for networks created with explicit batch dimension mode. Also DEPRECATED. - if end2end: - previous_output = self.network.get_output(0) - self.network.unmark_output(previous_output) - if not v8: + if v10: + try: + for previous_output in outputs: + self.network.unmark_output(previous_output) + except: + previous_output = self.network.get_output(0) + self.network.unmark_output(previous_output) + # output [1, 300, 6] + # 添加 TopK 层,在第二个维度上找到前 100 个最大值 [1, 100, 6] + strides = trt.Dims([1,1,1]) + starts = trt.Dims([0,0,0]) + bs, num_boxes, temp = previous_output.shape + shapes = trt.Dims([bs, num_boxes, 4]) + boxes = self.network.add_slice(previous_output, starts, shapes, strides) + starts[2] = 4 + shapes[2] = 1 + # [0, 0, 4] [1, 300, 1] [1, 1, 1] + obj_score = self.network.add_slice(previous_output, starts, shapes, strides) + starts[2] = 5 + # [0, 0, 5] [1, 300, 1] [1, 1, 1] + cls = self.network.add_slice(previous_output, starts, shapes, strides) + outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] + print("YOLOv10 Modify") + def squeeze(previous_output): + reshape_dims = (bs, 300) + previous_output = self.network.add_shuffle(previous_output.get_output(0)) + previous_output.reshape_dims = reshape_dims + return previous_output + + # 定义常量值和形状 + constant_value = 300.0 + constant_shape = (300,) + constant_data = np.full(constant_shape, constant_value, dtype=np.float32) + num = self.network.add_constant(constant_shape, trt.Weights(constant_data)) + num.get_output(0).name = "num" + self.network.mark_output(num.get_output(0)) + boxes.get_output(0).name = "boxes" + self.network.mark_output(boxes.get_output(0)) + obj_score= squeeze(obj_score) + obj_score.get_output(0).name = "scores" + self.network.mark_output(obj_score.get_output(0)) + cls = squeeze(cls) + cls.get_output(0).name = "classes" + self.network.mark_output(cls.get_output(0)) + + for output in outputs: + print("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype)) + + if end2end and not v10: + try: + for previous_output in outputs: + self.network.unmark_output(previous_output) + except: + previous_output = self.network.get_output(0) + self.network.unmark_output(previous_output) + if v8: + # output [1, 84, 8400] + strides = trt.Dims([1,1,1]) + starts = trt.Dims([0,0,0]) + previous_output = self.network.add_shuffle(previous_output) + previous_output.second_transpose = (0, 2, 1) + # output [1, 8400, 84] + bs, num_boxes, temp = previous_output.get_output(0).shape + shapes = trt.Dims([bs, num_boxes, 4]) + # [0, 0, 0] [1, 8400, 4] [1, 1, 1] + boxes = self.network.add_slice(previous_output.get_output(0), starts, shapes, strides) + num_classes = temp -4 + starts[2] = 4 + shapes[2] = num_classes + # [0, 0, 4] [1, 8400, 80] [1, 1, 1] + scores = self.network.add_slice(previous_output.get_output(0), starts, shapes, strides) + else: # output [1, 8400, 85] # slice boxes, obj_score, class_scores strides = trt.Dims([1,1,1]) @@ -169,21 +236,6 @@ def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det, **k scores = self.network.add_slice(previous_output, starts, shapes, strides) # scores = obj_score * class_scores => [bs, num_boxes, nc] scores = self.network.add_elementwise(obj_score.get_output(0), scores.get_output(0), trt.ElementWiseOperation.PROD) - else: - strides = trt.Dims([1,1,1]) - starts = trt.Dims([0,0,0]) - previous_output = self.network.add_shuffle(previous_output) - previous_output.second_transpose = (0, 2, 1) - print(previous_output.get_output(0).shape) - bs, num_boxes, temp = previous_output.get_output(0).shape - shapes = trt.Dims([bs, num_boxes, 4]) - # [0, 0, 0] [1, 8400, 4] [1, 1, 1] - boxes = self.network.add_slice(previous_output.get_output(0), starts, shapes, strides) - num_classes = temp -4 - starts[2] = 4 - shapes[2] = num_classes - # [0, 0, 4] [1, 8400, 80] [1, 1, 1] - scores = self.network.add_slice(previous_output.get_output(0), starts, shapes, strides) ''' "plugin_version": "1", "background_class": -1, # no background class @@ -204,7 +256,7 @@ def create_network(self, onnx_path, end2end, conf_thres, iou_thres, max_det, **k fc.append(trt.PluginField("iou_threshold", np.array([iou_thres], dtype=np.float32), trt.PluginFieldType.FLOAT32)) fc.append(trt.PluginField("box_coding", np.array([1], dtype=np.int32), trt.PluginFieldType.INT32)) fc.append(trt.PluginField("score_activation", np.array([0], dtype=np.int32), trt.PluginFieldType.INT32)) - + fc = trt.PluginFieldCollection(fc) nms_layer = creator.create_plugin("nms_layer", fc) @@ -236,7 +288,6 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No # TODO: Strict type is only needed If the per-layer precision overrides are used # If a better method is found to deal with that issue, this flag can be removed. - self.config.set_flag(trt.BuilderFlag.STRICT_TYPES) if precision == "fp16": if not self.builder.platform_has_fast_fp16: @@ -266,7 +317,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No def main(args): builder = EngineBuilder(args.verbose, args.workspace) - builder.create_network(args.onnx, args.end2end, args.conf_thres, args.iou_thres, args.max_det, v8=args.v8) + builder.create_network(args.onnx, args.end2end, args.conf_thres, args.iou_thres, args.max_det, v8=args.v8, v10=args.v10) builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images, args.calib_batch_size) @@ -295,7 +346,9 @@ def main(args): parser.add_argument("--max_det", default=100, type=int, help="The total num for results, default: 100") parser.add_argument("--v8", default=False, action="store_true", - help="use yolov8 model, default: False") + help="use yolov8/9 model, default: False") + parser.add_argument("--v10", default=False, action="store_true", + help="use yolov10 model, default: False") args = parser.parse_args() print(args) if not all([args.onnx, args.engine]): diff --git a/utils/utils.py b/utils/utils.py index 7cefc25..e508f5a 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -4,7 +4,7 @@ import cv2 import matplotlib.pyplot as plt -import common +from utils import common class BaseEngine(object): def __init__(self, engine_path): @@ -28,19 +28,19 @@ def __init__(self, engine_path): with open(engine_path, "rb") as f: serialized_engine = f.read() self.engine = runtime.deserialize_cuda_engine(serialized_engine) - self.imgsz = self.engine.get_binding_shape(0)[2:] # get the read shape of model, in case user input it wrong + self.imgsz = self.engine.get_tensor_shape(self.engine.get_tensor_name(0))[2:] # get the read shape of model, in case user input it wrong self.context = self.engine.create_execution_context() # Setup I/O bindings self.inputs = [] self.outputs = [] self.allocations = [] - for i in range(self.engine.num_bindings): + for i in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(i) + dtype = self.engine.get_tensor_dtype(name) + shape = self.engine.get_tensor_shape(name) is_input = False - if self.engine.binding_is_input(i): + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: is_input = True - name = self.engine.get_binding_name(i) - dtype = self.engine.get_binding_dtype(i) - shape = self.engine.get_binding_shape(i) if is_input: self.batch_size = shape[0] size = np.dtype(trt.nptype(dtype)).itemsize @@ -56,7 +56,7 @@ def __init__(self, engine_path): 'size': size } self.allocations.append(allocation) - if self.engine.binding_is_input(i): + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: self.inputs.append(binding) else: self.outputs.append(binding) @@ -90,7 +90,7 @@ def infer(self, img): self.context.execute_v2(self.allocations) for o in range(len(outputs)): - memcpy_device_to_host(outputs[o], self.outputs[o]['allocation']) + common.memcpy_device_to_host(outputs[o], self.outputs[o]['allocation']) return outputs def detect_video(self, video_path, conf=0.5, end2end=False): @@ -135,11 +135,17 @@ def detect_video(self, video_path, conf=0.5, end2end=False): def inference(self, img_path, conf=0.5, end2end=False): origin_img = cv2.imread(img_path) - img, ratio = preproc(origin_img, self.imgsz, self.mean, self.std) + # img, ratio = preproc(origin_img, self.imgsz, self.mean, self.std) + img, ratio, dwdh = letterbox(origin_img, self.imgsz) data = self.infer(img) if end2end: - num, final_boxes, final_scores, final_cls_inds = data + num, final_boxes, final_scores, final_cls_inds = data + # final_boxes, final_scores, final_cls_inds = data + dwdh = np.asarray(dwdh * 2, dtype=np.float32) + final_boxes -= dwdh final_boxes = np.reshape(final_boxes/ratio, (-1, 4)) + final_scores = np.reshape(final_scores, (-1, 1)) + final_cls_inds = np.reshape(final_cls_inds, (-1, 1)) dets = np.concatenate([np.array(final_boxes)[:int(num[0])], np.array(final_scores)[:int(num[0])], np.array(final_cls_inds)[:int(num[0])]], axis=-1) else: predictions = np.reshape(data, (1, -1, int(5+self.n_classes)))[0] @@ -258,6 +264,41 @@ def preproc(image, input_size, mean, std, swap=(2, 0, 1)): padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) return padded_img, r +def letterbox(im, + new_shape = (640, 640), + color = (114, 114, 114), + swap=(2, 0, 1)): + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + # new_shape: [width, height] + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[1], new_shape[1] / shape[0]) + # Compute padding [width, height] + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[0] - new_unpad[0], new_shape[1] - new_unpad[ + 1] # wh padding + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, + top, + bottom, + left, + right, + cv2.BORDER_CONSTANT, + value=color) # add border + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = im.transpose(swap) + im = np.ascontiguousarray(im, dtype=np.float32) / 255. + return im, r, (dw, dh) + def rainbow_fill(size=50): # simpler way to generate rainbow color cmap = plt.get_cmap('jet')