Skip to content

Commit

Permalink
Add custom-defined hooks in train api (open-mmlab#3395)
Browse files Browse the repository at this point in the history
* Add hooks

* change hooks to custom_hooks
  • Loading branch information
Johnson-Wang authored Jul 24, 2020
1 parent 352cf7f commit 17c0f8e
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 14 deletions.
2 changes: 1 addition & 1 deletion configs/lvis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
```
pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=lvis"
```
or
or
```
pip install -r requirements/optional.txt
```
Expand Down
2 changes: 1 addition & 1 deletion demo/MMDet_Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1696,4 +1696,4 @@
"outputs": []
}
]
}
}
13 changes: 11 additions & 2 deletions mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer)
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
OptimizerHook, build_optimizer)
from mmcv.utils import build_from_cfg

from mmdet.core import DistEvalHook, EvalHook, Fp16OptimizerHook
from mmdet.datasets import build_dataloader, build_dataset
Expand Down Expand Up @@ -121,6 +122,14 @@ def train_detector(model,
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

# user-defined hooks
if cfg.get('custom_hooks', None):
for hook_cfg in cfg.hooks:
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
15 changes: 9 additions & 6 deletions mmdet/models/dense_heads/corner_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ def _make_layers(self, out_channels, in_channels=256, feat_channels=256):
feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None))

def _init_corner_kpt_layers(self):
"""Initialize corner keypoint layers. Including corner heatmap branch
and corner offset branch. Each branch has two parts: prefix `tl_` for
top-left and `br_` for bottom-right.
"""Initialize corner keypoint layers.
Including corner heatmap branch and corner offset branch. Each branch
has two parts: prefix `tl_` for top-left and `br_` for bottom-right.
"""
self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList()
self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList()
Expand Down Expand Up @@ -184,9 +185,10 @@ def _init_corner_kpt_layers(self):
in_channels=self.in_channels))

def _init_corner_emb_layers(self):
"""Initialize corner embedding layers. Only include corner embedding
branch with two parts: prefix `tl_` for top-left and `br_` for
bottom-right.
"""Initialize corner embedding layers.
Only include corner embedding branch with two parts: prefix `tl_` for
top-left and `br_` for bottom-right.
"""
self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList()

Expand All @@ -202,6 +204,7 @@ def _init_corner_emb_layers(self):

def _init_layers(self):
"""Initialize layers for CornerHead.
Including two parts: corner keypoint layers and corner embedding layers
"""
self._init_corner_kpt_layers()
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ matplotlib
numpy
# need older pillow until torchvision is fixed
Pillow<=6.2.2
pycocotools@git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools
six
terminaltables
torch>=1.3
torchvision
pycocotools@git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools
4 changes: 1 addition & 3 deletions tests/test_models/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,7 @@ def _dummy_bbox_sampling(proposal_list, gt_bboxes, gt_labels):


def test_corner_head_loss():
"""
Tests corner head loss when truth is empty and non-empty
"""
"""Tests corner head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
Expand Down

0 comments on commit 17c0f8e

Please sign in to comment.