Skip to content

Commit

Permalink
separate PPYOLOE architecture from YOLOv3 (PaddlePaddle#7634)
Browse files Browse the repository at this point in the history
* add ppyoloe architectures

* fix deploy ppyoloe arch

* add ppyoloe arch coments, test=document_fix
  • Loading branch information
nemonameless authored Jan 28, 2023
1 parent c61c68d commit b7a6bb6
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 16 deletions.
1 change: 1 addition & 0 deletions deploy/pptracking/python/det_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# Global dictionary
SUPPORT_MODELS = {
'YOLO',
'PPYOLOE',
'PicoDet',
'JDE',
'FairMOT',
Expand Down
8 changes: 4 additions & 4 deletions deploy/python/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@

# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet', 'PPLCNet', 'DETR',
'CenterTrack'
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet',
'PPLCNet', 'DETR', 'CenterTrack'
}

TUNED_TRT_DYNAMIC_MODELS = {'DETR'}
Expand Down
6 changes: 3 additions & 3 deletions deploy/serving/python/web_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
}

GLOBAL_VAR = {}
Expand Down
10 changes: 5 additions & 5 deletions deploy/third_engine/demo_onnx_trt/trt_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
}


Expand Down Expand Up @@ -205,8 +205,8 @@ def create_trt_bindings(engine, context):
"is_input": True if engine.binding_is_input(name) else False
}
if engine.binding_is_input(name):
bindings[name]['cpu_data'] = np.random.randn(
*shape).astype(np.float32)
bindings[name]['cpu_data'] = np.random.randn(*shape).astype(
np.float32)
bindings[name]['cuda_ptr'] = cuda.mem_alloc(bindings[name][
'cpu_data'].nbytes)
else:
Expand Down
6 changes: 3 additions & 3 deletions deploy/third_engine/onnx/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
}

parser = argparse.ArgumentParser(description=__doc__)
Expand Down
3 changes: 2 additions & 1 deletion ppdet/engine/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# Global dictionary
TRT_MIN_SUBGRAPH = {
'YOLO': 3,
'PPYOLOE': 3,
'SSD': 60,
'RCNN': 40,
'RetinaNet': 40,
Expand Down Expand Up @@ -193,7 +194,7 @@ def _dump_infer_config(config, path, image_shape, model):
arch_state = True
break

if infer_arch in ['YOLOX', 'YOLOF']:
if infer_arch in ['PPYOLOE', 'YOLOX', 'YOLOF']:
infer_cfg['arch'] = infer_arch
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
arch_state = True
Expand Down
2 changes: 2 additions & 0 deletions ppdet/modeling/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from . import faster_rcnn
from . import mask_rcnn
from . import yolo
from . import ppyoloe
from . import cascade_rcnn
from . import ssd
from . import fcos
Expand Down Expand Up @@ -44,6 +45,7 @@
from .faster_rcnn import *
from .mask_rcnn import *
from .yolo import *
from .ppyoloe import *
from .cascade_rcnn import *
from .ssd import *
from .fcos import *
Expand Down
99 changes: 99 additions & 0 deletions ppdet/modeling/architectures/ppyoloe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ppdet.core.workspace import register, create
from .meta_arch import BaseArch

__all__ = ['PPYOLOE']
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py


@register
class PPYOLOE(BaseArch):
__category__ = 'architecture'
__inject__ = ['post_process']

def __init__(self,
backbone='CSPResNet',
neck='CustomCSPPAN',
yolo_head='PPYOLOEHead',
post_process='BBoxPostProcess',
for_mot=False):
"""
PPYOLOE network, see https://arxiv.org/abs/2203.16250
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
yolo_head (nn.Layer): anchor_head instance
post_process (object): `BBoxPostProcess` instance
for_mot (bool): whether return other features for multi-object tracking
models, default False in pure object detection models.
"""
super(PPYOLOE, self).__init__()
self.backbone = backbone
self.neck = neck
self.yolo_head = yolo_head
self.post_process = post_process
self.for_mot = for_mot

@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])

# fpn
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)

# head
kwargs = {'input_shape': neck.out_shape}
yolo_head = create(cfg['yolo_head'], **kwargs)

return {
'backbone': backbone,
'neck': neck,
"yolo_head": yolo_head,
}

def _forward(self):
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats, self.for_mot)

if self.training:
yolo_losses = self.yolo_head(neck_feats, self.inputs)
return yolo_losses
else:
yolo_head_outs = self.yolo_head(neck_feats)
if self.post_process is not None:
bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
else:
bbox, bbox_num = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}

return output

def get_loss(self):
return self._forward()

def get_pred(self):
return self._forward()
6 changes: 6 additions & 0 deletions ppdet/modeling/architectures/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ..post_process import JDEBBoxPostProcess

__all__ = ['YOLOv3']
# YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3
# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py


@register
Expand Down Expand Up @@ -99,6 +101,7 @@ def _forward(self):
yolo_head_outs = self.yolo_head(neck_feats)

if self.for_mot:
# the detection part of JDE MOT model
boxes_idx, bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors)
output = {
Expand All @@ -110,13 +113,16 @@ def _forward(self):
}
else:
if self.return_idx:
# the detection part of JDE MOT model
_, bbox, bbox_num, _ = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors)
elif self.post_process is not None:
# anchor based YOLOs: YOLOv3,PP-YOLO,PP-YOLOv2 use mask_anchors
bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
else:
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
bbox, bbox_num = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
Expand Down

0 comments on commit b7a6bb6

Please sign in to comment.