From 17c0f8e369998f69cbb5bd7e21322dd9570ea473 Mon Sep 17 00:00:00 2001 From: Wang Xinjiang Date: Fri, 24 Jul 2020 21:54:22 +0800 Subject: [PATCH] Add custom-defined hooks in train api (#3395) * Add hooks * change hooks to custom_hooks --- configs/lvis/README.md | 2 +- demo/MMDet_Tutorial.ipynb | 2 +- mmdet/apis/train.py | 13 +++++++++++-- mmdet/models/dense_heads/corner_head.py | 15 +++++++++------ requirements/runtime.txt | 2 +- tests/test_models/test_heads.py | 4 +--- 6 files changed, 24 insertions(+), 14 deletions(-) diff --git a/configs/lvis/README.md b/configs/lvis/README.md index 3539e5f9169..c36e9470a77 100644 --- a/configs/lvis/README.md +++ b/configs/lvis/README.md @@ -16,7 +16,7 @@ ``` pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=lvis" ``` - or + or ``` pip install -r requirements/optional.txt ``` diff --git a/demo/MMDet_Tutorial.ipynb b/demo/MMDet_Tutorial.ipynb index 348b9a655fe..e5fa16217d7 100644 --- a/demo/MMDet_Tutorial.ipynb +++ b/demo/MMDet_Tutorial.ipynb @@ -1696,4 +1696,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index 7e76539cf0b..6f00eb86fa4 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -3,8 +3,9 @@ import numpy as np import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, - build_optimizer) +from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, + OptimizerHook, build_optimizer) +from mmcv.utils import build_from_cfg from mmdet.core import DistEvalHook, EvalHook, Fp16OptimizerHook from mmdet.datasets import build_dataloader, build_dataset @@ -121,6 +122,14 @@ def train_detector(model, eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + # user-defined hooks + if cfg.get('custom_hooks', None): + for hook_cfg in cfg.hooks: + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = build_from_cfg(hook_cfg, HOOKS) + runner.register_hook(hook, priority=priority) + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py index 43973d7142f..7b5a970817a 100644 --- a/mmdet/models/dense_heads/corner_head.py +++ b/mmdet/models/dense_heads/corner_head.py @@ -147,9 +147,10 @@ def _make_layers(self, out_channels, in_channels=256, feat_channels=256): feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None)) def _init_corner_kpt_layers(self): - """Initialize corner keypoint layers. Including corner heatmap branch - and corner offset branch. Each branch has two parts: prefix `tl_` for - top-left and `br_` for bottom-right. + """Initialize corner keypoint layers. + + Including corner heatmap branch and corner offset branch. Each branch + has two parts: prefix `tl_` for top-left and `br_` for bottom-right. """ self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList() self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList() @@ -184,9 +185,10 @@ def _init_corner_kpt_layers(self): in_channels=self.in_channels)) def _init_corner_emb_layers(self): - """Initialize corner embedding layers. Only include corner embedding - branch with two parts: prefix `tl_` for top-left and `br_` for - bottom-right. + """Initialize corner embedding layers. + + Only include corner embedding branch with two parts: prefix `tl_` for + top-left and `br_` for bottom-right. """ self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList() @@ -202,6 +204,7 @@ def _init_corner_emb_layers(self): def _init_layers(self): """Initialize layers for CornerHead. + Including two parts: corner keypoint layers and corner embedding layers """ self._init_corner_kpt_layers() diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 786c9065ca3..6430b156f59 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -2,8 +2,8 @@ matplotlib numpy # need older pillow until torchvision is fixed Pillow<=6.2.2 +pycocotools@git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools six terminaltables torch>=1.3 torchvision -pycocotools@git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 898c1db2710..3badd9b1194 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -617,9 +617,7 @@ def _dummy_bbox_sampling(proposal_list, gt_bboxes, gt_labels): def test_corner_head_loss(): - """ - Tests corner head loss when truth is empty and non-empty - """ + """Tests corner head loss when truth is empty and non-empty.""" s = 256 img_metas = [{ 'img_shape': (s, s, 3),