diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index a7b803a3bd7..1e3fa12d8fe 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -2,8 +2,8 @@ from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner, MaxIoUAssigner, RegionAssigner) from .builder import build_assigner, build_bbox_coder, build_sampler -from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder, - TBLRBBoxCoder) +from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, DistancePointBBoxCoder, + PseudoBBoxCoder, TBLRBBoxCoder) from .iou_calculators import BboxOverlaps2D, bbox_overlaps from .samplers import (BaseSampler, CombinedSampler, InstanceBalancedPosSampler, IoUBalancedNegSampler, @@ -22,7 +22,7 @@ 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', - 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner', - 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh', - 'RegionAssigner' + 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder', + 'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy', + 'bbox_xyxy_to_cxcywh', 'RegionAssigner' ] diff --git a/mmdet/core/bbox/coder/__init__.py b/mmdet/core/bbox/coder/__init__.py index 4c7db000ae5..e12fd64e12b 100644 --- a/mmdet/core/bbox/coder/__init__.py +++ b/mmdet/core/bbox/coder/__init__.py @@ -2,6 +2,7 @@ from .base_bbox_coder import BaseBBoxCoder from .bucketing_bbox_coder import BucketingBBoxCoder from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder +from .distance_point_bbox_coder import DistancePointBBoxCoder from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder from .pseudo_bbox_coder import PseudoBBoxCoder from .tblr_bbox_coder import TBLRBBoxCoder @@ -10,5 +11,5 @@ __all__ = [ 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder', - 'BucketingBBoxCoder' + 'BucketingBBoxCoder', 'DistancePointBBoxCoder' ] diff --git a/mmdet/core/bbox/coder/distance_point_bbox_coder.py b/mmdet/core/bbox/coder/distance_point_bbox_coder.py new file mode 100644 index 00000000000..19499e3e270 --- /dev/null +++ b/mmdet/core/bbox/coder/distance_point_bbox_coder.py @@ -0,0 +1,62 @@ +from ..builder import BBOX_CODERS +from ..transforms import bbox2distance, distance2bbox +from .base_bbox_coder import BaseBBoxCoder + + +@BBOX_CODERS.register_module() +class DistancePointBBoxCoder(BaseBBoxCoder): + """Distance Point BBox coder. + + This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left, + right) and decode it back to the original. + + Args: + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + """ + + def __init__(self, clip_border=True): + super(BaseBBoxCoder, self).__init__() + self.clip_border = clip_border + + def encode(self, points, gt_bboxes, max_dis=None, eps=0.1): + """Encode bounding box to distances. + + Args: + points (Tensor): Shape (N, 2), The format is [x, y]. + gt_bboxes (Tensor): Shape (N, 4), The format is "xyxy" + max_dis (float): Upper bound of the distance. Default None. + eps (float): a small value to ensure target < max_dis, instead <=. + Default 0.1. + + Returns: + Tensor: Box transformation deltas. The shape is (N, 4). + """ + assert points.size(0) == gt_bboxes.size(0) + assert points.size(-1) == 2 + assert gt_bboxes.size(-1) == 4 + return bbox2distance(points, gt_bboxes, max_dis, eps) + + def decode(self, points, pred_bboxes, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (B, N, 2) or (N, 2). + pred_bboxes (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). Shape (B, N, 4) + or (N, 4) + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]], + and the length of max_shape should also be B. + Default None. + Returns: + Tensor: Boxes with shape (N, 4) or (B, N, 4) + """ + assert points.size(0) == pred_bboxes.size(0) + assert points.size(-1) == 2 + assert pred_bboxes.size(-1) == 4 + if self.clip_border is False: + max_shape = None + return distance2bbox(points, pred_bboxes, max_shape) diff --git a/mmdet/core/utils/__init__.py b/mmdet/core/utils/__init__.py index bbd909ff6e4..eea02119ce2 100644 --- a/mmdet/core/utils/__init__.py +++ b/mmdet/core/utils/__init__.py @@ -2,10 +2,10 @@ from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads, reduce_mean) from .misc import (center_of_mass, flip_tensor, generate_coordinate, - mask2ndarray, multi_apply, unmap) + mask2ndarray, multi_apply, select_single_mlvl, unmap) __all__ = [ 'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply', 'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict', - 'center_of_mass', 'generate_coordinate' + 'center_of_mass', 'generate_coordinate', 'select_single_mlvl' ] diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py index 36bb6883ddc..c0e730c5043 100644 --- a/mmdet/core/utils/misc.py +++ b/mmdet/core/utils/misc.py @@ -85,6 +85,37 @@ def flip_tensor(src_tensor, flip_direction): return out_tensor +def select_single_mlvl(mlvl_tensors, batch_id, detach=True): + """Extract a multi-scale single image tensor from a multi-scale batch + tensor based on batch index. + + Note: The default value of detach is True, because the proposal gradient + needs to be detached during the training of the two-stage model. E.g + Cascade Mask R-CNN. + + Args: + mlvl_tensors (list[Tensor]):Batch tensor for all scale levels, + each is a 4D-tensor. + batch_id (int): batch index. + detach (bool): Whether detach gradient. Default True. + + Returns: + list[Tensor]: multi-scale single image tensor. + """ + assert isinstance(mlvl_tensors, (list, tuple)) + num_levels = len(mlvl_tensors) + + if detach: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id].detach() for i in range(num_levels) + ] + else: + mlvl_tensor_list = [ + mlvl_tensors[i][batch_id] for i in range(num_levels) + ] + return mlvl_tensor_list + + def center_of_mass(mask, esp=1e-6): """Calculate the centroid coordinates of the mask. diff --git a/mmdet/models/dense_heads/anchor_free_head.py b/mmdet/models/dense_heads/anchor_free_head.py index 07dfcb032cd..4789880b899 100644 --- a/mmdet/models/dense_heads/anchor_free_head.py +++ b/mmdet/models/dense_heads/anchor_free_head.py @@ -6,7 +6,8 @@ from mmcv.cnn import ConvModule from mmcv.runner import force_fp32 -from mmdet.core import multi_apply +from mmdet.core import build_bbox_coder, multi_apply +from mmdet.core.anchor.point_generator import MlvlPointGenerator from ..builder import HEADS, build_loss from .base_dense_head import BaseDenseHead from .dense_test_mixins import BBoxTestMixin @@ -30,6 +31,8 @@ class AnchorFreeHead(BaseDenseHead, BBoxTestMixin): None, otherwise False. Default: "auto". loss_cls (dict): Config of classification loss. loss_bbox (dict): Config of localization loss. + bbox_coder (dict): Config of bbox coder. Defaults + 'DistancePointBBoxCoder'. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: None. train_cfg (dict): Training config of anchor head. @@ -54,6 +57,7 @@ def __init__(self, alpha=0.25, loss_weight=1.0), loss_bbox=dict(type='IoULoss', loss_weight=1.0), + bbox_coder=dict(type='DistancePointBBoxCoder'), conv_cfg=None, norm_cfg=None, train_cfg=None, @@ -69,7 +73,11 @@ def __init__(self, bias_prob=0.01))): super(AnchorFreeHead, self).__init__(init_cfg) self.num_classes = num_classes - self.cls_out_channels = num_classes + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 self.in_channels = in_channels self.feat_channels = feat_channels self.stacked_convs = stacked_convs @@ -79,6 +87,8 @@ def __init__(self, self.conv_bias = conv_bias self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) + self.bbox_coder = build_bbox_coder(bbox_coder) + self.prior_generator = MlvlPointGenerator(strides) self.train_cfg = train_cfg self.test_cfg = test_cfg self.conv_cfg = conv_cfg @@ -247,30 +257,6 @@ def loss(self, raise NotImplementedError - @abstractmethod - @force_fp32(apply_to=('cls_scores', 'bbox_preds')) - def get_bboxes(self, - cls_scores, - bbox_preds, - img_metas, - cfg=None, - rescale=None): - """Transform network output for a batch into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box scores for each scale level - Has shape (N, num_points * num_classes, H, W) - bbox_preds (list[Tensor]): Box energies / deltas for each scale - level with shape (N, num_points * 4, H, W) - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - cfg (mmcv.Config): Test / postprocessing configuration, - if None, test_cfg would be used - rescale (bool): If True, return boxes in original image space - """ - - raise NotImplementedError - @abstractmethod def get_targets(self, points, gt_bboxes_list, gt_labels_list): """Compute regression, classification and centerness targets for points diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py index a123c873586..b53433ac2e5 100644 --- a/mmdet/models/dense_heads/anchor_head.py +++ b/mmdet/models/dense_heads/anchor_head.py @@ -5,9 +5,9 @@ import torch.nn as nn from mmcv.runner import force_fp32 -from mmdet.core import (anchor_inside_flags, build_anchor_generator, - build_assigner, build_bbox_coder, build_sampler, - images_to_levels, multi_apply, multiclass_nms, unmap) +from mmdet.core import (anchor_inside_flags, build_assigner, build_bbox_coder, + build_prior_generator, build_sampler, images_to_levels, + multi_apply, unmap) from ..builder import HEADS, build_loss from .base_dense_head import BaseDenseHead from .dense_test_mixins import BBoxTestMixin @@ -103,12 +103,18 @@ def __init__(self, self.sampler = build_sampler(sampler_cfg, context=self) self.fp16_enabled = False - self.anchor_generator = build_anchor_generator(anchor_generator) + self.prior_generator = build_prior_generator(anchor_generator) # usually the numbers of anchors for each level are the same # except SSD detectors - self.num_anchors = self.anchor_generator.num_base_anchors[0] + self.num_anchors = self.prior_generator.num_base_priors[0] self._init_layers() + @property + def anchor_generator(self): + warnings.warn('DeprecationWarning: anchor_generator is deprecated, ' + 'please use "prior_generator" instead') + return self.prior_generator + def _init_layers(self): """Initialize layers of the head.""" self.conv_cls = nn.Conv2d(self.in_channels, @@ -168,14 +174,14 @@ def get_anchors(self, featmap_sizes, img_metas, device='cuda'): # since feature map sizes of all images are the same, we only compute # anchors for one time - multi_level_anchors = self.anchor_generator.grid_anchors( + multi_level_anchors = self.prior_generator.grid_anchors( featmap_sizes, device) anchor_list = [multi_level_anchors for _ in range(num_imgs)] # for each image, we compute valid flags of multi level anchors valid_flag_list = [] for img_id, img_meta in enumerate(img_metas): - multi_level_flags = self.anchor_generator.valid_flags( + multi_level_flags = self.prior_generator.valid_flags( featmap_sizes, img_meta['pad_shape'], device) valid_flag_list.append(multi_level_flags) @@ -459,7 +465,7 @@ def loss(self, dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device @@ -502,240 +508,6 @@ def loss(self, num_total_samples=num_total_samples) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) - @force_fp32(apply_to=('cls_scores', 'bbox_preds')) - def get_bboxes(self, - cls_scores, - bbox_preds, - img_metas, - cfg=None, - rescale=False, - with_nms=True): - """Transform network output for a batch into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box scores for each level in the - feature pyramid, has shape - (N, num_anchors * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for each - level in the feature pyramid, has shape - (N, num_anchors * 4, H, W). - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - cfg (mmcv.Config | None): Test / postprocessing configuration, - if None, test_cfg would be used - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before return boxes. - Default: True. - - Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where 5 represent - (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. - The shape of the second tensor in the tuple is (n,), and - each element represents the class label of the corresponding - box. - - Example: - >>> import mmcv - >>> self = AnchorHead( - >>> num_classes=9, - >>> in_channels=1, - >>> anchor_generator=dict( - >>> type='AnchorGenerator', - >>> scales=[8], - >>> ratios=[0.5, 1.0, 2.0], - >>> strides=[4,])) - >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}] - >>> cfg = mmcv.Config(dict( - >>> score_thr=0.00, - >>> nms=dict(type='nms', iou_thr=1.0), - >>> max_per_img=10)) - >>> feat = torch.rand(1, 1, 3, 3) - >>> cls_score, bbox_pred = self.forward_single(feat) - >>> # note the input lists are over different levels, not images - >>> cls_scores, bbox_preds = [cls_score], [bbox_pred] - >>> result_list = self.get_bboxes(cls_scores, bbox_preds, - >>> img_metas, cfg) - >>> det_bboxes, det_labels = result_list[0] - >>> assert len(result_list) == 1 - >>> assert det_bboxes.shape[1] == 5 - >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img - """ - assert len(cls_scores) == len(bbox_preds) - num_levels = len(cls_scores) - - device = cls_scores[0].device - featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_anchors = self.anchor_generator.grid_anchors( - featmap_sizes, device=device) - - mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] - mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] - - if torch.onnx.is_in_onnx_export(): - assert len( - img_metas - ) == 1, 'Only support one input image while in exporting to ONNX' - img_shapes = img_metas[0]['img_shape_for_onnx'] - else: - img_shapes = [ - img_metas[i]['img_shape'] - for i in range(cls_scores[0].shape[0]) - ] - scale_factors = [ - img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0]) - ] - - if with_nms: - # some heads don't support with_nms argument - result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds, - mlvl_anchors, img_shapes, - scale_factors, cfg, rescale) - else: - result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds, - mlvl_anchors, img_shapes, - scale_factors, cfg, rescale, - with_nms) - return result_list - - def _get_bboxes(self, - mlvl_cls_scores, - mlvl_bbox_preds, - mlvl_anchors, - img_shapes, - scale_factors, - cfg, - rescale=False, - with_nms=True): - """Transform outputs for a batch item into bbox predictions. - - Args: - mlvl_cls_scores (list[Tensor]): Each element in the list is - the scores of bboxes of single level in the feature pyramid, - has shape (N, num_anchors * num_classes, H, W). - mlvl_bbox_preds (list[Tensor]): Each element in the list is the - bboxes predictions of single level in the feature pyramid, - has shape (N, num_anchors * 4, H, W). - mlvl_anchors (list[Tensor]): Each element in the list is - the anchors of single level in feature pyramid, has shape - (num_anchors, 4). - img_shapes (list[tuple[int]]): Each tuple in the list represent - the shape(height, width, 3) of single image in the batch. - scale_factors (list[ndarray]): Scale factor of the batch - image arange as list[(w_scale, h_scale, w_scale, h_scale)]. - cfg (mmcv.Config): Test / postprocessing configuration, - if None, test_cfg would be used. - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before return boxes. - Default: True. - - Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where 5 represent - (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. - The shape of the second tensor in the tuple is (n,), and - each element represents the class label of the corresponding - box. - """ - cfg = self.test_cfg if cfg is None else cfg - assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len( - mlvl_anchors) - batch_size = mlvl_cls_scores[0].shape[0] - # convert to tensor to keep tracing - nms_pre_tensor = torch.tensor( - cfg.get('nms_pre', -1), - device=mlvl_cls_scores[0].device, - dtype=torch.long) - - mlvl_bboxes = [] - mlvl_scores = [] - for cls_score, bbox_pred, anchors in zip(mlvl_cls_scores, - mlvl_bbox_preds, - mlvl_anchors): - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - cls_score = cls_score.permute(0, 2, 3, - 1).reshape(batch_size, -1, - self.cls_out_channels) - if self.use_sigmoid_cls: - scores = cls_score.sigmoid() - else: - scores = cls_score.softmax(-1) - bbox_pred = bbox_pred.permute(0, 2, 3, - 1).reshape(batch_size, -1, 4) - anchors = anchors.expand_as(bbox_pred) - # Always keep topk op for dynamic input in onnx - from mmdet.core.export import get_k_for_topk - nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1]) - if nms_pre > 0: - # Get maximum scores for foreground classes. - if self.use_sigmoid_cls: - max_scores, _ = scores.max(-1) - else: - # remind that we set FG labels to [0, num_class-1] - # since mmdet v2.0 - # BG cat_id: num_class - max_scores, _ = scores[..., :-1].max(-1) - - _, topk_inds = max_scores.topk(nms_pre) - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds) - anchors = anchors[batch_inds, topk_inds, :] - bbox_pred = bbox_pred[batch_inds, topk_inds, :] - scores = scores[batch_inds, topk_inds, :] - - bboxes = self.bbox_coder.decode( - anchors, bbox_pred, max_shape=img_shapes) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - - batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) - if rescale: - batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( - scale_factors).unsqueeze(1) - batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) - - # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment - if torch.onnx.is_in_onnx_export() and with_nms: - from mmdet.core.export import add_dummy_nms_for_onnx - # ignore background class - if not self.use_sigmoid_cls: - num_classes = batch_mlvl_scores.shape[2] - 1 - batch_mlvl_scores = batch_mlvl_scores[..., :num_classes] - max_output_boxes_per_class = cfg.nms.get( - 'max_output_boxes_per_class', 200) - iou_threshold = cfg.nms.get('iou_threshold', 0.5) - score_threshold = cfg.score_thr - nms_pre = cfg.get('deploy_nms_pre', -1) - return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores, - max_output_boxes_per_class, - iou_threshold, score_threshold, - nms_pre, cfg.max_per_img) - if self.use_sigmoid_cls: - # Add a dummy background class to the backend when using sigmoid - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - padding = batch_mlvl_scores.new_zeros(batch_size, - batch_mlvl_scores.shape[1], - 1) - batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) - - if with_nms: - det_results = [] - for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes, - batch_mlvl_scores): - det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - det_results.append(tuple([det_bbox, det_label])) - else: - det_results = [ - tuple(mlvl_bs) - for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores) - ] - return det_results - def aug_test(self, feats, img_metas, rescale=False): """Test function with test time augmentation. diff --git a/mmdet/models/dense_heads/atss_head.py b/mmdet/models/dense_heads/atss_head.py index 0ebc59791bf..66942804440 100644 --- a/mmdet/models/dense_heads/atss_head.py +++ b/mmdet/models/dense_heads/atss_head.py @@ -5,8 +5,7 @@ from mmcv.runner import force_fp32 from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler, - images_to_levels, multi_apply, multiclass_nms, - reduce_mean, unmap) + images_to_levels, multi_apply, reduce_mean, unmap) from ..builder import HEADS, build_loss from .anchor_head import AnchorHead @@ -28,6 +27,7 @@ def __init__(self, stacked_convs=4, conv_cfg=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + reg_decoded_bbox=True, loss_centerness=dict( type='CrossEntropyLoss', use_sigmoid=True, @@ -46,7 +46,11 @@ def __init__(self, self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg super(ATSSHead, self).__init__( - num_classes, in_channels, init_cfg=init_cfg, **kwargs) + num_classes, + in_channels, + reg_decoded_bbox=reg_decoded_bbox, + init_cfg=init_cfg, + **kwargs) self.sampling = False if self.train_cfg: @@ -91,7 +95,7 @@ def _init_layers(self): self.atss_centerness = nn.Conv2d( self.feat_channels, self.num_anchors * 1, 3, padding=1) self.scales = nn.ModuleList( - [Scale(1.0) for _ in self.anchor_generator.strides]) + [Scale(1.0) for _ in self.prior_generator.strides]) def forward(self, feats): """Forward features from the upstream network. @@ -192,13 +196,11 @@ def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels, pos_anchors, pos_bbox_targets) pos_decode_bbox_pred = self.bbox_coder.decode( pos_anchors, pos_bbox_pred) - pos_decode_bbox_targets = self.bbox_coder.decode( - pos_anchors, pos_bbox_targets) # regression loss loss_bbox = self.loss_bbox( pos_decode_bbox_pred, - pos_decode_bbox_targets, + pos_bbox_targets, weight=centerness_targets, avg_factor=1.0) @@ -245,7 +247,7 @@ def loss(self, dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( @@ -291,9 +293,8 @@ def loss(self, loss_bbox=losses_bbox, loss_centerness=loss_centerness) - def centerness_target(self, anchors, bbox_targets): + def centerness_target(self, anchors, gts): # only calculate pos centerness targets, otherwise there may be nan - gts = self.bbox_coder.decode(anchors, bbox_targets) anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 l_ = anchors_cx - gts[:, 0] @@ -309,200 +310,6 @@ def centerness_target(self, anchors, bbox_targets): assert not torch.isnan(centerness).any() return centerness - @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses')) - def get_bboxes(self, - cls_scores, - bbox_preds, - centernesses, - img_metas, - cfg=None, - rescale=False, - with_nms=True): - """Transform network output for a batch into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box scores for each scale level - with shape (N, num_anchors * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for each scale - level with shape (N, num_anchors * 4, H, W). - centernesses (list[Tensor]): Centerness for each scale level with - shape (N, num_anchors * 1, H, W). - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - cfg (mmcv.Config | None): Test / postprocessing configuration, - if None, test_cfg would be used. Default: None. - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before return boxes. - Default: True. - - Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where 5 represent - (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. - The shape of the second tensor in the tuple is (n,), and - each element represents the class label of the corresponding - box. - """ - cfg = self.test_cfg if cfg is None else cfg - assert len(cls_scores) == len(bbox_preds) - num_levels = len(cls_scores) - device = cls_scores[0].device - featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_anchors = self.anchor_generator.grid_anchors( - featmap_sizes, device=device) - - cls_score_list = [cls_scores[i].detach() for i in range(num_levels)] - bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)] - centerness_pred_list = [ - centernesses[i].detach() for i in range(num_levels) - ] - img_shapes = [ - img_metas[i]['img_shape'] for i in range(cls_scores[0].shape[0]) - ] - scale_factors = [ - img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0]) - ] - result_list = self._get_bboxes(cls_score_list, bbox_pred_list, - centerness_pred_list, mlvl_anchors, - img_shapes, scale_factors, cfg, rescale, - with_nms) - return result_list - - def _get_bboxes(self, - cls_scores, - bbox_preds, - centernesses, - mlvl_anchors, - img_shapes, - scale_factors, - cfg, - rescale=False, - with_nms=True): - """Transform outputs for a single batch item into labeled boxes. - - Args: - cls_scores (list[Tensor]): Box scores for a single scale level - with shape (N, num_anchors * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for a single - scale level with shape (N, num_anchors * 4, H, W). - centernesses (list[Tensor]): Centerness for a single scale level - with shape (N, num_anchors * 1, H, W). - mlvl_anchors (list[Tensor]): Box reference for a single scale level - with shape (num_total_anchors, 4). - img_shapes (list[tuple[int]]): Shape of the input image, - list[(height, width, 3)]. - scale_factors (list[ndarray]): Scale factor of the image arrange as - (w_scale, h_scale, w_scale, h_scale). - cfg (mmcv.Config | None): Test / postprocessing configuration, - if None, test_cfg would be used. - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before return boxes. - Default: True. - - Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where 5 represent - (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. - The shape of the second tensor in the tuple is (n,), and - each element represents the class label of the corresponding - box. - """ - assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) - device = cls_scores[0].device - batch_size = cls_scores[0].shape[0] - # convert to tensor to keep tracing - nms_pre_tensor = torch.tensor( - cfg.get('nms_pre', -1), device=device, dtype=torch.long) - mlvl_bboxes = [] - mlvl_scores = [] - mlvl_centerness = [] - for cls_score, bbox_pred, centerness, anchors in zip( - cls_scores, bbox_preds, centernesses, mlvl_anchors): - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - scores = cls_score.permute(0, 2, 3, 1).reshape( - batch_size, -1, self.cls_out_channels).sigmoid() - centerness = centerness.permute(0, 2, 3, - 1).reshape(batch_size, - -1).sigmoid() - bbox_pred = bbox_pred.permute(0, 2, 3, - 1).reshape(batch_size, -1, 4) - - # Always keep topk op for dynamic input in onnx - if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export() - or scores.shape[-2] > nms_pre_tensor): - from torch import _shape_as_tensor - # keep shape as tensor and get k - num_anchor = _shape_as_tensor(scores)[-2].to(device) - nms_pre = torch.where(nms_pre_tensor < num_anchor, - nms_pre_tensor, num_anchor) - - max_scores, _ = (scores * centerness[..., None]).max(-1) - _, topk_inds = max_scores.topk(nms_pre) - anchors = anchors[topk_inds, :] - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds).long() - bbox_pred = bbox_pred[batch_inds, topk_inds, :] - scores = scores[batch_inds, topk_inds, :] - centerness = centerness[batch_inds, topk_inds] - else: - anchors = anchors.expand_as(bbox_pred) - - bboxes = self.bbox_coder.decode( - anchors, bbox_pred, max_shape=img_shapes) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - mlvl_centerness.append(centerness) - - batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) - if rescale: - batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( - scale_factors).unsqueeze(1) - batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) - batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1) - - # Set max number of box to be feed into nms in deployment - deploy_nms_pre = cfg.get('deploy_nms_pre', -1) - if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export(): - batch_mlvl_scores, _ = ( - batch_mlvl_scores * - batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores) - ).max(-1) - _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre) - batch_inds = torch.arange(batch_size).view(-1, - 1).expand_as(topk_inds) - batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :] - batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :] - batch_mlvl_centerness = batch_mlvl_centerness[batch_inds, - topk_inds] - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - padding = batch_mlvl_scores.new_zeros(batch_size, - batch_mlvl_scores.shape[1], 1) - batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) - - if with_nms: - det_results = [] - for (mlvl_bboxes, mlvl_scores, - mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores, - batch_mlvl_centerness): - det_bbox, det_label = multiclass_nms( - mlvl_bboxes, - mlvl_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=mlvl_centerness) - det_results.append(tuple([det_bbox, det_label])) - else: - det_results = [ - tuple(mlvl_bs) - for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores, - batch_mlvl_centerness) - ] - return det_results - def get_targets(self, anchor_list, valid_flag_list, @@ -641,12 +448,12 @@ def _get_target_single(self, pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: - if hasattr(self, 'bbox_coder'): + if self.reg_decoded_bbox: + pos_bbox_targets = sampling_result.pos_gt_bboxes + else: pos_bbox_targets = self.bbox_coder.encode( sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) - else: - # used in VFNetHead - pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 if gt_labels is None: diff --git a/mmdet/models/dense_heads/autoassign_head.py b/mmdet/models/dense_heads/autoassign_head.py index d6d5d831cb1..6d6316400f1 100644 --- a/mmdet/models/dense_heads/autoassign_head.py +++ b/mmdet/models/dense_heads/autoassign_head.py @@ -6,6 +6,7 @@ from mmcv.runner import force_fp32 from mmdet.core import distance2bbox, multi_apply +from mmdet.core.anchor.point_generator import MlvlPointGenerator from mmdet.core.bbox import bbox_overlaps from mmdet.models import HEADS from mmdet.models.dense_heads.atss_head import reduce_mean @@ -159,6 +160,7 @@ def __init__(self, self.pos_loss_weight = pos_loss_weight self.neg_loss_weight = neg_loss_weight self.center_loss_weight = center_loss_weight + self.prior_generator = MlvlPointGenerator(self.strides, offset=0) def init_weights(self): """Initialize weights of the head. diff --git a/mmdet/models/dense_heads/base_dense_head.py b/mmdet/models/dense_heads/base_dense_head.py index 0a2d05219b4..bc62ac3fef2 100644 --- a/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdet/models/dense_heads/base_dense_head.py @@ -1,7 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from mmcv.runner import BaseModule +import torch +from mmcv.runner import BaseModule, force_fp32 + +from mmdet.core import multiclass_nms +from mmdet.core.utils import select_single_mlvl class BaseDenseHead(BaseModule, metaclass=ABCMeta): @@ -15,10 +19,265 @@ def loss(self, **kwargs): """Compute losses of the head.""" pass - @abstractmethod - def get_bboxes(self, **kwargs): - """Transform network output for a batch into bbox predictions.""" - pass + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def get_bboxes(self, + cls_scores, + bbox_preds, + score_factors=None, + img_metas=None, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Transform network outputs of a batch into bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Default None. + img_metas (list[dict], Optional): Image meta info. Default None. + cfg (mmcv.Config, Optional): Test / postprocessing configuration, + if None, test_cfg would be used. Default None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default True. + + Returns: + list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class label of + the corresponding box. + """ + assert len(cls_scores) == len(bbox_preds) + + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + + num_levels = len(cls_scores) + + result_list = [] + + for img_id in range(len(img_metas)): + img_meta = img_metas[img_id] + cls_score_list = select_single_mlvl(cls_scores, img_id) + bbox_pred_list = select_single_mlvl(bbox_preds, img_id) + if with_score_factors: + score_factor_list = select_single_mlvl(score_factors, img_id) + else: + score_factor_list = [None for _ in range(num_levels)] + + results = self._get_bboxes_single(cls_score_list, bbox_pred_list, + score_factor_list, img_meta, cfg, + rescale, with_nms, **kwargs) + result_list.append(results) + return result_list + + def _get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + score_factor_list, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bbox, 5], where the first 4 columns are bounding box \ + positions (tl_x, tl_y, br_x, br_y) and the 5-th column \ + are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bbox]. + """ + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + for level_idx, (cls_score, bbox_pred, score_factor) in enumerate( + zip(cls_score_list, bbox_pred_list, score_factor_list)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + featmap_size_hw = cls_score.shape[-2:] + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + scores_ = scores + else: + scores = cls_score.softmax(-1) + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores_ = scores[:, :-1] + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + + if 0 < nms_pre < scores.shape[0]: + # Get maximum scores for foreground classes. + if with_score_factors: + max_scores, _ = (scores_ * + score_factor[:, None]).max(dim=1) + else: + max_scores, _ = scores_.max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + + priors = self.prior_generator.sparse_priors( + topk_inds, featmap_size_hw, level_idx, scores.dtype, + scores.device) + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + if with_score_factors: + score_factor = score_factor[topk_inds] + else: + priors = self.prior_generator.single_level_grid_priors( + featmap_size_hw, level_idx, scores.device) + + bboxes = self.bbox_coder.decode( + priors, bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + if with_score_factors: + mlvl_score_factors.append(score_factor) + + return self._bbox_post_process(mlvl_scores, mlvl_bboxes, + img_meta['scale_factor'], cfg, rescale, + with_nms, mlvl_score_factors, **kwargs) + + def _bbox_post_process(self, + mlvl_scores, + mlvl_bboxes, + scale_factor, + cfg, + rescale=False, + with_nms=True, + mlvl_score_factors=None, + **kwargs): + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually with_nms is False is used for aug test. + + Args: + mlvl_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num, num_class). + mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale + levels of a single image, each item has shape (num, 4). + scale_factor (ndarray, optional): Scale factor of the image arange + as (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + mlvl_score_factors (list[Tensor], optional): Score factor from + all scale levels of a single image, each item has shape + (num, ). Default: None. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bbox, 5], where the first 4 columns are bounding box \ + positions (tl_x, tl_y, br_x, br_y) and the 5-th column \ + are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bbox]. + """ + assert len(mlvl_scores) == len(mlvl_bboxes) + + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + mlvl_scores = torch.cat(mlvl_scores) + + if mlvl_score_factors is not None: + mlvl_score_factors = torch.cat(mlvl_score_factors) + + if self.use_sigmoid_cls: + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + + if with_nms: + det_bboxes, det_labels = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_score_factors) + return det_bboxes, det_labels + else: + if mlvl_score_factors is not None: + return mlvl_bboxes, mlvl_scores, mlvl_score_factors + else: + return mlvl_bboxes, mlvl_scores def forward_train(self, x, @@ -56,7 +315,8 @@ def forward_train(self, if proposal_cfg is None: return losses else: - proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg) + proposal_list = self.get_bboxes( + *outs, img_metas=img_metas, cfg=proposal_cfg) return losses, proposal_list def simple_test(self, feats, img_metas, rescale=False): @@ -74,6 +334,6 @@ def simple_test(self, feats, img_metas, rescale=False): The first item is ``bboxes`` with shape (n, 5), where 5 represent (tl_x, tl_y, br_x, br_y, score). The shape of the second tensor in the tuple is ``labels`` - with shape (n,) + with shape (n, ). """ return self.simple_test_bboxes(feats, img_metas, rescale=rescale) diff --git a/mmdet/models/dense_heads/cascade_rpn_head.py b/mmdet/models/dense_heads/cascade_rpn_head.py index f6257b7fca2..c4facc58d63 100644 --- a/mmdet/models/dense_heads/cascade_rpn_head.py +++ b/mmdet/models/dense_heads/cascade_rpn_head.py @@ -508,7 +508,27 @@ def get_bboxes(self, img_metas, cfg, rescale=False): - """Get proposal predict.""" + """Get proposal predict. + + Args: + anchor_list (list[list]): Multi level anchors of each image. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + img_metas (list[dict], Optional): Image meta info. Default None. + cfg (mmcv.Config, Optional): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ assert len(cls_scores) == len(bbox_preds) num_levels = len(cls_scores) @@ -528,23 +548,6 @@ def get_bboxes(self, result_list.append(proposals) return result_list - def refine_bboxes(self, anchor_list, bbox_preds, img_metas): - """Refine bboxes through stages.""" - num_levels = len(bbox_preds) - new_anchor_list = [] - for img_id in range(len(img_metas)): - mlvl_anchors = [] - for i in range(num_levels): - bbox_pred = bbox_preds[i][img_id].detach() - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - img_shape = img_metas[img_id]['img_shape'] - bboxes = self.bbox_coder.decode(anchor_list[img_id][i], - bbox_pred, img_shape) - mlvl_anchors.append(bboxes) - new_anchor_list.append(mlvl_anchors) - return new_anchor_list - - # TODO: temporary plan def _get_bboxes_single(self, cls_scores, bbox_preds, @@ -553,15 +556,18 @@ def _get_bboxes_single(self, scale_factor, cfg, rescale=False): - """Transform outputs for a single batch item into bbox predictions. + """Transform outputs of a single image into bbox predictions. Args: - cls_scores (list[Tensor]): Box scores for each scale level - Has shape (num_anchors * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for each scale - level with shape (num_anchors * 4, H, W). - mlvl_anchors (list[Tensor]): Box reference for each scale level - with shape (num_total_anchors, 4). + cls_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has + shape (num_anchors * 4, H, W). + mlvl_anchors (list[Tensor]): Box reference from all scale + levels of a single image, each item has shape + (num_total_anchors, 4). img_shape (tuple[int]): Shape of the input image, (height, width, 3). scale_factor (ndarray): Scale factor of the image arange as @@ -569,12 +575,12 @@ def _get_bboxes_single(self, cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. + Default False. Returns: - Tensor: Labeled boxes have the shape of (n,5), where the - first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score - between 0 and 1. + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. """ cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) @@ -584,6 +590,7 @@ def _get_bboxes_single(self, mlvl_scores = [] mlvl_bbox_preds = [] mlvl_valid_anchors = [] + nms_pre = cfg.get('nms_pre', -1) for idx in range(len(cls_scores)): rpn_cls_score = cls_scores[idx] rpn_bbox_pred = bbox_preds[idx] @@ -601,18 +608,13 @@ def _get_bboxes_single(self, scores = rpn_cls_score.softmax(dim=1)[:, 0] rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) anchors = mlvl_anchors[idx] - if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: + + if 0 < nms_pre < scores.shape[0]: # sort is faster than topk # _, topk_inds = scores.topk(cfg.nms_pre) - if torch.onnx.is_in_onnx_export(): - # sort op will be converted to TopK in onnx - # and k<=3480 in TensorRT - _, topk_inds = scores.topk(cfg.nms_pre) - scores = scores[topk_inds] - else: - ranked_scores, rank_inds = scores.sort(descending=True) - topk_inds = rank_inds[:cfg.nms_pre] - scores = ranked_scores[:cfg.nms_pre] + ranked_scores, rank_inds = scores.sort(descending=True) + topk_inds = rank_inds[:nms_pre] + scores = ranked_scores[:nms_pre] rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] anchors = anchors[topk_inds, :] mlvl_scores.append(scores) @@ -628,8 +630,7 @@ def _get_bboxes_single(self, anchors, rpn_bbox_pred, max_shape=img_shape) ids = torch.cat(level_ids) - # Skip nonzero op while exporting to ONNX - if cfg.min_bbox_size >= 0 and (not torch.onnx.is_in_onnx_export()): + if cfg.min_bbox_size >= 0: w = proposals[:, 2] - proposals[:, 0] h = proposals[:, 3] - proposals[:, 1] valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) @@ -665,9 +666,29 @@ def _get_bboxes_single(self, f' respectively. Please delete the nms_thr ' \ f'which will be deprecated.' - dets, keep = batched_nms(proposals, scores, ids, cfg.nms) + if proposals.numel() > 0: + dets, keep = batched_nms(proposals, scores, ids, cfg.nms) + else: + return proposals.new_zeros(0, 5) + return dets[:cfg.max_per_img] + def refine_bboxes(self, anchor_list, bbox_preds, img_metas): + """Refine bboxes through stages.""" + num_levels = len(bbox_preds) + new_anchor_list = [] + for img_id in range(len(img_metas)): + mlvl_anchors = [] + for i in range(num_levels): + bbox_pred = bbox_preds[i][img_id].detach() + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + img_shape = img_metas[img_id]['img_shape'] + bboxes = self.bbox_coder.decode(anchor_list[img_id][i], + bbox_pred, img_shape) + mlvl_anchors.append(bboxes) + new_anchor_list.append(mlvl_anchors) + return new_anchor_list + @HEADS.register_module() class CascadeRPNHead(BaseDenseHead): diff --git a/mmdet/models/dense_heads/centernet_head.py b/mmdet/models/dense_heads/centernet_head.py index e5b2c2bd45f..b9d5d2f01fb 100644 --- a/mmdet/models/dense_heads/centernet_head.py +++ b/mmdet/models/dense_heads/centernet_head.py @@ -248,6 +248,7 @@ def get_targets(self, gt_bboxes, gt_labels, feat_shape, img_shape): wh_offset_target_weight=wh_offset_target_weight) return target_result, avg_factor + @force_fp32(apply_to=('center_heatmap_preds', 'wh_preds', 'offset_preds')) def get_bboxes(self, center_heatmap_preds, wh_preds, @@ -258,11 +259,11 @@ def get_bboxes(self, """Transform network output for a batch into bbox predictions. Args: - center_heatmap_preds (list[Tensor]): center predict heatmaps for + center_heatmap_preds (list[Tensor]): Center predict heatmaps for all levels with shape (B, num_classes, H, W). - wh_preds (list[Tensor]): wh predicts for all levels with + wh_preds (list[Tensor]): WH predicts for all levels with shape (B, 2, H, W). - offset_preds (list[Tensor]): offset predicts for all levels + offset_preds (list[Tensor]): Offset predicts for all levels with shape (B, 2, H, W). img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. @@ -281,37 +282,71 @@ def get_bboxes(self, """ assert len(center_heatmap_preds) == len(wh_preds) == len( offset_preds) == 1 - scale_factors = [img_meta['scale_factor'] for img_meta in img_metas] - border_pixs = [img_meta['border'] for img_meta in img_metas] + result_list = [] + for img_id in range(len(img_metas)): + result_list.append( + self._get_bboxes_single( + center_heatmap_preds[0][img_id:img_id + 1, ...], + wh_preds[0][img_id:img_id + 1, ...], + offset_preds[0][img_id:img_id + 1, ...], + img_metas[img_id], + rescale=rescale, + with_nms=with_nms)) + return result_list + + def _get_bboxes_single(self, + center_heatmap_pred, + wh_pred, + offset_pred, + img_meta, + rescale=False, + with_nms=True): + """Transform outputs of a single image into bbox results. + Args: + center_heatmap_pred (Tensor): Center heatmap for current level with + shape (1, num_classes, H, W). + wh_pred (Tensor): WH heatmap for current level with shape + (1, num_classes, H, W). + offset_pred (Tensor): Offset for current level with shape + (1, corner_offset_channels, H, W). + img_meta (dict): Meta information of current image, e.g., + image size, scaling factor, etc. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor, Tensor]: The first item is an (n, 5) tensor, where + 5 represent (tl_x, tl_y, br_x, br_y, score) and the score + between 0 and 1. The shape of the second tensor in the tuple + is (n,), and each element represents the class label of the + corresponding box. + """ batch_det_bboxes, batch_labels = self.decode_heatmap( - center_heatmap_preds[0], - wh_preds[0], - offset_preds[0], - img_metas[0]['batch_input_shape'], + center_heatmap_pred, + wh_pred, + offset_pred, + img_meta['batch_input_shape'], k=self.test_cfg.topk, kernel=self.test_cfg.local_maximum_kernel) - batch_border = batch_det_bboxes.new_tensor( - border_pixs)[:, [2, 0, 2, 0]].unsqueeze(1) - batch_det_bboxes[..., :4] -= batch_border + det_bboxes = batch_det_bboxes.view([-1, 5]) + det_labels = batch_labels.view(-1) + + batch_border = det_bboxes.new_tensor(img_meta['border'])[..., + [2, 0, 2, 0]] + det_bboxes[..., :4] -= batch_border if rescale: - batch_det_bboxes[..., :4] /= batch_det_bboxes.new_tensor( - scale_factors).unsqueeze(1) + det_bboxes[..., :4] /= det_bboxes.new_tensor( + img_meta['scale_factor']) if with_nms: - det_results = [] - for (det_bboxes, det_labels) in zip(batch_det_bboxes, - batch_labels): - det_bbox, det_label = self._bboxes_nms(det_bboxes, det_labels, - self.test_cfg) - det_results.append(tuple([det_bbox, det_label])) - else: - det_results = [ - tuple(bs) for bs in zip(batch_det_bboxes, batch_labels) - ] - return det_results + det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels, + self.test_cfg) + return det_bboxes, det_labels def decode_heatmap(self, center_heatmap_pred, @@ -365,18 +400,13 @@ def decode_heatmap(self, return batch_bboxes, batch_topk_labels def _bboxes_nms(self, bboxes, labels, cfg): - if labels.numel() == 0: - return bboxes, labels - - out_bboxes, keep = batched_nms(bboxes[:, :4].contiguous(), - bboxes[:, -1].contiguous(), labels, - cfg.nms_cfg) - out_labels = labels[keep] - - if len(out_bboxes) > 0: - idx = torch.argsort(out_bboxes[:, -1], descending=True) - idx = idx[:cfg.max_per_img] - out_bboxes = out_bboxes[idx] - out_labels = out_labels[idx] - - return out_bboxes, out_labels + if labels.numel() > 0: + max_num = cfg.max_per_img + bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:, + -1].contiguous(), + labels, cfg.nms) + if max_num > 0: + bboxes = bboxes[:max_num] + labels = labels[keep][:max_num] + + return bboxes, labels diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py index ff9a97f720a..184fd48e382 100644 --- a/mmdet/models/dense_heads/corner_head.py +++ b/mmdet/models/dense_heads/corner_head.py @@ -766,15 +766,8 @@ def _get_bboxes_single(self, batch_bboxes /= batch_bboxes.new_tensor(img_meta['scale_factor']) bboxes = batch_bboxes.view([-1, 4]) - scores = batch_scores.view([-1, 1]) - clses = batch_clses.view([-1, 1]) - - # use `sort` instead of `argsort` here, since currently exporting - # `argsort` to ONNX opset version 11 is not supported - scores, idx = scores.sort(dim=0, descending=True) - bboxes = bboxes[idx].view([-1, 4]) - scores = scores.view(-1) - clses = clses[idx].view(-1) + scores = batch_scores.view(-1) + clses = batch_clses.view(-1) detections = torch.cat([bboxes, scores.unsqueeze(-1)], -1) keepinds = (detections[:, -1] > -0.1) @@ -788,33 +781,22 @@ def _get_bboxes_single(self, return detections, labels def _bboxes_nms(self, bboxes, labels, cfg): - if labels.numel() == 0: - return bboxes, labels - if 'nms_cfg' in cfg: warning.warn('nms_cfg in test_cfg will be deprecated. ' 'Please rename it as nms') if 'nms' not in cfg: cfg.nms = cfg.nms_cfg - out_bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:, -1], labels, - cfg.nms) - out_labels = labels[keep] - - if len(out_bboxes) > 0: - # use `sort` to replace with `argsort` here - _, idx = torch.sort(out_bboxes[:, -1], descending=True) - max_per_img = out_bboxes.new_tensor(cfg.max_per_img).to(torch.long) - nms_after = max_per_img - if torch.onnx.is_in_onnx_export(): - # Always keep topk op for dynamic input in onnx - from mmdet.core.export import get_k_for_topk - nms_after = get_k_for_topk(max_per_img, out_bboxes.shape[0]) - idx = idx[:nms_after] - out_bboxes = out_bboxes[idx] - out_labels = out_labels[idx] - - return out_bboxes, out_labels + if labels.numel() > 0: + max_num = cfg.max_per_img + bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:, + -1].contiguous(), + labels, cfg.nms) + if max_num > 0: + bboxes = bboxes[:max_num] + labels = labels[keep][:max_num] + + return bboxes, labels def decode_heatmap(self, tl_heat, diff --git a/mmdet/models/dense_heads/dense_test_mixins.py b/mmdet/models/dense_heads/dense_test_mixins.py index cb9df1a987e..56407fd5465 100644 --- a/mmdet/models/dense_heads/dense_test_mixins.py +++ b/mmdet/models/dense_heads/dense_test_mixins.py @@ -33,7 +33,8 @@ def simple_test_bboxes(self, feats, img_metas, rescale=False): with shape (n,) """ outs = self.forward(feats) - results_list = self.get_bboxes(*outs, img_metas, rescale=rescale) + results_list = self.get_bboxes( + *outs, img_metas=img_metas, rescale=rescale) return results_list def aug_test_bboxes(self, feats, img_metas, rescale=False): @@ -61,10 +62,7 @@ def aug_test_bboxes(self, feats, img_metas, rescale=False): # check with_nms argument gb_sig = signature(self.get_bboxes) gb_args = [p.name for p in gb_sig.parameters.values()] - if hasattr(self, '_get_bboxes'): - gbs_sig = signature(self._get_bboxes) - else: - gbs_sig = signature(self._get_bboxes_single) + gbs_sig = signature(self._get_bboxes_single) gbs_args = [p.name for p in gbs_sig.parameters.values()] assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ f'{self.__class__.__name__}' \ @@ -76,8 +74,12 @@ def aug_test_bboxes(self, feats, img_metas, rescale=False): for x, img_meta in zip(feats, img_metas): # only one image in the batch outs = self.forward(x) - bbox_inputs = outs + (img_meta, self.test_cfg, False, False) - bbox_outputs = self.get_bboxes(*bbox_inputs)[0] + bbox_outputs = self.get_bboxes( + *outs, + img_metas=img_meta, + cfg=self.test_cfg, + rescale=False, + with_nms=False)[0] aug_bboxes.append(bbox_outputs[0]) aug_scores.append(bbox_outputs[1]) # bbox_outputs of some detectors (e.g., ATSS, FCOS, YOLOv3) @@ -122,7 +124,7 @@ def simple_test_rpn(self, x, img_metas): where 5 represent (tl_x, tl_y, br_x, br_y, score). """ rpn_outs = self(x) - proposal_list = self.get_bboxes(*rpn_outs, img_metas) + proposal_list = self.get_bboxes(*rpn_outs, img_metas=img_metas) return proposal_list def aug_test_rpn(self, feats, img_metas): @@ -168,7 +170,7 @@ async def async_simple_test_rpn(self, x, img_metas): sleep_interval=sleep_interval): rpn_outs = self(x) - proposal_list = self.get_bboxes(*rpn_outs, img_metas) + proposal_list = self.get_bboxes(*rpn_outs, img_metas=img_metas) return proposal_list def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): diff --git a/mmdet/models/dense_heads/fcos_head.py b/mmdet/models/dense_heads/fcos_head.py index 8eedd788113..dff7c3740d8 100644 --- a/mmdet/models/dense_heads/fcos_head.py +++ b/mmdet/models/dense_heads/fcos_head.py @@ -5,7 +5,7 @@ from mmcv.cnn import Scale from mmcv.runner import force_fp32 -from mmdet.core import distance2bbox, multi_apply, multiclass_nms, reduce_mean +from mmdet.core import multi_apply, reduce_mean from ..builder import HEADS, build_loss from .anchor_free_head import AnchorFreeHead @@ -241,9 +241,10 @@ def loss(self, if len(pos_inds) > 0: pos_points = flatten_points[pos_inds] - pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) - pos_decoded_target_preds = distance2bbox(pos_points, - pos_bbox_targets) + pos_decoded_bbox_preds = self.bbox_coder.decode( + pos_points, pos_bbox_preds) + pos_decoded_target_preds = self.bbox_coder.decode( + pos_points, pos_bbox_targets) loss_bbox = self.loss_bbox( pos_decoded_bbox_preds, pos_decoded_target_preds, @@ -260,216 +261,6 @@ def loss(self, loss_bbox=loss_bbox, loss_centerness=loss_centerness) - @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses')) - def get_bboxes(self, - cls_scores, - bbox_preds, - centernesses, - img_metas, - cfg=None, - rescale=False, - with_nms=True): - """Transform network output for a batch into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box scores for each scale level - with shape (N, num_points * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for each scale - level with shape (N, num_points * 4, H, W). - centernesses (list[Tensor]): Centerness for each scale level with - shape (N, num_points * 1, H, W). - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - cfg (mmcv.Config | None): Test / postprocessing configuration, - if None, test_cfg would be used. Default: None. - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before return boxes. - Default: True. - - Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where 5 represent - (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. - The shape of the second tensor in the tuple is (n,), and - each element represents the class label of the corresponding - box. - """ - assert len(cls_scores) == len(bbox_preds) - num_levels = len(cls_scores) - - featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, - bbox_preds[0].device) - - cls_score_list = [cls_scores[i].detach() for i in range(num_levels)] - bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)] - centerness_pred_list = [ - centernesses[i].detach() for i in range(num_levels) - ] - if torch.onnx.is_in_onnx_export(): - assert len( - img_metas - ) == 1, 'Only support one input image while in exporting to ONNX' - img_shapes = img_metas[0]['img_shape_for_onnx'] - else: - img_shapes = [ - img_metas[i]['img_shape'] - for i in range(cls_scores[0].shape[0]) - ] - scale_factors = [ - img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0]) - ] - result_list = self._get_bboxes(cls_score_list, bbox_pred_list, - centerness_pred_list, mlvl_points, - img_shapes, scale_factors, cfg, rescale, - with_nms) - return result_list - - def _get_bboxes(self, - cls_scores, - bbox_preds, - centernesses, - mlvl_points, - img_shapes, - scale_factors, - cfg, - rescale=False, - with_nms=True): - """Transform outputs for a single batch item into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box scores for a single scale level - with shape (N, num_points * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for a single scale - level with shape (N, num_points * 4, H, W). - centernesses (list[Tensor]): Centerness for a single scale level - with shape (N, num_points, H, W). - mlvl_points (list[Tensor]): Box reference for a single scale level - with shape (num_total_points, 4). - img_shapes (list[tuple[int]]): Shape of the input image, - list[(height, width, 3)]. - scale_factors (list[ndarray]): Scale factor of the image arrange as - (w_scale, h_scale, w_scale, h_scale). - cfg (mmcv.Config | None): Test / postprocessing configuration, - if None, test_cfg would be used. - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before return boxes. - Default: True. - - Returns: - tuple(Tensor): - det_bboxes (Tensor): BBox predictions in shape (n, 5), where - the first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score - between 0 and 1. - det_labels (Tensor): A (n,) tensor where each item is the - predicted class label of the corresponding box. - """ - cfg = self.test_cfg if cfg is None else cfg - assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) - device = cls_scores[0].device - batch_size = cls_scores[0].shape[0] - # convert to tensor to keep tracing - nms_pre_tensor = torch.tensor( - cfg.get('nms_pre', -1), device=device, dtype=torch.long) - mlvl_bboxes = [] - mlvl_scores = [] - mlvl_centerness = [] - for cls_score, bbox_pred, centerness, points in zip( - cls_scores, bbox_preds, centernesses, mlvl_points): - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - scores = cls_score.permute(0, 2, 3, 1).reshape( - batch_size, -1, self.cls_out_channels).sigmoid() - centerness = centerness.permute(0, 2, 3, - 1).reshape(batch_size, - -1).sigmoid() - - bbox_pred = bbox_pred.permute(0, 2, 3, - 1).reshape(batch_size, -1, 4) - points = points.expand(batch_size, -1, 2) - # Get top-k prediction - from mmdet.core.export import get_k_for_topk - nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1]) - if nms_pre > 0: - max_scores, _ = (scores * centerness[..., None]).max(-1) - _, topk_inds = max_scores.topk(nms_pre) - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds).long() - # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 - if torch.onnx.is_in_onnx_export(): - transformed_inds = bbox_pred.shape[ - 1] * batch_inds + topk_inds - points = points.reshape(-1, - 2)[transformed_inds, :].reshape( - batch_size, -1, 2) - bbox_pred = bbox_pred.reshape( - -1, 4)[transformed_inds, :].reshape(batch_size, -1, 4) - scores = scores.reshape( - -1, self.num_classes)[transformed_inds, :].reshape( - batch_size, -1, self.num_classes) - centerness = centerness.reshape( - -1, 1)[transformed_inds].reshape(batch_size, -1) - else: - points = points[batch_inds, topk_inds, :] - bbox_pred = bbox_pred[batch_inds, topk_inds, :] - scores = scores[batch_inds, topk_inds, :] - centerness = centerness[batch_inds, topk_inds] - - bboxes = distance2bbox(points, bbox_pred, max_shape=img_shapes) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - mlvl_centerness.append(centerness) - - batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) - if rescale: - batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( - scale_factors).unsqueeze(1) - batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) - batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1) - - # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment - if torch.onnx.is_in_onnx_export() and with_nms: - from mmdet.core.export import add_dummy_nms_for_onnx - batch_mlvl_scores = batch_mlvl_scores * ( - batch_mlvl_centerness.unsqueeze(2)) - max_output_boxes_per_class = cfg.nms.get( - 'max_output_boxes_per_class', 200) - iou_threshold = cfg.nms.get('iou_threshold', 0.5) - score_threshold = cfg.score_thr - nms_pre = cfg.get('deploy_nms_pre', -1) - return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores, - max_output_boxes_per_class, - iou_threshold, score_threshold, - nms_pre, cfg.max_per_img) - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - padding = batch_mlvl_scores.new_zeros(batch_size, - batch_mlvl_scores.shape[1], 1) - batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) - - if with_nms: - det_results = [] - for (mlvl_bboxes, mlvl_scores, - mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores, - batch_mlvl_centerness): - det_bbox, det_label = multiclass_nms( - mlvl_bboxes, - mlvl_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=mlvl_centerness) - det_results.append(tuple([det_bbox, det_label])) - else: - det_results = [ - tuple(mlvl_bs) - for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores, - batch_mlvl_centerness) - ] - return det_results - def _get_points_single(self, featmap_size, stride, diff --git a/mmdet/models/dense_heads/fovea_head.py b/mmdet/models/dense_heads/fovea_head.py index 5dab829ef36..0173c2f52c7 100644 --- a/mmdet/models/dense_heads/fovea_head.py +++ b/mmdet/models/dense_heads/fovea_head.py @@ -5,7 +5,7 @@ from mmcv.ops import DeformConv2d from mmcv.runner import BaseModule -from mmdet.core import multi_apply, multiclass_nms +from mmdet.core import multi_apply from ..builder import HEADS from .anchor_free_head import AnchorFreeHead @@ -265,85 +265,97 @@ def _get_target_single(self, bbox_target_list.append(torch.log(bbox_targets)) return label_list, bbox_target_list - def get_bboxes(self, - cls_scores, - bbox_preds, - img_metas, - cfg=None, - rescale=None): - assert len(cls_scores) == len(bbox_preds) - num_levels = len(cls_scores) - featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - points = self.get_points( - featmap_sizes, - bbox_preds[0].dtype, - bbox_preds[0].device, - flatten=True) - result_list = [] - for img_id in range(len(img_metas)): - cls_score_list = [ - cls_scores[i][img_id].detach() for i in range(num_levels) - ] - bbox_pred_list = [ - bbox_preds[i][img_id].detach() for i in range(num_levels) - ] - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - det_bboxes = self._get_bboxes_single(cls_score_list, - bbox_pred_list, featmap_sizes, - points, img_shape, - scale_factor, cfg, rescale) - result_list.append(det_bboxes) - return result_list - + # Same as base_dense_head/_get_bboxes_single except self._bbox_decode def _get_bboxes_single(self, - cls_scores, - bbox_preds, - featmap_sizes, - point_list, - img_shape, - scale_factor, + cls_score_list, + bbox_pred_list, + score_factor_list, + img_meta, cfg, - rescale=False): + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. Fovea head does not need this value. + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bbox, 5], where the first 4 columns are bounding box \ + positions (tl_x, tl_y, br_x, br_y) and the 5-th column \ + are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bbox]. + """ cfg = self.test_cfg if cfg is None else cfg - assert len(cls_scores) == len(bbox_preds) == len(point_list) + assert len(cls_score_list) == len(bbox_pred_list) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + det_bboxes = [] det_scores = [] - for cls_score, bbox_pred, featmap_size, stride, base_len, (y, x) \ - in zip(cls_scores, bbox_preds, featmap_sizes, self.strides, - self.base_edge_list, point_list): + for level_idx, (cls_score, bbox_pred, stride, base_len) in enumerate( + zip(cls_score_list, bbox_pred_list, self.strides, + self.base_edge_list)): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + featmap_size_hw = cls_score.shape[-2:] scores = cls_score.permute(1, 2, 0).reshape( -1, self.cls_out_channels).sigmoid() - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4).exp() - nms_pre = cfg.get('nms_pre', -1) - if (nms_pre > 0) and (scores.shape[0] > nms_pre): + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + if 0 < nms_pre < scores.shape[0]: max_scores, _ = scores.max(dim=1) _, topk_inds = max_scores.topk(nms_pre) + priors = self.prior_generator.sparse_priors( + topk_inds, featmap_size_hw, level_idx, scores.dtype, + scores.device) bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] - y = y[topk_inds] - x = x[topk_inds] - x1 = (stride * x - base_len * bbox_pred[:, 0]). \ - clamp(min=0, max=img_shape[1] - 1) - y1 = (stride * y - base_len * bbox_pred[:, 1]). \ - clamp(min=0, max=img_shape[0] - 1) - x2 = (stride * x + base_len * bbox_pred[:, 2]). \ - clamp(min=0, max=img_shape[1] - 1) - y2 = (stride * y + base_len * bbox_pred[:, 3]). \ - clamp(min=0, max=img_shape[0] - 1) - bboxes = torch.stack([x1, y1, x2, y2], -1) + else: + priors = self.prior_generator.single_level_grid_priors( + featmap_size_hw, level_idx, scores.device) + + bboxes = self._bbox_decode(priors, bbox_pred, base_len, img_shape) + det_bboxes.append(bboxes) det_scores.append(scores) - det_bboxes = torch.cat(det_bboxes) - if rescale: - det_bboxes /= det_bboxes.new_tensor(scale_factor) - det_scores = torch.cat(det_scores) - padding = det_scores.new_zeros(det_scores.shape[0], 1) - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - det_scores = torch.cat([det_scores, padding], dim=1) - det_bboxes, det_labels = multiclass_nms(det_bboxes, det_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels + + return self._bbox_post_process(det_scores, det_bboxes, + img_meta['scale_factor'], cfg, rescale, + with_nms) + + def _bbox_decode(self, priors, bbox_pred, base_len, max_shape): + bbox_pred = bbox_pred.exp() + + y = priors[:, 1] + x = priors[:, 0] + x1 = (x - base_len * bbox_pred[:, 0]). \ + clamp(min=0, max=max_shape[1] - 1) + y1 = (y - base_len * bbox_pred[:, 1]). \ + clamp(min=0, max=max_shape[0] - 1) + x2 = (x + base_len * bbox_pred[:, 2]). \ + clamp(min=0, max=max_shape[1] - 1) + y2 = (y + base_len * bbox_pred[:, 3]). \ + clamp(min=0, max=max_shape[0] - 1) + decoded_bboxes = torch.stack([x1, y1, x2, y2], -1) + return decoded_bboxes diff --git a/mmdet/models/dense_heads/free_anchor_retina_head.py b/mmdet/models/dense_heads/free_anchor_retina_head.py index 6859ee3f295..441dc13a1ab 100644 --- a/mmdet/models/dense_heads/free_anchor_retina_head.py +++ b/mmdet/models/dense_heads/free_anchor_retina_head.py @@ -77,7 +77,7 @@ def loss(self, dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == len(self.anchor_generator.base_anchors) + assert len(featmap_sizes) == len(self.prior_generator.base_anchors) anchor_list, _ = self.get_anchors(featmap_sizes, img_metas) anchors = [torch.cat(anchor) for anchor in anchor_list] diff --git a/mmdet/models/dense_heads/fsaf_head.py b/mmdet/models/dense_heads/fsaf_head.py index 25f58042829..2d2b7879694 100644 --- a/mmdet/models/dense_heads/fsaf_head.py +++ b/mmdet/models/dense_heads/fsaf_head.py @@ -215,7 +215,7 @@ def loss(self, bbox_preds[i] = bbox_preds[i].clamp(min=1e-4) # TODO: It may directly use the base-class loss function. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels batch_size = len(gt_bboxes) device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( diff --git a/mmdet/models/dense_heads/gfl_head.py b/mmdet/models/dense_heads/gfl_head.py index 869e4b62dc1..5bdfe43f72f 100644 --- a/mmdet/models/dense_heads/gfl_head.py +++ b/mmdet/models/dense_heads/gfl_head.py @@ -7,8 +7,7 @@ from mmdet.core import (anchor_inside_flags, bbox2distance, bbox_overlaps, build_assigner, build_sampler, distance2bbox, - images_to_levels, multi_apply, multiclass_nms, - reduce_mean, unmap) + images_to_levels, multi_apply, reduce_mean, unmap) from ..builder import HEADS, build_loss from .anchor_head import AnchorHead @@ -73,6 +72,8 @@ class GFLHead(AnchorHead): norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='GN', num_groups=32, requires_grad=True). loss_qfl (dict): Config of Quality Focal Loss (QFL). + bbox_coder (dict): Config of bbox coder. Defaults + 'DistancePointBBoxCoder'. reg_max (int): Max value of integral set :math: `{0, ..., reg_max}` in QFL setting. Default: 16. init_cfg (dict or list[dict], optional): Initialization config dict. @@ -90,6 +91,7 @@ def __init__(self, conv_cfg=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25), + bbox_coder=dict(type='DistancePointBBoxCoder'), reg_max=16, init_cfg=dict( type='Normal', @@ -106,7 +108,11 @@ def __init__(self, self.norm_cfg = norm_cfg self.reg_max = reg_max super(GFLHead, self).__init__( - num_classes, in_channels, init_cfg=init_cfg, **kwargs) + num_classes, + in_channels, + bbox_coder=bbox_coder, + init_cfg=init_cfg, + **kwargs) self.sampling = False if self.train_cfg: @@ -149,7 +155,7 @@ def _init_layers(self): self.gfl_reg = nn.Conv2d( self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1) self.scales = nn.ModuleList( - [Scale(1.0) for _ in self.anchor_generator.strides]) + [Scale(1.0) for _ in self.prior_generator.strides]) def forward(self, feats): """Forward features from the upstream network. @@ -325,7 +331,7 @@ def loss(self, """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( @@ -360,7 +366,7 @@ def loss(self, labels_list, label_weights_list, bbox_targets_list, - self.anchor_generator.strides, + self.prior_generator.strides, num_total_samples=num_total_samples) avg_factor = sum(avg_factor) @@ -370,30 +376,28 @@ def loss(self, return dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl) - def _get_bboxes(self, - cls_scores, - bbox_preds, - mlvl_anchors, - img_shapes, - scale_factors, - cfg, - rescale=False, - with_nms=True): - """Transform outputs for a single batch item into labeled boxes. + def _get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + score_factor_list, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. Args: - cls_scores (list[Tensor]): Box scores for a single scale level - has shape (N, num_classes, H, W). - bbox_preds (list[Tensor]): Box distribution logits for a single - scale level with shape (N, 4*(n+1), H, W), n is max value of - integral set. - mlvl_anchors (list[Tensor]): Box reference for a single scale level - with shape (num_total_anchors, 4). - img_shapes (list[tuple[int]]): Shape of the input image, - list[(height, width, 3)]. - scale_factors (list[ndarray]): Scale factor of the image arange as - (w_scale, h_scale, w_scale, h_scale). - cfg (mmcv.Config | None): Test / postprocessing configuration, + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. GFL head does not need this value. + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. @@ -401,75 +405,61 @@ def _get_bboxes(self, Default: True. Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where 5 represent - (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. - The shape of the second tensor in the tuple is (n,), and - each element represents the class label of the corresponding - box. + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bbox, 5], where the first 4 columns are bounding box \ + positions (tl_x, tl_y, br_x, br_y) and the 5-th column \ + are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bbox]. """ cfg = self.test_cfg if cfg is None else cfg - assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) - batch_size = cls_scores[0].shape[0] + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) mlvl_bboxes = [] mlvl_scores = [] - for cls_score, bbox_pred, stride, anchors in zip( - cls_scores, bbox_preds, self.anchor_generator.strides, - mlvl_anchors): + for level_idx, (cls_score, bbox_pred, stride) in enumerate( + zip(cls_score_list, bbox_pred_list, + self.prior_generator.strides)): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert stride[0] == stride[1] - scores = cls_score.permute(0, 2, 3, 1).reshape( - batch_size, -1, self.cls_out_channels).sigmoid() - bbox_pred = bbox_pred.permute(0, 2, 3, 1) + featmap_size_hw = cls_score.shape[-2:] + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 0) bbox_pred = self.integral(bbox_pred) * stride[0] - bbox_pred = bbox_pred.reshape(batch_size, -1, 4) - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[1] > nms_pre: - max_scores, _ = scores.max(-1) + if 0 < nms_pre < scores.shape[0]: + max_scores, _ = scores.max(dim=1) _, topk_inds = max_scores.topk(nms_pre) - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds).long() - anchors = anchors[topk_inds, :] - bbox_pred = bbox_pred[batch_inds, topk_inds, :] - scores = scores[batch_inds, topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + anchors = self.prior_generator.sparse_priors( + topk_inds, featmap_size_hw, level_idx, scores.dtype, + scores.device) else: - anchors = anchors.expand_as(bbox_pred) + anchors = self.prior_generator.single_level_grid_priors( + featmap_size_hw, level_idx, scores.device) - bboxes = distance2bbox( - self.anchor_center(anchors), bbox_pred, max_shape=img_shapes) + bboxes = self.bbox_coder.decode( + self.anchor_center(anchors), bbox_pred, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) - batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) - if rescale: - batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( - scale_factors).unsqueeze(1) - - batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) - # Add a dummy background class to the backend when using sigmoid - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - padding = batch_mlvl_scores.new_zeros(batch_size, - batch_mlvl_scores.shape[1], 1) - batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) - - if with_nms: - det_results = [] - for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes, - batch_mlvl_scores): - det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - det_results.append(tuple([det_bbox, det_label])) - else: - det_results = [ - tuple(mlvl_bs) - for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores) - ] - return det_results + return self._bbox_post_process( + mlvl_scores, + mlvl_bboxes, + img_meta['scale_factor'], + cfg, + rescale=rescale, + with_nms=with_nms) def get_targets(self, anchor_list, diff --git a/mmdet/models/dense_heads/ld_head.py b/mmdet/models/dense_heads/ld_head.py index bd23355358e..6bd6e8972b5 100644 --- a/mmdet/models/dense_heads/ld_head.py +++ b/mmdet/models/dense_heads/ld_head.py @@ -212,7 +212,7 @@ def loss(self, """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( @@ -247,7 +247,7 @@ def loss(self, labels_list, label_weights_list, bbox_targets_list, - self.anchor_generator.strides, + self.prior_generator.strides, soft_target, num_total_samples=num_total_samples) diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py index f007d07b307..89643540d3a 100644 --- a/mmdet/models/dense_heads/paa_head.py +++ b/mmdet/models/dense_heads/paa_head.py @@ -114,7 +114,7 @@ def loss(self, """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( @@ -506,7 +506,7 @@ def _get_targets_single(self, This method is same as `AnchorHead._get_targets_single()`. """ assert unmap_outputs, 'We must map outputs back to the original' \ - 'set of anchors in PAAhead' + 'set of anchors in PAAhead' return super(ATSSHead, self)._get_targets_single( flat_anchors, valid_flags, @@ -517,94 +517,78 @@ def _get_targets_single(self, label_channels=1, unmap_outputs=True) - def _get_bboxes(self, - cls_scores, - bbox_preds, - iou_preds, - mlvl_anchors, - img_shapes, - scale_factors, - cfg, - rescale=False, - with_nms=True): - """Transform outputs for a single batch item into labeled boxes. - - This method is almost same as `ATSSHead._get_bboxes()`. - We use sqrt(iou_preds * cls_scores) in NMS process instead of just - cls_scores. Besides, score voting is used when `` score_voting`` - is set to True. + def _bbox_post_process(self, + mlvl_scores, + mlvl_bboxes, + scale_factor, + cfg, + rescale=False, + with_nms=True, + mlvl_score_factors=None, + **kwargs): + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually with_nms is False is used for aug test. + + Args: + mlvl_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num, num_class). + mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale + levels of a single image, each item has shape (num, 4). + scale_factor (ndarray, optional): Scale factor of the image arange + as (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + mlvl_score_factors (list[Tensor], optional): Score factor from + all scale levels of a single image, each item has shape + (num, ). Default: None. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bbox, 5], where the first 4 columns are bounding box \ + positions (tl_x, tl_y, br_x, br_y) and the 5-th column \ + are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bbox]. """ - assert with_nms, 'PAA only supports "with_nms=True" now and it ' \ - 'means PAAHead does not support ' \ - 'test-time augmentation' - assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) - batch_size = cls_scores[0].shape[0] - - mlvl_bboxes = [] - mlvl_scores = [] - mlvl_iou_preds = [] - for cls_score, bbox_pred, iou_preds, anchors in zip( - cls_scores, bbox_preds, iou_preds, mlvl_anchors): - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - - scores = cls_score.permute(0, 2, 3, 1).reshape( - batch_size, -1, self.cls_out_channels).sigmoid() - bbox_pred = bbox_pred.permute(0, 2, 3, - 1).reshape(batch_size, -1, 4) - iou_preds = iou_preds.permute(0, 2, 3, 1).reshape(batch_size, - -1).sigmoid() - - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[1] > nms_pre: - max_scores, _ = (scores * iou_preds[..., None]).sqrt().max(-1) - _, topk_inds = max_scores.topk(nms_pre) - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds).long() - anchors = anchors[topk_inds, :] - bbox_pred = bbox_pred[batch_inds, topk_inds, :] - scores = scores[batch_inds, topk_inds, :] - iou_preds = iou_preds[batch_inds, topk_inds] - else: - anchors = anchors.expand_as(bbox_pred) - - bboxes = self.bbox_coder.decode( - anchors, bbox_pred, max_shape=img_shapes) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - mlvl_iou_preds.append(iou_preds) - - batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) + mlvl_bboxes = torch.cat(mlvl_bboxes) if rescale: - batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( - scale_factors).unsqueeze(1) - batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + mlvl_scores = torch.cat(mlvl_scores) # Add a dummy background class to the backend when using sigmoid # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class - padding = batch_mlvl_scores.new_zeros(batch_size, - batch_mlvl_scores.shape[1], 1) - batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) - batch_mlvl_iou_preds = torch.cat(mlvl_iou_preds, dim=1) - batch_mlvl_nms_scores = (batch_mlvl_scores * - batch_mlvl_iou_preds[..., None]).sqrt() - - det_results = [] - for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes, - batch_mlvl_nms_scores): - det_bbox, det_label = multiclass_nms( - mlvl_bboxes, - mlvl_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=None) - if self.with_score_voting and len(det_bbox) > 0: - det_bbox, det_label = self.score_voting( - det_bbox, det_label, mlvl_bboxes, mlvl_scores, - cfg.score_thr) - det_results.append(tuple([det_bbox, det_label])) - - return det_results + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + + mlvl_iou_preds = torch.cat(mlvl_score_factors) + mlvl_nms_scores = (mlvl_scores * mlvl_iou_preds[:, None]).sqrt() + det_bboxes, det_labels = multiclass_nms( + mlvl_bboxes, + mlvl_nms_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=None) + if self.with_score_voting and len(det_bboxes) > 0: + det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels, + mlvl_bboxes, + mlvl_nms_scores, + cfg.score_thr) + + return det_bboxes, det_labels def score_voting(self, det_bboxes, det_labels, mlvl_bboxes, mlvl_nms_scores, score_thr): @@ -621,8 +605,6 @@ def score_voting(self, det_bboxes, det_labels, mlvl_bboxes, with shape (num_anchors,4). mlvl_nms_scores (Tensor): The scores of all boxes which is used in the NMS procedure, with shape (num_anchors, num_class) - mlvl_iou_preds (Tensor): The predictions of IOU of all boxes - before the NMS procedure, with shape (num_anchors, 1) score_thr (float): The score threshold of bboxes. Returns: @@ -635,7 +617,7 @@ def score_voting(self, det_bboxes, det_labels, mlvl_bboxes, after voting, with shape (num_anchors,). """ candidate_mask = mlvl_nms_scores > score_thr - candidate_mask_nonzeros = candidate_mask.nonzero() + candidate_mask_nonzeros = candidate_mask.nonzero(as_tuple=False) candidate_inds = candidate_mask_nonzeros[:, 0] candidate_labels = candidate_mask_nonzeros[:, 1] candidate_bboxes = mlvl_bboxes[candidate_inds] diff --git a/mmdet/models/dense_heads/reppoints_head.py b/mmdet/models/dense_heads/reppoints_head.py index 81658616755..c533d23f076 100644 --- a/mmdet/models/dense_heads/reppoints_head.py +++ b/mmdet/models/dense_heads/reppoints_head.py @@ -6,7 +6,7 @@ from mmcv.ops import DeformConv2d from mmdet.core import (build_assigner, build_sampler, images_to_levels, - multi_apply, multiclass_nms, unmap) + multi_apply, unmap) from mmdet.core.anchor.point_generator import MlvlPointGenerator from ..builder import HEADS, build_loss from .anchor_free_head import AnchorFreeHead @@ -94,7 +94,7 @@ def __init__(self, self.gradient_mul = gradient_mul self.point_base_scale = point_base_scale self.point_strides = point_strides - self.point_generator = MlvlPointGenerator( + self.prior_generator = MlvlPointGenerator( self.point_strides, offset=0.) self.sampling = loss_cls['type'] not in ['FocalLoss'] @@ -294,7 +294,11 @@ def forward_single(self, x): pts_out_refine, bbox_out_init.detach()) else: pts_out_refine = pts_out_refine + pts_out_init.detach() - return cls_out, pts_out_init, pts_out_refine + + if self.training: + return cls_out, pts_out_init, pts_out_refine + else: + return cls_out, self.points2bbox(pts_out_refine) def get_points(self, featmap_sizes, img_metas, device): """Get points according to feature map sizes. @@ -310,7 +314,7 @@ def get_points(self, featmap_sizes, img_metas, device): # since feature map sizes of all images are the same, we only compute # points center for one time - multi_level_points = self.point_generator.grid_priors( + multi_level_points = self.prior_generator.grid_priors( featmap_sizes, device, with_stride=True) points_list = [[point.clone() for point in multi_level_points] for _ in range(num_imgs)] @@ -318,7 +322,7 @@ def get_points(self, featmap_sizes, img_metas, device): # for each image, we compute valid flags of multi level grids valid_flag_list = [] for img_id, img_meta in enumerate(img_metas): - multi_level_flags = self.point_generator.valid_flags( + multi_level_flags = self.prior_generator.valid_flags( featmap_sizes, img_meta['pad_shape']) valid_flag_list.append(multi_level_flags) @@ -650,58 +654,61 @@ def loss(self, } return loss_dict_all - def get_bboxes(self, - cls_scores, - pts_preds_init, - pts_preds_refine, - img_metas, - cfg=None, - rescale=False, - with_nms=True): - assert len(cls_scores) == len(pts_preds_refine) - device = cls_scores[0].device - bbox_preds_refine = [ - self.points2bbox(pts_pred_refine) - for pts_pred_refine in pts_preds_refine - ] - num_levels = len(cls_scores) - featmap_sizes = [ - cls_scores[i].size()[-2:] for i in range(len(cls_scores)) - ] - multi_level_points = self.point_generator.grid_priors( - featmap_sizes, device) - - result_list = [] - for img_id in range(len(img_metas)): - cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)] - bbox_pred_list = [ - bbox_preds_refine[i][img_id] for i in range(num_levels) - ] - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, - multi_level_points, img_shape, - scale_factor, cfg, rescale, - with_nms) - result_list.append(proposals) - return result_list - + # Same as base_dense_head/_get_bboxes_single except self._bbox_decode def _get_bboxes_single(self, - cls_scores, - bbox_preds, - mlvl_points, - img_shape, - scale_factor, + cls_score_list, + bbox_pred_list, + score_factor_list, + img_meta, cfg, rescale=False, - with_nms=True): + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. RepPoints head does not need + this value. + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bbox, 5], where the first 4 columns are bounding box \ + positions (tl_x, tl_y, br_x, br_y) and the 5-th column \ + are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bbox]. + """ cfg = self.test_cfg if cfg is None else cfg - assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) + assert len(cls_score_list) == len(bbox_pred_list) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + mlvl_bboxes = [] mlvl_scores = [] - for i_lvl, (cls_score, bbox_pred, points) in enumerate( - zip(cls_scores, bbox_preds, mlvl_points)): + for level_idx, (cls_score, bbox_pred) in enumerate( + zip(cls_score_list, bbox_pred_list)): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + featmap_size_hw = cls_score.shape[-2:] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels) if self.use_sigmoid_cls: @@ -709,8 +716,8 @@ def _get_bboxes_single(self, else: scores = cls_score.softmax(-1) bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: + + if 0 < nms_pre < scores.shape[0]: if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=1) else: @@ -719,32 +726,37 @@ def _get_bboxes_single(self, # BG cat_id: num_class max_scores, _ = scores[:, :-1].max(dim=1) _, topk_inds = max_scores.topk(nms_pre) - points = points[topk_inds, :] + + points = self.prior_generator.sparse_priors( + topk_inds, featmap_size_hw, level_idx, scores.dtype, + scores.device) bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] - bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1) - bboxes = bbox_pred * self.point_strides[i_lvl] + bbox_pos_center - x1 = bboxes[:, 0].clamp(min=0, max=img_shape[1]) - y1 = bboxes[:, 1].clamp(min=0, max=img_shape[0]) - x2 = bboxes[:, 2].clamp(min=0, max=img_shape[1]) - y2 = bboxes[:, 3].clamp(min=0, max=img_shape[0]) - bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + else: + points = self.prior_generator.single_level_grid_priors( + featmap_size_hw, level_idx, scores.device) + + bboxes = self._bbox_decode(points, bbox_pred, + self.point_strides[level_idx], + img_shape) + mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) - mlvl_bboxes = torch.cat(mlvl_bboxes) - if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) - if self.use_sigmoid_cls: - # Add a dummy background class to the backend when using sigmoid - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - if with_nms: - det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels - else: - return mlvl_bboxes, mlvl_scores + + return self._bbox_post_process( + mlvl_scores, + mlvl_bboxes, + img_meta['scale_factor'], + cfg, + rescale=rescale, + with_nms=with_nms) + + def _bbox_decode(self, points, bbox_pred, stride, max_shape): + bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1) + bboxes = bbox_pred * stride + bbox_pos_center + x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1]) + y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0]) + x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1]) + y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0]) + decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return decoded_bboxes diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py index 3eef10f92d3..1024d782130 100644 --- a/mmdet/models/dense_heads/rpn_head.py +++ b/mmdet/models/dense_heads/rpn_head.py @@ -6,7 +6,6 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.ops import batched_nms -from mmcv.runner import force_fp32 from ..builder import HEADS from .anchor_head import AnchorHead @@ -99,87 +98,33 @@ def loss(self, return dict( loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox']) - @force_fp32(apply_to=('cls_scores', 'bbox_preds')) - def get_bboxes(self, - cls_scores, - bbox_preds, - img_metas, - cfg=None, - rescale=False, - with_nms=True): - """Transform network output for a batch into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box scores for each scale level - Has shape (N, num_anchors * num_classes, H, W) - bbox_preds (list[Tensor]): Box energies / deltas for each scale - level with shape (N, num_anchors * 4, H, W) - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - cfg (mmcv.Config | None): Test / postprocessing configuration, - if None, test_cfg would be used - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before return boxes. - Default: True. - - Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where the first 4 columns - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. The second item is a - (n,) tensor where each item is the predicted class label of the - corresponding box. - """ - assert with_nms, '``with_nms`` in RPNHead should always True' - assert len(cls_scores) == len(bbox_preds) - num_levels = len(cls_scores) - device = cls_scores[0].device - featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_anchors = self.anchor_generator.grid_anchors( - featmap_sizes, device=device) - - result_list = [] - for img_id in range(len(img_metas)): - cls_score_list = [ - cls_scores[i][img_id].detach() for i in range(num_levels) - ] - bbox_pred_list = [ - bbox_preds[i][img_id].detach() for i in range(num_levels) - ] - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, - mlvl_anchors, img_shape, - scale_factor, cfg, rescale) - result_list.append(proposals) - return result_list - def _get_bboxes_single(self, - cls_scores, - bbox_preds, - mlvl_anchors, - img_shape, - scale_factor, + cls_score_list, + bbox_pred_list, + score_factor_list, + img_meta, cfg, - rescale=False): - """Transform outputs for a single batch item into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box scores of all scale level - each item has shape (num_anchors * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas of all - scale level, each item has shape (num_anchors * 4, H, W). - mlvl_anchors (list[Tensor]): Anchors of all scale level - each item has shape (num_total_anchors, 4). - img_shape (tuple[int]): Shape of the input image, - (height, width, 3). - scale_factor (ndarray): Scale factor of the image arrange as - (w_scale, h_scale, w_scale, h_scale). + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_anchors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has + shape (num_anchors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image. RPN head does not need this value. + img_meta (dict): Image meta info. cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. Returns: Tensor: Labeled boxes in shape (n, 5), where the first 4 columns @@ -188,15 +133,19 @@ def _get_bboxes_single(self, """ cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + # bboxes from different level should be independent during NMS, # level_ids are used as labels for batched NMS to separate them level_ids = [] mlvl_scores = [] mlvl_bbox_preds = [] mlvl_valid_anchors = [] - for idx in range(len(cls_scores)): - rpn_cls_score = cls_scores[idx] - rpn_bbox_pred = bbox_preds[idx] + nms_pre = cfg.get('nms_pre', -1) + for level_idx in range(len(cls_score_list)): + rpn_cls_score = cls_score_list[level_idx] + featmap_size_hw = rpn_cls_score.shape[-2:] + rpn_bbox_pred = bbox_pred_list[level_idx] assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] rpn_cls_score = rpn_cls_score.permute(1, 2, 0) if self.use_sigmoid_cls: @@ -210,24 +159,62 @@ def _get_bboxes_single(self, # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. scores = rpn_cls_score.softmax(dim=1)[:, 0] rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) - anchors = mlvl_anchors[idx] - if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: + + if 0 < nms_pre < scores.shape[0]: # sort is faster than topk # _, topk_inds = scores.topk(cfg.nms_pre) ranked_scores, rank_inds = scores.sort(descending=True) - topk_inds = rank_inds[:cfg.nms_pre] - scores = ranked_scores[:cfg.nms_pre] + topk_inds = rank_inds[:nms_pre] + scores = ranked_scores[:nms_pre] rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] - anchors = anchors[topk_inds, :] + anchors = self.prior_generator.sparse_priors( + topk_inds, featmap_size_hw, level_idx, scores.dtype, + scores.device) + else: + anchors = self.prior_generator.single_level_grid_priors( + featmap_size_hw, level_idx, scores.device) mlvl_scores.append(scores) mlvl_bbox_preds.append(rpn_bbox_pred) mlvl_valid_anchors.append(anchors) level_ids.append( - scores.new_full((scores.size(0), ), idx, dtype=torch.long)) + scores.new_full((scores.size(0), ), + level_idx, + dtype=torch.long)) + + return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds, + mlvl_valid_anchors, level_ids, cfg, + img_shape) + + def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors, + level_ids, cfg, img_shape, **kwargs): + """bbox post-processing method. + The boxes would be rescaled to the original image scale and do + the nms operation. Usually with_nms is False is used for aug test. + + Args: + mlvl_scores (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num, num_class). + mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale + levels of a single image, each item has shape (num, 4). + mlvl_valid_anchors (list[Tensor]): Box reference from all scale + levels of a single image, each item has shape + (num, 4). + level_ids (list[Tensor]): Indexes from all scale levels of a + single image, each item has shape (num, ). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + img_shape (tuple(int)): Shape of current image. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ scores = torch.cat(mlvl_scores) anchors = torch.cat(mlvl_valid_anchors) - rpn_bbox_pred = torch.cat(mlvl_bbox_preds) + rpn_bbox_pred = torch.cat(mlvl_bboxes) proposals = self.bbox_coder.decode( anchors, rpn_bbox_pred, max_shape=img_shape) ids = torch.cat(level_ids) @@ -240,6 +227,7 @@ def _get_bboxes_single(self, proposals = proposals[valid_mask] scores = scores[valid_mask] ids = ids[valid_mask] + if proposals.numel() > 0: dets, keep = batched_nms(proposals, scores, ids, cfg.nms) else: @@ -247,6 +235,7 @@ def _get_bboxes_single(self, return dets[:cfg.max_per_img] + # TODO: waiting for refactor the anchor_head and anchor_free head def onnx_export(self, x, img_metas): """Test without augmentation. @@ -266,7 +255,7 @@ def onnx_export(self, x, img_metas): device = cls_scores[0].device featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_anchors = self.anchor_generator.grid_anchors( + mlvl_anchors = self.prior_generator.grid_anchors( featmap_sizes, device=device) cls_scores = [cls_scores[i].detach() for i in range(num_levels)] diff --git a/mmdet/models/dense_heads/sabl_retina_head.py b/mmdet/models/dense_heads/sabl_retina_head.py index 16e04d007bb..4822008f490 100644 --- a/mmdet/models/dense_heads/sabl_retina_head.py +++ b/mmdet/models/dense_heads/sabl_retina_head.py @@ -7,7 +7,7 @@ from mmdet.core import (build_anchor_generator, build_assigner, build_bbox_coder, build_sampler, images_to_levels, - multi_apply, multiclass_nms, unmap) + multi_apply, unmap) from ..builder import HEADS, build_loss from .base_dense_head import BaseDenseHead from .dense_test_mixins import BBoxTestMixin @@ -566,9 +566,11 @@ def get_bboxes_single(self, cfg, rescale=False): cfg = self.test_cfg if cfg is None else cfg + nms_pre = cfg.get('nms_pre', -1) + mlvl_bboxes = [] mlvl_scores = [] - mlvl_confidences = [] + mlvl_confids = [] assert len(cls_scores) == len(bbox_cls_preds) == len( bbox_reg_preds) == len(mlvl_anchors) for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip( @@ -585,8 +587,8 @@ def get_bboxes_single(self, -1, self.side_num * 4) bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape( -1, self.side_num * 4) - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: + + if 0 < nms_pre < scores.shape[0]: if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=1) else: @@ -600,24 +602,10 @@ def get_bboxes_single(self, bbox_cls_pred.contiguous(), bbox_reg_pred.contiguous() ] - bboxes, confidences = self.bbox_coder.decode( + bboxes, confids = self.bbox_coder.decode( anchors.contiguous(), bbox_preds, max_shape=img_shape) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) - mlvl_confidences.append(confidences) - mlvl_bboxes = torch.cat(mlvl_bboxes) - if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) - mlvl_confidences = torch.cat(mlvl_confidences) - if self.use_sigmoid_cls: - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - det_bboxes, det_labels = multiclass_nms( - mlvl_bboxes, - mlvl_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=mlvl_confidences) - return det_bboxes, det_labels + mlvl_confids.append(confids) + return self._bbox_post_process(mlvl_scores, mlvl_bboxes, scale_factor, + cfg, rescale, True, mlvl_confids) diff --git a/mmdet/models/dense_heads/ssd_head.py b/mmdet/models/dense_heads/ssd_head.py index 5f56abcc7f2..a588690ff54 100644 --- a/mmdet/models/dense_heads/ssd_head.py +++ b/mmdet/models/dense_heads/ssd_head.py @@ -86,8 +86,8 @@ def __init__(self, self.act_cfg = act_cfg self.cls_out_channels = num_classes + 1 # add background class - self.anchor_generator = build_anchor_generator(anchor_generator) - self.num_anchors = self.anchor_generator.num_base_anchors + self.prior_generator = build_anchor_generator(anchor_generator) + self.num_anchors = self.prior_generator.num_base_anchors self._init_layers() @@ -285,7 +285,7 @@ def loss(self, dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device diff --git a/mmdet/models/dense_heads/vfnet_head.py b/mmdet/models/dense_heads/vfnet_head.py index 7b7515c6a57..97b09de68cc 100644 --- a/mmdet/models/dense_heads/vfnet_head.py +++ b/mmdet/models/dense_heads/vfnet_head.py @@ -6,9 +6,9 @@ from mmcv.ops import DeformConv2d from mmcv.runner import force_fp32 -from mmdet.core import (bbox2distance, bbox_overlaps, build_anchor_generator, - build_assigner, build_sampler, distance2bbox, - multi_apply, multiclass_nms, reduce_mean) +from mmdet.core import (MlvlPointGenerator, bbox2distance, bbox_overlaps, + build_anchor_generator, build_assigner, build_sampler, + distance2bbox, multi_apply, reduce_mean) from ..builder import HEADS, build_loss from .atss_head import ATSSHead from .fcos_head import FCOSHead @@ -91,6 +91,7 @@ def __init__(self, loss_bbox_refine=dict(type='GIoULoss', loss_weight=2.0), norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), use_atss=True, + reg_decoded_bbox=True, anchor_generator=dict( type='AnchorGenerator', ratios=[1.0], @@ -146,16 +147,22 @@ def __init__(self, # for getting ATSS targets self.use_atss = use_atss + self.reg_decoded_bbox = reg_decoded_bbox self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) - self.anchor_generator = build_anchor_generator(anchor_generator) + self.prior_generator = build_anchor_generator(anchor_generator) self.anchor_center_offset = anchor_generator['center_offset'] - self.num_anchors = self.anchor_generator.num_base_anchors[0] + self.num_anchors = self.prior_generator.num_base_anchors[0] self.sampling = False if self.train_cfg: self.assigner = build_assigner(self.train_cfg.assigner) sampler_cfg = dict(type='PseudoSampler') self.sampler = build_sampler(sampler_cfg, context=self) + # in order to unify the get_bbox logic. Not needed during training. + self.test_prior_generator = MlvlPointGenerator( + anchor_generator['strides'], + self.anchor_center_offset if self.use_atss else 0.5) + def _init_layers(self): """Initialize layers of the head.""" super(FCOSHead, self)._init_cls_convs() @@ -269,7 +276,12 @@ def forward_single(self, x, scale, scale_refine, stride, reg_denom): cls_feat = self.relu(self.vfnet_cls_dconv(cls_feat, dcn_offset)) cls_score = self.vfnet_cls(cls_feat) - return cls_score, bbox_pred, bbox_pred_refine + if self.training: + return cls_score, bbox_pred, bbox_pred_refine + else: + # TODOļ¼š Find a better way + self.prior_generator = self.test_prior_generator + return cls_score, bbox_pred_refine def star_dcn_offset(self, bbox_pred, gradient_mul, stride): """Compute the star deformable conv offsets. @@ -460,140 +472,6 @@ def loss(self, loss_bbox=loss_bbox, loss_bbox_rf=loss_bbox_refine) - @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'bbox_preds_refine')) - def get_bboxes(self, - cls_scores, - bbox_preds, - bbox_preds_refine, - img_metas, - cfg=None, - rescale=None, - with_nms=True): - """Transform network outputs for a batch into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box iou-aware scores for each scale - level with shape (N, num_points * num_classes, H, W). - bbox_preds (list[Tensor]): Box offsets for each scale - level with shape (N, num_points * 4, H, W). - bbox_preds_refine (list[Tensor]): Refined Box offsets for - each scale level with shape (N, num_points * 4, H, W). - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - cfg (mmcv.Config): Test / postprocessing configuration, - if None, test_cfg would be used. Default: None. - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before returning boxes. - Default: True. - - Returns: - list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where the first 4 columns - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. The second item is a - (n,) tensor where each item is the predicted class label of - the corresponding box. - """ - assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine) - num_levels = len(cls_scores) - - featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, - bbox_preds[0].device) - result_list = [] - for img_id in range(len(img_metas)): - cls_score_list = [ - cls_scores[i][img_id].detach() for i in range(num_levels) - ] - bbox_pred_list = [ - bbox_preds_refine[i][img_id].detach() - for i in range(num_levels) - ] - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - det_bboxes = self._get_bboxes_single(cls_score_list, - bbox_pred_list, mlvl_points, - img_shape, scale_factor, cfg, - rescale, with_nms) - result_list.append(det_bboxes) - return result_list - - def _get_bboxes_single(self, - cls_scores, - bbox_preds, - mlvl_points, - img_shape, - scale_factor, - cfg, - rescale=False, - with_nms=True): - """Transform outputs for a single batch item into bbox predictions. - - Args: - cls_scores (list[Tensor]): Box iou-aware scores for a single scale - level with shape (num_points * num_classes, H, W). - bbox_preds (list[Tensor]): Box offsets for a single scale - level with shape (num_points * 4, H, W). - mlvl_points (list[Tensor]): Box reference for a single scale level - with shape (num_total_points, 4). - img_shape (tuple[int]): Shape of the input image, - (height, width, 3). - scale_factor (ndarray): Scale factor of the image arrange as - (w_scale, h_scale, w_scale, h_scale). - cfg (mmcv.Config | None): Test / postprocessing configuration, - if None, test_cfg would be used. - rescale (bool): If True, return boxes in original image space. - Default: False. - with_nms (bool): If True, do nms before returning boxes. - Default: True. - - Returns: - tuple(Tensor): - det_bboxes (Tensor): BBox predictions in shape (n, 5), where - the first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score - between 0 and 1. - det_labels (Tensor): A (n,) tensor where each item is the - predicted class label of the corresponding box. - """ - cfg = self.test_cfg if cfg is None else cfg - assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) - mlvl_bboxes = [] - mlvl_scores = [] - for cls_score, bbox_pred, points in zip(cls_scores, bbox_preds, - mlvl_points): - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - scores = cls_score.permute(1, 2, 0).reshape( - -1, self.cls_out_channels).contiguous().sigmoid() - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4).contiguous() - - nms_pre = cfg.get('nms_pre', -1) - if 0 < nms_pre < scores.shape[0]: - max_scores, _ = scores.max(dim=1) - _, topk_inds = max_scores.topk(nms_pre) - points = points[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] - bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - mlvl_bboxes = torch.cat(mlvl_bboxes) - if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - if with_nms: - det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels - else: - return mlvl_bboxes, mlvl_scores - def _get_points_single(self, featmap_size, stride, @@ -717,7 +595,7 @@ def get_atss_targets(self, bbox_weights (Tensor): Bbox weights of all levels. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( diff --git a/mmdet/models/dense_heads/yolact_head.py b/mmdet/models/dense_heads/yolact_head.py index c5e2bad62b7..bc75e985c6c 100644 --- a/mmdet/models/dense_heads/yolact_head.py +++ b/mmdet/models/dense_heads/yolact_head.py @@ -169,7 +169,7 @@ def loss(self, List[:obj:``SamplingResult``]: Sampler results for each image. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] - assert len(featmap_sizes) == self.anchor_generator.num_levels + assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device @@ -333,7 +333,7 @@ def get_bboxes(self, device = cls_scores[0].device featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_anchors = self.anchor_generator.grid_anchors( + mlvl_anchors = self.prior_generator.grid_anchors( featmap_sizes, device=device) det_bboxes = [] diff --git a/mmdet/models/dense_heads/yolof_head.py b/mmdet/models/dense_heads/yolof_head.py index 8c9a4861ebb..b03ba11af76 100644 --- a/mmdet/models/dense_heads/yolof_head.py +++ b/mmdet/models/dense_heads/yolof_head.py @@ -160,7 +160,7 @@ def loss(self, dict[str, Tensor]: A dictionary of loss components. """ assert len(cls_scores) == 1 - assert self.anchor_generator.num_levels == 1 + assert self.prior_generator.num_levels == 1 device = cls_scores[0].device featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] diff --git a/tests/test_models/test_dense_heads/test_autoassign_head.py b/tests/test_models/test_dense_heads/test_autoassign_head.py index 72cdddf00b8..b059e30d045 100644 --- a/tests/test_models/test_dense_heads/test_autoassign_head.py +++ b/tests/test_models/test_dense_heads/test_autoassign_head.py @@ -70,9 +70,6 @@ def test_autoassign_head_loss(): cls_scores = [torch.ones(2, 4, 5, 5)] bbox_preds = [torch.ones(2, 4, 5, 5)] iou_preds = [torch.ones(2, 1, 5, 5)] - mlvl_anchors = [torch.ones(5 * 5, 2)] - img_shape = None - scale_factor = [0.5, 0.5] cfg = mmcv.Config( dict( nms_pre=1000, @@ -81,12 +78,5 @@ def test_autoassign_head_loss(): nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) rescale = False - self._get_bboxes( - cls_scores, - bbox_preds, - iou_preds, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=rescale) + self.get_bboxes( + cls_scores, bbox_preds, iou_preds, img_metas, cfg, rescale=rescale) diff --git a/tests/test_models/test_dense_heads/test_paa_head.py b/tests/test_models/test_dense_heads/test_paa_head.py index 0fceb9aad89..bc1ba73abc9 100644 --- a/tests/test_models/test_dense_heads/test_paa_head.py +++ b/tests/test_models/test_dense_heads/test_paa_head.py @@ -50,6 +50,12 @@ def score_samples(self, loss): num_classes=4, in_channels=1, train_cfg=train_cfg, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict(type='GIoULoss', loss_weight=1.3), @@ -101,9 +107,6 @@ def score_samples(self, loss): cls_scores = [torch.ones(2, 4, 5, 5)] bbox_preds = [torch.ones(2, 4, 5, 5)] iou_preds = [torch.ones(2, 1, 5, 5)] - mlvl_anchors = [torch.ones(2, 5 * 5, 4)] - img_shape = None - scale_factor = [0.5, 0.5] cfg = mmcv.Config( dict( nms_pre=1000, @@ -112,12 +115,5 @@ def score_samples(self, loss): nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) rescale = False - self._get_bboxes( - cls_scores, - bbox_preds, - iou_preds, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=rescale) + self.get_bboxes( + cls_scores, bbox_preds, iou_preds, img_metas, cfg, rescale=rescale) diff --git a/tests/test_utils/test_coder.py b/tests/test_utils/test_coder.py index 4e8877b3a86..f23649d1736 100644 --- a/tests/test_utils/test_coder.py +++ b/tests/test_utils/test_coder.py @@ -2,8 +2,8 @@ import pytest import torch -from mmdet.core.bbox.coder import (DeltaXYWHBBoxCoder, TBLRBBoxCoder, - YOLOBBoxCoder) +from mmdet.core.bbox.coder import (DeltaXYWHBBoxCoder, DistancePointBBoxCoder, + TBLRBBoxCoder, YOLOBBoxCoder) def test_yolo_bbox_coder(): @@ -108,3 +108,20 @@ def test_tblr_bbox_coder(): deltas = torch.zeros((0, 4)) out = coder.decode(rois, deltas, max_shape=(32, 32)) assert rois.shape == out.shape + + +def test_distance_point_bbox_coder(): + coder = DistancePointBBoxCoder() + + points = torch.Tensor([[74., 61.], [-29., 106.], [138., 61.], [29., 170.]]) + gt_bboxes = torch.Tensor([[74., 61., 75., 62.], [0., 104., 0., 112.], + [100., 90., 100., 120.], [0., 120., 100., 120.]]) + expected_distance = torch.Tensor([[0., 0., 1., 1.], [0., 2., 29., 6.], + [38., 0., 0., 50.], [29., 50., 50., 0.]]) + out_distance = coder.encode(points, gt_bboxes, max_dis=50, eps=0) + assert expected_distance.allclose(out_distance) + + distance = torch.Tensor([[0., 0, 1., 1.], [1., 2., 10., 6.], + [22., -29., 138., 61.], [54., -29., 170., 61.]]) + out_bbox = coder.decode(points, distance, max_shape=(120, 100)) + assert gt_bboxes.allclose(out_bbox)