Skip to content

Commit

Permalink
MMDet RTMDet Inst decoupling (#3433)
Browse files Browse the repository at this point in the history
* migrate mmdet maskrcnn modules

* add cross-entropy loss

* mypy changes and style changes

* remove box structures

* modify resnet

* add annotation

* fix all mypy issues

* fix mypy issues

* style changes

* remove unused losses

* remove focal_loss_pb

* fix all rull and mypy issues

* remove mmdet mask structures

* chagne device for unit test

* add deployment files

* remove deployment from inst-seg

* update deployment

* add mmdeploy maskrcnn opset

* replace mmcv.cnn module

* remove upsample building

* add swin transformer

* update instance_segmentation/maskrcnn.py

* decopule mmdeploy and replace with native exporter

* remove duplicates import

* fix rpn_head training issue

* remove maskrcnn r50 mmconfigs

* fix anchor head and related fixes

* remove gather_topk

* remove maskrcnn efficientnet mmconfig

* remove maskrcnn-swint mmconfig

* revert some changes

* update recipes

* replace mmcv.ops.roi_align with torchvision.ops.roi_align

* fix format issue

* update anchor head

* add CrossSigmoidFocalLoss back

* remove mmdet decouple test

* skip xai test for inst-seg for now

* update todo roi_align comment

* add custom otx roi align

* reformat OTXRoIAlign

* remove config from MMDetInstanceSegCompatibleModel

* add rtmdet inst test

* add unit tests

* rename unit tests

* update rtmdet recipe

* remove RTMDetSepBNHead

* skip xai test for rtmdet inst for now

* update license docstring

* revert src/otx/core/model/instance_segmentation.py
  • Loading branch information
eugene123tw authored May 9, 2024
1 parent 4cf4cc5 commit 3f32845
Show file tree
Hide file tree
Showing 36 changed files with 3,458 additions and 782 deletions.
229 changes: 229 additions & 0 deletions src/otx/algo/detection/backbones/cspnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.

"""CSPNeXt backbone used in RTMDet."""

from __future__ import annotations

import math
from typing import ClassVar

from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm

from otx.algo.detection.backbones.csp_darknet import SPPBottleneck
from otx.algo.detection.layers.csp_layer import CSPLayer
from otx.algo.modules.base_module import BaseModule
from otx.algo.modules.conv_module import ConvModule
from otx.algo.modules.depthwise_separable_conv_module import DepthwiseSeparableConvModule


class CSPNeXt(BaseModule):
"""CSPNeXt backbone used in RTMDet.
Args:
arch (str): Architecture of CSPNeXt, from {P5, P6}.
Defaults to P5.
expand_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Defaults to 0.5.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
out_indices (Sequence[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
use_depthwise (bool): Whether to use depthwise separable convolution.
Defaults to False.
arch_ovewrite (list): Overwrite default arch settings.
Defaults to None.
spp_kernel_sizes: (tuple[int]): Sequential of kernel sizes of SPP
layers. Defaults to (5, 9, 13).
channel_attention (bool): Whether to add channel attention in each
stage. Defaults to True.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
config norm layer. Defaults to dict(type='BN', requires_grad=True).
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='SiLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
list[:obj:`ConfigDict`]): Initialization config dict.
"""

# From left to right:
# in_channels, out_channels, num_blocks, add_identity, use_spp
arch_settings: ClassVar = {
"P5": [
[64, 128, 3, True, False],
[128, 256, 6, True, False],
[256, 512, 6, True, False],
[512, 1024, 3, False, True],
],
"P6": [
[64, 128, 3, True, False],
[128, 256, 6, True, False],
[256, 512, 6, True, False],
[512, 768, 3, True, False],
[768, 1024, 3, False, True],
],
}

def __init__(
self,
arch: str = "P5",
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
out_indices: tuple[int, int, int] = (2, 3, 4),
frozen_stages: int = -1,
use_depthwise: bool = False,
expand_ratio: float = 0.5,
arch_ovewrite: dict | None = None,
spp_kernel_sizes: tuple[int, int, int] = (5, 9, 13),
channel_attention: bool = True,
conv_cfg: dict | None = None,
norm_cfg: dict | None = None,
act_cfg: dict | None = None,
norm_eval: bool = False,
init_cfg: dict | None = None,
) -> None:
if init_cfg is None:
init_cfg = {
"type": "Kaiming",
"layer": "Conv2d",
"a": math.sqrt(5),
"distribution": "uniform",
"mode": "fan_in",
"nonlinearity": "leaky_relu",
}

super().__init__(init_cfg=init_cfg)
arch_setting = self.arch_settings[arch]
if arch_ovewrite:
arch_setting = arch_ovewrite # type: ignore[assignment]

if not set(out_indices).issubset(i for i in range(len(arch_setting) + 1)):
msg = f"out_indices must be in range(0, len(arch_setting) + 1). But received {out_indices}"
raise ValueError(msg)

if frozen_stages not in range(-1, len(arch_setting) + 1):
msg = f"frozen_stages must be in (-1, len(arch_setting) + 1). But received {frozen_stages}"
raise ValueError(msg)

if norm_cfg is None:
norm_cfg = {"type": "BN", "momentum": 0.03, "eps": 0.001}

if act_cfg is None:
act_cfg = {"type": "SiLU"}

self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.use_depthwise = use_depthwise
self.norm_eval = norm_eval
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
self.stem = nn.Sequential(
ConvModule(
3,
int(arch_setting[0][0] * widen_factor // 2),
3,
padding=1,
stride=2,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
ConvModule(
int(arch_setting[0][0] * widen_factor // 2),
int(arch_setting[0][0] * widen_factor // 2),
3,
padding=1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
ConvModule(
int(arch_setting[0][0] * widen_factor // 2),
int(arch_setting[0][0] * widen_factor),
3,
padding=1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
)
self.layers = ["stem"]

for i, (in_channels, out_channels, num_blocks, add_identity, use_spp) in enumerate(arch_setting):
in_channels = int(in_channels * widen_factor) # noqa: PLW2901
out_channels = int(out_channels * widen_factor) # noqa: PLW2901
num_blocks = max(round(num_blocks * deepen_factor), 1) # noqa: PLW2901
stage = []
conv_layer = conv(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
stage.append(conv_layer)
if use_spp:
spp = SPPBottleneck(
out_channels,
out_channels,
kernel_sizes=spp_kernel_sizes,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
stage.append(spp)
csp_layer = CSPLayer(
out_channels,
out_channels,
num_blocks=num_blocks,
add_identity=add_identity,
use_depthwise=use_depthwise,
use_cspnext_block=True,
expand_ratio=expand_ratio,
channel_attention=channel_attention,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
stage.append(csp_layer)
self.add_module(f"stage{i + 1}", nn.Sequential(*stage))
self.layers.append(f"stage{i + 1}")

def _freeze_stages(self) -> None:
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
for i in range(self.frozen_stages + 1):
m = getattr(self, self.layers[i])
m.eval()
for param in m.parameters():
param.requires_grad = False

def train(self, mode: bool = True) -> None:
"""Set modules in training mode."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

def forward(self, x: tuple[Tensor, ...]) -> tuple[Tensor, ...]:
"""Forward function."""
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
7 changes: 3 additions & 4 deletions src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from otx.algo.detection.heads.anchor_generator import AnchorGenerator
from otx.algo.detection.heads.base_head import BaseDenseHead
from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap
from otx.algo.utils.mmengine_utils import InstanceData

Expand Down Expand Up @@ -49,14 +48,14 @@ def __init__(
self,
num_classes: int,
in_channels: tuple[int, ...] | int,
anchor_generator: AnchorGenerator,
bbox_coder: DeltaXYWHBBoxCoder,
anchor_generator: nn.Module,
bbox_coder: nn.Module,
loss_cls: nn.Module,
loss_bbox: nn.Module,
train_cfg: dict,
test_cfg: DictConfig,
feat_channels: int = 256,
reg_decoded_bbox: bool = False,
test_cfg: DictConfig | None = None,
init_cfg: dict | list[dict] | None = None,
) -> None:
super().__init__(init_cfg=init_cfg)
Expand Down
86 changes: 86 additions & 0 deletions src/otx/algo/detection/heads/distance_point_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
""""Distance Point BBox coder."""

from __future__ import annotations

from typing import TYPE_CHECKING

from otx.algo.detection.utils.utils import bbox2distance, distance2bbox

if TYPE_CHECKING:
from torch import Tensor


class DistancePointBBoxCoder:
"""Distance Point BBox coder.
This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
right) and decode it back to the original.
Args:
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
"""

def __init__(
self,
clip_border: bool = True,
encode_size: int = 4,
use_box_type: bool = False,
) -> None:
self.clip_border = clip_border
self.encode_size = encode_size
self.use_box_type = use_box_type

def encode(
self,
points: Tensor,
gt_bboxes: Tensor,
max_dis: float | None = None,
eps: float = 0.1,
) -> Tensor:
"""Encode bounding box to distances.
Args:
points (Tensor): Shape (N, 2), The format is [x, y].
gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format
is "xyxy"
max_dis (float): Upper bound of the distance. Default None.
eps (float): a small value to ensure target < max_dis, instead <=.
Default 0.1.
Returns:
Tensor: Box transformation deltas. The shape is (N, 4).
"""
if points.size(0) != gt_bboxes.size(0):
msg = "The number of points should be equal to the number of boxes."
raise ValueError(msg)
if points.size(-1) != 2:
msg = "The last dimension of points should be 2."
raise ValueError(msg)
if gt_bboxes.size(-1) != 4:
msg = "The last dimension of gt_bboxes should be 4."
raise ValueError(msg)
return bbox2distance(points, gt_bboxes, max_dis, eps)

def decode(
self,
points: Tensor,
pred_bboxes: Tensor,
max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None,
) -> Tensor:
"""Decode distance prediction to bounding box."""
if points.size(0) != pred_bboxes.size(0):
msg = "The number of points should be equal to the number of boxes."
raise ValueError(msg)
if points.size(-1) != 2:
msg = "The last dimension of points should be 2."
raise ValueError(msg)
if pred_bboxes.size(-1) != 4:
msg = "The last dimension of pred_bboxes should be 4."
raise ValueError(msg)
if self.clip_border is False:
max_shape = None
return distance2bbox(points, pred_bboxes, max_shape)
Loading

0 comments on commit 3f32845

Please sign in to comment.