From 26562a1d0eaa20ebbe922d0d9f23b61ac083f375 Mon Sep 17 00:00:00 2001 From: Jiaqi Wang Date: Wed, 2 Sep 2020 00:55:14 +0800 Subject: [PATCH] Code Release of ECCV 2020 Spotlight paper for Side-Aware Boundary Localization for More Precise Object Detection (#3603) * add sabl two stage * add sabl retina * ret cfg bug fix * test bug fix * minor update * update * add r101 two stage * update * add cfgs * add cfgs * update cfgs * format * format * add readme * fix isort * update * update readme * add doc string for sabl retina head * add doc string and rename some functions * add docstring for bucketing coder * update docstring * bucket_num -> num_buckets * bucket_pw -> bucket_w bucket_ph -> bucket_h * update label2onehot * update bucketing bbox coder doc * update * typo fix * bboxes_ -> rescaled_bboxes * rename some params in sabl head * init with mmcv.cnn * update doc * rename pos->post * update cfgs * update test cfg * update * add unitest for sabl head * add unitest for sabl retina * rename * minor rename * minor update * update docstring * update * use F.one_hot * update docstring * update test heads * update ReadMe * fix --- README.md | 1 + configs/sabl/README.md | 36 + .../sabl_cascade_rcnn_r101_fpn_1x_coco.py | 88 +++ .../sabl/sabl_cascade_rcnn_r50_fpn_1x_coco.py | 86 +++ .../sabl/sabl_faster_rcnn_r101_fpn_1x_coco.py | 36 + .../sabl/sabl_faster_rcnn_r50_fpn_1x_coco.py | 34 + .../sabl/sabl_retinanet_r101_fpn_1x_coco.py | 52 ++ .../sabl_retinanet_r101_fpn_gn_1x_coco.py | 54 ++ ...etinanet_r101_fpn_gn_2x_ms_480_960_coco.py | 71 ++ ...etinanet_r101_fpn_gn_2x_ms_640_800_coco.py | 71 ++ .../sabl/sabl_retinanet_r50_fpn_1x_coco.py | 50 ++ .../sabl/sabl_retinanet_r50_fpn_gn_1x_coco.py | 52 ++ mmdet/core/bbox/__init__.py | 7 +- mmdet/core/bbox/coder/__init__.py | 4 +- mmdet/core/bbox/coder/bucketing_bbox_coder.py | 339 ++++++++++ mmdet/core/bbox/transforms.py | 32 + mmdet/models/dense_heads/__init__.py | 3 +- mmdet/models/dense_heads/sabl_retina_head.py | 622 ++++++++++++++++++ mmdet/models/roi_heads/bbox_heads/__init__.py | 3 +- .../models/roi_heads/bbox_heads/sabl_head.py | 563 ++++++++++++++++ tests/test_config.py | 35 +- tests/test_models/test_heads.py | 145 +++- 22 files changed, 2365 insertions(+), 19 deletions(-) create mode 100644 configs/sabl/README.md create mode 100644 configs/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco.py create mode 100644 configs/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco.py create mode 100644 configs/sabl/sabl_faster_rcnn_r101_fpn_1x_coco.py create mode 100644 configs/sabl/sabl_faster_rcnn_r50_fpn_1x_coco.py create mode 100644 configs/sabl/sabl_retinanet_r101_fpn_1x_coco.py create mode 100644 configs/sabl/sabl_retinanet_r101_fpn_gn_1x_coco.py create mode 100644 configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco.py create mode 100644 configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco.py create mode 100644 configs/sabl/sabl_retinanet_r50_fpn_1x_coco.py create mode 100644 configs/sabl/sabl_retinanet_r50_fpn_gn_1x_coco.py create mode 100644 mmdet/core/bbox/coder/bucketing_bbox_coder.py create mode 100644 mmdet/models/dense_heads/sabl_retina_head.py create mode 100644 mmdet/models/roi_heads/bbox_heads/sabl_head.py diff --git a/README.md b/README.md index 6bb4929df6a..b9b4ef4ca7d 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ Supported methods: - [x] [DetectoRS](configs/detectors/README.md) - [x] [Generalized Focal Loss](configs/gfl/README.md) - [x] [CornerNet](configs/cornernet/README.md) +- [x] [Side-Aware Boundary Localization](configs/sabl/README.md) - [x] [YOLOv3](configs/yolo/README.md) Some other methods are also supported in [projects using MMDetection](./docs/projects.md). diff --git a/configs/sabl/README.md b/configs/sabl/README.md new file mode 100644 index 00000000000..b8f9738a6a2 --- /dev/null +++ b/configs/sabl/README.md @@ -0,0 +1,36 @@ +# Side-Aware Boundary Localization for More Precise Object Detection + +## Introduction + +We provide config files to reproduce the object detection results in the ECCV 2020 Spotlight paper for [Side-Aware Boundary Localization for More Precise Object Detection](https://arxiv.org/abs/1912.04260). + +``` +@inproceedings{Wang_2020_ECCV, + title = {Side-Aware Boundary Localization for More Precise Object Detection}, + author = {Wang, Jiaqi and Zhang, Wenwei and Cao, Yuhang and Chen, Kai and Pang, Jiangmiao and Gong, Tao and Shi, Jianping, Loy, Chen Change and Lin, Dahua}, + booktitle = {ECCV}, + year = {2020} +} +``` + +## Results and Models + +The results on COCO 2017 val is shown in the below table. (results on test-dev are usually slightly higher than val). +Single-scale testing (1333x800) is adopted in all results. + + +| Method | Backbone | Lr schd | ms-train | box AP | Download | +| :----------------: | :-------: | :-----: | :------: | :----: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| SABL Faster R-CNN | R-50-FPN | 1x | N | 39.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r50_fpn_1x_coco/sabl_faster_rcnn_r50_fpn_1x_coco-e867595b.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r50_fpn_1x_coco/20200830_130324.log.json) | +| SABL Faster R-CNN | R-101-FPN | 1x | N | 41.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r101_fpn_1x_coco/sabl_faster_rcnn_r101_fpn_1x_coco-f804c6c1.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r101_fpn_1x_coco/20200830_183949.log.json) | +| SABL Cascade R-CNN | R-50-FPN | 1x | N | 41.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco/sabl_cascade_rcnn_r50_fpn_1x_coco-e1748e5e.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco/20200831_033726.log.json) | +| SABL Cascade R-CNN | R-101-FPN | 1x | N | 43.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco/sabl_cascade_rcnn_r101_fpn_1x_coco-2b83e87c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco/20200831_141745.log.json) | + +| Method | Backbone | GN | Lr schd | ms-train | box AP | Download | +| :------------: | :-------: | :---: | :-----: | :---------: | :----: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| SABL RetinaNet | R-50-FPN | N | 1x | N | 37.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_1x_coco/sabl_retinanet_r50_fpn_1x_coco-6c54fd4f.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_1x_coco/20200830_053451.log.json) | +| SABL RetinaNet | R-50-FPN | Y | 1x | N | 38.8 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_gn_1x_coco/sabl_retinanet_r50_fpn_gn_1x_coco-e16dfcf1.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_gn_1x_coco/20200831_141955.log.json) | +| SABL RetinaNet | R-101-FPN | N | 1x | N | 39.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_1x_coco/sabl_retinanet_r101_fpn_1x_coco-42026904.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_1x_coco/20200831_034256.log.json) | +| SABL RetinaNet | R-101-FPN | Y | 1x | N | 40.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_1x_coco/sabl_retinanet_r101_fpn_gn_1x_coco-40a893e8.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_1x_coco/20200830_201422.log.json) | +| SABL RetinaNet | R-101-FPN | Y | 2x | Y (640~800) | 42.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco-1e63382c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco/20200830_144807.log.json) | +| SABL RetinaNet | R-101-FPN | Y | 2x | Y (480~960) | 43.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco-5342f857.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco/20200830_164537.log.json) | diff --git a/configs/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco.py b/configs/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco.py new file mode 100644 index 00000000000..0322006464e --- /dev/null +++ b/configs/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco.py @@ -0,0 +1,88 @@ +_base_ = [ + '../_base_/models/cascade_rcnn_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +model = dict( + pretrained='torchvision://resnet101', + backbone=dict(depth=101), + roi_head=dict(bbox_head=[ + dict( + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, + loss_weight=1.0)), + dict( + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.5), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, + loss_weight=1.0)), + dict( + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.3), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, loss_weight=1.0)) + ])) diff --git a/configs/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco.py b/configs/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 00000000000..4b28a59280e --- /dev/null +++ b/configs/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,86 @@ +_base_ = [ + '../_base_/models/cascade_rcnn_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +model = dict( + roi_head=dict(bbox_head=[ + dict( + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, + loss_weight=1.0)), + dict( + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.5), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, + loss_weight=1.0)), + dict( + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.3), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, loss_weight=1.0)) + ])) diff --git a/configs/sabl/sabl_faster_rcnn_r101_fpn_1x_coco.py b/configs/sabl/sabl_faster_rcnn_r101_fpn_1x_coco.py new file mode 100644 index 00000000000..4c797cad1c6 --- /dev/null +++ b/configs/sabl/sabl_faster_rcnn_r101_fpn_1x_coco.py @@ -0,0 +1,36 @@ +_base_ = [ + '../_base_/models/faster_rcnn_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +model = dict( + pretrained='torchvision://resnet101', + backbone=dict(depth=101), + roi_head=dict( + bbox_head=dict( + _delete_=True, + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, + loss_weight=1.0)))) diff --git a/configs/sabl/sabl_faster_rcnn_r50_fpn_1x_coco.py b/configs/sabl/sabl_faster_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 00000000000..732c7ba3f60 --- /dev/null +++ b/configs/sabl/sabl_faster_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,34 @@ +_base_ = [ + '../_base_/models/faster_rcnn_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +model = dict( + roi_head=dict( + bbox_head=dict( + _delete_=True, + type='SABLHead', + num_classes=80, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, + loss_weight=1.0)))) diff --git a/configs/sabl/sabl_retinanet_r101_fpn_1x_coco.py b/configs/sabl/sabl_retinanet_r101_fpn_1x_coco.py new file mode 100644 index 00000000000..7504fe21605 --- /dev/null +++ b/configs/sabl/sabl_retinanet_r101_fpn_1x_coco.py @@ -0,0 +1,52 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +model = dict( + pretrained='torchvision://resnet101', + backbone=dict(depth=101), + bbox_head=dict( + _delete_=True, + type='SABLRetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + approx_anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='ApproxMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/sabl/sabl_retinanet_r101_fpn_gn_1x_coco.py b/configs/sabl/sabl_retinanet_r101_fpn_gn_1x_coco.py new file mode 100644 index 00000000000..8143af21297 --- /dev/null +++ b/configs/sabl/sabl_retinanet_r101_fpn_gn_1x_coco.py @@ -0,0 +1,54 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +model = dict( + pretrained='torchvision://resnet101', + backbone=dict(depth=101), + bbox_head=dict( + _delete_=True, + type='SABLRetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + approx_anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + norm_cfg=norm_cfg, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='ApproxMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco.py b/configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco.py new file mode 100644 index 00000000000..4e2b71bfe67 --- /dev/null +++ b/configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco.py @@ -0,0 +1,71 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' +] +# model settings +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +model = dict( + pretrained='torchvision://resnet101', + backbone=dict(depth=101), + bbox_head=dict( + _delete_=True, + type='SABLRetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + approx_anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + norm_cfg=norm_cfg, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='ApproxMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 480), (1333, 960)], + multiscale_mode='range', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +data = dict(train=dict(pipeline=train_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco.py b/configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco.py new file mode 100644 index 00000000000..013020105a0 --- /dev/null +++ b/configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco.py @@ -0,0 +1,71 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' +] +# model settings +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +model = dict( + pretrained='torchvision://resnet101', + backbone=dict(depth=101), + bbox_head=dict( + _delete_=True, + type='SABLRetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + approx_anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + norm_cfg=norm_cfg, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='ApproxMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 640), (1333, 800)], + multiscale_mode='range', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +data = dict(train=dict(pipeline=train_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/sabl/sabl_retinanet_r50_fpn_1x_coco.py b/configs/sabl/sabl_retinanet_r50_fpn_1x_coco.py new file mode 100644 index 00000000000..ce518306b57 --- /dev/null +++ b/configs/sabl/sabl_retinanet_r50_fpn_1x_coco.py @@ -0,0 +1,50 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +model = dict( + bbox_head=dict( + _delete_=True, + type='SABLRetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + approx_anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='ApproxMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/sabl/sabl_retinanet_r50_fpn_gn_1x_coco.py b/configs/sabl/sabl_retinanet_r50_fpn_gn_1x_coco.py new file mode 100644 index 00000000000..bb1dad59b63 --- /dev/null +++ b/configs/sabl/sabl_retinanet_r50_fpn_gn_1x_coco.py @@ -0,0 +1,52 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +# model settings +norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +model = dict( + bbox_head=dict( + _delete_=True, + type='SABLRetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + approx_anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + norm_cfg=norm_cfg, + bbox_coder=dict( + type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='ApproxMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index d87715369fb..1826cbf39f3 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -9,8 +9,8 @@ OHEMSampler, PseudoSampler, RandomSampler, SamplingResult, ScoreHLRSampler) from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_flip, - bbox_mapping, bbox_mapping_back, distance2bbox, - roi2bbox) + bbox_mapping, bbox_mapping_back, bbox_rescale, + distance2bbox, roi2bbox) __all__ = [ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner', @@ -20,5 +20,6 @@ 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', - 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner' + 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner', + 'bbox_rescale' ] diff --git a/mmdet/core/bbox/coder/__init__.py b/mmdet/core/bbox/coder/__init__.py index 953ac5a849e..ae455ba8fc0 100644 --- a/mmdet/core/bbox/coder/__init__.py +++ b/mmdet/core/bbox/coder/__init__.py @@ -1,4 +1,5 @@ from .base_bbox_coder import BaseBBoxCoder +from .bucketing_bbox_coder import BucketingBBoxCoder from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder from .pseudo_bbox_coder import PseudoBBoxCoder @@ -7,5 +8,6 @@ __all__ = [ 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', - 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder' + 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder', + 'BucketingBBoxCoder' ] diff --git a/mmdet/core/bbox/coder/bucketing_bbox_coder.py b/mmdet/core/bbox/coder/bucketing_bbox_coder.py new file mode 100644 index 00000000000..4f67087937e --- /dev/null +++ b/mmdet/core/bbox/coder/bucketing_bbox_coder.py @@ -0,0 +1,339 @@ +import numpy as np +import torch +import torch.nn.functional as F + +from ..builder import BBOX_CODERS +from ..transforms import bbox_rescale +from .base_bbox_coder import BaseBBoxCoder + + +@BBOX_CODERS.register_module() +class BucketingBBoxCoder(BaseBBoxCoder): + """Bucketing BBox Coder for Side-Aware Bounday Localization (SABL). + + Boundary Localization with Bucketing and Bucketing Guided Rescoring + are implemented here. + + Please refer to https://arxiv.org/abs/1912.04260 for more details. + + Args: + num_buckets (int): Number of buckets. + scale_factor (int): Scale factor of proposals to generate buckets. + offset_topk (int): Topk buckets are used to generate + bucket fine regression targets. Defaults to 2. + offset_upperbound (float): Offset upperbound to generate + bucket fine regression targets. + To avoid too large offset displacements. Defaults to 1.0. + cls_ignore_neighbor (bool): Ignore second nearest bucket or Not. + Defaults to True. + """ + + def __init__(self, + num_buckets, + scale_factor, + offset_topk=2, + offset_upperbound=1.0, + cls_ignore_neighbor=True): + super(BucketingBBoxCoder, self).__init__() + self.num_buckets = num_buckets + self.scale_factor = scale_factor + self.offset_topk = offset_topk + self.offset_upperbound = offset_upperbound + self.cls_ignore_neighbor = cls_ignore_neighbor + + def encode(self, bboxes, gt_bboxes): + """Get bucketing estimation and fine regression targets during + training. + + Args: + bboxes (torch.Tensor): source boxes, e.g., object proposals. + gt_bboxes (torch.Tensor): target of the transformation, e.g., + ground truth boxes. + + Returns: + encoded_bboxes(tuple[Tensor]): bucketing estimation + and fine regression targets and weights + """ + + assert bboxes.size(0) == gt_bboxes.size(0) + assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 + encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets, + self.scale_factor, self.offset_topk, + self.offset_upperbound, + self.cls_ignore_neighbor) + return encoded_bboxes + + def decode(self, bboxes, pred_bboxes, max_shape=None): + """Apply transformation `pred_bboxes` to `boxes`. + Args: + boxes (torch.Tensor): Basic boxes. + pred_bboxes (torch.Tensor): Predictions for bucketing estimation + and fine regression + max_shape (tuple[int], optional): Maximum shape of boxes. + Defaults to None. + + Returns: + torch.Tensor: Decoded boxes. + """ + assert len(pred_bboxes) == 2 + cls_preds, offset_preds = pred_bboxes + assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size( + 0) == bboxes.size(0) + decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds, + self.num_buckets, self.scale_factor, + max_shape) + + return decoded_bboxes + + +def generat_buckets(proposals, num_buckets, scale_factor=1.0): + """Generate buckets w.r.t bucket number and scale factor of proposals. + + Args: + proposals (Tensor): Shape (n, 4) + num_buckets (int): Number of buckets. + scale_factor (float): Scale factor to rescale proposals. + + Returns: + tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets, + t_buckets, d_buckets) + + - bucket_w: Width of buckets on x-axis. Shape (n, ). + - bucket_h: Height of buckets on y-axis. Shape (n, ). + - l_buckets: Left buckets. Shape (n, ceil(side_num/2)). + - r_buckets: Right buckets. Shape (n, ceil(side_num/2)). + - t_buckets: Top buckets. Shape (n, ceil(side_num/2)). + - d_buckets: Down buckets. Shape (n, ceil(side_num/2)). + """ + proposals = bbox_rescale(proposals, scale_factor) + + # number of buckets in each side + side_num = int(np.ceil(num_buckets / 2.0)) + pw = proposals[..., 2] - proposals[..., 0] + ph = proposals[..., 3] - proposals[..., 1] + px1 = proposals[..., 0] + py1 = proposals[..., 1] + px2 = proposals[..., 2] + py2 = proposals[..., 3] + + bucket_w = pw / num_buckets + bucket_h = ph / num_buckets + + # left buckets + l_buckets = px1[:, None] + (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None] + # right buckets + r_buckets = px2[:, None] - (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None] + # top buckets + t_buckets = py1[:, None] + (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None] + # down buckets + d_buckets = py2[:, None] - (0.5 + torch.arange( + 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None] + return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets + + +def bbox2bucket(proposals, + gt, + num_buckets, + scale_factor, + offset_topk=2, + offset_upperbound=1.0, + cls_ignore_neighbor=True): + """Generate buckets estimation and fine regression targets. + + Args: + proposals (Tensor): Shape (n, 4) + gt (Tensor): Shape (n, 4) + num_buckets (int): Number of buckets. + scale_factor (float): Scale factor to rescale proposals. + offset_topk (int): Topk buckets are used to generate + bucket fine regression targets. Defaults to 2. + offset_upperbound (float): Offset allowance to generate + bucket fine regression targets. + To avoid too large offset displacements. Defaults to 1.0. + cls_ignore_neighbor (bool): Ignore second nearest bucket or Not. + Defaults to True. + + Returns: + tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights). + + - offsets: Fine regression targets. \ + Shape (n, num_buckets*2). + - offsets_weights: Fine regression weights. \ + Shape (n, num_buckets*2). + - bucket_labels: Bucketing estimation labels. \ + Shape (n, num_buckets*2). + - cls_weights: Bucketing estimation weights. \ + Shape (n, num_buckets*2). + """ + assert proposals.size() == gt.size() + + # generate buckets + proposals = proposals.float() + gt = gt.float() + (bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, + d_buckets) = generat_buckets(proposals, num_buckets, scale_factor) + + gx1 = gt[..., 0] + gy1 = gt[..., 1] + gx2 = gt[..., 2] + gy2 = gt[..., 3] + + # generate offset targets and weights + # offsets from buckets to gts + l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None] + r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None] + t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None] + d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None] + + # select top-k nearset buckets + l_topk, l_label = l_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + r_topk, r_label = r_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + t_topk, t_label = t_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + d_topk, d_label = d_offsets.abs().topk( + offset_topk, dim=1, largest=False, sorted=True) + + offset_l_weights = l_offsets.new_zeros(l_offsets.size()) + offset_r_weights = r_offsets.new_zeros(r_offsets.size()) + offset_t_weights = t_offsets.new_zeros(t_offsets.size()) + offset_d_weights = d_offsets.new_zeros(d_offsets.size()) + inds = torch.arange(0, proposals.size(0)).to(proposals).long() + + # generate offset weights of top-k nearset buckets + for k in range(offset_topk): + if k >= 1: + offset_l_weights[inds, l_label[:, + k]] = (l_topk[:, k] < + offset_upperbound).float() + offset_r_weights[inds, r_label[:, + k]] = (r_topk[:, k] < + offset_upperbound).float() + offset_t_weights[inds, t_label[:, + k]] = (t_topk[:, k] < + offset_upperbound).float() + offset_d_weights[inds, d_label[:, + k]] = (d_topk[:, k] < + offset_upperbound).float() + else: + offset_l_weights[inds, l_label[:, k]] = 1.0 + offset_r_weights[inds, r_label[:, k]] = 1.0 + offset_t_weights[inds, t_label[:, k]] = 1.0 + offset_d_weights[inds, d_label[:, k]] = 1.0 + + offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1) + offsets_weights = torch.cat([ + offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights + ], + dim=-1) + + # generate bucket labels and weight + side_num = int(np.ceil(num_buckets / 2.0)) + labels = torch.stack( + [l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1) + + batch_size = labels.size(0) + bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size, + -1).float() + bucket_cls_l_weights = (l_offsets.abs() < 1).float() + bucket_cls_r_weights = (r_offsets.abs() < 1).float() + bucket_cls_t_weights = (t_offsets.abs() < 1).float() + bucket_cls_d_weights = (d_offsets.abs() < 1).float() + bucket_cls_weights = torch.cat([ + bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights, + bucket_cls_d_weights + ], + dim=-1) + # ignore second nearest buckets for cls if necessay + if cls_ignore_neighbor: + bucket_cls_weights = (~((bucket_cls_weights == 1) & + (bucket_labels == 0))).float() + else: + bucket_cls_weights[:] = 1.0 + return offsets, offsets_weights, bucket_labels, bucket_cls_weights + + +def bucket2bbox(proposals, + cls_preds, + offset_preds, + num_buckets, + scale_factor=1.0, + max_shape=None): + """Apply bucketing estimation (cls preds) and fine regression (offset + preds) to generate det bboxes. + + Args: + proposals (Tensor): Boxes to be transformed. Shape (n, 4) + cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2). + offset_preds (Tensor): fine regression. Shape (n, num_buckets*2). + num_buckets (int): Number of buckets. + scale_factor (float): Scale factor to rescale proposals. + max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) + + Returns: + tuple[Tensor]: (bboxes, loc_confidence). + + - bboxes: predicted bboxes. Shape (n, 4) + - loc_confidence: localization confidence of predicted bboxes. + Shape (n,). + """ + + side_num = int(np.ceil(num_buckets / 2.0)) + cls_preds = cls_preds.view(-1, side_num) + offset_preds = offset_preds.view(-1, side_num) + + scores = F.softmax(cls_preds, dim=1) + score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True) + + rescaled_proposals = bbox_rescale(proposals, scale_factor) + + pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0] + ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1] + px1 = rescaled_proposals[..., 0] + py1 = rescaled_proposals[..., 1] + px2 = rescaled_proposals[..., 2] + py2 = rescaled_proposals[..., 3] + + bucket_w = pw / num_buckets + bucket_h = ph / num_buckets + + score_inds_l = score_label[0::4, 0] + score_inds_r = score_label[1::4, 0] + score_inds_t = score_label[2::4, 0] + score_inds_d = score_label[3::4, 0] + l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w + r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w + t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h + d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h + + offsets = offset_preds.view(-1, 4, side_num) + inds = torch.arange(proposals.size(0)).to(proposals).long() + l_offsets = offsets[:, 0, :][inds, score_inds_l] + r_offsets = offsets[:, 1, :][inds, score_inds_r] + t_offsets = offsets[:, 2, :][inds, score_inds_t] + d_offsets = offsets[:, 3, :][inds, score_inds_d] + + x1 = l_buckets - l_offsets * bucket_w + x2 = r_buckets - r_offsets * bucket_w + y1 = t_buckets - t_offsets * bucket_h + y2 = d_buckets - d_offsets * bucket_h + + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1] - 1) + y1 = y1.clamp(min=0, max=max_shape[0] - 1) + x2 = x2.clamp(min=0, max=max_shape[1] - 1) + y2 = y2.clamp(min=0, max=max_shape[0] - 1) + bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]], + dim=-1) + + # bucketing guided rescoring + loc_confidence = score_topk[:, 0] + top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1 + loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float() + loc_confidence = loc_confidence.view(-1, 4).mean(dim=1) + + return bboxes, loc_confidence diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index e1b1010eb99..89aa81b2a63 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -163,3 +163,35 @@ def bbox2distance(points, bbox, max_dis=None, eps=0.1): right = right.clamp(min=0, max=max_dis - eps) bottom = bottom.clamp(min=0, max=max_dis - eps) return torch.stack([left, top, right, bottom], -1) + + +def bbox_rescale(bboxes, scale_factor=1.0): + """Rescale bounding box w.r.t. scale_factor. + + Args: + bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois + scale_factor (float): rescale factor + + Returns: + Tensor: Rescaled bboxes. + """ + if bboxes.size(1) == 5: + bboxes_ = bboxes[:, 1:] + inds_ = bboxes[:, 0] + else: + bboxes_ = bboxes + cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5 + cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5 + w = bboxes_[:, 2] - bboxes_[:, 0] + h = bboxes_[:, 3] - bboxes_[:, 1] + w = w * scale_factor + h = h * scale_factor + x1 = cx - 0.5 * w + x2 = cx + 0.5 * w + y1 = cy - 0.5 * h + y2 = cy + 0.5 * h + if bboxes.size(1) == 5: + rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1) + else: + rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return rescaled_bboxes diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 3c6405b803f..2e93a380f95 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -18,6 +18,7 @@ from .retina_head import RetinaHead from .retina_sepbn_head import RetinaSepBNHead from .rpn_head import RPNHead +from .sabl_retina_head import SABLRetinaHead from .ssd_head import SSDHead from .yolo_head import YOLOV3Head @@ -27,5 +28,5 @@ 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead', 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'PAAHead', - 'YOLOV3Head' + 'YOLOV3Head', 'SABLRetinaHead' ] diff --git a/mmdet/models/dense_heads/sabl_retina_head.py b/mmdet/models/dense_heads/sabl_retina_head.py new file mode 100644 index 00000000000..6c0a339de76 --- /dev/null +++ b/mmdet/models/dense_heads/sabl_retina_head.py @@ -0,0 +1,622 @@ +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init + +from mmdet.core import (build_anchor_generator, build_assigner, + build_bbox_coder, build_sampler, force_fp32, + images_to_levels, multi_apply, multiclass_nms, unmap) +from ..builder import HEADS, build_loss +from .base_dense_head import BaseDenseHead +from .guided_anchor_head import GuidedAnchorHead + + +@HEADS.register_module +class SABLRetinaHead(BaseDenseHead): + """Side-Aware Boundary Localization (SABL) for RetinaNet. + + The anchor generation, assigning and sampling in SABLRetinaHead + are the same as GuidedAnchorHead for guided anchoring. + + Please refer to https://arxiv.org/abs/1912.04260 for more details. + + Args: + num_classes (int): Number of classes. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): Number of Convs for classification \ + and regression branches. Defaults to 4. + feat_channels (int): Number of hidden channels. \ + Defaults to 256. + approx_anchor_generator (dict): Config dict for approx generator. + square_anchor_generator (dict): Config dict for square generator. + conv_cfg (dict): Config dict for ConvModule. Defaults to None. + norm_cfg (dict): Config dict for Norm Layer. Defaults to None. + bbox_coder (dict): Config dict for bbox coder. + reg_decoded_bbox (bool): Whether to regress decoded bbox. \ + Defaults to False. + background_label (int): Background label. Defaults to None. + train_cfg (dict): Training config of SABLRetinaHead. + test_cfg (dict): Testing config of SABLRetinaHead. + loss_cls (dict): Config of classification loss. + loss_bbox_cls (dict): Config of classification loss for bbox branch. + loss_bbox_reg (dict): Config of regression loss for bbox branch. + """ + + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + feat_channels=256, + approx_anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + square_anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[4], + strides=[8, 16, 32, 64, 128]), + conv_cfg=None, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', + num_buckets=14, + scale_factor=3.0), + reg_decoded_bbox=False, + background_label=None, + train_cfg=None, + test_cfg=None, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.5), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)): + super(SABLRetinaHead, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + self.feat_channels = feat_channels + self.num_buckets = bbox_coder['num_buckets'] + self.side_num = int(np.ceil(self.num_buckets / 2)) + + assert (approx_anchor_generator['octave_base_scale'] == + square_anchor_generator['scales'][0]) + assert (approx_anchor_generator['strides'] == + square_anchor_generator['strides']) + + self.approx_anchor_generator = build_anchor_generator( + approx_anchor_generator) + self.square_anchor_generator = build_anchor_generator( + square_anchor_generator) + self.approxs_per_octave = ( + self.approx_anchor_generator.num_base_anchors[0]) + + # one anchor per location + self.num_anchors = 1 + self.stacked_convs = stacked_convs + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + self.reg_decoded_bbox = reg_decoded_bbox + self.background_label = ( + num_classes if background_label is None else background_label) + + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + self.sampling = loss_cls['type'] not in [ + 'FocalLoss', 'GHMC', 'QualityFocalLoss' + ] + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + + self.bbox_coder = build_bbox_coder(bbox_coder) + self.loss_cls = build_loss(loss_cls) + self.loss_bbox_cls = build_loss(loss_bbox_cls) + self.loss_bbox_reg = build_loss(loss_bbox_reg) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if self.train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + # use PseudoSampler when sampling is False + if self.sampling and hasattr(self.train_cfg, 'sampler'): + sampler_cfg = self.train_cfg.sampler + else: + sampler_cfg = dict(type='PseudoSampler') + self.sampler = build_sampler(sampler_cfg, context=self) + + self.fp16_enabled = False + self._init_layers() + + def _init_layers(self): + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg)) + self.retina_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + self.retina_bbox_reg = nn.Conv2d( + self.feat_channels, self.side_num * 4, 3, padding=1) + self.retina_bbox_cls = nn.Conv2d( + self.feat_channels, self.side_num * 4, 3, padding=1) + + def init_weights(self): + for m in self.cls_convs: + normal_init(m.conv, std=0.01) + for m in self.reg_convs: + normal_init(m.conv, std=0.01) + bias_cls = bias_init_with_prob(0.01) + normal_init(self.retina_cls, std=0.01, bias=bias_cls) + normal_init(self.retina_bbox_reg, std=0.01) + normal_init(self.retina_bbox_cls, std=0.01) + + def forward_single(self, x): + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.retina_cls(cls_feat) + bbox_cls_pred = self.retina_bbox_cls(reg_feat) + bbox_reg_pred = self.retina_bbox_reg(reg_feat) + bbox_pred = (bbox_cls_pred, bbox_reg_pred) + return cls_score, bbox_pred + + def forward(self, feats): + return multi_apply(self.forward_single, feats) + + def get_anchors(self, featmap_sizes, img_metas, device='cuda'): + """Get squares according to feature map sizes and guided anchors. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + device (torch.device | str): device for returned tensors + + Returns: + tuple: square approxs of each image + """ + num_imgs = len(img_metas) + + # since feature map sizes of all images are the same, we only compute + # squares for one time + multi_level_squares = self.square_anchor_generator.grid_anchors( + featmap_sizes, device=device) + squares_list = [multi_level_squares for _ in range(num_imgs)] + + return squares_list + + def get_target(self, + approx_list, + inside_flag_list, + square_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=None, + sampling=True, + unmap_outputs=True): + """Compute bucketing targets. + Args: + approx_list (list[list]): Multi level approxs of each image. + inside_flag_list (list[list]): Multi level inside flags of each + image. + square_list (list[list]): Multi level squares of each image. + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes. + gt_bboxes_list (list[Tensor]): Gt bboxes of each image. + label_channels (int): Channel of label. + sampling (bool): Sample Anchors or not. + unmap_outputs (bool): unmap outputs or not. + + Returns: + tuple: Returns a tuple containing learning targets. + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each \ + level. + - bbox_cls_targets_list (list[Tensor]): BBox cls targets of \ + each level. + - bbox_cls_weights_list (list[Tensor]): BBox cls weights of \ + each level. + - bbox_reg_targets_list (list[Tensor]): BBox reg targets of \ + each level. + - bbox_reg_weights_list (list[Tensor]): BBox reg weights of \ + each level. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + num_imgs = len(img_metas) + assert len(approx_list) == len(inside_flag_list) == len( + square_list) == num_imgs + # anchor number of multi levels + num_level_squares = [squares.size(0) for squares in square_list[0]] + # concat all level anchors and flags to a single tensor + inside_flag_flat_list = [] + approx_flat_list = [] + square_flat_list = [] + for i in range(num_imgs): + assert len(square_list[i]) == len(inside_flag_list[i]) + inside_flag_flat_list.append(torch.cat(inside_flag_list[i])) + approx_flat_list.append(torch.cat(approx_list[i])) + square_flat_list.append(torch.cat(square_list[i])) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + (all_labels, all_label_weights, all_bbox_cls_targets, + all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights, + pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, + approx_flat_list, + inside_flag_flat_list, + square_flat_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + sampling=sampling, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + # split targets to a list w.r.t. multiple levels + labels_list = images_to_levels(all_labels, num_level_squares) + label_weights_list = images_to_levels(all_label_weights, + num_level_squares) + bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets, + num_level_squares) + bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights, + num_level_squares) + bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets, + num_level_squares) + bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights, + num_level_squares) + return (labels_list, label_weights_list, bbox_cls_targets_list, + bbox_cls_weights_list, bbox_reg_targets_list, + bbox_reg_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, + flat_approxs, + inside_flags, + flat_squares, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=None, + sampling=True, + unmap_outputs=True): + """Compute regression and classification targets for anchors in a + single image. + + Args: + flat_approxs (Tensor): flat approxs of a single image, + shape (n, 4) + inside_flags (Tensor): inside flags of a single image, + shape (n, ). + flat_squares (Tensor): flat squares of a single image, + shape (approxs_per_octave * n, 4) + gt_bboxes (Tensor): Ground truth bboxes of a single image, \ + shape (num_gts, 4). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + img_meta (dict): Meta info of the image. + label_channels (int): Channel of label. + sampling (bool): Sample Anchors or not. + unmap_outputs (bool): unmap outputs or not. + + Returns: + tuple: + + - labels_list (Tensor): Labels in a single image + - label_weights (Tensor): Label weights in a single image + - bbox_cls_targets (Tensor): BBox cls targets in a single image + - bbox_cls_weights (Tensor): BBox cls weights in a single image + - bbox_reg_targets (Tensor): BBox reg targets in a single image + - bbox_reg_weights (Tensor): BBox reg weights in a single image + - num_total_pos (int): Number of positive samples \ + in a single image + - num_total_neg (int): Number of negative samples \ + in a single image + """ + if not inside_flags.any(): + return (None, ) * 8 + # assign gt and sample anchors + expand_inside_flags = inside_flags[:, None].expand( + -1, self.approxs_per_octave).reshape(-1) + approxs = flat_approxs[expand_inside_flags, :] + squares = flat_squares[inside_flags, :] + + assign_result = self.assigner.assign(approxs, squares, + self.approxs_per_octave, + gt_bboxes, gt_bboxes_ignore) + sampling_result = self.sampler.sample(assign_result, squares, + gt_bboxes) + + num_valid_squares = squares.shape[0] + bbox_cls_targets = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + bbox_cls_weights = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + bbox_reg_targets = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + bbox_reg_weights = squares.new_zeros( + (num_valid_squares, self.side_num * 4)) + labels = squares.new_full((num_valid_squares, ), + self.background_label, + dtype=torch.long) + label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + (pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets, + pos_bbox_cls_weights) = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + + bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets + bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets + bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights + bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights + if gt_labels is None: + labels[pos_inds] = 1 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_squares.size(0) + labels = unmap( + labels, + num_total_anchors, + inside_flags, + fill=self.background_label) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors, + inside_flags) + bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors, + inside_flags) + bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors, + inside_flags) + bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors, + inside_flags) + return (labels, label_weights, bbox_cls_targets, bbox_cls_weights, + bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds) + + def loss_single(self, cls_score, bbox_pred, labels, label_weights, + bbox_cls_targets, bbox_cls_weights, bbox_reg_targets, + bbox_reg_weights, num_total_samples): + # classification loss + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=num_total_samples) + # regression loss + bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4) + bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4) + bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4) + bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4) + (bbox_cls_pred, bbox_reg_pred) = bbox_pred + bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape( + -1, self.side_num * 4) + bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape( + -1, self.side_num * 4) + loss_bbox_cls = self.loss_bbox_cls( + bbox_cls_pred, + bbox_cls_targets.long(), + bbox_cls_weights, + avg_factor=num_total_samples * 4 * self.side_num) + loss_bbox_reg = self.loss_bbox_reg( + bbox_reg_pred, + bbox_reg_targets, + bbox_reg_weights, + avg_factor=num_total_samples * 4 * self.bbox_coder.offset_topk) + return loss_cls, loss_bbox_cls, loss_bbox_reg + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.approx_anchor_generator.num_levels + + device = cls_scores[0].device + + # get sampled approxes + approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs( + self, featmap_sizes, img_metas, device=device) + + square_list = self.get_anchors(featmap_sizes, img_metas, device=device) + + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + cls_reg_targets = self.get_target( + approxs_list, + inside_flag_list, + square_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels, + sampling=self.sampling) + if cls_reg_targets is None: + return None + (labels_list, label_weights_list, bbox_cls_targets_list, + bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + num_total_samples = ( + num_total_pos + num_total_neg if self.sampling else num_total_pos) + losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply( + self.loss_single, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_cls_targets_list, + bbox_cls_weights_list, + bbox_reg_targets_list, + bbox_reg_weights_list, + num_total_samples=num_total_samples) + return dict( + loss_cls=losses_cls, + loss_bbox_cls=losses_bbox_cls, + loss_bbox_reg=losses_bbox_reg) + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def get_bboxes(self, + cls_scores, + bbox_preds, + img_metas, + cfg=None, + rescale=False): + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + + device = cls_scores[0].device + mlvl_anchors = self.get_anchors( + featmap_sizes, img_metas, device=device) + result_list = [] + for img_id in range(len(img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_cls_pred_list = [ + bbox_preds[i][0][img_id].detach() for i in range(num_levels) + ] + bbox_reg_pred_list = [ + bbox_preds[i][1][img_id].detach() for i in range(num_levels) + ] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + proposals = self.get_bboxes_single(cls_score_list, + bbox_cls_pred_list, + bbox_reg_pred_list, + mlvl_anchors[img_id], img_shape, + scale_factor, cfg, rescale) + result_list.append(proposals) + return result_list + + def get_bboxes_single(self, + cls_scores, + bbox_cls_preds, + bbox_reg_preds, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False): + cfg = self.test_cfg if cfg is None else cfg + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_confids = [] + assert len(cls_scores) == len(bbox_cls_preds) == len( + bbox_reg_preds) == len(mlvl_anchors) + for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip( + cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_cls_pred.size( + )[-2:] == bbox_reg_pred.size()[-2::] + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1) + bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape( + -1, self.side_num * 4) + bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape( + -1, self.side_num * 4) + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + if self.use_sigmoid_cls: + max_scores, _ = scores.max(dim=1) + else: + max_scores, _ = scores[:, :-1].max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_cls_pred = bbox_cls_pred[topk_inds, :] + bbox_reg_pred = bbox_reg_pred[topk_inds, :] + scores = scores[topk_inds, :] + bbox_preds = [ + bbox_cls_pred.contiguous(), + bbox_reg_pred.contiguous() + ] + bboxes, confids = self.bbox_coder.decode( + anchors.contiguous(), bbox_preds, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_confids.append(confids) + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + mlvl_scores = torch.cat(mlvl_scores) + mlvl_confids = torch.cat(mlvl_confids) + if self.use_sigmoid_cls: + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + det_bboxes, det_labels = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_confids) + return det_bboxes, det_labels diff --git a/mmdet/models/roi_heads/bbox_heads/__init__.py b/mmdet/models/roi_heads/bbox_heads/__init__.py index 708f8ae6438..acfb71ebebe 100644 --- a/mmdet/models/roi_heads/bbox_heads/__init__.py +++ b/mmdet/models/roi_heads/bbox_heads/__init__.py @@ -2,8 +2,9 @@ from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead, Shared4Conv1FCBBoxHead) from .double_bbox_head import DoubleConvFCBBoxHead +from .sabl_head import SABLHead __all__ = [ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', - 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead' + 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead' ] diff --git a/mmdet/models/roi_heads/bbox_heads/sabl_head.py b/mmdet/models/roi_heads/bbox_heads/sabl_head.py new file mode 100644 index 00000000000..683e816078b --- /dev/null +++ b/mmdet/models/roi_heads/bbox_heads/sabl_head.py @@ -0,0 +1,563 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init + +from mmdet.core import (build_bbox_coder, force_fp32, multi_apply, + multiclass_nms) +from mmdet.models.builder import HEADS, build_loss +from mmdet.models.losses import accuracy + + +@HEADS.register_module +class SABLHead(nn.Module): + """Side-Aware Boundary Localization (SABL) for RoI-Head. + + Side-Aware features are extracted by conv layers + with an attention mechanism. + Boundary Localization with Bucketing and Bucketing Guided Rescoring + are implemented in BucketingBBoxCoder. + + Please refer to https://arxiv.org/abs/1912.04260 for more details. + + Args: + cls_in_channels (int): Input channels of cls RoI feature. \ + Defaults to 256. + reg_in_channels (int): Input channels of reg RoI feature. \ + Defaults to 256. + roi_feat_size (int): Size of RoI features. Defaults to 7. + reg_feat_up_ratio (int): Upsample ratio of reg features. \ + Defaults to 2. + reg_pre_kernel (int): Kernel of 2D conv layers before \ + attention pooling. Defaults to 3. + reg_post_kernel (int): Kernel of 1D conv layers after \ + attention pooling. Defaults to 3. + reg_pre_num (int): Number of pre convs. Defaults to 2. + reg_post_num (int): Number of post convs. Defaults to 1. + num_classes (int): Number of classes in dataset. Defaults to 80. + cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024. + reg_offset_out_channels (int): Hidden and output channel \ + of reg offset branch. Defaults to 256. + reg_cls_out_channels (int): Hidden and output channel \ + of reg cls branch. Defaults to 256. + num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1. + num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0. + reg_class_agnostic (bool): Class agnostic regresion or not. \ + Defaults to True. + norm_cfg (dict): Config of norm layers. Defaults to None. + bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'. + loss_cls (dict): Config of classification loss. + loss_bbox_cls (dict): Config of classification loss for bbox branch. + loss_bbox_reg (dict): Config of regression loss for bbox branch. + """ + + def __init__(self, + num_classes, + cls_in_channels=256, + reg_in_channels=256, + roi_feat_size=7, + reg_feat_up_ratio=2, + reg_pre_kernel=3, + reg_post_kernel=3, + reg_pre_num=2, + reg_post_num=1, + cls_out_channels=1024, + reg_offset_out_channels=256, + reg_cls_out_channels=256, + num_cls_fcs=1, + num_reg_fcs=0, + reg_class_agnostic=True, + norm_cfg=None, + bbox_coder=dict( + type='BucketingBBoxCoder', + num_buckets=14, + scale_factor=1.7), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_bbox_reg=dict( + type='SmoothL1Loss', beta=0.1, loss_weight=1.0)): + super(SABLHead, self).__init__() + self.cls_in_channels = cls_in_channels + self.reg_in_channels = reg_in_channels + self.roi_feat_size = roi_feat_size + self.reg_feat_up_ratio = int(reg_feat_up_ratio) + self.num_buckets = bbox_coder['num_buckets'] + assert self.reg_feat_up_ratio // 2 >= 1 + self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio + assert self.up_reg_feat_size == bbox_coder['num_buckets'] + self.reg_pre_kernel = reg_pre_kernel + self.reg_post_kernel = reg_post_kernel + self.reg_pre_num = reg_pre_num + self.reg_post_num = reg_post_num + self.num_classes = num_classes + self.cls_out_channels = cls_out_channels + self.reg_offset_out_channels = reg_offset_out_channels + self.reg_cls_out_channels = reg_cls_out_channels + self.num_cls_fcs = num_cls_fcs + self.num_reg_fcs = num_reg_fcs + self.reg_class_agnostic = reg_class_agnostic + assert self.reg_class_agnostic + self.norm_cfg = norm_cfg + + self.bbox_coder = build_bbox_coder(bbox_coder) + self.loss_cls = build_loss(loss_cls) + self.loss_bbox_cls = build_loss(loss_bbox_cls) + self.loss_bbox_reg = build_loss(loss_bbox_reg) + + self.cls_fcs = self._add_fc_branch(self.num_cls_fcs, + self.cls_in_channels, + self.roi_feat_size, + self.cls_out_channels) + + self.side_num = int(np.ceil(self.num_buckets / 2)) + + if self.reg_feat_up_ratio > 1: + self.upsample_x = nn.ConvTranspose1d( + reg_in_channels, + reg_in_channels, + self.reg_feat_up_ratio, + stride=self.reg_feat_up_ratio) + self.upsample_y = nn.ConvTranspose1d( + reg_in_channels, + reg_in_channels, + self.reg_feat_up_ratio, + stride=self.reg_feat_up_ratio) + + self.reg_pre_convs = nn.ModuleList() + for i in range(self.reg_pre_num): + reg_pre_conv = ConvModule( + reg_in_channels, + reg_in_channels, + kernel_size=reg_pre_kernel, + padding=reg_pre_kernel // 2, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.reg_pre_convs.append(reg_pre_conv) + + self.reg_post_conv_xs = nn.ModuleList() + for i in range(self.reg_post_num): + reg_post_conv_x = ConvModule( + reg_in_channels, + reg_in_channels, + kernel_size=(1, reg_post_kernel), + padding=(0, reg_post_kernel // 2), + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.reg_post_conv_xs.append(reg_post_conv_x) + self.reg_post_conv_ys = nn.ModuleList() + for i in range(self.reg_post_num): + reg_post_conv_y = ConvModule( + reg_in_channels, + reg_in_channels, + kernel_size=(reg_post_kernel, 1), + padding=(reg_post_kernel // 2, 0), + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU')) + self.reg_post_conv_ys.append(reg_post_conv_y) + + self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1) + self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1) + + self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1) + self.relu = nn.ReLU(inplace=True) + + self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs, + self.reg_in_channels, 1, + self.reg_cls_out_channels) + self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs, + self.reg_in_channels, 1, + self.reg_offset_out_channels) + self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1) + self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1) + + def _add_fc_branch(self, num_branch_fcs, in_channels, roi_feat_size, + fc_out_channels): + in_channels = in_channels * roi_feat_size * roi_feat_size + branch_fcs = nn.ModuleList() + for i in range(num_branch_fcs): + fc_in_channels = (in_channels if i == 0 else fc_out_channels) + branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels)) + return branch_fcs + + def init_weights(self): + for module_list in [ + self.reg_cls_fcs, self.reg_offset_fcs, self.cls_fcs + ]: + for m in module_list.modules(): + if isinstance(m, nn.Linear): + xavier_init(m, distribution='uniform') + if self.reg_feat_up_ratio > 1: + kaiming_init(self.upsample_x, distribution='normal') + kaiming_init(self.upsample_y, distribution='normal') + + normal_init(self.reg_conv_att_x, 0, 0.01) + normal_init(self.reg_conv_att_y, 0, 0.01) + normal_init(self.fc_reg_offset, 0, 0.001) + normal_init(self.fc_reg_cls, 0, 0.01) + normal_init(self.fc_cls, 0, 0.01) + + def cls_forward(self, cls_x): + cls_x = cls_x.view(cls_x.size(0), -1) + for fc in self.cls_fcs: + cls_x = self.relu(fc(cls_x)) + cls_score = self.fc_cls(cls_x) + return cls_score + + def attention_pool(self, reg_x): + """Extract direction-specific features fx and fy with attention + methanism.""" + reg_fx = reg_x + reg_fy = reg_x + reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid() + reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid() + reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2) + reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3) + reg_fx = (reg_fx * reg_fx_att).sum(dim=2) + reg_fy = (reg_fy * reg_fy_att).sum(dim=3) + return reg_fx, reg_fy + + def side_aware_feature_extractor(self, reg_x): + """Refine and extract side-aware features without split them.""" + for reg_pre_conv in self.reg_pre_convs: + reg_x = reg_pre_conv(reg_x) + reg_fx, reg_fy = self.attention_pool(reg_x) + + if self.reg_post_num > 0: + reg_fx = reg_fx.unsqueeze(2) + reg_fy = reg_fy.unsqueeze(3) + for i in range(self.reg_post_num): + reg_fx = self.reg_post_conv_xs[i](reg_fx) + reg_fy = self.reg_post_conv_ys[i](reg_fy) + reg_fx = reg_fx.squeeze(2) + reg_fy = reg_fy.squeeze(3) + if self.reg_feat_up_ratio > 1: + reg_fx = self.relu(self.upsample_x(reg_fx)) + reg_fy = self.relu(self.upsample_y(reg_fy)) + reg_fx = torch.transpose(reg_fx, 1, 2) + reg_fy = torch.transpose(reg_fy, 1, 2) + return reg_fx.contiguous(), reg_fy.contiguous() + + def reg_pred(self, x, offfset_fcs, cls_fcs): + """Predict bucketing esimation (cls_pred) and fine regression (offset + pred) with side-aware features.""" + x_offset = x.view(-1, self.reg_in_channels) + x_cls = x.view(-1, self.reg_in_channels) + + for fc in offfset_fcs: + x_offset = self.relu(fc(x_offset)) + for fc in cls_fcs: + x_cls = self.relu(fc(x_cls)) + offset_pred = self.fc_reg_offset(x_offset) + cls_pred = self.fc_reg_cls(x_cls) + + offset_pred = offset_pred.view(x.size(0), -1) + cls_pred = cls_pred.view(x.size(0), -1) + + return offset_pred, cls_pred + + def side_aware_split(self, feat): + """Split side-aware features aligned with orders of bucketing + targets.""" + l_end = int(np.ceil(self.up_reg_feat_size / 2)) + r_start = int(np.floor(self.up_reg_feat_size / 2)) + feat_fl = feat[:, :l_end] + feat_fr = feat[:, r_start:].flip(dims=(1, )) + feat_fl = feat_fl.contiguous() + feat_fr = feat_fr.contiguous() + feat = torch.cat([feat_fl, feat_fr], dim=-1) + return feat + + def reg_forward(self, reg_x): + outs = self.side_aware_feature_extractor(reg_x) + edge_offset_preds = [] + edge_cls_preds = [] + reg_fx = outs[0] + reg_fy = outs[1] + offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs, + self.reg_cls_fcs) + offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs, + self.reg_cls_fcs) + offset_pred_x = self.side_aware_split(offset_pred_x) + offset_pred_y = self.side_aware_split(offset_pred_y) + cls_pred_x = self.side_aware_split(cls_pred_x) + cls_pred_y = self.side_aware_split(cls_pred_y) + edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1) + edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1) + + return (edge_cls_preds, edge_offset_preds) + + def forward(self, x): + + bbox_pred = self.reg_forward(x) + cls_score = self.cls_forward(x) + + return cls_score, bbox_pred + + def get_targets(self, sampling_results, gt_bboxes, gt_labels, + rcnn_train_cfg): + pos_proposals = [res.pos_bboxes for res in sampling_results] + neg_proposals = [res.neg_bboxes for res in sampling_results] + pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results] + pos_gt_labels = [res.pos_gt_labels for res in sampling_results] + cls_reg_targets = self.bucket_target(pos_proposals, neg_proposals, + pos_gt_bboxes, pos_gt_labels, + rcnn_train_cfg) + (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) = cls_reg_targets + return (labels, label_weights, (bucket_cls_targets, + bucket_offset_targets), + (bucket_cls_weights, bucket_offset_weights)) + + def bucket_target(self, + pos_proposals_list, + neg_proposals_list, + pos_gt_bboxes_list, + pos_gt_labels_list, + rcnn_train_cfg, + concat=True): + (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) = multi_apply( + self._bucket_target_single, + pos_proposals_list, + neg_proposals_list, + pos_gt_bboxes_list, + pos_gt_labels_list, + cfg=rcnn_train_cfg) + + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + bucket_cls_targets = torch.cat(bucket_cls_targets, 0) + bucket_cls_weights = torch.cat(bucket_cls_weights, 0) + bucket_offset_targets = torch.cat(bucket_offset_targets, 0) + bucket_offset_weights = torch.cat(bucket_offset_weights, 0) + return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) + + def _bucket_target_single(self, pos_proposals, neg_proposals, + pos_gt_bboxes, pos_gt_labels, cfg): + """Compute bucketing estimation targets and fine regression targets for + a single image. + + Args: + pos_proposals (Tensor): positive proposals of a single image, + Shape (n_pos, 4) + neg_proposals (Tensor): negative proposals of a single image, + Shape (n_neg, 4). + pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals + of a single image, Shape (n_pos, 4). + pos_gt_labels (Tensor): gt labels assigned to positive proposals + of a single image, Shape (n_pos, ). + cfg (dict): Config of calculating targets + + Returns: + tuple: + + - labels (Tensor): Labels in a single image. \ + Shape (n,). + - label_weights (Tensor): Label weights in a single image.\ + Shape (n,) + - bucket_cls_targets (Tensor): Bucket cls targets in \ + a single image. Shape (n, num_buckets*2). + - bucket_cls_weights (Tensor): Bucket cls weights in \ + a single image. Shape (n, num_buckets*2). + - bucket_offset_targets (Tensor): Bucket offset targets \ + in a single image. Shape (n, num_buckets*2). + - bucket_offset_targets (Tensor): Bucket offset weights \ + in a single image. Shape (n, num_buckets*2). + """ + num_pos = pos_proposals.size(0) + num_neg = neg_proposals.size(0) + num_samples = num_pos + num_neg + labels = pos_gt_bboxes.new_full((num_samples, ), + self.num_classes, + dtype=torch.long) + label_weights = pos_proposals.new_zeros(num_samples) + bucket_cls_targets = pos_proposals.new_zeros(num_samples, + 4 * self.side_num) + bucket_cls_weights = pos_proposals.new_zeros(num_samples, + 4 * self.side_num) + bucket_offset_targets = pos_proposals.new_zeros( + num_samples, 4 * self.side_num) + bucket_offset_weights = pos_proposals.new_zeros( + num_samples, 4 * self.side_num) + if num_pos > 0: + labels[:num_pos] = pos_gt_labels + label_weights[:num_pos] = 1.0 + (pos_bucket_offset_targets, pos_bucket_offset_weights, + pos_bucket_cls_targets, + pos_bucket_cls_weights) = self.bbox_coder.encode( + pos_proposals, pos_gt_bboxes) + bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets + bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights + bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets + bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights + if num_neg > 0: + label_weights[-num_neg:] = 1.0 + return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, + bucket_offset_targets, bucket_offset_weights) + + def loss(self, + cls_score, + bbox_pred, + rois, + labels, + label_weights, + bbox_targets, + bbox_weights, + reduction_override=None): + losses = dict() + if cls_score is not None: + avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) + losses['loss_cls'] = self.loss_cls( + cls_score, + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + losses['acc'] = accuracy(cls_score, labels) + + if bbox_pred is not None: + bucket_cls_preds, bucket_offset_preds = bbox_pred + bucket_cls_targets, bucket_offset_targets = bbox_targets + bucket_cls_weights, bucket_offset_weights = bbox_weights + # edge cls + bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num) + bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num) + bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num) + losses['loss_bbox_cls'] = self.loss_bbox_cls( + bucket_cls_preds, + bucket_cls_targets, + bucket_cls_weights, + avg_factor=bucket_cls_targets.size(0), + reduction_override=reduction_override) + + losses['loss_bbox_reg'] = self.loss_bbox_reg( + bucket_offset_preds, + bucket_offset_targets, + bucket_offset_weights, + avg_factor=bucket_offset_targets.size(0), + reduction_override=reduction_override) + + return losses + + @force_fp32(apply_to=('cls_score', 'bbox_pred')) + def get_bboxes(self, + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=False, + cfg=None): + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + scores = F.softmax(cls_score, dim=1) if cls_score is not None else None + + if bbox_pred is not None: + bboxes, confids = self.bbox_coder.decode(rois[:, 1:], bbox_pred, + img_shape) + else: + bboxes = rois[:, 1:].clone() + confids = None + if img_shape is not None: + bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1) + bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1) + + if rescale and bboxes.size(0) > 0: + if isinstance(scale_factor, float): + bboxes /= scale_factor + else: + bboxes /= torch.from_numpy(scale_factor).to(bboxes.device) + + if cfg is None: + return bboxes, scores + else: + det_bboxes, det_labels = multiclass_nms( + bboxes, + scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=confids) + + return det_bboxes, det_labels + + @force_fp32(apply_to=('bbox_preds', )) + def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): + """Refine bboxes during training. + + Args: + rois (Tensor): Shape (n*bs, 5), where n is image number per GPU, + and bs is the sampled RoIs per image. + labels (Tensor): Shape (n*bs, ). + bbox_preds (list[Tensor]): Shape [(n*bs, num_buckets*2), \ + (n*bs, num_buckets*2)]. + pos_is_gts (list[Tensor]): Flags indicating if each positive bbox + is a gt bbox. + img_metas (list[dict]): Meta info of each image. + + Returns: + list[Tensor]: Refined bboxes of each image in a mini-batch. + """ + img_ids = rois[:, 0].long().unique(sorted=True) + assert img_ids.numel() == len(img_metas) + + bboxes_list = [] + for i in range(len(img_metas)): + inds = torch.nonzero( + rois[:, 0] == i, as_tuple=False).squeeze(dim=1) + num_rois = inds.numel() + + bboxes_ = rois[inds, 1:] + label_ = labels[inds] + edge_cls_preds, edge_offset_preds = bbox_preds + edge_cls_preds_ = edge_cls_preds[inds] + edge_offset_preds_ = edge_offset_preds[inds] + bbox_pred_ = [edge_cls_preds_, edge_offset_preds_] + img_meta_ = img_metas[i] + pos_is_gts_ = pos_is_gts[i] + + bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, + img_meta_) + # filter gt bboxes + pos_keep = 1 - pos_is_gts_ + keep_inds = pos_is_gts_.new_ones(num_rois) + keep_inds[:len(pos_is_gts_)] = pos_keep + + bboxes_list.append(bboxes[keep_inds.type(torch.bool)]) + + return bboxes_list + + @force_fp32(apply_to=('bbox_pred', )) + def regress_by_class(self, rois, label, bbox_pred, img_meta): + """Regress the bbox for the predicted class. Used in Cascade R-CNN. + + Args: + rois (Tensor): shape (n, 4) or (n, 5) + label (Tensor): shape (n, ) + bbox_pred (list[Tensor]): shape [(n, num_buckets *2), \ + (n, num_buckets *2)] + img_meta (dict): Image meta info. + + Returns: + Tensor: Regressed bboxes, the same shape as input rois. + """ + assert rois.size(1) == 4 or rois.size(1) == 5 + + if rois.size(1) == 4: + new_rois, _ = self.bbox_coder.decode(rois, bbox_pred, + img_meta['img_shape']) + else: + bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred, + img_meta['img_shape']) + new_rois = torch.cat((rois[:, [0]], bboxes), dim=1) + + return new_rois diff --git a/tests/test_config.py b/tests/test_config.py index bbcf4493206..2fc44246e11 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -324,18 +324,27 @@ def _check_bbox_head(bbox_cfg, bbox_head): _check_bbox_head(bbox_cfg, single_bbox_head) else: assert bbox_cfg['type'] == bbox_head.__class__.__name__ - assert bbox_cfg.in_channels == bbox_head.in_channels - with_cls = bbox_cfg.get('with_cls', True) - if with_cls: - fc_out_channels = bbox_cfg.get('fc_out_channels', 2048) - assert (fc_out_channels == bbox_head.fc_cls.in_features) - assert bbox_cfg.num_classes + 1 == bbox_head.fc_cls.out_features + if bbox_cfg['type'] == 'SABLHead': + assert bbox_cfg.cls_in_channels == bbox_head.cls_in_channels + assert bbox_cfg.reg_in_channels == bbox_head.reg_in_channels - with_reg = bbox_cfg.get('with_reg', True) - if with_reg: - out_dim = (4 if bbox_cfg.reg_class_agnostic else 4 * - bbox_cfg.num_classes) - assert bbox_head.fc_reg.out_features == out_dim + cls_out_channels = bbox_cfg.get('cls_out_channels', 1024) + assert (cls_out_channels == bbox_head.fc_cls.in_features) + assert (bbox_cfg.num_classes + 1 == bbox_head.fc_cls.out_features) + else: + assert bbox_cfg.in_channels == bbox_head.in_channels + with_cls = bbox_cfg.get('with_cls', True) + if with_cls: + fc_out_channels = bbox_cfg.get('fc_out_channels', 2048) + assert (fc_out_channels == bbox_head.fc_cls.in_features) + assert (bbox_cfg.num_classes + + 1 == bbox_head.fc_cls.out_features) + + with_reg = bbox_cfg.get('with_reg', True) + if with_reg: + out_dim = (4 if bbox_cfg.reg_class_agnostic else 4 * + bbox_cfg.num_classes) + assert bbox_head.fc_reg.out_features == out_dim def _check_anchorhead(config, head): @@ -350,6 +359,10 @@ def _check_anchorhead(config, head): assert (config.feat_channels == head.atss_cls.in_channels) assert (config.feat_channels == head.atss_reg.in_channels) assert (config.feat_channels == head.atss_centerness.in_channels) + elif config['type'] == 'SABLRetinaHead': + assert (config.feat_channels == head.retina_cls.in_channels) + assert (config.feat_channels == head.retina_bbox_reg.in_channels) + assert (config.feat_channels == head.retina_bbox_cls.in_channels) else: assert (config.in_channels == head.conv_cls.in_channels) assert (config.in_channels == head.conv_reg.in_channels) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 65aa1a07d32..78c227fc3c7 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -6,9 +6,9 @@ from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from mmdet.models.dense_heads import (AnchorHead, CornerHead, FCOSHead, FSAFHead, GuidedAnchorHead, PAAHead, - paa_head) + SABLRetinaHead, paa_head) from mmdet.models.dense_heads.paa_head import levels_to_images -from mmdet.models.roi_heads.bbox_heads import BBoxHead +from mmdet.models.roi_heads.bbox_heads import BBoxHead, SABLHead from mmdet.models.roi_heads.mask_heads import FCNMaskHead, MaskIoUHead @@ -473,6 +473,147 @@ def test_bbox_head_loss(): assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero' +def test_sabl_bbox_head_loss(): + """Tests bbox head loss when truth is empty and non-empty.""" + self = SABLHead( + num_classes=4, + cls_in_channels=3, + reg_in_channels=3, + cls_out_channels=3, + reg_offset_out_channels=3, + reg_cls_out_channels=3, + roi_feat_size=7) + + # Dummy proposals + proposal_list = [ + torch.Tensor([[23.6667, 23.8757, 228.6326, 153.8874]]), + ] + + target_cfg = mmcv.Config(dict(pos_weight=1)) + + # Test bbox loss when truth is empty + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + + sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes, + gt_labels) + + bbox_targets = self.get_targets(sampling_results, gt_bboxes, gt_labels, + target_cfg) + labels, label_weights, bbox_targets, bbox_weights = bbox_targets + + # Create dummy features "extracted" for each sampled bbox + num_sampled = sum(len(res.bboxes) for res in sampling_results) + rois = bbox2roi([res.bboxes for res in sampling_results]) + dummy_feats = torch.rand(num_sampled, 3, 7, 7) + cls_scores, bbox_preds = self.forward(dummy_feats) + + losses = self.loss(cls_scores, bbox_preds, rois, labels, label_weights, + bbox_targets, bbox_weights) + assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero' + assert losses.get('loss_bbox_cls', + 0) == 0, 'empty gt bbox-cls-loss should be zero' + assert losses.get('loss_bbox_reg', + 0) == 0, 'empty gt bbox-reg-loss should be zero' + + # Test bbox loss when truth is non-empty + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + + sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes, + gt_labels) + rois = bbox2roi([res.bboxes for res in sampling_results]) + + bbox_targets = self.get_targets(sampling_results, gt_bboxes, gt_labels, + target_cfg) + labels, label_weights, bbox_targets, bbox_weights = bbox_targets + + # Create dummy features "extracted" for each sampled bbox + num_sampled = sum(len(res.bboxes) for res in sampling_results) + dummy_feats = torch.rand(num_sampled, 3, 7, 7) + cls_scores, bbox_preds = self.forward(dummy_feats) + + losses = self.loss(cls_scores, bbox_preds, rois, labels, label_weights, + bbox_targets, bbox_weights) + assert losses.get('loss_bbox_cls', + 0) > 0, 'empty gt bbox-cls-loss should be zero' + assert losses.get('loss_bbox_reg', + 0) > 0, 'empty gt bbox-reg-loss should be zero' + + +def test_sabl_retina_head_loss(): + """Tests anchor head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + + cfg = mmcv.Config( + dict( + assigner=dict( + type='ApproxMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0.0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False)) + head = SABLRetinaHead( + num_classes=4, + in_channels=3, + feat_channels=10, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + train_cfg=cfg) + if torch.cuda.is_available(): + head.cuda() + # Anchor head expects a multiple levels of features per image + feat = [ + torch.rand(1, 3, s // (2**(i + 2)), s // (2**(i + 2))).cuda() + for i in range(len(head.approx_anchor_generator.base_anchors)) + ] + cls_scores, bbox_preds = head.forward(feat) + + # Test that empty ground truth encourages the network + # to predict background + gt_bboxes = [torch.empty((0, 4)).cuda()] + gt_labels = [torch.LongTensor([]).cuda()] + + gt_bboxes_ignore = None + empty_gt_losses = head.loss(cls_scores, bbox_preds, gt_bboxes, + gt_labels, img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there + # should be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_cls_loss = sum(empty_gt_losses['loss_bbox_cls']) + empty_box_reg_loss = sum(empty_gt_losses['loss_bbox_reg']) + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_cls_loss.item() == 0, ( + 'there should be no box cls loss when there are no true boxes') + assert empty_box_reg_loss.item() == 0, ( + 'there should be no box reg loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should + # be nonzero for random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]).cuda(), + ] + gt_labels = [torch.LongTensor([2]).cuda()] + one_gt_losses = head.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_cls_loss = sum(one_gt_losses['loss_bbox_cls']) + onegt_box_reg_loss = sum(one_gt_losses['loss_bbox_reg']) + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_cls_loss.item() > 0, 'box loss cls should be non-zero' + assert onegt_box_reg_loss.item() > 0, 'box loss reg should be non-zero' + + def test_refine_boxes(): """Mirrors the doctest in ``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for