diff --git a/.gitignore b/.gitignore index 6937521..05e1ee9 100644 --- a/.gitignore +++ b/.gitignore @@ -113,3 +113,4 @@ data trash/ experiments +work_dirs \ No newline at end of file diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 044278a..94e009c 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -4,7 +4,6 @@ This page provides basic tutorials about the usage of ReDet. For installation instructions, please see [INSTALL.md](INSTALL.md). - ## Prepare DOTA dataset. It is recommended to symlink the dataset root to `ReDet/data`. @@ -15,12 +14,12 @@ First, make sure your initial data are in the following structure. data/dota15 ├── train │   ├──images -│   └── labelTxt +│   └──labelTxt ├── val -│   ├── images -│   └── labelTxt +│   ├──images +│   └──labelTxt └── test -   └── images +   └──images ``` Split the original images and create COCO format json. ``` @@ -30,11 +29,11 @@ Then you will get data in the following structure ``` dota15_1024 ├── test1024 -│   ├── DOTA_test1024.json -│   └── images +│   ├──DOTA_test1024.json +│   └──images └── trainval1024 -     ├── DOTA_trainval1024.json -    └── images +     ├──DOTA_trainval1024.json +    └──images ``` For data preparation with data augmentation, refer to "DOTA_devkit/prepare_dota1_5_v2.py" @@ -47,16 +46,15 @@ First, make sure your initial data are in the following structure. data/HRSC2016 ├── Train │   ├──AllImages -│   └── Annotations +│   └──Annotations └── Test │   ├──AllImages -│   └── Annotations +│   └──Annotations ``` Then you need to convert HRSC2016 to DOTA's format, i.e., rename `AllImages` to `images`, convert xml `Annotations` to DOTA's `txt` format. -Here we provide a script from s2anet: [HRSC2DOTA.py](https://github.com/csuhan/s2anet/blob/original_version/DOTA_devkit/HRSC2DOTA.py). It will be added to this repo later. -After that, your `data/HRSC2016` should contain the following folders. +Here we provide a script from s2anet: [HRSC2DOTA.py](https://github.com/csuhan/s2anet/blob/original_version/DOTA_devkit/HRSC2DOTA.py). Now, your `data/HRSC2016` should contain the following folders. ``` data/HRSC2016 @@ -90,10 +88,6 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] # multi-gpu testing ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] - -# If you want to test ReDet under Cyclic group C_4 (default C_8), you need to pass the ENV: Orientation=4 -# See mmdet/models/backbones/re_resnet.py for details -Orientation=4 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] ``` Optional arguments: @@ -103,7 +97,7 @@ Examples: Assume that you have already downloaded the checkpoints to `work_dirs/`. -1. Test ReDet. +1. Test ReDet with 1 GPU. ```shell python tools/test.py configs/ReDet/ReDet_re50_refpn_1x_dota15.py \ work_dirs/ReDet_re50_refpn_1x_dota15/ReDet_re50_refpn_1x_dota15-7f2d6dda.pth \ @@ -117,7 +111,7 @@ python tools/test.py configs/ReDet/ReDet_re50_refpn_1x_dota15.py \ 4 --out work_dirs/ReDet_re50_refpn_1x_dota15/results.pkl ``` -3. Parse the results.pkl to the format needed for [DOTA evaluation](https://captain-whu.github.io/DOTA/evaluation.html) +3. Parse results for [DOTA evaluation](https://captain-whu.github.io/DOTA/evaluation.html) ``` python tools/parse_results.py --config configs/ReDet/ReDet_re50_refpn_1x_dota15.py --type OBB ``` @@ -134,6 +128,28 @@ python tools/test.py configs/ReDet/ReDet_re50_refpn_3x_hrsc2016.py \ python DOTA_devkit/hrsc2016_evaluation.py ``` +### Convert ReResNet+ReFPN to standard Pytorch layers + +We provide a [script](tools/convert_ReDet_to_torch.py) to convert the pre-trained weights of ReResNet+ReFPN to standard Pytorch layers. Take ReDet on DOTA-v1.5 as an example. + +1. download pretrained weights at [here](https://drive.google.com/file/d/1AjG3-Db_hmZF1YSKRVnq8j_yuxzualRo/view?usp=sharing), and convert it to standard pytorch layers. +``` +python tools/convert_ReDet_to_torch.py configs/ReDet/ReDet_re50_refpn_1x_dota15.py \ + work_dirs/ReDet_re50_refpn_1x_dota15/ReDet_re50_refpn_1x_dota15-7f2d6dda.pth \ + work_dirs/ReDet_re50_refpn_1x_dota15/ReDet_r50_fpn_1x_dota15.pth +``` + +2. use standard ResNet+FPN as the backbone of ReDet and test it on DOTA-v1.5. +``` +mkdir work_dirs/ReDet_r50_fpn_1x_dota15 + +bash ./tools/dist_test.sh configs/ReDet/ReDet_r50_fpn_1x_dota15.py \ + work_dirs/ReDet_re50_refpn_1x_dota15/ReDet_r50_fpn_1x_dota15.pth 8 \ + --out work_dirs/ReDet_r50_fpn_1x_dota15/results.pkl + +# submit parsed results to the evaluation server. +python tools/parse_results.py --config configs/ReDet/ReDet_r50_fpn_1x_dota15.py +``` ### Demo of inference in a large size image. @@ -159,10 +175,6 @@ to the GPU num, e.g., 0.01 for 4 GPUs and 0.04 for 16 GPUs. ```shell python tools/train.py ${CONFIG_FILE} - -# If you want to train a model under Cyclic group C_4 (default C_8), you need to pass the ENV: Orientation=4 -# See mmdet/models/backbones/re_resnet.py for details -Orientation=4 python tools/train.py ${CONFIG_FILE} ``` If you want to specify the working directory in the command, you can add an argument `--work_dir ${YOUR_WORK_DIR}`. @@ -199,124 +211,3 @@ You can check [slurm_train.sh](tools/slurm_train.sh) for full arguments and envi If you have just multiple machines connected with ethernet, you can refer to pytorch [launch utility](https://pytorch.org/docs/stable/distributed_deprecated.html#launch-utility). Usually it is slow if you do not have high speed networking like infiniband. - - -## How-to - -### Use my own datasets - -The simplest way is to convert your dataset to existing dataset formats (COCO or PASCAL VOC). - -Here we show an example of adding a custom dataset of 5 classes, assuming it is also in COCO format. - -In `mmdet/datasets/my_dataset.py`: - -```python -from .coco import CocoDataset - - -class MyDataset(CocoDataset): - - CLASSES = ('a', 'b', 'c', 'd', 'e') -``` - -In `mmdet/datasets/__init__.py`: - -```python -from .my_dataset import MyDataset -``` - -Then you can use `MyDataset` in config files, with the same API as CocoDataset. - - -It is also fine if you do not want to convert the annotation format to COCO or PASCAL format. -Actually, we define a simple annotation format and all existing datasets are -processed to be compatible with it, either online or offline. - -The annotation of a dataset is a list of dict, each dict corresponds to an image. -There are 3 field `filename` (relative path), `width`, `height` for testing, -and an additional field `ann` for training. `ann` is also a dict containing at least 2 fields: -`bboxes` and `labels`, both of which are numpy arrays. Some datasets may provide -annotations like crowd/difficult/ignored bboxes, we use `bboxes_ignore` and `labels_ignore` -to cover them. - -Here is an example. -``` -[ - { - 'filename': 'a.jpg', - 'width': 1280, - 'height': 720, - 'ann': { - 'bboxes': (n, 4), - 'labels': (n, ), - 'bboxes_ignore': (k, 4), - 'labels_ignore': (k, ) (optional field) - } - }, - ... -] -``` - -There are two ways to work with custom datasets. - -- online conversion - - You can write a new Dataset class inherited from `CustomDataset`, and overwrite two methods - `load_annotations(self, ann_file)` and `get_ann_info(self, idx)`, - like [CocoDataset](mmdet/datasets/coco.py) and [VOCDataset](mmdet/datasets/voc.py). - -- offline conversion - - You can convert the annotation format to the expected format above and save it to - a pickle or json file, like [pascal_voc.py](tools/convert_datasets/pascal_voc.py). - Then you can simply use `CustomDataset`. - -### Develop new components - -We basically categorize model components into 4 types. - -- backbone: usually a FCN network to extract feature maps, e.g., ResNet, MobileNet. -- neck: the component between backbones and heads, e.g., FPN, PAFPN. -- head: the component for specific tasks, e.g., bbox prediction and mask prediction. -- roi extractor: the part for extracting RoI features from feature maps, e.g., RoI Align. - -Here we show how to develop new components with an example of MobileNet. - -1. Create a new file `mmdet/models/backbones/mobilenet.py`. - -```python -import torch.nn as nn - -from ..registry import BACKBONES - - -@BACKBONES.register -class MobileNet(nn.Module): - - def __init__(self, arg1, arg2): - pass - - def forward(x): # should return a tuple - pass -``` - -2. Import the module in `mmdet/models/backbones/__init__.py`. - -```python -from .mobilenet import MobileNet -``` - -3. Use it in your config file. - -```python -model = dict( - ... - backbone=dict( - type='MobileNet', - arg1=xxx, - arg2=xxx), - ... -``` - -For more information on how it works, you can refer to [TECHNICAL_DETAILS.md](TECHNICAL_DETAILS.md) (TODO). diff --git a/README.md b/README.md index ed40867..abeaa20 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,11 @@ More precisely, we incorporate rotation-equivariant networks into the detector t Based on the rotation-equivariant features, we also present Rotation-invariant RoI Align (RiRoI Align), which adaptively extracts rotation-invariant features from equivariant features according to the orientation of RoI. Extensive experiments on several challenging aerial image datasets DOTA-v1.0, DOTA-v1.5 and HRSC2016, show that our method can achieve state-of-the-art performance on the task of aerial object detection. Compared with previous best results, our ReDet gains 1.2, 3.5 and 2.6 mAP on DOTA-v1.0, DOTA-v1.5 and HRSC2016 respectively while reducing the number of parameters by 60% (313 Mb vs. 121 Mb). + ## Changelog -* **2021-04-13**. Update our [pretrained ReResNet](https://drive.google.com/file/d/1FshfREfLZaNl5FcaKrH0lxFyZt50Uyu2/view) and fix by [this commit](https://github.com/csuhan/ReDet/commit/88f8170db12a34ec342ab61571db217c9589888d). For the users that can not reach our reported mAP, please download it and train again. +* **2022-03-28**. Speed up ReDet now! We convert the pre-trained weights of ReResNet+ReFPN to standard pytorch layers (see [GETTING_STARTED.md](GETTING_STARTED.md)). In the testing phase, you can directly use ResNet+FPN as the backbone of ReDet without compromising its rotation equivariance. Besides, you can also convert ReResNet to standard ResNet with [this script](https://github.com/csuhan/ReDet/blob/ReDet_mmcls/tools/convert_re_resnet_to_torch.py). +* **2021-04-13**. Update our [pretrained ReResNet](https://drive.google.com/file/d/1FshfREfLZaNl5FcaKrH0lxFyZt50Uyu2/view) and fix by [this commit](https://github.com/csuhan/ReDet/commit/88f8170db12a34ec342ab61571db217c9589888d). If you cannot reach the reported mAP, please download it and try again. * **2021-03-09**. Code released. ## Benchmark and model zoo @@ -64,7 +66,7 @@ Please see [GETTING_STARTED.md](GETTING_STARTED.md) for the basic usage. ## Citation -``` +```BibTeX @InProceedings{han2021ReDet, author = {Han, Jiaming and Ding, Jian and Xue, Nan and Xia, Gui-Song}, title = {ReDet: A Rotation-equivariant Detector for Aerial Object Detection}, diff --git a/configs/ReDet/ReDet_r50_fpn_1x_dota15.py b/configs/ReDet/ReDet_r50_fpn_1x_dota15.py new file mode 100644 index 0000000..d5a549d --- /dev/null +++ b/configs/ReDet/ReDet_r50_fpn_1x_dota15.py @@ -0,0 +1,203 @@ +# model settings +model = dict( + type='ReDet', + pretrained=None, + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_scales=[8], + anchor_ratios=[0.5, 1.0, 2.0], + anchor_strides=[4, 8, 16, 32, 64], + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='SharedFCBBoxHeadRbbox', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=17, + target_means=[0., 0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2, 0.1], + reg_class_agnostic=True, + with_module=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), + rbbox_roi_extractor=dict( + type='RboxSingleRoIExtractor', + roi_layer=dict(type='RiRoIAlign', out_size=7, sample_num=2), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + rbbox_head=dict( + type='SharedFCBBoxHeadRbbox', + num_fcs=2, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=17, + target_means=[0., 0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1, 0.05], + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) +) +# model training and testing settings +train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssignerCy', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type='MaxIoUAssignerCy', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssignerRbbox', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + ignore_iof_thr=-1), + sampler=dict( + type='RandomRbboxSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False) + ]) +test_cfg = dict( + rpn=dict( + # TODO: test nms 2000 + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, nms=dict(type='py_cpu_nms_poly_fast', iou_thr=0.1), max_per_img=2000) +) +# dataset settings +dataset_type = 'DOTA1_5Dataset_v2' +data_root = 'data/dota15_1024/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data = dict( + imgs_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'trainval1024/DOTA1_5_trainval1024.json', + img_prefix=data_root + 'trainval1024/images/', + img_scale=(1024, 1024), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0.5, + with_mask=True, + with_crowd=True, + with_label=True), + val=dict( + type=dataset_type, + ann_file=data_root + 'trainval1024/DOTA1_5_trainval1024.json', + img_prefix=data_root + 'trainval1024/images/', + img_scale=(1024, 1024), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=True, + with_crowd=True, + with_label=True), + test=dict( + type=dataset_type, + ann_file=data_root + 'test1024/DOTA1_5_test1024.json', + img_prefix=data_root + 'test1024/images', + img_scale=(1024, 1024), + img_norm_cfg=img_norm_cfg, + size_divisor=32, + flip_ratio=0, + with_mask=False, + with_label=False, + test_mode=True)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) +checkpoint_config = dict(interval=12) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +# runtime settings +total_epochs = 12 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = 'work_dirs/ReDet_re50_refpn_1x_dota15' +load_from = None +resume_from = None +workflow = [('train', 1)] \ No newline at end of file diff --git a/mmdet/models/backbones/re_resnet.py b/mmdet/models/backbones/re_resnet.py index b3ab660..33be0e7 100644 --- a/mmdet/models/backbones/re_resnet.py +++ b/mmdet/models/backbones/re_resnet.py @@ -1,160 +1,25 @@ """ -This file contains our implementation of ReResNet. +Implementation of ReResNet V2. @author: Jiaming Han """ -import e2cnn.nn as enn import math import os +from collections import OrderedDict + +import e2cnn.nn as enn +import torch import torch.nn as nn import torch.utils.checkpoint as cp from e2cnn import gspaces -from mmcv.cnn import (constant_init, kaiming_init) +from mmcv.cnn import constant_init, kaiming_init from torch.nn.modules.batchnorm import _BatchNorm -from .base_backbone import BaseBackbone from ..builder import BACKBONES - -# Set default Orientation=8, .i.e, the group C8 -# One can change it by passing the env Orientation=xx -Orientation = 8 -# keep similar computation or similar params -# One can change it by passing the env fixparams=True -fixparams = False -if 'Orientation' in os.environ: - Orientation = int(os.environ['Orientation']) -if 'fixparams' in os.environ: - fixparams = True -print('ReResNet Orientation: {}\tFix Params: {}'.format(Orientation, fixparams)) - -# define the equivariant group. We use C8 group by default. -gspace = gspaces.Rot2dOnR2(N=Orientation) - - -def regular_feature_type(gspace: gspaces.GSpace, planes: int): - """ build a regular feature map with the specified number of channels""" - assert gspace.fibergroup.order() > 0 - N = gspace.fibergroup.order() - if fixparams: - planes *= math.sqrt(N) - planes = planes / N - planes = int(planes) - return enn.FieldType(gspace, [gspace.regular_repr] * planes) - - -def trivial_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = True): - """ build a trivial feature map with the specified number of channels""" - if fixparams: - planes *= math.sqrt(gspace.fibergroup.order()) - planes = int(planes) - return enn.FieldType(gspace, [gspace.trivial_repr] * planes) - - -FIELD_TYPE = { - "trivial": trivial_feature_type, - "regular": regular_feature_type, -} - - -def conv7x7(inplanes, out_planes, stride=2, padding=3, bias=False): - """7x7 convolution with padding""" - in_type = enn.FieldType(gspace, inplanes * [gspace.trivial_repr]) - out_type = FIELD_TYPE['regular'](gspace, out_planes) - return enn.R2Conv(in_type, out_type, 7, - stride=stride, - padding=padding, - bias=bias, - sigma=None, - frequencies_cutoff=lambda r: 3 * r, ) - - -def conv3x3(inplanes, out_planes, stride=1, padding=1, groups=1, dilation=1): - """3x3 convolution with padding""" - in_type = FIELD_TYPE['regular'](gspace, inplanes) - out_type = FIELD_TYPE['regular'](gspace, out_planes) - return enn.R2Conv(in_type, out_type, 3, - stride=stride, - padding=padding, - groups=groups, - bias=False, - dilation=dilation, - sigma=None, - frequencies_cutoff=lambda r: 3 * r, - initialize=False) - - -def conv1x1(inplanes, out_planes, stride=1): - """1x1 convolution""" - in_type = FIELD_TYPE['regular'](gspace, inplanes) - out_type = FIELD_TYPE['regular'](gspace, out_planes) - return enn.R2Conv(in_type, out_type, 1, - stride=stride, - bias=False, - sigma=None, - frequencies_cutoff=lambda r: 3 * r, - initialize=False) - - -def convnxn(inplanes, outplanes, kernel_size=3, stride=1, padding=0, groups=1, bias=False, dilation=1): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - out_type = FIELD_TYPE['regular'](gspace, outplanes) - return enn.R2Conv(in_type, out_type, kernel_size, - stride=stride, - padding=padding, - groups=groups, - bias=bias, - dilation=dilation, - sigma=None, - frequencies_cutoff=lambda r: 3 * r, ) - - -def ennReLU(inplanes): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - return enn.ReLU(in_type, inplace=True) - - -def ennAvgPool(inplanes, kernel_size=1, stride=None, padding=0, ceil_mode=False): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - return enn.PointwiseAvgPool(in_type, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode) - - -def ennMaxPool(inplanes, kernel_size, stride=1, padding=0): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - return enn.PointwiseMaxPool(in_type, kernel_size=kernel_size, stride=stride, padding=padding) - - -def build_conv_layer(cfg, *args, **kwargs): - layer = convnxn(*args, **kwargs) - return layer - - -def build_norm_layer(cfg, num_features, postfix=''): - in_type = FIELD_TYPE['regular'](gspace, num_features) - return 'bn' + str(postfix), enn.InnerBatchNorm(in_type) +from ..utils.enn_layers import FIELD_TYPE, build_norm_layer, conv1x1, conv3x3 +from .base_backbone import BaseBackbone class BasicBlock(enn.EquivariantModule): - """BasicBlock for ReResNet. - - Args: - in_channels (int): Input channels of this block. - out_channels (int): Output channels of this block. - expansion (int): The ratio of ``out_channels/mid_channels`` where - ``mid_channels`` is the output channels of conv1. This is a - reserved argument in BasicBlock and should always be 1. Default: 1. - stride (int): stride of the block. Default: 1 - dilation (int): dilation of convolution. Default: 1 - downsample (nn.Module): downsample operation on identity branch. - Default: None. - style (str): `pytorch` or `caffe`. It is unused and reserved for - unified API with Bottleneck. - with_cp (bool): Use checkpoint or not. Using checkpoint will save some - memory while slowing down the training speed. - conv_cfg (dict): dictionary to construct and config conv layer. - Default: None - norm_cfg (dict): dictionary to construct and config norm layer. - Default: dict(type='BN') - """ - def __init__(self, in_channels, out_channels, @@ -165,10 +30,14 @@ def __init__(self, style='pytorch', with_cp=False, conv_cfg=None, - norm_cfg=dict(type='BN')): + norm_cfg=dict(type='BN'), + gspace=None, + fixparams=False): super(BasicBlock, self).__init__() - self.in_type = FIELD_TYPE['regular'](gspace, in_channels) - self.out_type = FIELD_TYPE['regular'](gspace, out_channels) + self.in_type = FIELD_TYPE['regular']( + gspace, in_channels, fixparams=fixparams) + self.out_type = FIELD_TYPE['regular']( + gspace, out_channels, fixparams=fixparams) self.in_channels = in_channels self.out_channels = out_channels self.expansion = expansion @@ -183,31 +52,31 @@ def __init__(self, self.norm_cfg = norm_cfg self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.mid_channels, postfix=1) + norm_cfg, gspace, self.mid_channels, postfix=1) self.norm2_name, norm2 = build_norm_layer( - norm_cfg, out_channels, postfix=2) + norm_cfg, gspace, out_channels, postfix=2) - self.conv1 = build_conv_layer( - conv_cfg, + self.conv1 = conv3x3( + gspace, in_channels, self.mid_channels, - 3, stride=stride, padding=dilation, dilation=dilation, - bias=False) + bias=False, + fixparams=fixparams) self.add_module(self.norm1_name, norm1) - self.relu1 = ennReLU(self.mid_channels) - self.conv2 = build_conv_layer( - conv_cfg, + self.relu1 = enn.ReLU(self.conv1.out_type, inplace=True) + self.conv2 = conv3x3( + gspace, self.mid_channels, out_channels, - 3, padding=1, - bias=False) + bias=False, + fixparams=fixparams) self.add_module(self.norm2_name, norm2) - self.relu2 = ennReLU(out_channels) + self.relu2 = enn.ReLU(self.conv1.out_type, inplace=True) self.downsample = downsample @property @@ -254,30 +123,18 @@ def evaluate_output_shape(self, input_shape): else: return input_shape + def export(self): + self.eval() + submodules = [] + # convert all the submodules if necessary + for name, module in self._modules.items(): + if hasattr(module, 'export'): + module = module.export() + submodules.append((name, module)) + return torch.nn.ModuleDict(OrderedDict(submodules)) -class Bottleneck(enn.EquivariantModule): - """Bottleneck block for ReResNet. - - Args: - in_channels (int): Input channels of this block. - out_channels (int): Output channels of this block. - expansion (int): The ratio of ``out_channels/mid_channels`` where - ``mid_channels`` is the input/output channels of conv2. Default: 4. - stride (int): stride of the block. Default: 1 - dilation (int): dilation of convolution. Default: 1 - downsample (nn.Module): downsample operation on identity branch. - Default: None. - style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the - stride-two layer is the 3x3 conv layer, otherwise the stride-two - layer is the first 1x1 conv layer. Default: "pytorch". - with_cp (bool): Use checkpoint or not. Using checkpoint will save some - memory while slowing down the training speed. - conv_cfg (dict): dictionary to construct and config conv layer. - Default: None - norm_cfg (dict): dictionary to construct and config norm layer. - Default: dict(type='BN') - """ +class Bottleneck(enn.EquivariantModule): def __init__(self, in_channels, out_channels, @@ -288,11 +145,15 @@ def __init__(self, style='pytorch', with_cp=False, conv_cfg=None, - norm_cfg=dict(type='BN')): + norm_cfg=dict(type='BN'), + gspace=None, + fixparams=False): super(Bottleneck, self).__init__() assert style in ['pytorch', 'caffe'] - self.in_type = FIELD_TYPE['regular'](gspace, in_channels) - self.out_type = FIELD_TYPE['regular'](gspace, out_channels) + self.in_type = FIELD_TYPE['regular']( + gspace, in_channels, fixparams=fixparams) + self.out_type = FIELD_TYPE['regular']( + gspace, out_channels, fixparams=fixparams) self.in_channels = in_channels self.out_channels = out_channels self.expansion = expansion @@ -312,41 +173,41 @@ def __init__(self, self.conv2_stride = 1 self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.mid_channels, postfix=1) + norm_cfg, gspace, self.mid_channels, postfix=1) self.norm2_name, norm2 = build_norm_layer( - norm_cfg, self.mid_channels, postfix=2) + norm_cfg, gspace, self.mid_channels, postfix=2) self.norm3_name, norm3 = build_norm_layer( - norm_cfg, out_channels, postfix=3) + norm_cfg, gspace, out_channels, postfix=3) - self.conv1 = build_conv_layer( - conv_cfg, + self.conv1 = conv1x1( + gspace, in_channels, self.mid_channels, - kernel_size=1, stride=self.conv1_stride, - bias=False) + bias=False, + fixparams=fixparams) self.add_module(self.norm1_name, norm1) - self.relu1 = ennReLU(self.mid_channels) - self.conv2 = build_conv_layer( - conv_cfg, + self.relu1 = enn.ReLU(self.conv1.out_type, inplace=True) + self.conv2 = conv3x3( + gspace, self.mid_channels, self.mid_channels, - kernel_size=3, stride=self.conv2_stride, padding=dilation, dilation=dilation, - bias=False) + bias=False, + fixparams=fixparams) self.add_module(self.norm2_name, norm2) - self.relu2 = ennReLU(self.mid_channels) - self.conv3 = build_conv_layer( - conv_cfg, + self.relu2 = enn.ReLU(self.conv2.out_type, inplace=True) + self.conv3 = conv1x1( + gspace, self.mid_channels, out_channels, - kernel_size=1, - bias=False) + bias=False, + fixparams=fixparams) self.add_module(self.norm3_name, norm3) - self.relu3 = ennReLU(out_channels) + self.relu3 = enn.ReLU(self.conv3.out_type, inplace=True) self.downsample = downsample @@ -402,25 +263,18 @@ def evaluate_output_shape(self, input_shape): else: return input_shape + def export(self): + self.eval() + submodules = [] + # convert all the submodules if necessary + for name, module in self._modules.items(): + if hasattr(module, 'export'): + module = module.export() + submodules.append((name, module)) + return torch.nn.ModuleDict(OrderedDict(submodules)) -def get_expansion(block, expansion=None): - """Get the expansion of a residual block. - - The block expansion will be obtained by the following order: - - 1. If ``expansion`` is given, just return it. - 2. If ``block`` has the attribute ``expansion``, then return - ``block.expansion``. - 3. Return the default value according the the block type: - 1 for ``BasicBlock`` and 4 for ``Bottleneck``. - Args: - block (class): The block class. - expansion (int | None): The given expansion ratio. - - Returns: - int: The expansion of the block. - """ +def get_expansion(block, expansion=None): if isinstance(expansion, int): assert expansion > 0 elif expansion is None: @@ -439,27 +293,6 @@ def get_expansion(block, expansion=None): class ResLayer(nn.Sequential): - """ResLayer to build ReResNet style backbone. - - Args: - block (nn.Module): Residual block used to build ResLayer. - num_blocks (int): Number of blocks. - in_channels (int): Input channels of this block. - out_channels (int): Output channels of this block. - expansion (int, optional): The expansion for BasicBlock/Bottleneck. - If not specified, it will firstly be obtained via - ``block.expansion``. If the block has no attribute "expansion", - the following default values will be used: 1 for BasicBlock and - 4 for Bottleneck. Default: None. - stride (int): stride of the first block. Default: 1. - avg_down (bool): Use AvgPool instead of stride conv when - downsampling in the bottleneck. Default: False - conv_cfg (dict): dictionary to construct and config conv layer. - Default: None - norm_cfg (dict): dictionary to construct and config norm layer. - Default: dict(type='BN') - """ - def __init__(self, block, num_blocks, @@ -470,6 +303,8 @@ def __init__(self, avg_down=False, conv_cfg=None, norm_cfg=dict(type='BN'), + gspace=None, + fixparams=False, **kwargs): self.block = block self.expansion = get_expansion(block, expansion) @@ -480,21 +315,18 @@ def __init__(self, conv_stride = stride if avg_down and stride != 1: conv_stride = 1 + in_type = FIELD_TYPE["regular"]( + gspace, in_channels, fixparams=fixparams) downsample.append( - ennAvgPool( - in_channels, + enn.PointwiseAvgPool( + in_type, kernel_size=stride, stride=stride, ceil_mode=True)) downsample.extend([ - build_conv_layer( - conv_cfg, - in_channels, - out_channels, - kernel_size=1, - stride=conv_stride, - bias=False), - build_norm_layer(norm_cfg, out_channels)[1] + conv1x1(gspace, in_channels, out_channels, + stride=conv_stride, bias=False), + build_norm_layer(norm_cfg, gspace, out_channels)[1] ]) downsample = enn.SequentialModule(*downsample) @@ -508,6 +340,8 @@ def __init__(self, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, + gspace=gspace, + fixparams=fixparams, **kwargs)) in_channels = out_channels for i in range(1, num_blocks): @@ -519,64 +353,24 @@ def __init__(self, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, + gspace=gspace, + fixparams=fixparams, **kwargs)) super(ResLayer, self).__init__(*layers) + def export(self): + self.eval() + submodules = [] + # convert all the submodules if necessary + for name, module in self._modules.items(): + if hasattr(module, 'export'): + module = module.export() + submodules.append((name, module)) + return torch.nn.ModuleDict(OrderedDict(submodules)) + @BACKBONES.register_module class ReResNet(BaseBackbone): - """ReResNet backbone. - - Please refer to the `paper `_ for - details. - - Args: - depth (int): Network depth, from {18, 34, 50, 101, 152}. - in_channels (int): Number of input image channels. Default: 3. - stem_channels (int): Output channels of the stem layer. Default: 64. - base_channels (int): Middle channels of the first stage. Default: 64. - num_stages (int): Stages of the network. Default: 4. - strides (Sequence[int]): Strides of the first block of each stage. - Default: ``(1, 2, 2, 2)``. - dilations (Sequence[int]): Dilation of each stage. - Default: ``(1, 1, 1, 1)``. - out_indices (Sequence[int]): Output from which stages. If only one - stage is specified, a single tensor (feature map) is returned, - otherwise multiple stages are specified, a tuple of tensors will - be returned. Default: ``(3, )``. - style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two - layer is the 3x3 conv layer, otherwise the stride-two layer is - the first 1x1 conv layer. - deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. - Default: False. - avg_down (bool): Use AvgPool instead of stride conv when - downsampling in the bottleneck. Default: False. - frozen_stages (int): Stages to be frozen (stop grad and set eval mode). - -1 means not freezing any parameters. Default: -1. - conv_cfg (dict | None): The config dict for conv layers. Default: None. - norm_cfg (dict): The config dict for norm layers. - norm_eval (bool): Whether to set norm layers to eval mode, namely, - freeze running stats (mean and var). Note: Effect on Batch Norm - and its variants only. Default: False. - with_cp (bool): Use checkpoint or not. Using checkpoint will save some - memory while slowing down the training speed. Default: False. - zero_init_residual (bool): Whether to use zero init for last norm layer - in resblocks to let them behave as identity. Default: True. - - Example: - >>> from mmcls.models import ReResNet - >>> import torch - >>> self = ReResNet(depth=18) - >>> self.eval() - >>> inputs = torch.rand(1, 3, 32, 32) - >>> level_outputs = self.forward(inputs) - >>> for level_out in level_outputs: - ... print(tuple(level_out.shape)) - (1, 64, 8, 8) - (1, 128, 4, 4) - (1, 256, 2, 2) - (1, 512, 1, 1) - """ arch_settings = { 18: (BasicBlock, (2, 2, 2, 2)), @@ -604,9 +398,10 @@ def __init__(self, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False, with_cp=False, - zero_init_residual=True): + zero_init_residual=True, + orientation=8, + fixparams=False): super(ReResNet, self).__init__() - self.in_type = enn.FieldType(gspace, 3 * [gspace.trivial_repr]) if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') self.depth = depth @@ -628,12 +423,17 @@ def __init__(self, self.with_cp = with_cp self.norm_eval = norm_eval self.zero_init_residual = zero_init_residual - self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] self.expansion = get_expansion(self.block, expansion) - self._make_stem_layer(in_channels, stem_channels) + self.orientation = orientation + self.fixparams = fixparams + self.gspace = gspaces.Rot2dOnR2(orientation) + self.in_type = enn.FieldType( + self.gspace, [self.gspace.trivial_repr] * 3) + + self._make_stem_layer(self.gspace, in_channels, stem_channels) self.res_layers = [] _in_channels = stem_channels @@ -653,7 +453,9 @@ def __init__(self, avg_down=self.avg_down, with_cp=with_cp, conv_cfg=conv_cfg, - norm_cfg=norm_cfg) + norm_cfg=norm_cfg, + gspace=self.gspace, + fixparams=self.fixparams) _in_channels = _out_channels _out_channels *= 2 layer_name = f'layer{i + 1}' @@ -671,14 +473,23 @@ def make_res_layer(self, **kwargs): def norm1(self): return getattr(self, self.norm1_name) - def _make_stem_layer(self, in_channels, stem_channels): + def _make_stem_layer(self, gspace, in_channels, stem_channels): if not self.deep_stem: - self.conv1 = conv7x7(in_channels, stem_channels) + in_type = enn.FieldType( + gspace, in_channels * [gspace.trivial_repr]) + out_type = FIELD_TYPE['regular'](gspace, stem_channels) + self.conv1 = enn.R2Conv(in_type, out_type, 7, + stride=2, + padding=3, + bias=False, + sigma=None, + frequencies_cutoff=lambda r: 3 * r) self.norm1_name, norm1 = build_norm_layer( - self.norm_cfg, stem_channels, postfix=1) + self.norm_cfg, gspace, stem_channels, postfix=1) self.add_module(self.norm1_name, norm1) - self.relu = ennReLU(stem_channels) - self.maxpool = ennMaxPool(stem_channels, kernel_size=3, stride=2, padding=1) + self.relu = enn.ReLU(self.conv1.out_type, inplace=True) + self.maxpool = enn.PointwiseMaxPool( + self.conv1.out_type, kernel_size=3, stride=2, padding=1) def _freeze_stages(self): if self.frozen_stages >= 0: @@ -716,7 +527,6 @@ def forward(self, x): x = res_layer(x) if i in self.out_indices: outs.append(x) - if len(outs) == 1: return outs[0] else: @@ -730,3 +540,13 @@ def train(self, mode=True): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() + + def export(self): + self.eval() + submodules = [] + # convert all the submodules if necessary + for name, module in self._modules.items(): + if hasattr(module, 'export'): + module = module.export() + submodules.append((name, module)) + return torch.nn.ModuleDict(OrderedDict(submodules)) \ No newline at end of file diff --git a/mmdet/models/necks/re_fpn.py b/mmdet/models/necks/re_fpn.py index 088f8b4..86d1aad 100644 --- a/mmdet/models/necks/re_fpn.py +++ b/mmdet/models/necks/re_fpn.py @@ -1,93 +1,18 @@ -import e2cnn.nn as enn import math import os +import warnings +from collections import OrderedDict + +import e2cnn.nn as enn import torch import torch.nn as nn import torch.nn.functional as F -import warnings from e2cnn import gspaces from mmcv.cnn import constant_init, kaiming_init, xavier_init from ..registry import NECKS - -# Set default Orientation=8, .i.e, the group C8 -# One can change it by passing the env Orientation=xx -Orientation = 8 -# keep similar computation or similar params -# One can change it by passing the env fixparams=True -fixparams = False -if 'Orientation' in os.environ: - Orientation = int(os.environ['Orientation']) -if 'fixparams' in os.environ: - fixparams = True - -gspace = gspaces.Rot2dOnR2(N=Orientation) - - -def regular_feature_type(gspace: gspaces.GSpace, planes: int): - """ build a regular feature map with the specified number of channels""" - assert gspace.fibergroup.order() > 0 - - N = gspace.fibergroup.order() - if fixparams: - planes *= math.sqrt(N) - planes = planes / N - planes = int(planes) - return enn.FieldType(gspace, [gspace.regular_repr] * planes) - - -def trivial_feature_type(gspace: gspaces.GSpace, planes: int): - """ build a trivial feature map with the specified number of channels""" - - if fixparams: - planes *= math.sqrt(gspace.fibergroup.order()) - - planes = int(planes) - return enn.FieldType(gspace, [gspace.trivial_repr] * planes) - - -FIELD_TYPE = { - "trivial": trivial_feature_type, - "regular": regular_feature_type, -} - - -def convnxn(inplanes, outplanes, kernel_size=3, stride=1, padding=0, groups=1, bias=False, dilation=1): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - out_type = FIELD_TYPE['regular'](gspace, outplanes) - return enn.R2Conv(in_type, out_type, kernel_size, - stride=stride, - padding=padding, - groups=groups, - bias=bias, - dilation=dilation, - sigma=None, - frequencies_cutoff=lambda r: 3 * r, ) - - -def ennReLU(inplanes, inplace=True): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - return enn.ReLU(in_type, inplace=inplace) - - -def ennInterpolate(inplanes, scale_factor, mode='nearest', align_corners=False): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - return enn.R2Upsampling(in_type, scale_factor, mode=mode, align_corners=align_corners) - - -def ennMaxPool(inplanes, kernel_size, stride=1, padding=0): - in_type = FIELD_TYPE['regular'](gspace, inplanes) - return enn.PointwiseMaxPool(in_type, kernel_size=kernel_size, stride=stride, padding=padding) - - -def build_conv_layer(cfg, *args, **kwargs): - layer = convnxn(*args, **kwargs) - return layer - - -def build_norm_layer(cfg, num_features, postfix=''): - in_type = FIELD_TYPE['regular'](gspace, num_features) - return 'bn' + str(postfix), enn.InnerBatchNorm(in_type) +from ..utils.enn_layers import (FIELD_TYPE, build_norm_layer, convnxn, + ennInterpolate, ennMaxPool, ennReLU) class ConvModule(enn.EquivariantModule): @@ -104,12 +29,17 @@ def __init__(self, norm_cfg=None, activation='relu', inplace=True, - order=('conv', 'norm', 'act')): + order=('conv', 'norm', 'act'), + gspace=None, + fixparams=False): super(ConvModule, self).__init__() assert conv_cfg is None or isinstance(conv_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict) - self.in_type = enn.FieldType(gspace, [gspace.regular_repr] * in_channels) - self.out_type = enn.FieldType(gspace, [gspace.regular_repr] * out_channels) + self.gspace = gspace + self.in_type = enn.FieldType( + gspace, [gspace.regular_repr] * in_channels) + self.out_type = enn.FieldType( + gspace, [gspace.regular_repr] * out_channels) self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.activation = activation @@ -128,8 +58,8 @@ def __init__(self, if self.with_norm and self.with_bias: warnings.warn('ConvModule has norm and bias at the same time') # build convolution layer - self.conv = build_conv_layer( - conv_cfg, + self.conv = convnxn( + gspace, in_channels, out_channels, kernel_size, @@ -158,7 +88,8 @@ def __init__(self, norm_channels = in_channels if conv_cfg != None and conv_cfg['type'] == 'ORConv': norm_channels = int(norm_channels * 8) - self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) + self.norm_name, norm = build_norm_layer( + norm_cfg, gspace, norm_channels) self.add_module(self.norm_name, norm) # build activation layer @@ -168,7 +99,8 @@ def __init__(self, raise ValueError('{} is currently not supported.'.format( self.activation)) if self.activation == 'relu': - self.activate = ennReLU(out_channels, inplace=self.inplace) + self.activate = ennReLU( + gspace, out_channels, inplace=self.inplace) # Use msra init by default self.init_weights() @@ -196,6 +128,15 @@ def forward(self, x, activate=True, norm=True): def evaluate_output_shape(self, input_shape): return input_shape + def export(self): + self.eval() + submodules = [] + for name, module in self._modules.items(): + if hasattr(module, 'export'): + module = module.export() + submodules.append((name, module)) + return torch.nn.ModuleDict(OrderedDict(submodules)) + @NECKS.register_module class ReFPN(nn.Module): @@ -212,7 +153,9 @@ def __init__(self, no_norm_on_lateral=False, conv_cfg=None, norm_cfg=None, - activation=None): + activation=None, + orientation=8, + fixparams=False): super(ReFPN, self).__init__() assert isinstance(in_channels, list) self.in_channels = in_channels @@ -220,6 +163,13 @@ def __init__(self, self.num_ins = len(in_channels) self.num_outs = num_outs self.activation = activation + + self.orientation = orientation + self.fixparams = fixparams + self.gspace = gspaces.Rot2dOnR2(orientation) + self.in_type = enn.FieldType( + self.gspace, [self.gspace.trivial_repr] * 3) + self.relu_before_extra_convs = relu_before_extra_convs self.no_norm_on_lateral = no_norm_on_lateral self.fp16_enabled = False @@ -236,9 +186,9 @@ def __init__(self, self.add_extra_convs = add_extra_convs self.extra_convs_on_inputs = extra_convs_on_inputs - self.lateral_convs = nn.ModuleList() - self.up_samples = nn.ModuleList() - self.fpn_convs = nn.ModuleList() + self.lateral_convs = enn.ModuleList() + self.up_samples = enn.ModuleList() + self.fpn_convs = enn.ModuleList() for i in range(self.start_level, self.backbone_end_level): l_conv = ConvModule( @@ -248,8 +198,10 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, activation=self.activation, - inplace=False) - up_sample = ennInterpolate(out_channels, 2) + inplace=False, + gspace=self.gspace, + fixparams=fixparams) + up_sample = ennInterpolate(self.gspace, out_channels, 2) fpn_conv = ConvModule( out_channels, out_channels, @@ -258,7 +210,9 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, activation=self.activation, - inplace=False) + inplace=False, + gspace=self.gspace, + fixparams=fixparams) self.lateral_convs.append(l_conv) self.up_samples.append(up_sample) @@ -281,11 +235,13 @@ def __init__(self, conv_cfg=conv_cfg, norm_cfg=norm_cfg, activation=self.activation, - inplace=False) + inplace=False, + gspace=self.gspace, + fixparams=fixparams) self.fpn_convs.append(extra_fpn_conv) - self.max_pools = nn.ModuleList() - self.relus = nn.ModuleList() + self.max_pools = enn.ModuleList() + self.relus = enn.ModuleList() used_backbone_levels = len(self.lateral_convs) if self.num_outs > used_backbone_levels: @@ -293,11 +249,12 @@ def __init__(self, # (e.g., Faster R-CNN, Mask R-CNN) if not self.add_extra_convs: for i in range(self.num_outs - used_backbone_levels): - self.max_pools.append(ennMaxPool(out_channels, 1, stride=2)) + self.max_pools.append( + ennMaxPool(self.gspace, out_channels, 1, stride=2)) # add conv layers on top of original feature maps (RetinaNet) else: for i in range(used_backbone_levels + 1, self.num_outs): - self.relus.append(ennReLU(out_channels)) + self.relus.append(ennReLU(self.gspace, out_channels)) # default init_weights for conv(msra) and norm in ConvModule def init_weights(self): @@ -350,3 +307,12 @@ def forward(self, inputs): outs = [out.tensor for out in outs] return tuple(outs) + + def export(self): + self.eval() + submodules = [] + for name, module in self._modules.items(): + if hasattr(module, 'export'): + module = module.export() + submodules.append((name, module)) + return torch.nn.ModuleDict(OrderedDict(submodules)) diff --git a/mmdet/models/utils/enn_layers.py b/mmdet/models/utils/enn_layers.py new file mode 100644 index 0000000..4085049 --- /dev/null +++ b/mmdet/models/utils/enn_layers.py @@ -0,0 +1,94 @@ +from e2cnn import gspaces +import math +import e2cnn.nn as enn + + +def regular_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = False): + """ build a regular feature map with the specified number of channels""" + assert gspace.fibergroup.order() > 0 + + N = gspace.fibergroup.order() + + if fixparams: + planes *= math.sqrt(N) + + planes = planes / N + planes = int(planes) + + return enn.FieldType(gspace, [gspace.regular_repr] * planes) + + +def trivial_feature_type(gspace: gspaces.GSpace, planes: int, fixparams: bool = False): + """ build a trivial feature map with the specified number of channels""" + + if fixparams: + planes *= math.sqrt(gspace.fibergroup.order()) + + planes = int(planes) + return enn.FieldType(gspace, [gspace.trivial_repr] * planes) + + +FIELD_TYPE = { + "trivial": trivial_feature_type, + "regular": regular_feature_type, +} + + +def conv3x3(gspace, inplanes, out_planes, stride=1, padding=1, dilation=1, bias=False, fixparams=False): + """3x3 convolution with padding""" + in_type = FIELD_TYPE['regular'](gspace, inplanes, fixparams=fixparams) + out_type = FIELD_TYPE['regular'](gspace, out_planes, fixparams=fixparams) + return enn.R2Conv(in_type, out_type, 3, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + sigma=None, + frequencies_cutoff=lambda r: 3 * r) + + +def conv1x1(gspace, inplanes, out_planes, stride=1, padding=0, dilation=1, bias=False, fixparams=False): + """1x1 convolution""" + in_type = FIELD_TYPE['regular'](gspace, inplanes, fixparams=fixparams) + out_type = FIELD_TYPE['regular'](gspace, out_planes, fixparams=fixparams) + return enn.R2Conv(in_type, out_type, 1, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + sigma=None, + frequencies_cutoff=lambda r: 3 * r) + + +def convnxn(gspace, inplanes, out_planes, kernel_size=3, stride=1, padding=0, groups=1, bias=False, dilation=1, fixparams=False): + in_type = FIELD_TYPE['regular'](gspace, inplanes, fixparams=fixparams) + out_type = FIELD_TYPE['regular'](gspace, out_planes, fixparams=fixparams) + return enn.R2Conv(in_type, out_type, kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=bias, + dilation=dilation, + sigma=None, + frequencies_cutoff=lambda r: 3 * r) + + +def build_norm_layer(cfg, gspace, num_features, postfix=''): + in_type = FIELD_TYPE['regular'](gspace, num_features) + return 'bn' + str(postfix), enn.InnerBatchNorm(in_type) + + +def ennReLU(gspace, inplanes, inplace=True): + in_type = FIELD_TYPE['regular'](gspace, inplanes) + return enn.ReLU(in_type, inplace=inplace) + + +def ennInterpolate(gspace, inplanes, scale_factor, mode='nearest', align_corners=False): + in_type = FIELD_TYPE['regular'](gspace, inplanes) + return enn.R2Upsampling(in_type, scale_factor, mode=mode, align_corners=align_corners) + + +def ennMaxPool(gspace, inplanes, kernel_size, stride=1, padding=0): + in_type = FIELD_TYPE['regular'](gspace, inplanes) + return enn.PointwiseMaxPool(in_type, kernel_size=kernel_size, stride=stride, padding=padding) + diff --git a/tools/convert_ReDet_to_torch.py b/tools/convert_ReDet_to_torch.py new file mode 100644 index 0000000..e8309fc --- /dev/null +++ b/tools/convert_ReDet_to_torch.py @@ -0,0 +1,48 @@ +import argparse +from collections import OrderedDict + +import torch +from mmdet.apis import init_detector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert ReDet to standard pytorch layers') + parser.add_argument('config', help="config file path") + parser.add_argument('in_weight', help="input weights of ReDet") + parser.add_argument( + 'out_weight', help="output weights of standard pytorch layers") + args = parser.parse_args() + + return args + + +def convert_ReDet_to_pytorch(config, in_weight, out_weight): + + ckpt = torch.load(in_weight) + old_state_dict = ckpt["state_dict"] + + model = init_detector(config, in_weight, device='cuda:0') + # export to pytorch layers + backbone_dict = model.backbone.export().state_dict() + neck_dict = model.neck.export().state_dict() + + new_state_dict = OrderedDict() + print("copy detection head of the original model") + for key in old_state_dict.keys(): + if 'backbone' in key or 'neck' in key: + continue + new_state_dict[key] = old_state_dict[key] + print("copy converted backbone and neck") + for key in backbone_dict.keys(): + new_state_dict["backbone." + key] = backbone_dict[key] + for key in neck_dict.keys(): + new_state_dict["neck." + key] = neck_dict[key] + + ckpt["state_dict"] = new_state_dict + print("save converted weights to {}".format(out_weight)) + torch.save(ckpt, out_weight) + +if __name__ == '__main__': + args = parse_args() + convert_ReDet_to_pytorch(args.config, args.in_weight, args.out_weight)