Skip to content

Commit

Permalink
fix dynamic_shape bug of SyncRandomSizeHook (open-mmlab#6144)
Browse files Browse the repository at this point in the history
* fix SyncRandomSizeHook dynamic_shape bug

* fix unittest

* fix unittest

* update comments

* fix scale error

* update docstr
  • Loading branch information
hhaAndroid authored Sep 23, 2021
1 parent c91451f commit 8e89779
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 13 deletions.
1 change: 0 additions & 1 deletion configs/yolox/yolox_s_8x8_300e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@
type='SyncRandomSizeHook',
ratio_range=(14, 26),
img_scale=img_scale,
interval=interval,
priority=48),
dict(
type='SyncNormHook',
Expand Down
1 change: 0 additions & 1 deletion configs/yolox/yolox_tiny_8x8_300e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
type='SyncRandomSizeHook',
ratio_range=(10, 20),
img_scale=img_scale,
interval=interval,
priority=48),
dict(
type='SyncNormHook',
Expand Down
27 changes: 21 additions & 6 deletions mmdet/core/hook/sync_random_size_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings

import torch
from mmcv.runner import get_dist_info
Expand All @@ -9,34 +10,48 @@

@HOOKS.register_module()
class SyncRandomSizeHook(Hook):
"""Change and synchronize the random image size across ranks, currently
used in YOLOX.
"""Change and synchronize the random image size across ranks.
SyncRandomSizeHook is deprecated, please use Resize pipeline
to achieve similar functions. Such as `dict(type='Resize', img_scale=[(448,
448), (832, 832)], multiscale_mode='range', keep_ratio=True)`.
Note: Due to the multi-process dataloader, its behavior is different
from YOLOX's official implementation, the official is to change the
size every fixed iteration interval and what we achieved is a fixed
epoch interval.
Args:
ratio_range (tuple[int]): Random ratio range. It will be multiplied
by 32, and then change the dataset output image size.
Default: (14, 26).
img_scale (tuple[int]): Size of input image. Default: (640, 640).
interval (int): The interval of change image size. Default: 10.
interval (int): The epoch interval of change image size. Default: 1.
device (torch.device | str): device for returned tensors.
Default: 'cuda'.
"""

def __init__(self,
ratio_range=(14, 26),
img_scale=(640, 640),
interval=10,
interval=1,
device='cuda'):
warnings.warn('DeprecationWarning: SyncRandomSizeHook is deprecated. '
'Please use Resize pipeline to achieve similar '
'functions. Due to the multi-process dataloader, '
'its behavior is different from YOLOX\'s official '
'implementation, the official is to change the size '
'every fixed iteration interval and what we achieved '
'is a fixed epoch interval.')
self.rank, world_size = get_dist_info()
self.is_distributed = world_size > 1
self.ratio_range = ratio_range
self.img_scale = img_scale
self.interval = interval
self.device = device

def after_train_iter(self, runner):
def after_train_epoch(self, runner):
"""Change the dataset output image size."""
if self.ratio_range is not None and (runner.iter +
if self.ratio_range is not None and (runner.epoch +
1) % self.interval == 0:
# Due to DDP and DP get the device behavior inconsistent,
# so we did not get the device from runner.model.
Expand Down
2 changes: 0 additions & 2 deletions mmdet/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,6 @@ def __getitem__(self, idx):

if 'mix_results' in results:
results.pop('mix_results')
if 'img_scale' in results:
results.pop('img_scale')

return results

Expand Down
25 changes: 22 additions & 3 deletions tests/test_utils/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mmcv.runner import (CheckpointHook, IterTimerHook, PaviLoggerHook,
build_runner)
from torch.nn.init import constant_
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from mmdet.core.hook import ExpMomentumEMAHook, YOLOXLrUpdaterHook
from mmdet.core.hook.sync_norm_hook import SyncNormHook
Expand Down Expand Up @@ -254,12 +254,31 @@ def test_sync_random_size_hook():
# Only used to prevent program errors
SyncRandomSizeHook()

loader = DataLoader(torch.ones((5, 2)))
class DemoDataset(Dataset):

def __getitem__(self, item):
return torch.ones(2)

def __len__(self):
return 5

def update_dynamic_scale(self, dynamic_scale):
pass

loader = DataLoader(DemoDataset())
runner = _build_demo_runner()
runner.register_hook_from_cfg(dict(type='SyncRandomSizeHook'))
runner.register_hook_from_cfg(
dict(type='SyncRandomSizeHook', device='cpu'))
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)

if torch.cuda.is_available():
runner = _build_demo_runner()
runner.register_hook_from_cfg(
dict(type='SyncRandomSizeHook', device='cuda'))
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)


@pytest.mark.parametrize('set_loss', [
dict(set_loss_nan=False, set_loss_inf=False),
Expand Down

0 comments on commit 8e89779

Please sign in to comment.