diff --git a/configs/yolox/yolox_s_8x8_300e_coco.py b/configs/yolox/yolox_s_8x8_300e_coco.py index 71374953607..c414051bb73 100644 --- a/configs/yolox/yolox_s_8x8_300e_coco.py +++ b/configs/yolox/yolox_s_8x8_300e_coco.py @@ -129,7 +129,6 @@ type='SyncRandomSizeHook', ratio_range=(14, 26), img_scale=img_scale, - interval=interval, priority=48), dict( type='SyncNormHook', diff --git a/configs/yolox/yolox_tiny_8x8_300e_coco.py b/configs/yolox/yolox_tiny_8x8_300e_coco.py index 4d517cbc6d1..292bae70300 100644 --- a/configs/yolox/yolox_tiny_8x8_300e_coco.py +++ b/configs/yolox/yolox_tiny_8x8_300e_coco.py @@ -66,7 +66,6 @@ type='SyncRandomSizeHook', ratio_range=(10, 20), img_scale=img_scale, - interval=interval, priority=48), dict( type='SyncNormHook', diff --git a/mmdet/core/hook/sync_random_size_hook.py b/mmdet/core/hook/sync_random_size_hook.py index 3e0199a4d5d..c968991f0e4 100644 --- a/mmdet/core/hook/sync_random_size_hook.py +++ b/mmdet/core/hook/sync_random_size_hook.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import random +import warnings import torch from mmcv.runner import get_dist_info @@ -9,15 +10,22 @@ @HOOKS.register_module() class SyncRandomSizeHook(Hook): - """Change and synchronize the random image size across ranks, currently - used in YOLOX. + """Change and synchronize the random image size across ranks. + SyncRandomSizeHook is deprecated, please use Resize pipeline + to achieve similar functions. Such as `dict(type='Resize', img_scale=[(448, + 448), (832, 832)], multiscale_mode='range', keep_ratio=True)`. + + Note: Due to the multi-process dataloader, its behavior is different + from YOLOX's official implementation, the official is to change the + size every fixed iteration interval and what we achieved is a fixed + epoch interval. Args: ratio_range (tuple[int]): Random ratio range. It will be multiplied by 32, and then change the dataset output image size. Default: (14, 26). img_scale (tuple[int]): Size of input image. Default: (640, 640). - interval (int): The interval of change image size. Default: 10. + interval (int): The epoch interval of change image size. Default: 1. device (torch.device | str): device for returned tensors. Default: 'cuda'. """ @@ -25,8 +33,15 @@ class SyncRandomSizeHook(Hook): def __init__(self, ratio_range=(14, 26), img_scale=(640, 640), - interval=10, + interval=1, device='cuda'): + warnings.warn('DeprecationWarning: SyncRandomSizeHook is deprecated. ' + 'Please use Resize pipeline to achieve similar ' + 'functions. Due to the multi-process dataloader, ' + 'its behavior is different from YOLOX\'s official ' + 'implementation, the official is to change the size ' + 'every fixed iteration interval and what we achieved ' + 'is a fixed epoch interval.') self.rank, world_size = get_dist_info() self.is_distributed = world_size > 1 self.ratio_range = ratio_range @@ -34,9 +49,9 @@ def __init__(self, self.interval = interval self.device = device - def after_train_iter(self, runner): + def after_train_epoch(self, runner): """Change the dataset output image size.""" - if self.ratio_range is not None and (runner.iter + + if self.ratio_range is not None and (runner.epoch + 1) % self.interval == 0: # Due to DDP and DP get the device behavior inconsistent, # so we did not get the device from runner.model. diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py index 5a9bc8d11ac..84cdc0ff980 100644 --- a/mmdet/datasets/dataset_wrappers.py +++ b/mmdet/datasets/dataset_wrappers.py @@ -368,8 +368,6 @@ def __getitem__(self, idx): if 'mix_results' in results: results.pop('mix_results') - if 'img_scale' in results: - results.pop('img_scale') return results diff --git a/tests/test_utils/test_hook.py b/tests/test_utils/test_hook.py index ca7964e1691..afd176788f0 100644 --- a/tests/test_utils/test_hook.py +++ b/tests/test_utils/test_hook.py @@ -12,7 +12,7 @@ from mmcv.runner import (CheckpointHook, IterTimerHook, PaviLoggerHook, build_runner) from torch.nn.init import constant_ -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from mmdet.core.hook import ExpMomentumEMAHook, YOLOXLrUpdaterHook from mmdet.core.hook.sync_norm_hook import SyncNormHook @@ -254,12 +254,31 @@ def test_sync_random_size_hook(): # Only used to prevent program errors SyncRandomSizeHook() - loader = DataLoader(torch.ones((5, 2))) + class DemoDataset(Dataset): + + def __getitem__(self, item): + return torch.ones(2) + + def __len__(self): + return 5 + + def update_dynamic_scale(self, dynamic_scale): + pass + + loader = DataLoader(DemoDataset()) runner = _build_demo_runner() - runner.register_hook_from_cfg(dict(type='SyncRandomSizeHook')) + runner.register_hook_from_cfg( + dict(type='SyncRandomSizeHook', device='cpu')) runner.run([loader, loader], [('train', 1), ('val', 1)]) shutil.rmtree(runner.work_dir) + if torch.cuda.is_available(): + runner = _build_demo_runner() + runner.register_hook_from_cfg( + dict(type='SyncRandomSizeHook', device='cuda')) + runner.run([loader, loader], [('train', 1), ('val', 1)]) + shutil.rmtree(runner.work_dir) + @pytest.mark.parametrize('set_loss', [ dict(set_loss_nan=False, set_loss_inf=False),