Skip to content

Commit

Permalink
Fix onnx unitest (open-mmlab#6369)
Browse files Browse the repository at this point in the history
Fix all unitests
  • Loading branch information
jshilong authored and ZwwWayne committed Oct 28, 2021
1 parent 82c4e77 commit cc5f8fb
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 97 deletions.
2 changes: 1 addition & 1 deletion mmdet/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def distance2bbox(points, distance, max_shape=None):
bboxes = torch.stack([x1, y1, x2, y2], -1)

if max_shape is not None:
if points.dim() == 2 and not torch.onnx.is_in_onnx_export():
if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export():
# speed up
bboxes[:, 0::2].clamp_(min=0, max=max_shape[1])
bboxes[:, 1::2].clamp_(min=0, max=max_shape[0])
Expand Down
82 changes: 8 additions & 74 deletions mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,97 +235,31 @@ def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors,

return dets[:cfg.max_per_img]

# TODO: waiting for refactor the anchor_head and anchor_free head
def onnx_export(self, x, img_metas):
"""Test without augmentation.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): Meta info of each image.
Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
Tensor: dets of shape [N, num_det, 5].
"""
cls_scores, bbox_preds = self(x)

assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)

device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.prior_generator.grid_anchors(
featmap_sizes, device=device)

cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]

assert len(
img_metas
) == 1, 'Only support one input image while in exporting to ONNX'
img_shapes = img_metas[0]['img_shape_for_onnx']

cfg = copy.deepcopy(self.test_cfg)

mlvl_scores = []
mlvl_bbox_preds = []
mlvl_valid_anchors = []
batch_size = cls_scores[0].shape[0]
nms_pre_tensor = torch.tensor(
cfg.nms_pre, device=cls_scores[0].device, dtype=torch.long)
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
rpn_cls_score = rpn_cls_score.permute(0, 2, 3, 1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(batch_size, -1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(batch_size, -1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = rpn_cls_score.softmax(-1)[..., 0]
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).reshape(
batch_size, -1, 4)
anchors = mlvl_anchors[idx]
anchors = anchors.expand_as(rpn_bbox_pred)
# Get top-k prediction
from mmdet.core.export import get_k_for_topk
nms_pre = get_k_for_topk(nms_pre_tensor, rpn_bbox_pred.shape[1])
if nms_pre > 0:
_, topk_inds = scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
# Mind k<=3480 in TensorRT for TopK
transformed_inds = scores.shape[1] * batch_inds + topk_inds
scores = scores.reshape(-1, 1)[transformed_inds].reshape(
batch_size, -1)
rpn_bbox_pred = rpn_bbox_pred.reshape(
-1, 4)[transformed_inds, :].reshape(batch_size, -1, 4)
anchors = anchors.reshape(-1, 4)[transformed_inds, :].reshape(
batch_size, -1, 4)
mlvl_scores.append(scores)
mlvl_bbox_preds.append(rpn_bbox_pred)
mlvl_valid_anchors.append(anchors)

batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_rpn_bbox_pred = torch.cat(mlvl_bbox_preds, dim=1)
batch_mlvl_proposals = self.bbox_coder.decode(
batch_mlvl_anchors, batch_mlvl_rpn_bbox_pred, max_shape=img_shapes)

batch_bboxes, batch_scores = super(RPNHead, self).onnx_export(
cls_scores, bbox_preds, img_metas=img_metas, with_nms=False)
# Use ONNX::NonMaxSuppression in deployment
from mmdet.core.export import add_dummy_nms_for_onnx
batch_mlvl_scores = batch_mlvl_scores.unsqueeze(2)
cfg = copy.deepcopy(self.test_cfg)
score_threshold = cfg.nms.get('score_thr', 0.0)
nms_pre = cfg.get('deploy_nms_pre', -1)
dets, _ = add_dummy_nms_for_onnx(batch_mlvl_proposals,
batch_mlvl_scores, cfg.max_per_img,
# Different from the normal forward doing NMS level by level,
# we do NMS across all levels when exporting ONNX.
dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
cfg.max_per_img,
cfg.nms.iou_threshold,
score_threshold, nms_pre,
cfg.max_per_img)
Expand Down
2 changes: 1 addition & 1 deletion mmdet/models/dense_heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def num_anchors(self):
"""
warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
'please use "num_base_priors" instead')
return self.prior_generator.num_base_priors[0]
return self.num_base_priors

@property
def num_levels(self):
Expand Down
5 changes: 3 additions & 2 deletions mmdet/models/detectors/single_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def aug_test(self, imgs, img_metas, rescale=False):
]
return bbox_results

def onnx_export(self, img, img_metas):
def onnx_export(self, img, img_metas, with_nms=True):
"""Test function without test time augmentation.
Args:
Expand All @@ -165,6 +165,7 @@ def onnx_export(self, img, img_metas):
# add dummy score_factor
outs = (*outs, None)
# TODO Can we change to `get_bboxes` when `onnx_export` fail
det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
det_bboxes, det_labels = self.bbox_head.onnx_export(
*outs, img_metas, with_nms=with_nms)

return det_bboxes, det_labels
2 changes: 1 addition & 1 deletion tests/test_models/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ def test_rpn_forward():
@pytest.mark.parametrize(
'cfg_file',
[
'reppoints/reppoints_moment_r50_fpn_1x_coco.py',
'retinanet/retinanet_r50_fpn_1x_coco.py',
'guided_anchoring/ga_retinanet_r50_fpn_1x_coco.py',
'ghm/retinanet_ghm_r50_fpn_1x_coco.py',
'fcos/fcos_center_r50_caffe_fpn_gn-head_1x_coco.py',
'foveabox/fovea_align_r50_fpn_gn-head_4x4_2x_coco.py',
# 'free_anchor/retinanet_free_anchor_r50_fpn_1x_coco.py',
# 'atss/atss_r50_fpn_1x_coco.py', # not ready for topk
'reppoints/reppoints_moment_r50_fpn_1x_coco.py',
'yolo/yolov3_mobilenetv2_320_300e_coco.py',
'yolox/yolox_tiny_8x8_300e_coco.py'
])
Expand Down
36 changes: 18 additions & 18 deletions tests/test_onnx/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ def test_retina_head_forward():
feats = [
torch.rand(1, retina_model.in_channels, s // (2**(i + 2)),
s // (2**(i + 2))) # [32, 16, 8, 4, 2]
for i in range(len(retina_model.anchor_generator.strides))
for i in range(len(retina_model.prior_generator.strides))
]
ort_validate(retina_model.forward, feats)


def test_retinanet_head_get_bboxes():
def test_retinanet_head_onnx_export():
"""Test RetinaNet Head _get_bboxes() in torch and onnxruntime env."""
retina_model = retinanet_config()
s = 128
Expand All @@ -168,9 +168,9 @@ def test_retinanet_head_get_bboxes():
cls_score = feats[:5]
bboxes = feats[5:]

retina_model.get_bboxes = partial(
retina_model.get_bboxes, img_metas=img_metas, with_nms=False)
ort_validate(retina_model.get_bboxes, (cls_score, bboxes))
retina_model.onnx_export = partial(
retina_model.onnx_export, img_metas=img_metas, with_nms=False)
ort_validate(retina_model.onnx_export, (cls_score, bboxes))


def yolo_config():
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_yolov3_head_forward():
ort_validate(yolo_model.forward, feats)


def test_yolov3_head_get_bboxes():
def test_yolov3_head_onnx_export():
"""Test yolov3 head get_bboxes() in torch and ort env."""
yolo_model = yolo_config()
s = 128
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_fcos_head_forward():
ort_validate(fcos_model.forward, feats)


def test_fcos_head_get_bboxes():
def test_fcos_head_onnx_export():
"""Test fcos head get_bboxes() in ort."""
fcos_model = fcos_config()
s = 128
Expand All @@ -303,9 +303,9 @@ def test_fcos_head_get_bboxes():
for feat_size in [4, 8, 16, 32, 64]
]

fcos_model.get_bboxes = partial(
fcos_model.get_bboxes, img_metas=img_metas, with_nms=False)
ort_validate(fcos_model.get_bboxes, (cls_scores, bboxes, centerness))
fcos_model.onnx_export = partial(
fcos_model.onnx_export, img_metas=img_metas, with_nms=False)
ort_validate(fcos_model.onnx_export, (cls_scores, bboxes, centerness))


def fsaf_config():
Expand Down Expand Up @@ -351,7 +351,7 @@ def test_fsaf_head_forward():
ort_validate(fsaf_model.forward, feats)


def test_fsaf_head_get_bboxes():
def test_fsaf_head_onnx_export():
"""Test RetinaNet Head get_bboxes in torch and onnxruntime env."""
fsaf_model = fsaf_config()
s = 256
Expand All @@ -374,9 +374,9 @@ def test_fsaf_head_get_bboxes():
cls_score = feats[:5]
bboxes = feats[5:]

fsaf_model.get_bboxes = partial(
fsaf_model.get_bboxes, img_metas=img_metas, with_nms=False)
ort_validate(fsaf_model.get_bboxes, (cls_score, bboxes))
fsaf_model.onnx_export = partial(
fsaf_model.onnx_export, img_metas=img_metas, with_nms=False)
ort_validate(fsaf_model.onnx_export, (cls_score, bboxes))


def ssd_config():
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_ssd_head_forward():
ort_validate(ssd_model.forward, feats)


def test_ssd_head_get_bboxes():
def test_ssd_head_onnx_export():
"""Test SSD Head get_bboxes in torch and onnxruntime env."""
ssd_model = ssd_config()
s = 300
Expand All @@ -448,6 +448,6 @@ def test_ssd_head_get_bboxes():
cls_score = feats[:6]
bboxes = feats[6:]

ssd_model.get_bboxes = partial(
ssd_model.get_bboxes, img_metas=img_metas, with_nms=False)
ort_validate(ssd_model.get_bboxes, (cls_score, bboxes))
ssd_model.onnx_export = partial(
ssd_model.onnx_export, img_metas=img_metas, with_nms=False)
ort_validate(ssd_model.onnx_export, (cls_score, bboxes))

0 comments on commit cc5f8fb

Please sign in to comment.