Skip to content

Commit

Permalink
[Enchance] Set a random seed when the user does not set a seed. (open…
Browse files Browse the repository at this point in the history
…-mmlab#6457)

* fix random seed bug

* add comment

* enchance random seed

* rename

Co-authored-by: Haobo Yuan <[email protected]>
  • Loading branch information
hhaAndroid and HarborYuan authored Nov 8, 2021
1 parent 374d4c0 commit 1ab934e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 10 deletions.
5 changes: 3 additions & 2 deletions mmdet/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from .inference import (async_inference_detector, inference_detector,
init_detector, show_result_pyplot)
from .test import multi_gpu_test, single_gpu_test
from .train import get_root_logger, set_random_seed, train_detector
from .train import (get_root_logger, init_random_seed, set_random_seed,
train_detector)

__all__ = [
'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
'async_inference_detector', 'inference_detector', 'show_result_pyplot',
'multi_gpu_test', 'single_gpu_test'
'multi_gpu_test', 'single_gpu_test', 'init_random_seed'
]
36 changes: 35 additions & 1 deletion mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner)
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg

from mmdet.core import DistEvalHook, EvalHook
Expand All @@ -16,6 +17,39 @@
from mmdet.utils import get_root_logger


def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed

# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()


def set_random_seed(seed, deterministic=False):
"""Set random seed.
Expand Down
14 changes: 7 additions & 7 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mmcv.utils import get_git_hash

from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.apis import init_random_seed, set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
Expand Down Expand Up @@ -148,12 +148,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)

model = build_detector(
Expand Down

0 comments on commit 1ab934e

Please sign in to comment.