diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8edd013c6..73dbbb8b71 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,10 +12,14 @@ repos: rev: 5.11.5 hooks: - id: isort - - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 + - repo: local hooks: - id: yapf + name: yapf + entry: yapf + language: system + types: [python] + args: ["-i"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: diff --git a/examples/distributed_training.py b/examples/distributed_training.py index 236bee234c..c9af4929fa 100644 --- a/examples/distributed_training.py +++ b/examples/distributed_training.py @@ -42,11 +42,10 @@ def compute_metrics(self, results): def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -73,16 +72,14 @@ def main(): transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize(**norm_cfg)])) - train_dataloader = dict( - batch_size=32, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) - val_dataloader = dict( - batch_size=32, - dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + train_dataloader = dict(batch_size=32, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate')) + val_dataloader = dict(batch_size=32, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate')) runner = Runner( model=MMResNet50(), work_dir='./work_dirs', diff --git a/examples/distributed_training_with_flexible_runner.py b/examples/distributed_training_with_flexible_runner.py index 99d2cf257d..43772fbf81 100644 --- a/examples/distributed_training_with_flexible_runner.py +++ b/examples/distributed_training_with_flexible_runner.py @@ -70,16 +70,14 @@ def main(): transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize(**norm_cfg)])) - train_dataloader = dict( - batch_size=128, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) - val_dataloader = dict( - batch_size=128, - dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + train_dataloader = dict(batch_size=128, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate')) + val_dataloader = dict(batch_size=128, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate')) if args.use_deepspeed: strategy = dict( @@ -97,30 +95,28 @@ def main(): # bf16=dict( # enabled=True, # ), - zero_optimization=dict( - stage=3, - allgather_partitions=True, - reduce_scatter=True, - allgather_bucket_size=50000000, - reduce_bucket_size=50000000, - overlap_comm=True, - contiguous_gradients=True, - cpu_offload=False), + zero_optimization=dict(stage=3, + allgather_partitions=True, + reduce_scatter=True, + allgather_bucket_size=50000000, + reduce_bucket_size=50000000, + overlap_comm=True, + contiguous_gradients=True, + cpu_offload=False), ) - optim_wrapper = dict( - type='DeepSpeedOptimWrapper', - optimizer=dict(type='AdamW', lr=1e-3)) + optim_wrapper = dict(type='DeepSpeedOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3)) elif args.use_fsdp: from functools import partial from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy - size_based_auto_wrap_policy = partial( - size_based_auto_wrap_policy, min_num_params=1e7) + size_based_auto_wrap_policy = partial(size_based_auto_wrap_policy, + min_num_params=1e7) strategy = dict( type='FSDPStrategy', model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy)) - optim_wrapper = dict( - type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) + optim_wrapper = dict(type='AmpOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3)) elif args.use_colossalai: from colossalai.tensor.op_wrapper import colo_op_impl @@ -142,20 +138,21 @@ def main(): optim_wrapper = dict(optimizer=dict(type='HybridAdam', lr=1e-3)) else: strategy = None - optim_wrapper = dict( - type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) - - runner = FlexibleRunner( - model=MMResNet50(), - work_dir='./work_dirs', - strategy=strategy, - train_dataloader=train_dataloader, - optim_wrapper=optim_wrapper, - param_scheduler=dict(type='LinearLR'), - train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1), - val_dataloader=val_dataloader, - val_cfg=dict(), - val_evaluator=dict(type=Accuracy)) + optim_wrapper = dict(type='AmpOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3)) + + runner = FlexibleRunner(model=MMResNet50(), + work_dir='./work_dirs', + strategy=strategy, + train_dataloader=train_dataloader, + optim_wrapper=optim_wrapper, + param_scheduler=dict(type='LinearLR'), + train_cfg=dict(by_epoch=True, + max_epochs=10, + val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy)) runner.train() diff --git a/examples/llama2/fsdp_finetune.py b/examples/llama2/fsdp_finetune.py index 0d7e2751b7..d1879c9e1c 100644 --- a/examples/llama2/fsdp_finetune.py +++ b/examples/llama2/fsdp_finetune.py @@ -92,17 +92,14 @@ def parse_args(): def train(): args = parse_args() # Setup distributed related component in Strategy. - strategy = FSDPStrategy( - model_wrapper=dict( - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={LlamaDecoderLayer})), - state_dict_cfg='full', - env_kwargs=dict(randomness=dict(seed=42))) - visualizer = Visualizer( - name='mmengine', - save_dir=args.output_dir, - vis_backends=[dict(type=WandbVisBackend)]) + strategy = FSDPStrategy(model_wrapper=dict( + auto_wrap_policy=partial(transformer_auto_wrap_policy, + transformer_layer_cls={LlamaDecoderLayer})), + state_dict_cfg='full', + env_kwargs=dict(randomness=dict(seed=42))) + visualizer = Visualizer(name='mmengine', + save_dir=args.output_dir, + vis_backends=[dict(type=WandbVisBackend)]) # Prepare model tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint) @@ -112,21 +109,20 @@ def train(): model.train() # Prepare dataset - train_dataset = AlpacaDataset( - tokenizer=tokenizer, data_path=args.data_root) - train_dataloader = DataLoader( - train_dataset, - batch_size=args.batch_size, - sampler=DefaultSampler(train_dataset, seed=0), - collate_fn=default_data_collator, - drop_last=True) + train_dataset = AlpacaDataset(tokenizer=tokenizer, + data_path=args.data_root) + train_dataloader = DataLoader(train_dataset, + batch_size=args.batch_size, + sampler=DefaultSampler(train_dataset, + seed=0), + collate_fn=default_data_collator, + drop_last=True) # Get the prepared model, scheduler and optimizer from strategy epoch_length = len(train_dataloader) max_iters = epoch_length * args.max_epoch - optim_cfg = dict( - optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0), - accumulative_counts=ORI_BATCH_SIZE / args.batch_size) + optim_cfg = dict(optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0), + accumulative_counts=ORI_BATCH_SIZE / args.batch_size) scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)] model, optimizer, schedulers = strategy.prepare( model, diff --git a/examples/llama2/generate.py b/examples/llama2/generate.py index 85635c37ae..83f80ccaa5 100644 --- a/examples/llama2/generate.py +++ b/examples/llama2/generate.py @@ -30,7 +30,6 @@ def parse_args(): with torch.no_grad(): generate_ids = model.generate(inputs.input_ids.cuda(), max_length=300) print( - tokenizer.batch_decode( - generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False)[0]) + tokenizer.batch_decode(generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False)[0]) diff --git a/examples/segmentation/train.py b/examples/segmentation/train.py index dc045a18b9..a26654952f 100644 --- a/examples/segmentation/train.py +++ b/examples/segmentation/train.py @@ -40,8 +40,9 @@ def __init__(self, mask_folder, transform=None, target_transform=None): - super().__init__( - root, transform=transform, target_transform=target_transform) + super().__init__(root, + transform=transform, + target_transform=target_transform) self.img_folder = img_folder self.mask_folder = mask_folder self.images = list( @@ -72,8 +73,9 @@ def __getitem__(self, index): if self.target_transform is not None: labels = self.target_transform(labels) - data_samples = dict( - labels=labels, img_path=img_path, mask_path=mask_path) + data_samples = dict(labels=labels, + img_path=img_path, + mask_path=mask_path) return img, data_samples def __len__(self): @@ -102,8 +104,8 @@ def process(self, data_batch, data_samples): intersect = (labels == preds).sum() union = (torch.logical_or(preds, labels)).sum() iou = (intersect / union).cpu() - self.results.append( - dict(batch_size=len(labels), iou=iou * len(labels))) + self.results.append(dict(batch_size=len(labels), + iou=iou * len(labels))) def compute_metrics(self, results): total_iou = sum(result['iou'] for result in self.results) @@ -151,18 +153,16 @@ def after_val_iter(self, osp.join(saved_dir, osp.basename(img_path))) shutil.copyfile(mask_path, osp.join(saved_dir, osp.basename(mask_path))) - cv2.imwrite( - osp.join(saved_dir, f'pred_{osp.basename(img_path)}'), - pred_mask) + cv2.imwrite(osp.join(saved_dir, f'pred_{osp.basename(img_path)}'), + pred_mask) def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -181,37 +181,33 @@ def main(): target_transform = transforms.Lambda( lambda x: torch.tensor(np.array(x), dtype=torch.long)) - train_set = CamVid( - 'data/CamVid', - img_folder='train', - mask_folder='train_labels', - transform=transform, - target_transform=target_transform) - - valid_set = CamVid( - 'data/CamVid', - img_folder='val', - mask_folder='val_labels', - transform=transform, - target_transform=target_transform) - - train_dataloader = dict( - batch_size=3, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) - val_dataloader = dict( - batch_size=3, - dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + train_set = CamVid('data/CamVid', + img_folder='train', + mask_folder='train_labels', + transform=transform, + target_transform=target_transform) + + valid_set = CamVid('data/CamVid', + img_folder='val', + mask_folder='val_labels', + transform=transform, + target_transform=target_transform) + + train_dataloader = dict(batch_size=3, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate')) + val_dataloader = dict(batch_size=3, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate')) runner = Runner( model=MMDeeplabV3(num_classes), work_dir='./work_dir', train_dataloader=train_dataloader, - optim_wrapper=dict( - type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)), + optim_wrapper=dict(type=AmpOptimWrapper, + optimizer=dict(type=AdamW, lr=2e-4)), train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10), val_dataloader=val_dataloader, val_cfg=dict(), diff --git a/examples/test_time_augmentation.py b/examples/test_time_augmentation.py index 0a896a05a2..f2ed739c22 100644 --- a/examples/test_time_augmentation.py +++ b/examples/test_time_augmentation.py @@ -28,15 +28,14 @@ def _merge_single_sample(self, data_samples): cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10' cfg.model = dict(type='ClsTTAModel', module=cfg.model) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline) - flip_tta = dict( - type='TestTimeAug', - transforms=[ - [ - dict(type='RandomFlip', prob=1.), - dict(type='RandomFlip', prob=0.) - ], - [test_pipeline[-1]], - ]) + flip_tta = dict(type='TestTimeAug', + transforms=[ + [ + dict(type='RandomFlip', prob=1.), + dict(type='RandomFlip', prob=0.) + ], + [test_pipeline[-1]], + ]) # Replace the last transform with `TestTimeAug` cfg.test_dataloader.dataset.pipeline[-1] = flip_tta cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \ diff --git a/examples/text_classification/train.py b/examples/text_classification/train.py index 84a2841729..81e1d17ba3 100644 --- a/examples/text_classification/train.py +++ b/examples/text_classification/train.py @@ -17,11 +17,10 @@ def __init__(self, model): self.model = model def forward(self, label, input_ids, token_type_ids, attention_mask, mode): - output = self.model( - input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - labels=label) + output = self.model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + labels=label) if mode == 'loss': return {'loss': output.loss} elif mode == 'predict': @@ -45,11 +44,10 @@ def compute_metrics(self, results): def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -71,41 +69,36 @@ def collate_fn(data): token_type_ids = torch.stack(token_type_ids) attention_mask = torch.stack(attention_mask) label = torch.tensor(labels) - return dict( - label=label, - input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask) + return dict(label=label, + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) def main(): args = parse_args() - model = BertForSequenceClassification.from_pretrained( - 'bert-base-uncased', num_labels=2) + model = BertForSequenceClassification.from_pretrained('bert-base-uncased', + num_labels=2) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') train_set = load_dataset('imdb', split='train') test_set = load_dataset('imdb', split='test') - train_set = train_set.map( - lambda x: tokenizer( - x['text'], truncation=True, padding=True, max_length=128), - batched=True) - test_set = test_set.map( - lambda x: tokenizer( - x['text'], truncation=True, padding=True, max_length=128), - batched=True) - - train_loader = dict( - batch_size=32, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=collate_fn) - test_loader = dict( - batch_size=32, - dataset=test_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=collate_fn) + train_set = train_set.map(lambda x: tokenizer( + x['text'], truncation=True, padding=True, max_length=128), + batched=True) + test_set = test_set.map(lambda x: tokenizer( + x['text'], truncation=True, padding=True, max_length=128), + batched=True) + + train_loader = dict(batch_size=32, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=collate_fn) + test_loader = dict(batch_size=32, + dataset=test_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=collate_fn) runner = Runner( model=MMBertForClassify(model), train_dataloader=train_loader, diff --git a/examples/text_translation/train.py b/examples/text_translation/train.py index 61f43bafef..12cc11455a 100644 --- a/examples/text_translation/train.py +++ b/examples/text_translation/train.py @@ -19,10 +19,9 @@ def __init__(self, model): def forward(self, label, input_ids, attention_mask, mode): if mode == 'loss': - output = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=label) + output = self.model(input_ids=input_ids, + attention_mask=attention_mask, + labels=label) return {'loss': output.loss} elif mode == 'predict': output = self.model.generate(input_ids) @@ -80,10 +79,9 @@ def collate_fn(data): ).input_ids label[label == tokenizer.pad_token_id] = -100 # ignore contribution to loss - return dict( - label=label, - input_ids=input_dict.input_ids, - attention_mask=input_dict.attention_mask) + return dict(label=label, + input_ids=input_dict.input_ids, + attention_mask=input_dict.attention_mask) def main(): @@ -93,16 +91,14 @@ def main(): books = books['train'].train_test_split(test_size=0.2) train_set, test_set = books['train'], books['test'] - train_loader = dict( - batch_size=16, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=collate_fn) - test_loader = dict( - batch_size=32, - dataset=test_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=collate_fn) + train_loader = dict(batch_size=16, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=collate_fn) + test_loader = dict(batch_size=32, + dataset=test_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=collate_fn) runner = Runner( model=MMT5ForTranslation(model), train_dataloader=train_loader, diff --git a/mmengine/_strategy/__init__.py b/mmengine/_strategy/__init__.py index 764abcf868..2e1a3b2c19 100644 --- a/mmengine/_strategy/__init__.py +++ b/mmengine/_strategy/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION + from .base import BaseStrategy from .colossalai import ColossalAIStrategy from .deepspeed import DeepSpeedStrategy diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index a713da9a70..af444a0d99 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -270,10 +270,9 @@ def _set_randomness( more details. """ from mmengine.runner import set_random_seed - self._seed = set_random_seed( - seed=seed, - deterministic=deterministic, - diff_rank_seed=diff_rank_seed) + self._seed = set_random_seed(seed=seed, + deterministic=deterministic, + diff_rank_seed=diff_rank_seed) def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: """Build model. @@ -322,7 +321,8 @@ def compile_model( Returns: nn.Module: Compiled model. """ - if isinstance(compile, bool) and not compile: + if isinstance(compile, bool) and not compile or \ + isinstance(compile, dict) and not compile.get('disable', False): return model assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( @@ -560,10 +560,10 @@ def _build_param_scheduler( 'Use the max epochs/iters of train loop as default.') param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, **default_args))) + PARAM_SCHEDULERS.build(_scheduler, + default_args=dict( + optimizer=optim_wrapper, + **default_args))) else: raise TypeError( 'scheduler should be a _ParamScheduler object or dict, ' @@ -799,8 +799,10 @@ def load_model_state_dict( else: model = self.model - _load_checkpoint_to_model( - model, state_dict, strict=strict, revise_keys=revise_keys) + _load_checkpoint_to_model(model, + state_dict, + strict=strict, + revise_keys=revise_keys) def load_optim_state_dict(self, state_dict: dict) -> None: """Load optimizer state from dict.""" diff --git a/mmengine/_strategy/colossalai.py b/mmengine/_strategy/colossalai.py index 13d9f38fc3..1a2eac6143 100644 --- a/mmengine/_strategy/colossalai.py +++ b/mmengine/_strategy/colossalai.py @@ -365,8 +365,9 @@ def resume( directly.""" self.logger.info(f'Resume checkpoint from {filename}') - extra_ckpt = self.load_checkpoint( - filename, map_location=map_location, callback=callback) + extra_ckpt = self.load_checkpoint(filename, + map_location=map_location, + callback=callback) if resume_optimizer: self.booster.load_optimizer( @@ -438,10 +439,11 @@ def save_checkpoint( extra_ckpt = dict() if 'meta' not in extra_ckpt: extra_ckpt['meta'] = dict() - extra_ckpt['meta'].update( - seed=self.seed, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine=mmengine.__version__ + get_git_hash()) + extra_ckpt['meta'].update(seed=self.seed, + time=time.strftime('%Y%m%d_%H%M%S', + time.localtime()), + mmengine=mmengine.__version__ + + get_git_hash()) model_dir = join_path(filename, self.MODEL_DIR) optimizer_dir = join_path(filename, self.OPTIMIZER_DIR) @@ -450,14 +452,14 @@ def save_checkpoint( mkdir_or_exist(optimizer_dir) mkdir_or_exist(schedulers_dir) - self.booster.save_model( - self.model.model_wrapper, checkpoint=model_dir, shard=True) + self.booster.save_model(self.model.model_wrapper, + checkpoint=model_dir, + shard=True) if save_optimizer: - self.booster.save_optimizer( - self.optim_wrapper.optimizer, - checkpoint=optimizer_dir, - shard=True) + self.booster.save_optimizer(self.optim_wrapper.optimizer, + checkpoint=optimizer_dir, + shard=True) if is_main_process() and save_param_scheduler: for i, scheduler in enumerate(self.param_schedulers): @@ -470,8 +472,8 @@ def _build_plugin(self, plugin: Union[str, dict]): if isinstance(plugin, str): if plugin == 'gemini': try: - plugin = colo_plugin.GeminiPlugin( - precision='bf16', placement_policy='auto') + plugin = colo_plugin.GeminiPlugin(precision='bf16', + placement_policy='auto') except AssertionError: from colossalai.zero.gemini.placement_policy import \ PlacementPolicyFactory as colo_placement @@ -545,14 +547,14 @@ def _wrap( model_wrapper, optimizer, *_ = self.booster.boost(model, optimizer) optim_wrapper.optimizer = optimizer default_args = {'model_wrapper': model_wrapper, 'model': model} - model_wrapper = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + model_wrapper = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) return model_wrapper, optim_wrapper # type: ignore else: model_wrapper, *_ = self.booster.boost(model) default_args = {'model_wrapper': model_wrapper, 'model': model} - model_wrapper = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + model_wrapper = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) return model_wrapper def _setup_distributed( # type: ignore @@ -561,5 +563,7 @@ def _setup_distributed( # type: ignore backend: str = 'nccl', **kwargs, ): - init_dist( - launcher, backend, init_backend='colossalai', config=self.config) + init_dist(launcher, + backend, + init_backend='colossalai', + config=self.config) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 3f89ff760d..7ff827e9d8 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -24,6 +24,7 @@ STRATEGIES) from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu from mmengine.utils import apply_to, digit_version, get_git_hash + from .base import BaseStrategy @@ -310,10 +311,10 @@ def __init__( self.config.setdefault('gradient_accumulation_steps', 1) self.config['steps_per_print'] = steps_per_print self._inputs_to_half = inputs_to_half - assert (exclude_frozen_parameters is None or - digit_version(deepspeed.__version__) >= digit_version('0.13.2') - ), ('DeepSpeed >= 0.13.2 is required to enable ' - 'exclude_frozen_parameters') + assert (exclude_frozen_parameters is None or digit_version( + deepspeed.__version__) >= digit_version('0.13.2')), ( + 'DeepSpeed >= 0.13.2 is required to enable ' + 'exclude_frozen_parameters') self.exclude_frozen_parameters = exclude_frozen_parameters register_deepspeed_optimizers() @@ -405,8 +406,8 @@ def _wrap_model(self, model: nn.Module) -> nn.Module: else: engine, *_ = deepspeed.initialize(model=model, config=self.config) - wrapper = MMDeepSpeedEngineWrapper( - model=engine, inputs_to_half=self._inputs_to_half) + wrapper = MMDeepSpeedEngineWrapper(model=engine, + inputs_to_half=self._inputs_to_half) return wrapper def load_checkpoint( @@ -563,12 +564,11 @@ def save_checkpoint( extra_ckpt['optim_wrapper'] = self.optim_state_dict() dirname, basename = osp.split(filename) - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False, - **state_dict_kwargs) + self.model.save_checkpoint(dirname, + tag=basename, + client_state=extra_ckpt, + save_latest=False, + **state_dict_kwargs) else: if self.model.zero_optimization_partition_weights(): state_dict = self.model._zero3_consolidated_16bit_state_dict( diff --git a/mmengine/_strategy/distributed.py b/mmengine/_strategy/distributed.py index dbe17d5aeb..057f8de38d 100644 --- a/mmengine/_strategy/distributed.py +++ b/mmengine/_strategy/distributed.py @@ -9,6 +9,7 @@ from mmengine.dist import init_dist, is_distributed, master_only from mmengine.model import convert_sync_batchnorm, is_model_wrapper from mmengine.registry import MODEL_WRAPPERS, STRATEGIES + from .single_device import SingleDeviceStrategy @@ -93,15 +94,14 @@ def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: if self.model_wrapper is None: # set broadcast_buffers as False to keep compatibility with # OpenMMLab repos - self.model_wrapper = dict( - type='MMDistributedDataParallel', broadcast_buffers=False) - - default_args = dict( - type='MMDistributedDataParallel', - module=model, - device_ids=[int(os.environ['LOCAL_RANK'])]) - model = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + self.model_wrapper = dict(type='MMDistributedDataParallel', + broadcast_buffers=False) + + default_args = dict(type='MMDistributedDataParallel', + module=model, + device_ids=[int(os.environ['LOCAL_RANK'])]) + model = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) return model @master_only @@ -114,9 +114,8 @@ def save_checkpoint( extra_ckpt: Optional[dict] = None, callback: Optional[Callable] = None, ) -> None: - super().save_checkpoint( - filename=filename, - save_optimizer=save_optimizer, - save_param_scheduler=save_param_scheduler, - extra_ckpt=extra_ckpt, - callback=callback) + super().save_checkpoint(filename=filename, + save_optimizer=save_optimizer, + save_param_scheduler=save_param_scheduler, + extra_ckpt=extra_ckpt, + callback=callback) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 0788fafdab..1a1cec07c4 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -29,6 +29,7 @@ from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, STRATEGIES, Registry) from mmengine.utils import get_git_hash, mkdir_or_exist + from .distributed import DDPStrategy from .utils import MetaTensorContext @@ -151,12 +152,11 @@ def _wrap_model(self, model: nn.Module) -> None: if self.model_wrapper is None: self.model_wrapper = dict(type='MMFullyShardedDataParallel') - default_args = dict( - module=model, - device_id=int(os.environ['LOCAL_RANK']), - type='MMFullyShardedDataParallel') - model = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + default_args = dict(module=model, + device_id=int(os.environ['LOCAL_RANK']), + type='MMFullyShardedDataParallel') + model = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) model.set_state_dict_type(model, self.state_dict_type, self.state_dict_config, self.optim_state_dict_config) @@ -408,7 +408,9 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) + optim_state_dict=state_dict, + model=self.model, + optim=self.optim_wrapper.optimizer) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: @@ -539,7 +541,9 @@ def build_optim_wrapper( # Force to load the converted optim_state_dict in full mode. with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): optim_state_dict = FSDP.optim_state_dict_to_load( - optim_state_dict, model, new_optimizer) + optim_state_dict=optim_state_dict, + model=model, + optim=new_optimizer) new_optimizer.load_state_dict(optim_state_dict) optim_wrapper.optimizer = new_optimizer return optim_wrapper @@ -632,10 +636,10 @@ def _build_param_scheduler( 'Use the max epochs/iters of train loop as default.') param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, **default_args))) + PARAM_SCHEDULERS.build(_scheduler, + default_args=dict( + optimizer=optim_wrapper, + **default_args))) else: raise TypeError( 'scheduler should be a _ParamScheduler object or dict, ' diff --git a/mmengine/_strategy/single_device.py b/mmengine/_strategy/single_device.py index c7d8accd5a..ddcdce8966 100644 --- a/mmengine/_strategy/single_device.py +++ b/mmengine/_strategy/single_device.py @@ -10,6 +10,7 @@ from mmengine.optim import BaseOptimWrapper, _ParamScheduler from mmengine.registry import STRATEGIES from mmengine.utils import get_git_hash + from .base import BaseStrategy @@ -150,8 +151,9 @@ def load_checkpoint( callback(checkpoint) state_dict = checkpoint.pop('state_dict') - self.load_model_state_dict( - state_dict, strict=strict, revise_keys=revise_keys) + self.load_model_state_dict(state_dict, + strict=strict, + revise_keys=revise_keys) return checkpoint @@ -191,8 +193,9 @@ def resume( """ self.logger.info(f'Resume checkpoint from {filename}') - checkpoint = self.load_checkpoint( - filename, map_location=map_location, callback=callback) + checkpoint = self.load_checkpoint(filename, + map_location=map_location, + callback=callback) if resume_optimizer: self.load_optim_state_dict(checkpoint.pop('optimizer')) diff --git a/mmengine/analysis/complexity_analysis.py b/mmengine/analysis/complexity_analysis.py index 435e5fe5d3..6daeb925b6 100644 --- a/mmengine/analysis/complexity_analysis.py +++ b/mmengine/analysis/complexity_analysis.py @@ -342,8 +342,8 @@ def fill(lvl: int, prefix: str) -> None: rows.append(('model', format_size(count.pop('')))) fill(0, '') - table = Table( - title=f'parameter count of {model.__class__.__name__}', box=box.ASCII2) + table = Table(title=f'parameter count of {model.__class__.__name__}', + box=box.ASCII2) table.add_column('name') table.add_column('#elements or shape') diff --git a/mmengine/analysis/jit_analysis.py b/mmengine/analysis/jit_analysis.py index 17b294863a..4c4a628291 100644 --- a/mmengine/analysis/jit_analysis.py +++ b/mmengine/analysis/jit_analysis.py @@ -20,6 +20,7 @@ from torch.jit import TracerWarning, _get_trace_graph from mmengine.logging import print_log + from .jit_handles import Handle T = TypeVar('T', bound='JitModelAnalysis') @@ -628,10 +629,9 @@ def _analyze(self) -> 'Statistics': counts[name] += op_counts uncalled_mods = set(self._aliases.values()) - all_seen - stats = Statistics( - counts=counts, - unsupported_ops=unsupported_ops, - uncalled_mods=uncalled_mods) + stats = Statistics(counts=counts, + unsupported_ops=unsupported_ops, + uncalled_mods=uncalled_mods) self._stats = stats self._warn_unsupported_ops(unsupported_ops['']) self._warn_uncalled_mods(uncalled_mods) diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index 3b87d42373..9c4fadc2e5 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -13,6 +13,7 @@ from torch import nn from mmengine.utils import is_tuple_of + from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer, parameter_count) diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 801243c82d..3ca4a13066 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -26,6 +26,7 @@ from mmengine.utils import (check_file_exist, digit_version, get_installed_path, import_modules_from_strings, is_installed) + from .lazy import LazyAttr, LazyObject from .utils import (ConfigParsingError, ImportTransformer, RemoveAssignFromAST, _gather_abs_import_lazyobj, _get_external_cfg_base_path, @@ -46,9 +47,10 @@ def _lazy2string(cfg_dict, dict_type=None): if isinstance(cfg_dict, dict): dict_type = dict_type or type(cfg_dict) - return dict_type( - {k: _lazy2string(v, dict_type) - for k, v in dict.items(cfg_dict)}) + return dict_type({ + k: _lazy2string(v, dict_type) + for k, v in dict.items(cfg_dict) + }) elif isinstance(cfg_dict, (tuple, list)): return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict) elif isinstance(cfg_dict, (LazyAttr, LazyObject)): @@ -254,8 +256,8 @@ def _merge_a_into_b(a, b): b.clear() all_keys = list(b.keys()) + list(a.keys()) return { - key: - _merge_a_into_b(a.get(key, default), b.get(key, default)) + key: _merge_a_into_b(a.get(key, default), + b.get(key, default)) for key in all_keys if key != DELETE_KEY } else: @@ -271,13 +273,15 @@ def __reduce_ex__(self, proto): # called by CPython interpreter during pickling. See more details in # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501 if digit_version(platform.python_version()) < digit_version('3.8'): - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None) else: - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None, None) def __eq__(self, other): if isinstance(other, ConfigDict): @@ -338,12 +342,12 @@ def add_args(parser: ArgumentParser, elif isinstance(v, dict): add_args(parser, v, prefix + k + '.') elif isinstance(v, abc.Iterable): - parser.add_argument( - '--' + prefix + k, type=type(next(iter(v))), nargs='+') + parser.add_argument('--' + prefix + k, + type=type(next(iter(v))), + nargs='+') else: - print_log( - f'cannot parse key {prefix + k} of type {type(v)}', - logger='current') + print_log(f'cannot parse key {prefix + k} of type {type(v)}', + logger='current') return parser @@ -495,10 +499,9 @@ def fromfile(filename: Union[str, Path], # about lazy in the docstring of ConfigDict ConfigDict.lazy = False - cfg = Config( - cfg_dict, - filename=filename, - format_python_code=format_python_code) + cfg = Config(cfg_dict, + filename=filename, + format_python_code=format_python_code) object.__setattr__(cfg, '_imported_names', imported_names) return cfg @@ -529,9 +532,10 @@ def fromstring(cfg_str: str, file_format: str) -> 'Config': # As a workaround we set `delete=False` and close the temporary file # before opening again. - with tempfile.NamedTemporaryFile( - 'w', encoding='utf-8', suffix=file_format, - delete=False) as temp_file: + with tempfile.NamedTemporaryFile('w', + encoding='utf-8', + suffix=file_format, + delete=False) as temp_file: temp_file.write(cfg_str) cfg = Config.fromfile(temp_file.name) @@ -1094,19 +1098,17 @@ def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]: # the global dict. After the ast transformation, most of import # syntax will be removed (except for the builtin import) and # replaced with the `LazyObject` - transform = ImportTransformer( - global_dict=global_dict, - base_dict=base_dict, - filename=filename) + transform = ImportTransformer(global_dict=global_dict, + base_dict=base_dict, + filename=filename) modified_code = transform.visit(parsed_codes) modified_code, abs_imported = _gather_abs_import_lazyobj( modified_code, filename=filename) imported_names = transform.imported_obj | abs_imported imported_names |= base_imported_names modified_code = ast.fix_missing_locations(modified_code) - exec( - compile(modified_code, filename, mode='exec'), global_dict, - global_dict) + exec(compile(modified_code, filename, mode='exec'), global_dict, + global_dict) ret: dict = {} for key, value in global_dict.items(): @@ -1138,8 +1140,8 @@ def _dict_to_config_dict_lazy(cfg: dict): cfg_dict[key] = Config._dict_to_config_dict_lazy(value) return cfg_dict if isinstance(cfg, (tuple, list)): - return type(cfg)( - Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg) + return type(cfg)(Config._dict_to_config_dict_lazy(_cfg) + for _cfg in cfg) return cfg @staticmethod @@ -1165,8 +1167,9 @@ def _dict_to_config_dict(cfg: dict, cfg = ConfigDict(cfg) dict.__setattr__(cfg, 'scope', scope) for key, value in cfg.items(): - cfg[key] = Config._dict_to_config_dict( - value, scope=scope, has_scope=has_scope) + cfg[key] = Config._dict_to_config_dict(value, + scope=scope, + has_scope=has_scope) elif isinstance(cfg, tuple): cfg = tuple( Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) @@ -1475,16 +1478,16 @@ def _format_dict(input_dict, outest_level=False): text = _format_dict(cfg_dict, outest_level=True) if self._format_python_code: # copied from setup.cfg - yapf_style = dict( - based_on_style='pep8', - blank_line_before_nested_class_or_def=True, - split_before_expression_after_opening_paren=True) + yapf_style = dict(based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) try: if digit_version(yapf.__version__) >= digit_version('0.40.2'): text, _ = FormatCode(text, style_config=yapf_style) else: - text, _ = FormatCode( - text, style_config=yapf_style, verify=True) + text, _ = FormatCode(text, + style_config=yapf_style, + verify=True) except: # noqa: E722 raise SyntaxError('Failed to format the config file, please ' f'check the syntax of: \n{text}') @@ -1622,8 +1625,9 @@ def merge_from_dict(self, cfg_dict = super().__getattribute__('_cfg_dict') super().__setattr__( '_cfg_dict', - Config._merge_a_into_b( - option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) + Config._merge_a_into_b(option_cfg_dict, + cfg_dict, + allow_list_keys=allow_list_keys)) @staticmethod def diff(cfg1: Union[str, 'Config'], cfg2: Union[str, 'Config']) -> str: @@ -1633,8 +1637,8 @@ def diff(cfg1: Union[str, 'Config'], cfg2: Union[str, 'Config']) -> str: if isinstance(cfg2, str): cfg2 = Config.fromfile(cfg2) - res = difflib.unified_diff( - cfg1.pretty_text.split('\n'), cfg2.pretty_text.split('\n')) + res = difflib.unified_diff(cfg1.pretty_text.split('\n'), + cfg2.pretty_text.split('\n')) # Convert into rich format for better visualization console = Console() diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py index 81b58fb49a..bb15d689bd 100644 --- a/mmengine/config/utils.py +++ b/mmengine/config/utils.py @@ -175,6 +175,8 @@ def _is_builtin_module(module_name: str) -> bool: origin_path = getattr(spec, 'origin', None) if origin_path is None: return False + if origin_path == 'frozen': + return True origin_path = osp.abspath(origin_path) if ('site-package' in origin_path or 'dist-package' in origin_path or not origin_path.startswith( diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index e63860bee0..8e167ba650 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -11,6 +11,7 @@ from mmengine.logging import print_log from mmengine.registry import DATASETS + from .base_dataset import BaseDataset, force_full_init diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py index 2c9cf96497..d140cc8dc4 100644 --- a/mmengine/dataset/utils.py +++ b/mmengine/dataset/utils.py @@ -158,7 +158,8 @@ def default_collate(data_batch: Sequence) -> Any: return [default_collate(samples) for samples in transposed] elif isinstance(data_item, Mapping): return data_item_type({ - key: default_collate([d[key] for d in data_batch]) + key: + default_collate([d[key] for d in data_batch]) for key in data_item }) else: diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index f70cc3ef46..88e9b4559d 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -646,8 +646,9 @@ def _all_gather_object(object_list: List[Any], # Gather all local sizes. This is so that we can find the max size, and # index until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device) + object_sizes_tensor = torch.zeros(group_size, + dtype=torch.long, + device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] @@ -656,8 +657,9 @@ def _all_gather_object(object_list: List[Any], max_object_size = int(max(object_size_list).item()) # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=current_device) + coalesced_output_tensor = torch.empty(max_object_size * group_size, + dtype=torch.uint8, + device=current_device) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i:max_object_size * (i + 1)] @@ -800,8 +802,9 @@ def _gather_object(obj: Any, # Gather all local sizes. This is so that we can find the max size, and # index until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device) + object_sizes_tensor = torch.zeros(group_size, + dtype=torch.long, + device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] @@ -815,10 +818,9 @@ def _gather_object(obj: Any, # Avoid populating output tensors if the result won't be gathered on this # rank. if my_rank == dst: - coalesced_output_tensor = torch.empty( - max_object_size * group_size, - dtype=torch.uint8, - device=current_device) + coalesced_output_tensor = torch.empty(max_object_size * group_size, + dtype=torch.uint8, + device=current_device) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i:max_object_size * @@ -996,8 +998,8 @@ def collect_results_cpu(result_part: list, if rank == 0: mmengine.mkdir_or_exist('.dist_test') tmpdir = tempfile.mkdtemp(dir='.dist_test') - tmpdir = torch.tensor( - bytearray(tmpdir.encode()), dtype=torch.uint8) + tmpdir = torch.tensor(bytearray(tmpdir.encode()), + dtype=torch.uint8) dir_tensor[:len(tmpdir)] = tmpdir broadcast(dir_tensor, 0) tmpdir = dir_tensor.numpy().tobytes().decode().rstrip() diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index 5d32cec36b..f41b938155 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -105,27 +105,24 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: if is_mlu_available(): import torch_mlu # noqa: F401 torch.mlu.set_device(local_rank) - torch_dist.init_process_group( - backend='cncl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) + torch_dist.init_process_group(backend='cncl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) elif is_npu_available(): import torch_npu # noqa: F401 torch.npu.set_device(local_rank) - torch_dist.init_process_group( - backend='hccl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) + torch_dist.init_process_group(backend='hccl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) elif is_musa_available(): import torch_musa # noqa: F401 torch.musa.set_device(rank) - torch_dist.init_process_group( - backend='mccl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) + torch_dist.init_process_group(backend='mccl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) else: torch.cuda.set_device(local_rank) diff --git a/mmengine/evaluator/evaluator.py b/mmengine/evaluator/evaluator.py index 930ce93028..065e057aa8 100644 --- a/mmengine/evaluator/evaluator.py +++ b/mmengine/evaluator/evaluator.py @@ -4,6 +4,7 @@ from mmengine.dataset import pseudo_collate from mmengine.registry import EVALUATOR, METRICS from mmengine.structures import BaseDataElement + from .metric import BaseMetric diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 1292ce61ec..06396e103f 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -119,11 +119,10 @@ def evaluate(self, size: int) -> dict: level=logging.WARNING) if self.collect_device == 'cpu': - results = collect_results( - self.results, - size, - self.collect_device, - tmpdir=self.collect_dir) + results = collect_results(self.results, + size, + self.collect_device, + tmpdir=self.collect_dir) else: results = collect_results(self.results, size, self.collect_device) @@ -168,8 +167,8 @@ def __init__(self, out_file_path: str, collect_device: str = 'cpu', collect_dir: Optional[str] = None) -> None: - super().__init__( - collect_device=collect_device, collect_dir=collect_dir) + super().__init__(collect_device=collect_device, + collect_dir=collect_dir) if not out_file_path.endswith(('.pkl', '.pickle')): raise ValueError('The output file must be a pkl file.') self.out_file_path = out_file_path @@ -181,9 +180,8 @@ def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: def compute_metrics(self, results: list) -> dict: """Dump the prediction results to a pickle file.""" dump(results, self.out_file_path) - print_log( - f'Results has been saved to {self.out_file_path}.', - logger='current') + print_log(f'Results has been saved to {self.out_file_path}.', + logger='current') return {} diff --git a/mmengine/fileio/backends/base.py b/mmengine/fileio/backends/base.py index 9331edf598..6759d8b2a8 100644 --- a/mmengine/fileio/backends/base.py +++ b/mmengine/fileio/backends/base.py @@ -21,10 +21,9 @@ class BaseStorageBackend(metaclass=ABCMeta): @property def allow_symlink(self): - print_log( - 'allow_symlink will be deprecated in future', - logger='current', - level=logging.WARNING) + print_log('allow_symlink will be deprecated in future', + logger='current', + level=logging.WARNING) return self._allow_symlink @property diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py index eb47923e56..60cce2145a 100644 --- a/mmengine/fileio/backends/lmdb_backend.py +++ b/mmengine/fileio/backends/lmdb_backend.py @@ -70,12 +70,11 @@ def get_text(self, filepath, encoding=None): def _get_client(self): import lmdb - return lmdb.open( - self.db_path, - readonly=self.readonly, - lock=self.lock, - readahead=self.readahead, - **self.kwargs) + return lmdb.open(self.db_path, + readonly=self.readonly, + lock=self.lock, + readahead=self.readahead, + **self.kwargs) def __del__(self): if self._client is not None: diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index c7d5f04621..ea7bd9fdc3 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -7,6 +7,7 @@ from typing import Generator, Iterator, Optional, Tuple, Union import mmengine + from .base import BaseStorageBackend @@ -156,8 +157,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return osp.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py index 3994372f66..21deaf3839 100644 --- a/mmengine/fileio/backends/petrel_backend.py +++ b/mmengine/fileio/backends/petrel_backend.py @@ -10,6 +10,7 @@ import mmengine from mmengine.utils import has_method + from .base import BaseStorageBackend @@ -605,8 +606,9 @@ def rmtree(self, dir_path: Union[str, Path]) -> None: >>> dir_path = 'petrel://path/of/dir' >>> backend.rmtree(dir_path) """ - for path in self.list_dir_or_file( - dir_path, list_dir=False, recursive=True): + for path in self.list_dir_or_file(dir_path, + list_dir=False, + recursive=True): filepath = self.join_path(dir_path, path) self.remove(filepath) diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 61551d3d1d..6393f56163 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -7,6 +7,7 @@ from mmengine.logging import print_log from mmengine.utils import is_filepath + from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, LocalBackend, MemcachedBackend, PetrelBackend) @@ -271,13 +272,17 @@ def get_text(self, filepath): `New in version 1.3.15.` """ if backend is not None: - cls._register_backend( - name, backend, force=force, prefixes=prefixes) + cls._register_backend(name, + backend, + force=force, + prefixes=prefixes) return def _register(backend_cls): - cls._register_backend( - name, backend_cls, force=force, prefixes=prefixes) + cls._register_backend(name, + backend_cls, + force=force, + prefixes=prefixes) return backend_cls return _register @@ -385,8 +390,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return self.client.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/fileio/handlers/registry_utils.py b/mmengine/fileio/handlers/registry_utils.py index 106fc881f2..49a50d35fc 100644 --- a/mmengine/fileio/handlers/registry_utils.py +++ b/mmengine/fileio/handlers/registry_utils.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils import is_list_of + from .base import BaseFileHandler from .json_handler import JsonHandler from .pickle_handler import PickleHandler diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index fdeb4dc6df..00b7b52f6d 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -38,6 +38,7 @@ from typing import Generator, Iterator, Optional, Tuple, Union from mmengine.utils import is_filepath, is_str + from .backends import backends, prefix_to_backends from .file_client import FileClient # file_handlers and register_handler had been moved to @@ -176,8 +177,9 @@ def get( >>> get(filepath) b'hello world' """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.get(filepath) @@ -203,8 +205,9 @@ def get_text( >>> get_text(filepath) 'hello world' """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.get_text(filepath, encoding) @@ -229,8 +232,9 @@ def put( >>> filepath = '/path/of/file' >>> put(b'hello world', filepath) """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) backend.put(obj, filepath) @@ -257,8 +261,9 @@ def put_text( >>> filepath = '/path/of/file' >>> put_text('hello world', filepath) """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) backend.put_text(obj, filepath) @@ -281,8 +286,9 @@ def exists( >>> exists(filepath) True """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.exists(filepath) @@ -307,8 +313,9 @@ def isdir( >>> isdir(filepath) True """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.isdir(filepath) @@ -332,8 +339,9 @@ def isfile( >>> isfile(filepath) True """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.isfile(filepath) @@ -363,8 +371,9 @@ def join_path( >>> join_path(filepath1, filepath2, filepath3) '/path/of/dir/dir2/path/of/file' """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.join_path(filepath, *filepaths) @@ -395,8 +404,9 @@ def get_local_path( >>> with get_local_path('s3://bucket/abc.jpg') as path: ... # do something here """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) with backend.get_local_path(str(filepath)) as local_path: yield local_path @@ -439,8 +449,9 @@ def copyfile( >>> copyfile(src, dst) '/path1/of/dir/file' """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(src, + backend_args=backend_args, + enable_singleton=True) return backend.copyfile(src, dst) @@ -473,8 +484,9 @@ def copytree( >>> copytree(src, dst) '/path/of/dir2' """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(src, + backend_args=backend_args, + enable_singleton=True) return backend.copytree(src, dst) @@ -513,8 +525,9 @@ def copyfile_from_local( >>> copyfile_from_local(src, dst) 's3://openmmlab/mmengine/file' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copyfile_from_local(src, dst) @@ -545,8 +558,9 @@ def copytree_from_local( >>> copyfile_from_local(src, dst) 's3://openmmlab/mmengine/dir' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copytree_from_local(src, dst) @@ -589,8 +603,9 @@ def copyfile_to_local( >>> copyfile_to_local(src, dst) '/path/of/dir/file' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copyfile_to_local(src, dst) @@ -621,8 +636,9 @@ def copytree_to_local( >>> copytree_to_local(src, dst) '/path/of/dir' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copytree_to_local(src, dst) @@ -647,8 +663,9 @@ def remove( >>> filepath = '/path/of/file' >>> remove(filepath) """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) backend.remove(filepath) @@ -667,8 +684,9 @@ def rmtree( >>> dir_path = '/path/of/dir' >>> rmtree(dir_path) """ - backend = get_file_backend( - dir_path, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dir_path, + backend_args=backend_args, + enable_singleton=True) backend.rmtree(dir_path) @@ -702,8 +720,9 @@ def copy_if_symlink_fails( >>> copy_if_symlink_fails(src, dst) True """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(src, + backend_args=backend_args, + enable_singleton=True) return backend.copy_if_symlink_fails(src, dst) @@ -755,8 +774,9 @@ def list_dir_or_file( >>> for file_path in list_dir_or_file(dir_path, recursive=True): ... print(file_path) """ - backend = get_file_backend( - dir_path, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dir_path, + backend_args=backend_args, + enable_singleton=True) yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) @@ -784,8 +804,9 @@ def generate_presigned_url( Returns: str: Generated presigned url. """ - backend = get_file_backend( - url, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(url, + backend_args=backend_args, + enable_singleton=True) return backend.generate_presigned_url(url, client_method, expires_in) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 92a4867bb9..c3e62914b3 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -13,6 +13,7 @@ from mmengine.logging import print_log from mmengine.registry import HOOKS from mmengine.utils import is_list_of, is_seq_of + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -196,10 +197,10 @@ def __init__(self, self.save_best = save_best # rule logic - assert (isinstance(rule, str) or is_list_of(rule, str) - or (rule is None)), ( - '"rule" should be a str or list of str or None, ' - f'but got {type(rule)}') + assert (isinstance(rule, str) or is_list_of(rule, str) or + (rule + is None)), ('"rule" should be a str or list of str or None, ' + f'but got {type(rule)}') if isinstance(rule, list): # check the length of rule list assert len(rule) in [ @@ -440,16 +441,15 @@ def _save_checkpoint_with_step(self, runner, step, meta): ckpt_filename) runner.message_hub.update_info('last_ckpt', self.last_ckpt) - runner.save_checkpoint( - self.out_dir, - ckpt_filename, - self.file_client_args, - save_optimizer=self.save_optimizer, - save_param_scheduler=self.save_param_scheduler, - meta=meta, - by_epoch=self.by_epoch, - backend_args=self.backend_args, - **self.args) + runner.save_checkpoint(self.out_dir, + ckpt_filename, + self.file_client_args, + save_optimizer=self.save_optimizer, + save_param_scheduler=self.save_param_scheduler, + meta=meta, + by_epoch=self.by_epoch, + backend_args=self.backend_args, + **self.args) # Model parallel-like training should involve pulling sharded states # from all ranks, but skip the following procedure. @@ -557,15 +557,14 @@ def _save_best_checkpoint(self, runner, metrics) -> None: runner.message_hub.update_info( runtime_best_ckpt_key, self.best_ckpt_path_dict[key_indicator]) - runner.save_checkpoint( - self.out_dir, - filename=best_ckpt_name, - file_client_args=self.file_client_args, - save_optimizer=False, - save_param_scheduler=False, - meta=meta, - by_epoch=False, - backend_args=self.backend_args) + runner.save_checkpoint(self.out_dir, + filename=best_ckpt_name, + file_client_args=self.file_client_args, + save_optimizer=False, + save_param_scheduler=False, + meta=meta, + by_epoch=False, + backend_args=self.backend_args) runner.logger.info( f'The best checkpoint with {best_score:0.4f} {key_indicator} ' f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') diff --git a/mmengine/hooks/early_stopping_hook.py b/mmengine/hooks/early_stopping_hook.py index 5533ebc84c..517265f93e 100644 --- a/mmengine/hooks/early_stopping_hook.py +++ b/mmengine/hooks/early_stopping_hook.py @@ -4,6 +4,7 @@ from typing import Optional, Tuple, Union from mmengine.registry import HOOKS + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index 5bc1051d0b..504fcd30ab 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -7,6 +7,7 @@ from mmengine.logging import print_log from mmengine.model import is_model_wrapper from mmengine.registry import HOOKS, MODELS + from .hook import DATA_BATCH, Hook @@ -71,8 +72,8 @@ def before_run(self, runner) -> None: if is_model_wrapper(model): model = model.module self.src_model = model - self.ema_model = MODELS.build( - self.ema_cfg, default_args=dict(model=self.src_model)) + self.ema_model = MODELS.build(self.ema_cfg, + default_args=dict(model=self.src_model)) def before_train(self, runner) -> None: """Check the begin_epoch/iter is smaller than max_epochs/iters. @@ -181,8 +182,8 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) - self.ema_model.load_state_dict( - checkpoint['ema_state_dict'], strict=self.strict_load) + self.ema_model.load_state_dict(checkpoint['ema_state_dict'], + strict=self.strict_load) # Support load checkpoint without ema state dict. else: @@ -191,22 +192,20 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: 'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the ' 'initial `ema_state_dict`', 'current', logging.WARNING) - load_state_dict( - self.ema_model.module, - copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load) + load_state_dict(self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" - avg_param = ( - itertools.chain(self.ema_model.module.parameters(), - self.ema_model.module.buffers()) - if self.ema_model.update_buffers else - self.ema_model.module.parameters()) - src_param = ( - itertools.chain(self.src_model.parameters(), - self.src_model.buffers()) - if self.ema_model.update_buffers else self.src_model.parameters()) + avg_param = (itertools.chain(self.ema_model.module.parameters(), + self.ema_model.module.buffers()) + if self.ema_model.update_buffers else + self.ema_model.module.parameters()) + src_param = (itertools.chain(self.src_model.parameters(), + self.src_model.buffers()) + if self.ema_model.update_buffers else + self.src_model.parameters()) for p_avg, p_src in zip(avg_param, src_param): tmp = p_avg.data.clone() p_avg.data.copy_(p_src.data) diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 9a92cdebfe..7b691d107d 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -4,6 +4,7 @@ import torch from mmengine.registry import HOOKS + from ..device import is_cuda_available, is_musa_available from .hook import Hook diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 4e1c4ce8bd..27230c6fad 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -183,8 +183,10 @@ def before_train_iter(self, batch_idx (int): The index of the current batch in the train loop. data_batch (dict or tuple or list, optional): Data from dataloader. """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') + self._before_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + mode='train') def before_val_iter(self, runner, @@ -199,8 +201,10 @@ def before_val_iter(self, data_batch (dict, optional): Data from dataloader. Defaults to None. """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') + self._before_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + mode='val') def before_test_iter(self, runner, @@ -215,8 +219,10 @@ def before_test_iter(self, data_batch (dict or tuple or list, optional): Data from dataloader. Defaults to None. """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') + self._before_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + mode='test') def after_train_iter(self, runner, @@ -232,12 +238,11 @@ def after_train_iter(self, data_batch (dict tuple or list, optional): Data from dataloader. outputs (dict, optional): Outputs from model. """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='train') + self._after_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='train') def after_val_iter(self, runner, @@ -253,12 +258,11 @@ def after_val_iter(self, data_batch (dict or tuple or list, optional): Data from dataloader. outputs (Sequence, optional): Outputs from model. """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='val') + self._after_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='val') def after_test_iter(self, runner, @@ -274,12 +278,11 @@ def after_test_iter(self, data_batch (dict or tuple or list, optional): Data from dataloader. outputs (Sequence, optional): Outputs from model. """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='test') + self._after_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='test') def _before_epoch(self, runner, mode: str = 'train') -> None: """All subclasses should override this method, if they need any diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 5632c2b25e..edbf209d3a 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -3,6 +3,7 @@ from typing import Optional, Sequence, Union from mmengine.registry import HOOKS + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -90,8 +91,8 @@ def _after_iter(self, if mode == 'train': self.time_sec_tot += iter_time.current() # Calculate average iterative time. - time_sec_avg = self.time_sec_tot / ( - runner.iter - self.start_iter + 1) + time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + + 1) # Calculate eta. eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) runner.message_hub.update_info('eta', eta_sec) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index fa0b79dcf9..cfd0a2dd36 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -137,8 +137,8 @@ def __init__(self, self.file_client = FileClient.infer_client(file_client_args, self.out_dir) if file_client_args is None: - self.file_backend = get_file_backend( - self.out_dir, backend_args=backend_args) + self.file_backend = get_file_backend(self.out_dir, + backend_args=backend_args) else: self.file_backend = self.file_client @@ -196,8 +196,9 @@ def after_train_iter(self, else: return runner.logger.info(log_str) - runner.visualizer.add_scalars( - tag, step=runner.iter + 1, file_path=self.json_log_path) + runner.visualizer.add_scalars(tag, + step=runner.iter + 1, + file_path=self.json_log_path) def after_val_iter(self, runner, @@ -262,16 +263,18 @@ def after_val_epoch(self, epoch = 0 else: epoch = runner.epoch - runner.visualizer.add_scalars( - tag, step=epoch, file_path=self.json_log_path) + runner.visualizer.add_scalars(tag, + step=epoch, + file_path=self.json_log_path) else: if (isinstance(runner._train_loop, dict) or runner._train_loop is None): iter = 0 else: iter = runner.iter - runner.visualizer.add_scalars( - tag, step=iter, file_path=self.json_log_path) + runner.visualizer.add_scalars(tag, + step=iter, + file_path=self.json_log_path) def after_test_epoch(self, runner, @@ -288,9 +291,8 @@ def after_test_epoch(self, tag, log_str = runner.log_processor.get_log_after_epoch( runner, len(runner.test_dataloader), 'test', with_non_scalar=True) runner.logger.info(log_str) - dump( - self._process_tags(tag), - osp.join(runner.log_dir, self.json_log_path)) # type: ignore + dump(self._process_tags(tag), + osp.join(runner.log_dir, self.json_log_path)) # type: ignore @staticmethod def _process_tags(tags: dict): diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 3b2f1e610a..60cb0270fd 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -4,6 +4,7 @@ from mmengine.optim import _ParamScheduler from mmengine.registry import HOOKS from mmengine.utils import is_list_of + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 49407e4563..34487caa78 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -7,6 +7,7 @@ from mmengine.registry import HOOKS from mmengine.utils import get_git_hash from mmengine.version import __version__ + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -47,11 +48,10 @@ def before_run(self, runner) -> None: Args: runner (Runner): The runner of the training process. """ - metainfo = dict( - cfg=runner.cfg.pretty_text, - seed=runner.seed, - experiment_name=runner.experiment_name, - mmengine_version=__version__ + get_git_hash()) + metainfo = dict(cfg=runner.cfg.pretty_text, + seed=runner.seed, + experiment_name=runner.experiment_name, + mmengine_version=__version__ + get_git_hash()) runner.message_hub.update_info_dict(metainfo) self.last_loop_stage = None diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index 9aed9dbcf5..6317fb5a3a 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import HOOKS + from .hook import Hook diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index 7cc75757fe..5e85bc24bd 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.dist import all_reduce_params, is_distributed from mmengine.registry import HOOKS + from .hook import Hook diff --git a/mmengine/infer/infer.py b/mmengine/infer/infer.py index 322d885224..c46b7b5a7b 100644 --- a/mmengine/infer/infer.py +++ b/mmengine/infer/infer.py @@ -51,9 +51,8 @@ def __init__(self, *args, **kwargs): assert isinstance(self.visualize_kwargs, set) assert isinstance(self.postprocess_kwargs, set) - all_kwargs = ( - self.preprocess_kwargs | self.forward_kwargs - | self.visualize_kwargs | self.postprocess_kwargs) + all_kwargs = (self.preprocess_kwargs | self.forward_kwargs + | self.visualize_kwargs | self.postprocess_kwargs) assert len(all_kwargs) == ( len(self.preprocess_kwargs) + len(self.forward_kwargs) + @@ -215,8 +214,9 @@ def __call__( ) = self._dispatch_kwargs(**kwargs) ori_inputs = self._inputs_to_list(inputs) - inputs = self.preprocess( - ori_inputs, batch_size=batch_size, **preprocess_kwargs) + inputs = self.preprocess(ori_inputs, + batch_size=batch_size, + **preprocess_kwargs) preds = [] for data in (track(inputs, description='Inference') if self.show_progress else inputs): @@ -286,8 +286,8 @@ def __call__(self, inputs, batch_size=1, **kwargs): Yields: Any: Data processed by the ``pipeline`` and ``collate_fn``. """ - chunked_data = self._get_chunk_data( - map(self.pipeline, inputs), batch_size) + chunked_data = self._get_chunk_data(map(self.pipeline, inputs), + batch_size) yield from map(self.collate_fn, chunked_data) @torch.no_grad() diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index e6cf9fe37d..b35024f5e5 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -57,8 +57,10 @@ class MMFormatter(logging.Formatter): **kwargs: Keyword arguments passed to :meth:`logging.Formatter.__init__`. """ - _color_mapping: dict = dict( - ERROR='red', WARNING='yellow', INFO='white', DEBUG='green') + _color_mapping: dict = dict(ERROR='red', + WARNING='yellow', + INFO='white', + DEBUG='green') def __init__(self, color: bool = True, blink: bool = False, **kwargs): super().__init__(**kwargs) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 82565d8832..f056b5eabb 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -2,17 +2,16 @@ import copy import logging from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import Any, Optional, Union import numpy as np +import torch from mmengine.utils import ManagerMixin + from .history_buffer import HistoryBuffer from .logger import print_log -if TYPE_CHECKING: - import torch - class MessageHub(ManagerMixin): """Message hub for component interaction. MessageHub is created and @@ -92,6 +91,7 @@ def get_current_instance(cls) -> 'MessageHub': cls.get_instance('mmengine') return super().get_current_instance() + @torch.compiler.disable def update_scalar(self, key: str, value: Union[int, float, np.ndarray, 'torch.Tensor'], @@ -342,8 +342,11 @@ def _get_valid_value( else: # check whether value is torch.Tensor but don't want # to import torch in this file - assert hasattr(value, 'numel') and value.numel() == 1 - value = value.item() + if hasattr(value, 'numel') and value.numel() == 1: + value = value.item() + else: + print_log(f"MessageHub got unexpceted log: {value}", + level=logging.WARN) return value # type: ignore def state_dict(self) -> dict: @@ -374,10 +377,9 @@ def state_dict(self) -> dict: logger='current', level=logging.WARNING) saved_info[key] = value - return dict( - log_scalars=saved_scalars, - runtime_info=saved_info, - resumed_keys=self._resumed_keys) + return dict(log_scalars=saved_scalars, + runtime_info=saved_info, + resumed_keys=self._resumed_keys) def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None: """Loads log scalars, runtime information and resumed keys from diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 033512a985..41f65f41fd 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version + from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage, MomentumAnnealingEMA, StochasticWeightAverage) from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 58457c2a6e..eb14294cf4 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -96,9 +96,8 @@ def update_parameters(self, model: nn.Module) -> None: Args: model (nn.Module): The model whose parameters will be averaged. """ - src_parameters = ( - model.state_dict() - if self.update_buffers else dict(model.named_parameters())) + src_parameters = (model.state_dict() if self.update_buffers else dict( + model.named_parameters())) if self.steps == 0: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) @@ -138,9 +137,8 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor, steps (int): The number of times the parameters have been updated. """ - averaged_param.add_( - source_param - averaged_param, - alpha=1 / float(steps // self.interval + 1)) + averaged_param.add_(source_param - averaged_param, + alpha=1 / float(steps // self.interval + 1)) @MODELS.register_module() @@ -238,12 +236,11 @@ def __init__(self, interval: int = 1, device: Optional[torch.device] = None, update_buffers: bool = False) -> None: - super().__init__( - model=model, - momentum=momentum, - interval=interval, - device=device, - update_buffers=update_buffers) + super().__init__(model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' self.gamma = gamma diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 299cd67557..660054dc6c 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -9,6 +9,7 @@ from mmengine.optim import OptimWrapper from mmengine.registry import MODELS from mmengine.utils import is_list_of + from ..base_module import BaseModule from .data_preprocessor import BaseDataPreprocessor diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 4d621851b0..3cd38b4286 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -9,6 +9,7 @@ from mmengine.registry import MODELS from mmengine.structures import BaseDataElement from mmengine.utils import is_seq_of + from ..utils import stack_batch CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 3cfe0b14a8..6eee81b4c0 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -10,6 +10,7 @@ from mmengine.dist import master_only from mmengine.logging import MMLogger, print_log + from .weight_init import PretrainedInit, initialize, update_init_info from .wrappers.utils import is_model_wrapper @@ -135,11 +136,10 @@ def init_weights(self): m, 'is_init', False): m.init_weights() # users may overload the `init_weights` - update_init_info( - m, - init_info=f'Initialized by ' - f'user-defined `init_weights`' - f' in {m.__class__.__name__} ') + update_init_info(m, + init_info=f'Initialized by ' + f'user-defined `init_weights`' + f' in {m.__class__.__name__} ') if self.init_cfg and pretrained_cfg: initialize(self, pretrained_cfg) self._is_init = True diff --git a/mmengine/model/efficient_conv_bn_eval.py b/mmengine/model/efficient_conv_bn_eval.py index 9cb2ad6199..ef12ffa818 100644 --- a/mmengine/model/efficient_conv_bn_eval.py +++ b/mmengine/model/efficient_conv_bn_eval.py @@ -111,10 +111,12 @@ def efficient_conv_bn_eval_graph_transform(fx_model): # note that we directly call `create_node` to fill the `name` # argument. `fx_model.graph.get_attr` and # `fx_model.graph.call_function` does not allow the `name` argument. - conv_get_node = fx_model.graph.create_node( - op='get_attr', target=conv_node.target, name='get_conv') - bn_get_node = fx_model.graph.create_node( - op='get_attr', target=bn_node.target, name='get_bn') + conv_get_node = fx_model.graph.create_node(op='get_attr', + target=conv_node.target, + name='get_conv') + bn_get_node = fx_model.graph.create_node(op='get_attr', + target=bn_node.target, + name='get_bn') # prepare args for the fused function args = (bn_get_node, conv_get_node, conv_node.args[0]) # create a new node diff --git a/mmengine/model/test_time_aug.py b/mmengine/model/test_time_aug.py index c623eec8bc..65fcab5405 100644 --- a/mmengine/model/test_time_aug.py +++ b/mmengine/model/test_time_aug.py @@ -7,6 +7,7 @@ from mmengine.registry import MODELS from mmengine.structures import BaseDataElement + from .base_model import BaseModel # multi-batch inputs processed by different augmentations from the same batch. @@ -124,9 +125,10 @@ def test_step(self, data): data_list: Union[List[dict], List[list]] if isinstance(data, dict): num_augs = len(data[next(iter(data))]) - data_list = [{key: value[idx] - for key, value in data.items()} - for idx in range(num_augs)] + data_list = [{ + key: value[idx] + for key, value in data.items() + } for idx in range(num_augs)] elif isinstance(data, (tuple, list)): num_augs = len(data[0]) data_list = [[_data[idx] for _data in data] diff --git a/mmengine/model/utils.py b/mmengine/model/utils.py index c78ea3134d..0d30aa44ca 100644 --- a/mmengine/model/utils.py +++ b/mmengine/model/utils.py @@ -199,10 +199,9 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module: try: module_output.add_module(name, revert_sync_batchnorm(child)) except Exception: - print_log( - F'Failed to convert {child} from SyncBN to BN!', - logger='current', - level=logging.WARNING) + print_log(F'Failed to convert {child} from SyncBN to BN!', + logger='current', + level=logging.WARNING) del module return module_output diff --git a/mmengine/model/weight_init.py b/mmengine/model/weight_init.py index b6d0186ed7..c1d5a07d08 100644 --- a/mmengine/model/weight_init.py +++ b/mmengine/model/weight_init.py @@ -97,11 +97,15 @@ def kaiming_init(module, assert distribution in ['uniform', 'normal'] if hasattr(module, 'weight') and module.weight is not None: if distribution == 'uniform': - nn.init.kaiming_uniform_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + nn.init.kaiming_uniform_(module.weight, + a=a, + mode=mode, + nonlinearity=nonlinearity) else: - nn.init.kaiming_normal_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + nn.init.kaiming_normal_(module.weight, + a=a, + mode=mode, + nonlinearity=nonlinearity) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) @@ -109,13 +113,12 @@ def kaiming_init(module, def caffe2_xavier_init(module, bias=0): # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # Acknowledgment to FAIR's internal code - kaiming_init( - module, - a=1, - mode='fan_in', - nonlinearity='leaky_relu', - bias=bias, - distribution='uniform') + kaiming_init(module, + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + bias=bias, + distribution='uniform') def bias_init_with_prob(prior_prob): @@ -450,12 +453,11 @@ class Caffe2XavierInit(KaimingInit): # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # Acknowledgment to FAIR's internal code def __init__(self, **kwargs): - super().__init__( - a=1, - mode='fan_in', - nonlinearity='leaky_relu', - distribution='uniform', - **kwargs) + super().__init__(a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform', + **kwargs) def __call__(self, module): super().__call__(module) @@ -487,16 +489,14 @@ def __call__(self, module): load_state_dict) if self.prefix is None: print_log(f'load model from: {self.checkpoint}', logger='current') - load_checkpoint( - module, - self.checkpoint, - map_location=self.map_location, - strict=False, - logger='current') + load_checkpoint(module, + self.checkpoint, + map_location=self.map_location, + strict=False, + logger='current') else: - print_log( - f'load {self.prefix} in model from: {self.checkpoint}', - logger='current') + print_log(f'load {self.prefix} in model from: {self.checkpoint}', + logger='current') state_dict = _load_checkpoint_with_prefix( self.prefix, self.checkpoint, map_location=self.map_location) load_state_dict(module, state_dict, strict=False, logger='current') diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py index 90eddabbe1..35480c8df3 100644 --- a/mmengine/model/wrappers/__init__.py +++ b/mmengine/model/wrappers/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version + from .distributed import MMDistributedDataParallel from .seperate_distributed import MMSeparateDistributedDataParallel from .utils import is_model_wrapper diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 4113aebf9e..dda05ff685 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -6,6 +6,7 @@ from mmengine.optim import OptimWrapper from mmengine.registry import MODEL_WRAPPERS + from ..utils import detect_anomalous_params MODEL_WRAPPERS.register_module(module=DistributedDataParallel) diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py index df128597b1..d991b7d703 100644 --- a/mmengine/model/wrappers/fully_sharded_distributed.py +++ b/mmengine/model/wrappers/fully_sharded_distributed.py @@ -233,17 +233,16 @@ def parse_dtype(dtype): kwargs['ignored_modules'] = self._get_ignored_modules( module, kwargs['ignored_modules']) - super().__init__( - module=module, - process_group=process_group, - sharding_strategy=sharding_strategy, - auto_wrap_policy=auto_wrap_policy, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - param_init_fn=param_init_fn, - use_orig_params=use_orig_params, - **kwargs) + super().__init__(module=module, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=auto_wrap_policy, + cpu_offload=cpu_offload, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + param_init_fn=param_init_fn, + use_orig_params=use_orig_params, + **kwargs) def train_step(self, data: dict, optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py index ac9c2383c3..43e860c124 100644 --- a/mmengine/model/wrappers/seperate_distributed.py +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -9,6 +9,7 @@ from mmengine.device import get_device from mmengine.optim import OptimWrapperDict from mmengine.registry import MODEL_WRAPPERS + from .distributed import MMDistributedDataParallel diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..b3beb4ef2e 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from functools import partial from typing import Union import torch @@ -10,6 +11,7 @@ from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION + from .optimizer_wrapper import OptimWrapper if is_npu_available(): @@ -17,7 +19,8 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - from torch.cuda.amp import GradScaler + from torch.amp import GradScaler as amp_GradScaler + GradScaler = partial(amp_GradScaler, device='cuda') @OPTIM_WRAPPERS.register_module() diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index a2e6190460..ad38dad21a 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -9,6 +9,7 @@ # from mmengine.model.wrappers import is_model_wrapper import mmengine from mmengine.registry import OPTIM_WRAPPERS + from .optimizer_wrapper import OptimWrapper try: diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index fef95f729a..edb36a3c56 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import inspect +import warnings from typing import List, Union import torch @@ -9,6 +10,8 @@ from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available, is_npu_support_full_precision from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS + +from .default_constructor import DefaultOptimWrapperConstructor from .optimizer_wrapper import OptimWrapper @@ -26,8 +29,8 @@ def register_torch_optimizers() -> List[str]: if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): if module_name == 'Adafactor': - OPTIMIZERS.register_module( - name='TorchAdafactor', module=_optim) + OPTIMIZERS.register_module(name='TorchAdafactor', + module=_optim) else: OPTIMIZERS.register_module(module=_optim) torch_optimizers.append(module_name) @@ -115,7 +118,7 @@ def register_sophia_optimizers() -> List[str]: Returns: List[str]: A list of registered optimizers' name. """ - optimizers = [] + optimizers = [] # type: ignore try: import Sophia except ImportError: @@ -128,7 +131,7 @@ def register_sophia_optimizers() -> List[str]: try: OPTIMIZERS.register_module(module=_optim) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f'Failed to import {Sophia} for {e}') return optimizers @@ -161,7 +164,7 @@ def register_bitsandbytes_optimizers() -> List[str]: try: OPTIMIZERS.register_module(module=optim_cls, name=name) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f'Failed to import {optim_cls.__name__} for {e}') dadaptation_optimizers.append(name) return dadaptation_optimizers @@ -169,8 +172,8 @@ def register_bitsandbytes_optimizers() -> List[str]: BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers() -def register_transformers_optimizers(): - transformer_optimizers = [] +def register_transformers_optimizers() -> List[str]: + transformer_optimizers: List[str] = [] try: from transformers import Adafactor except ImportError: @@ -179,7 +182,7 @@ def register_transformers_optimizers(): try: OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f'Failed to import Adafactor for {e}') transformer_optimizers.append('Adafactor') return transformer_optimizers @@ -205,8 +208,9 @@ def build_optim_wrapper(model: nn.Module, OptimWrapper: The built optimizer wrapper. """ optim_wrapper_cfg = copy.deepcopy(cfg) - constructor_type = optim_wrapper_cfg.pop('constructor', - 'DefaultOptimWrapperConstructor') + constructor_cfg = optim_wrapper_cfg.pop('constructor', None) + if constructor_cfg is None: + constructor_cfg = dict(type=DefaultOptimWrapperConstructor) paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) # Since the current generation of NPU(Ascend 910) only supports @@ -215,10 +219,10 @@ def build_optim_wrapper(model: nn.Module, if is_npu_available() and not is_npu_support_full_precision(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' + constructor_cfg.update( + dict(optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - dict( - type=constructor_type, - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg)) + constructor_cfg) optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index b623a3e70e..344a57d0cd 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -13,6 +13,7 @@ from mmengine.utils import is_list_of from mmengine.utils.dl_utils import mmcv_full_available from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm + from .optimizer_wrapper import OptimWrapper @@ -199,9 +200,8 @@ def add_params(self, # special rules for norm layers and depth-wise conv layers is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) - is_dwconv = ( - isinstance(module, torch.nn.Conv2d) - and module.in_channels == module.groups) + is_dwconv = (isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) for name, param in module.named_parameters(recurse=False): param_group = {'params': [param]} @@ -272,9 +272,8 @@ def add_params(self, if key == 'params': continue full_name = f'{prefix}.{name}' if prefix else name - print_log( - f'paramwise_options -- {full_name}:{key}={value}', - logger='current') + print_log(f'paramwise_options -- {full_name}:{key}={value}', + logger='current') if mmcv_full_available(): from mmcv.ops import DeformConv2d, ModulatedDeformConv2d @@ -284,11 +283,10 @@ def add_params(self, is_dcn_module = False for child_name, child_mod in module.named_children(): child_prefix = f'{prefix}.{child_name}' if prefix else child_name - self.add_params( - params, - child_mod, - prefix=child_prefix, - is_dcn_module=is_dcn_module) + self.add_params(params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) def __call__(self, model: nn.Module) -> OptimWrapper: if hasattr(model, 'module'): @@ -304,8 +302,8 @@ def __call__(self, model: nn.Module) -> OptimWrapper: if isinstance(optimizer_cls, str): with OPTIMIZERS.switch_scope_and_registry(None) as registry: optimizer_cls = registry.get(self.optimizer_cfg['type']) - fisrt_arg_name = next( - iter(inspect.signature(optimizer_cls).parameters)) + fisrt_arg_name = next(iter( + inspect.signature(optimizer_cls).parameters)) # if no paramwise option is specified, just use the global setting if not self.paramwise_cfg: optimizer_cfg[fisrt_arg_name] = model.parameters() diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py index 41218ef768..75aa4d08b9 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -10,6 +10,7 @@ from mmengine.logging import MessageHub, print_log from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils.dl_utils import has_batch_norm + from .base import BaseOptimWrapper diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index 13bc61d542..48405f2770 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import PARAM_SCHEDULERS + # yapf: disable from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index e356e70f7b..50df22347a 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import PARAM_SCHEDULERS + # yapf: disable from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index 2dcb1af072..9f46034ac7 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -258,14 +258,13 @@ def __init__(self, verbose: bool = False): self.step_size = step_size self.gamma = gamma - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer=optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -288,13 +287,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - step_size=step_size, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + step_size=step_size, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -346,14 +344,13 @@ def __init__(self, verbose: bool = False): self.milestones = Counter(milestones) self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -376,13 +373,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - milestones=milestones, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + milestones=milestones, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -438,14 +434,13 @@ def __init__(self, self.factor = factor self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -521,14 +516,13 @@ def __init__(self, by_epoch: bool = True, verbose: bool = False): self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -638,14 +632,13 @@ def __init__(self, self.T_max = T_max or (end - begin) self.eta_min = eta_min self.eta_min_ratio = eta_min_ratio - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -669,13 +662,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - T_max=T_max, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + T_max=T_max, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self) -> list: """Compute value using chainable form of the scheduler.""" @@ -756,14 +748,13 @@ def __init__(self, self.start_factor = start_factor self.end_factor = end_factor self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -846,14 +837,13 @@ def __init__(self, self.power = power self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -1043,14 +1033,13 @@ def __init__(self, group[f'min_{param_name}'] = \ group[f'initial_{param_name}'] / final_div_factor - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer=optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) def _format_param(self, name, optimizer, param): """Return correctly formatted lr/momentum for each param group.""" @@ -1098,13 +1087,12 @@ def build_iter_from_epoch(cls, end = int(end * epoch_length) if total_steps is not None: total_steps = total_steps * epoch_length - return cls( - *args, - begin=begin, - end=end, - total_steps=total_steps, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + begin=begin, + end=end, + total_steps=total_steps, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -1190,14 +1178,13 @@ def __init__(self, sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) ] - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -1220,13 +1207,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - periods=periods, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + periods=periods, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -1444,8 +1430,9 @@ def __init__(self, self.eps = eps self.monitor = monitor - self._init_is_better( - rule=rule, threshold=threshold, threshold_rule=threshold_rule) + self._init_is_better(rule=rule, + threshold=threshold, + threshold_rule=threshold_rule) self._reset() # remove call self.step() and init self._global_step = 0 diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 3de6798514..2856bacdf6 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -5,6 +5,7 @@ from mmengine.config import Config, ConfigDict from mmengine.utils import ManagerMixin, digit_version + from .registry import Registry if TYPE_CHECKING: diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index e7d8962be4..387b3e3d43 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -12,6 +12,7 @@ from mmengine.config.utils import MODULE2PACKAGE from mmengine.utils import get_object_from_string, is_seq_of + from .default_scope import DefaultScope diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index eb9a225a91..06a4817ea0 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -41,8 +41,8 @@ # manage constructors that customize the optimization hyperparameters. OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor') # mangage all kinds of parameter schedulers like `MultiStepLR` -PARAM_SCHEDULERS = Registry( - 'parameter scheduler', build_func=build_scheduler_from_cfg) +PARAM_SCHEDULERS = Registry('parameter scheduler', + build_func=build_scheduler_from_cfg) # manage all kinds of metrics METRICS = Registry('metric') diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py index 2737e879a7..66b1ac4cfa 100644 --- a/mmengine/registry/utils.py +++ b/mmengine/registry/utils.py @@ -6,6 +6,7 @@ from mmengine.fileio import dump from mmengine.logging import print_log + from . import root from .default_scope import DefaultScope from .registry import Registry @@ -85,8 +86,8 @@ def count_registered_modules(save_path: Optional[str] = None, scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), registries=registries_info) if verbose: - print_log( - f'Finish registry analysis, got: {scan_data}', logger='current') + print_log(f'Finish registry analysis, got: {scan_data}', + logger='current') if save_path is not None: json_path = osp.join(save_path, 'modules_statistic_results.json') dump(scan_data, json_path, indent=2) diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py index 5160a5cfb0..3ad936c3be 100644 --- a/mmengine/runner/_flexible_runner.py +++ b/mmengine/runner/_flexible_runner.py @@ -26,6 +26,7 @@ from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.visualization import Visualizer + from .base_loop import BaseLoop from .checkpoint import find_latest_checkpoint from .log_processor import LogProcessor @@ -708,10 +709,9 @@ def build_visualizer( Visualizer: A Visualizer object build from ``visualizer``. """ if visualizer is None: - visualizer = dict( - name=self.experiment_name, - vis_backends=[dict(type='LocalVisBackend')], - save_dir=self.log_dir) + visualizer = dict(name=self.experiment_name, + vis_backends=[dict(type='LocalVisBackend')], + save_dir=self.log_dir) return Visualizer.get_instance(**visualizer) if isinstance(visualizer, Visualizer): @@ -833,9 +833,9 @@ def build_dataloader( sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build( - sampler_cfg, - default_args=dict(dataset=dataset, seed=sampler_seed)) + sampler = DATA_SAMPLERS.build(sampler_cfg, + default_args=dict(dataset=dataset, + seed=sampler_seed)) else: # fallback to raise error in dataloader # if `sampler_cfg` is not a valid type @@ -848,9 +848,8 @@ def build_dataloader( elif isinstance(batch_sampler_cfg, dict): batch_sampler = DATA_SAMPLERS.build( batch_sampler_cfg, - default_args=dict( - sampler=sampler, - batch_size=dataloader_cfg.pop('batch_size'))) + default_args=dict(sampler=sampler, + batch_size=dataloader_cfg.pop('batch_size'))) else: # fallback to raise error in dataloader # if `batch_sampler_cfg` is not a valid type @@ -955,18 +954,20 @@ def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: 'Only one of `type` or `by_epoch` can exist in `loop_cfg`.') if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, dataloader=self._train_dataloader)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._train_dataloader)) else: by_epoch = loop_cfg.pop('by_epoch') if by_epoch: - loop = EpochBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = EpochBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) else: - loop = IterBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = IterBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) return loop # type: ignore def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: @@ -997,18 +998,16 @@ def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator)) else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator) # type: ignore + loop = ValLoop(**loop_cfg, + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator) # type: ignore return loop # type: ignore @@ -1039,18 +1038,16 @@ def build_test_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) # type: ignore if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator)) else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator) # type: ignore + loop = TestLoop(**loop_cfg, + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator) # type: ignore return loop # type: ignore @@ -1172,12 +1169,11 @@ def train(self) -> nn.Module: compile = copy.copy(self._compile) compile.setdefault('target', 'train_step') - dispatch_kwargs = dict( - epoch_length=len(self.train_dataloader), - max_epochs=self.max_epochs, - max_iters=self.max_iters, - train_micro_batch_size_per_gpu=_get_batch_size( - self.train_dataloader)) # type: ignore + dispatch_kwargs = dict(epoch_length=len(self.train_dataloader), + max_epochs=self.max_epochs, + max_iters=self.max_iters, + train_micro_batch_size_per_gpu=_get_batch_size( + self.train_dataloader)) # type: ignore self.strategy.prepare( self.model, @@ -1215,9 +1211,8 @@ def val(self) -> dict: self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - dispatch_kwargs = dict( - init_weights_for_test_or_val=self.cfg.get( - 'init_weights_for_test_or_val', True)) + dispatch_kwargs = dict(init_weights_for_test_or_val=self.cfg.get( + 'init_weights_for_test_or_val', True)) self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) self.model = self.strategy.model @@ -1242,9 +1237,8 @@ def test(self) -> dict: '`test_evaluator` arguments when initializing runner.') self._test_loop = self.build_test_loop(self._test_loop) # type: ignore - dispatch_kwargs = dict( - init_weights_for_test_or_val=self.cfg.get( - 'init_weights_for_test_or_val', True)) + dispatch_kwargs = dict(init_weights_for_test_or_val=self.cfg.get( + 'init_weights_for_test_or_val', True)) self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) self.model = self.strategy.model @@ -1467,8 +1461,8 @@ def callback(checkpoint): # check whether the number of GPU used for current experiment # is consistent with resuming from checkpoint if 'config' in checkpoint['meta']: - config = mmengine.Config.fromstring( - checkpoint['meta']['config'], file_format='.py') + config = mmengine.Config.fromstring(checkpoint['meta']['config'], + file_format='.py') previous_gpu_ids = config.get('gpu_ids', None) if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0 and len(previous_gpu_ids) != self.world_size): @@ -1525,12 +1519,11 @@ def load_checkpoint(self, def callback(checkpoint): self.call_hook('after_load_checkpoint', checkpoint=checkpoint) - self.strategy.load_checkpoint( - filename, - map_location=map_location, - strict=strict, - revise_keys=revise_keys, - callback=callback) + self.strategy.load_checkpoint(filename, + map_location=map_location, + strict=strict, + revise_keys=revise_keys, + callback=callback) def save_checkpoint( self, @@ -1596,8 +1589,8 @@ def save_checkpoint( filepath = join_path( # type: ignore out_dir, filename, backend_args=backend_args) - meta.update( - cfg=self.cfg.pretty_text, experiment_name=self.experiment_name) + meta.update(cfg=self.cfg.pretty_text, + experiment_name=self.experiment_name) if hasattr(self.train_dataloader.dataset, 'metainfo'): meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 198babc582..0bd23a1f84 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -138,8 +138,9 @@ def autocast(device_type: Optional[str] = None, elif device_type == 'musa': if dtype is None: dtype = torch.get_autocast_gpu_dtype() - with torch.musa.amp.autocast( - enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): + with torch.musa.amp.autocast(enabled=enabled, + dtype=dtype, + cache_enabled=cache_enabled): yield return else: @@ -153,9 +154,8 @@ def autocast(device_type: Optional[str] = None, raise ValueError('User specified autocast device_type must be ' f'cuda or cpu, but got {device_type}') - with torch.autocast( - device_type=device_type, - enabled=enabled, - dtype=dtype, - cache_enabled=cache_enabled): + with torch.autocast(device_type=device_type, + enabled=enabled, + dtype=dtype, + cache_enabled=cache_enabled): yield diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..20c8f9c814 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -48,8 +48,8 @@ def _get_mmengine_home(): mmengine_home = os.path.expanduser( os.getenv( ENV_MMENGINE_HOME, - os.path.join( - os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine'))) + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), + 'mmengine'))) mkdir_or_exist(mmengine_home) return mmengine_home @@ -344,7 +344,9 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load(filename, + map_location=map_location, + weights_only=False) return checkpoint @@ -368,19 +370,17 @@ def load_from_http(filename, """ rank, world_size = get_dist_info() if rank == 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) + checkpoint = load_url(filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) if world_size > 1: torch.distributed.barrier() if rank > 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) + checkpoint = load_url(filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) return checkpoint @@ -412,7 +412,9 @@ def load_from_pavi(filename, map_location=None): with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) + checkpoint = torch.load(downloaded_file, + map_location=map_location, + weights_only=False) return checkpoint @@ -432,10 +434,12 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): Returns: dict or OrderedDict: The loaded checkpoint. """ - file_backend = get_file_backend( - filename, backend_args={'backend': backend}) + file_backend = get_file_backend(filename, + backend_args={'backend': backend}) with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) + checkpoint = torch.load(buffer, + map_location=map_location, + weights_only=False) return checkpoint @@ -504,7 +508,9 @@ def load_from_openmmlab(filename, map_location=None): filename = osp.join(_get_mmengine_home(), model_url) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load(filename, + map_location=map_location, + weights_only=False) return checkpoint @@ -522,8 +528,8 @@ def load_from_mmcls(filename, map_location=None): model_urls = get_mmcls_models() model_name = filename[8:] - checkpoint = load_from_http( - model_urls[model_name], map_location=map_location) + checkpoint = load_from_http(model_urls[model_name], + map_location=map_location) checkpoint = _process_mmcls_checkpoint(checkpoint) return checkpoint @@ -597,9 +603,10 @@ def _load_checkpoint_to_model(model, # strip prefix of state_dict metadata = getattr(state_dict, '_metadata', OrderedDict()) for p, r in revise_keys: - state_dict = OrderedDict( - {re.sub(p, r, k): v - for k, v in state_dict.items()}) + state_dict = OrderedDict({ + re.sub(p, r, k): v + for k, v in state_dict.items() + }) # Keep metadata in state_dict state_dict._metadata = metadata @@ -720,8 +727,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): module._save_to_state_dict(destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: - get_state_dict( - child, destination, prefix + name + '.', keep_vars=keep_vars) + get_state_dict(child, + destination, + prefix + name + '.', + keep_vars=keep_vars) for hook in module._state_dict_hooks.values(): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: @@ -783,8 +792,8 @@ def save_checkpoint(checkpoint, else: file_client = FileClient.infer_client(file_client_args, filename) if file_client_args is None: - file_backend = get_file_backend( - filename, backend_args=backend_args) + file_backend = get_file_backend(filename, + backend_args=backend_args) else: file_backend = file_client diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py index 98183ae317..404000f510 100644 --- a/mmengine/runner/log_processor.py +++ b/mmengine/runner/log_processor.py @@ -301,10 +301,9 @@ def get_log_after_epoch(self, dict(data_src='time', window_size='epoch', method_name='mean')) if 'data_time' not in custom_keys: custom_cfg_copy.append( - dict( - data_src='data_time', - window_size='epoch', - method_name='mean')) + dict(data_src='data_time', + window_size='epoch', + method_name='mean')) parsed_cfg = self._parse_windows_size(runner, batch_idx, custom_cfg_copy) # tag is used to write log information to different backends. diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..b7ae43b7b5 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -7,11 +7,13 @@ import torch from torch.utils.data import DataLoader +from mmengine.dataset.sampler import InfiniteSampler from mmengine.evaluator import Evaluator from mmengine.logging import HistoryBuffer, print_log from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of + from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -123,19 +125,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: Args: data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_train_iter', batch_idx=idx, data_batch=data_batch) + self.runner.call_hook('before_train_iter', + batch_idx=idx, + data_batch=data_batch) # Enable gradient accumulation mode and avoid unnecessary gradient # synchronization during gradient accumulation process. # outputs should be a dict of loss. outputs = self.runner.model.train_step( data_batch, optim_wrapper=self.runner.optim_wrapper) - self.runner.call_hook( - 'after_train_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_train_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) self._iter += 1 def _decide_current_val_interval(self) -> None: @@ -274,7 +276,8 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - if self._iter > 0: + if self._iter > 0 and not isinstance(self.dataloader.sampler, + InfiniteSampler): print_log( f'Advance dataloader {self._iter} steps to skip data ' 'that has already been trained', @@ -305,19 +308,19 @@ def run_iter(self, data_batch: Sequence[dict]) -> None: Args: data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) + self.runner.call_hook('before_train_iter', + batch_idx=self._iter, + data_batch=data_batch) # Enable gradient accumulation mode and avoid unnecessary gradient # synchronization during gradient accumulation process. # outputs should be a dict of loss. outputs = self.runner.model.train_step( data_batch, optim_wrapper=self.runner.optim_wrapper) - self.runner.call_hook( - 'after_train_iter', - batch_idx=self._iter, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_train_iter', + batch_idx=self._iter, + data_batch=data_batch, + outputs=outputs) self._iter += 1 def _decide_current_val_interval(self) -> None: @@ -397,8 +400,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]): data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_val_iter', batch_idx=idx, data_batch=data_batch) + self.runner.call_hook('before_val_iter', + batch_idx=idx, + data_batch=data_batch) # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) @@ -406,11 +410,10 @@ def run_iter(self, idx, data_batch: Sequence[dict]): outputs, self.val_loss = _update_losses(outputs, self.val_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook( - 'after_val_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) @LOOPS.register_module() @@ -480,8 +483,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: Args: data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) + self.runner.call_hook('before_test_iter', + batch_idx=idx, + data_batch=data_batch) # predictions should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.test_step(data_batch) @@ -489,11 +493,10 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: outputs, self.test_loss = _update_losses(outputs, self.test_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) def _parse_losses(losses: Dict[str, HistoryBuffer], diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 7d1f655aad..764d6e7d46 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect import logging import os import os.path as osp @@ -41,6 +42,7 @@ from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, set_multi_processing) from mmengine.visualization import Visualizer + from .activation_checkpointing import turn_on_activation_checkpointing from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, @@ -428,8 +430,8 @@ def __init__( model.setdefault('data_preprocessor', data_preprocessor) self.model = self.build_model(model) # wrap model - self.model = self.wrap_model( - self.cfg.get('model_wrapper_cfg'), self.model) + self.model = self.wrap_model(self.cfg.get('model_wrapper_cfg'), + self.model) # get model name from the model class if hasattr(self.model, 'module'): @@ -713,10 +715,9 @@ def set_randomness(self, more details. """ self._deterministic = deterministic - self._seed = set_random_seed( - seed=seed, - deterministic=deterministic, - diff_rank_seed=diff_rank_seed) + self._seed = set_random_seed(seed=seed, + deterministic=deterministic, + diff_rank_seed=diff_rank_seed) def build_logger(self, log_level: Union[int, str] = 'INFO', @@ -787,10 +788,9 @@ def build_visualizer( Visualizer: A Visualizer object build from ``visualizer``. """ if visualizer is None: - visualizer = dict( - name=self._experiment_name, - vis_backends=[dict(type='LocalVisBackend')], - save_dir=self._log_dir) + visualizer = dict(name=self._experiment_name, + vis_backends=[dict(type='LocalVisBackend')], + save_dir=self._log_dir) return Visualizer.get_instance(**visualizer) if isinstance(visualizer, Visualizer): @@ -902,16 +902,28 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_cfg.get('type')) # type: ignore + + model_wrapper_type = model_wrapper_cfg.get('type') + if isinstance(model_wrapper_type, str): + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_type) # type: ignore + elif inspect.isclass(model_wrapper_type): + pass + else: + raise KeyError( + f'{model_wrapper_type} is not in the ' + 'registry. Please check whether the value of ' + f'`{model_wrapper_type}` is correct or it was registered ' + 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 + ) default_args: dict = dict() if issubclass( model_wrapper_type, # type: ignore DistributedDataParallel): default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])] default_args['module'] = model - model = MODEL_WRAPPERS.build( - model_wrapper_cfg, default_args=default_args) + model = MODEL_WRAPPERS.build(model_wrapper_cfg, + default_args=default_args) return model def _init_model_weights(self) -> None: @@ -1176,11 +1188,11 @@ def _build_param_scheduler( 'Use the max epochs/iters of train loop as default.') param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, - epoch_length=len(self.train_dataloader)))) + PARAM_SCHEDULERS.build(_scheduler, + default_args=dict( + optimizer=optim_wrapper, + epoch_length=len( + self.train_dataloader)))) else: raise TypeError( 'scheduler should be a _ParamScheduler object or dict, ' @@ -1378,18 +1390,17 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], num_batch_per_epoch = dataloader_cfg.pop('num_batch_per_epoch', None) if num_batch_per_epoch is not None: world_size = get_world_size() - num_samples = ( - num_batch_per_epoch * _get_batch_size(dataloader_cfg) * - world_size) + num_samples = (num_batch_per_epoch * + _get_batch_size(dataloader_cfg) * world_size) dataset = _SlicedDataset(dataset, num_samples) # build sampler sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build( - sampler_cfg, - default_args=dict(dataset=dataset, seed=sampler_seed)) + sampler = DATA_SAMPLERS.build(sampler_cfg, + default_args=dict(dataset=dataset, + seed=sampler_seed)) else: # fallback to raise error in dataloader # if `sampler_cfg` is not a valid type @@ -1402,9 +1413,8 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], elif isinstance(batch_sampler_cfg, dict): batch_sampler = DATA_SAMPLERS.build( batch_sampler_cfg, - default_args=dict( - sampler=sampler, - batch_size=dataloader_cfg.pop('batch_size'))) + default_args=dict(sampler=sampler, + batch_size=dataloader_cfg.pop('batch_size'))) else: # fallback to raise error in dataloader # if `batch_sampler_cfg` is not a valid type @@ -1517,18 +1527,20 @@ def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: 'Only one of `type` or `by_epoch` can exist in `loop_cfg`.') if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, dataloader=self._train_dataloader)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._train_dataloader)) else: by_epoch = loop_cfg.pop('by_epoch') if by_epoch: - loop = EpochBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = EpochBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) else: - loop = IterBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = IterBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) return loop # type: ignore def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: @@ -1559,18 +1571,16 @@ def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator)) else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator) # type: ignore + loop = ValLoop(**loop_cfg, + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator) # type: ignore return loop # type: ignore @@ -1601,18 +1611,16 @@ def build_test_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) # type: ignore if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator)) else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator) # type: ignore + loop = TestLoop(**loop_cfg, + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator) # type: ignore return loop # type: ignore @@ -1838,7 +1846,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None: try: getattr(hook, fn_name)(self, **kwargs) except TypeError as e: - raise TypeError(f'{e} in {hook}') from None + raise TypeError(f'{e} in {hook}') from e def register_hook( self, @@ -2016,8 +2024,8 @@ def resume(self, device = get_device() checkpoint = self.load_checkpoint(filename, map_location=device) else: - checkpoint = self.load_checkpoint( - filename, map_location=map_location) + checkpoint = self.load_checkpoint(filename, + map_location=map_location) self.train_loop._epoch = checkpoint['meta']['epoch'] self.train_loop._iter = checkpoint['meta']['iter'] @@ -2025,8 +2033,8 @@ def resume(self, # check whether the number of GPU used for current experiment # is consistent with resuming from checkpoint if 'config' in checkpoint['meta']: - config = mmengine.Config.fromstring( - checkpoint['meta']['config'], file_format='.py') + config = mmengine.Config.fromstring(checkpoint['meta']['config'], + file_format='.py') previous_gpu_ids = config.get('gpu_ids', None) if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0 and len(previous_gpu_ids) != self._world_size): @@ -2134,8 +2142,10 @@ def load_checkpoint(self, else: model = self.model - checkpoint = _load_checkpoint_to_model( - model, checkpoint, strict, revise_keys=revise_keys) + checkpoint = _load_checkpoint_to_model(model, + checkpoint, + strict, + revise_keys=revise_keys) self._has_loaded = True @@ -2211,12 +2221,11 @@ def save_checkpoint( filepath = join_path( # type: ignore out_dir, filename, backend_args=backend_args) - meta.update( - cfg=self.cfg.pretty_text, - seed=self.seed, - experiment_name=self.experiment_name, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine_version=mmengine.__version__ + get_git_hash()) + meta.update(cfg=self.cfg.pretty_text, + seed=self.seed, + experiment_name=self.experiment_name, + time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), + mmengine_version=mmengine.__version__ + get_git_hash()) if hasattr(self.train_dataloader.dataset, 'metainfo'): meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) @@ -2268,11 +2277,10 @@ def save_checkpoint( checkpoint['param_schedulers'].append(state_dict) self.call_hook('before_save_checkpoint', checkpoint=checkpoint) - save_checkpoint( - checkpoint, - filepath, - file_client_args=file_client_args, - backend_args=backend_args) + save_checkpoint(checkpoint, + filepath, + file_client_args=file_client_args, + backend_args=backend_args) @master_only def dump_config(self) -> None: diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py index 8ac5a3d27d..da27a4b16e 100644 --- a/mmengine/structures/base_data_element.py +++ b/mmengine/structures/base_data_element.py @@ -395,8 +395,10 @@ def __setattr__(self, name: str, value: Any): raise AttributeError(f'{name} has been used as a ' 'private attribute, which is immutable.') else: - self.set_field( - name=name, value=value, field_type='data', dtype=None) + self.set_field(name=name, + value=value, + field_type='data', + dtype=None) def __delattr__(self, item: str): """Delete the item in dataelement. diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py index 8633b86037..e841a4d73a 100644 --- a/mmengine/structures/instance_data.py +++ b/mmengine/structures/instance_data.py @@ -7,6 +7,7 @@ import torch from mmengine.device import get_device + from .base_data_element import BaseDataElement BoolTypeTensor: Union[Any] diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py index 14c7a97ba7..549fbe64ef 100644 --- a/mmengine/testing/compare.py +++ b/mmengine/testing/compare.py @@ -42,18 +42,20 @@ def assert_allclose( """ if 'parrots' not in TORCH_VERSION and \ digit_version(TORCH_VERSION) >= digit_version('1.6'): - _assert_allclose( - actual, - expected, - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - msg=msg) + _assert_allclose(actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + msg=msg) else: # torch.testing.assert_allclose has no ``msg`` argument # when PyTorch < 1.6 - _assert_allclose( - actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) + _assert_allclose(actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan) def check_python_script(cmd): @@ -180,8 +182,8 @@ def assert_params_all_zeros(module) -> bool: if hasattr(module, 'bias') and module.bias is not None: bias_data = module.bias.data - is_bias_zero = bias_data.allclose( - bias_data.new_zeros(bias_data.size())) + is_bias_zero = bias_data.allclose(bias_data.new_zeros( + bias_data.size())) else: is_bias_zero = True diff --git a/mmengine/testing/runner_test_case.py b/mmengine/testing/runner_test_case.py index f64594acef..c1dea6bdb4 100644 --- a/mmengine/testing/runner_test_case.py +++ b/mmengine/testing/runner_test_case.py @@ -91,12 +91,11 @@ class RunnerTestCase(TestCase): 3. Provides `build_runner` method to build runner easily. 4. Clean the global variable used by the runner. """ - dist_cfg = dict( - MASTER_ADDR='127.0.0.1', - MASTER_PORT=29600, - RANK='0', - WORLD_SIZE='1', - LOCAL_RANK='0') + dist_cfg = dict(MASTER_ADDR='127.0.0.1', + MASTER_PORT=29600, + RANK='0', + WORLD_SIZE='1', + LOCAL_RANK='0') def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() @@ -108,22 +107,22 @@ def setUp(self) -> None: epoch_based_cfg = dict( work_dir=self.temp_dir.name, model=dict(type='ToyModel'), - train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), + train_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), val_evaluator=[dict(type='ToyMetric')], - test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), + test_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), test_evaluator=[dict(type='ToyMetric')], optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), @@ -145,10 +144,12 @@ def setUp(self) -> None: self.iter_based_cfg.log_processor = dict(by_epoch=False) self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) - self.iter_based_cfg.default_hooks = dict( - logger=dict(type='LoggerHook', interval=1), - checkpoint=dict( - type='CheckpointHook', interval=12, by_epoch=False)) + self.iter_based_cfg.default_hooks = dict(logger=dict(type='LoggerHook', + interval=1), + checkpoint=dict( + type='CheckpointHook', + interval=12, + by_epoch=False)) def tearDown(self): # `FileHandler` should be closed in Windows, otherwise we cannot diff --git a/mmengine/utils/dl_utils/collect_env.py b/mmengine/utils/dl_utils/collect_env.py index 0ee99abad2..83882425c8 100644 --- a/mmengine/utils/dl_utils/collect_env.py +++ b/mmengine/utils/dl_utils/collect_env.py @@ -11,6 +11,7 @@ import mmengine from mmengine.device import is_cuda_available, is_musa_available + from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch @@ -77,8 +78,8 @@ def collect_env(): if CUDA_HOME == '/opt/rocm': try: nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc') - nvcc = subprocess.check_output( - f'"{nvcc}" --version', shell=True) + nvcc = subprocess.check_output(f'"{nvcc}" --version', + shell=True) nvcc = nvcc.decode('utf-8').strip() release = nvcc.rfind('HIP version:') build = nvcc.rfind('') @@ -134,8 +135,9 @@ def collect_env(): from distutils.ccompiler import new_compiler ccompiler = new_compiler() ccompiler.initialize() - cc = subprocess.check_output( - f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True) + cc = subprocess.check_output(f'{ccompiler.cc}', + stderr=subprocess.STDOUT, + shell=True) encoding = os.device_encoding( sys.stdout.fileno()) or locale.getpreferredencoding() env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip() diff --git a/mmengine/utils/dl_utils/hub.py b/mmengine/utils/dl_utils/hub.py index 7f7f1a087d..41deaa0b1a 100644 --- a/mmengine/utils/dl_utils/hub.py +++ b/mmengine/utils/dl_utils/hub.py @@ -107,8 +107,10 @@ def load_url(url, if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None - download_url_to_file( - url, cached_file, hash_prefix, progress=progress) + download_url_to_file(url, + cached_file, + hash_prefix, + progress=progress) if _is_legacy_zip_format(cached_file): return _legacy_zip_load(cached_file, model_dir, map_location) diff --git a/mmengine/utils/dl_utils/torch_ops.py b/mmengine/utils/dl_utils/torch_ops.py index 2550ae6986..85dc3100d2 100644 --- a/mmengine/utils/dl_utils/torch_ops.py +++ b/mmengine/utils/dl_utils/torch_ops.py @@ -4,9 +4,9 @@ from ..version_utils import digit_version from .parrots_wrapper import TORCH_VERSION -_torch_version_meshgrid_indexing = ( - 'parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) +_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) + >= digit_version('1.10.0a0')) def torch_meshgrid(*tensors): diff --git a/mmengine/utils/dl_utils/visualize.py b/mmengine/utils/dl_utils/visualize.py index f3361e1d50..6f7b05e095 100644 --- a/mmengine/utils/dl_utils/visualize.py +++ b/mmengine/utils/dl_utils/visualize.py @@ -49,11 +49,11 @@ def fake_run(cfg): cfg.pop('test_cfg') extra_cfg = dict( model=dict(type='ToyModel'), - visualizer=dict( - type='Visualizer', - vis_backends=[ - dict(type='TensorboardVisBackend', save_dir='temp_dir') - ]), + visualizer=dict(type='Visualizer', + vis_backends=[ + dict(type='TensorboardVisBackend', + save_dir='temp_dir') + ]), ) cfg.merge_from_dict(extra_cfg) # build the runner from config diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 1816f47f07..3d78194a5d 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -1,6 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import subprocess +from typing import Any + +# Import distribution function with fallback for older Python versions +try: + from importlib.metadata import PackageNotFoundError, distribution +except ImportError: + from importlib_metadata import ( # type: ignore[import-untyped, no-redef, import-not-found] # noqa: E501 + PackageNotFoundError, distribution) def is_installed(package: str) -> bool: @@ -9,21 +17,21 @@ def is_installed(package: str) -> bool: Args: package (str): Name of package to be checked. """ - # When executing `import mmengine.runner`, - # pkg_resources will be imported and it takes too much time. - # Therefore, import it in function scope to save time. + # Use importlib.metadata instead of deprecated pkg_resources + # importlib.metadata is available in Python 3.8+ + # For Python 3.7, importlib_metadata backport can be used import importlib.util - import pkg_resources - from pkg_resources import get_distribution + import pkg_resources # type: ignore # refresh the pkg_resources # more datails at https://github.com/pypa/setuptools/issues/373 importlib.reload(pkg_resources) try: - get_distribution(package) + distribution(package) return True - except pkg_resources.DistributionNotFound: + except Exception: + # If distribution not found, check if module can be imported spec = importlib.util.find_spec(package) if spec is None: return False @@ -45,15 +53,31 @@ def get_installed_path(package: str) -> str: """ import importlib.util - from pkg_resources import DistributionNotFound, get_distribution - # if the package name is not the same as module name, module name should be # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name try: - pkg = get_distribution(package) - except DistributionNotFound as e: + dist = distribution(package) + # In importlib.metadata, we use dist.locate_file() or files + if hasattr(dist, 'locate_file'): + # Python 3.9+ + # locate_file returns PathLike, need to access parent + locate_result: Any = dist.locate_file('') + location = str(locate_result.parent) + elif hasattr(dist, '_path'): + # Python 3.8 - _path is a pathlib.Path object + # We know _path exists because we checked with hasattr + dist_any: Any = dist + location = str(dist_any._path.parent) # type: ignore[attr-defined] + else: + # Fallback: try to find via importlib + spec = importlib.util.find_spec(package) + if spec is not None and spec.origin is not None: + return osp.dirname(spec.origin) + raise RuntimeError( + f'Cannot determine installation path for {package}') + except PackageNotFoundError as e: # if the package is not installed, package path set in PYTHONPATH # can be detected by `find_spec` spec = importlib.util.find_spec(package) @@ -69,23 +93,26 @@ def get_installed_path(package: str) -> str: else: raise e - possible_path = osp.join(pkg.location, package) # type: ignore + possible_path = osp.join(location, package) if osp.exists(possible_path): return possible_path else: - return osp.join(pkg.location, package2module(package)) # type: ignore + return osp.join(location, package2module(package)) -def package2module(package: str): +def package2module(package: str) -> str: """Infer module name from package. Args: package (str): Package to infer module name. """ - from pkg_resources import get_distribution - pkg = get_distribution(package) - if pkg.has_metadata('top_level.txt'): - module_name = pkg.get_metadata('top_level.txt').split('\n')[0] + dist = distribution(package) + + # In importlib.metadata, + # top-level modules are in dist.read_text('top_level.txt') + top_level_text = dist.read_text('top_level.txt') + if top_level_text: + module_name = top_level_text.split('\n')[0] return module_name else: raise ValueError(f'can not infer the module name of {package}') diff --git a/mmengine/utils/progressbar_rich.py b/mmengine/utils/progressbar_rich.py index f8e04d8041..44162e2160 100644 --- a/mmengine/utils/progressbar_rich.py +++ b/mmengine/utils/progressbar_rich.py @@ -121,8 +121,9 @@ def track_progress_rich(func: Callable, ) worker = _Worker(func) - task_id = prog_bar.add_task( - total=task_num, color=color, description=description) + task_id = prog_bar.add_task(total=task_num, + color=color, + description=description) tasks = _tasks_with_index(tasks) # Use single process when nproc is 1, else use multiprocess. diff --git a/mmengine/utils/version_utils.py b/mmengine/utils/version_utils.py index 620180547a..2e02ecddd4 100644 --- a/mmengine/utils/version_utils.py +++ b/mmengine/utils/version_utils.py @@ -58,9 +58,10 @@ def _minimal_ext_cmd(cmd): env['LANGUAGE'] = 'C' env['LANG'] = 'C' env['LC_ALL'] = 'C' - out, err = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - env=env).communicate() + out, err = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env).communicate() return out diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index b752ec85a7..3a350daf07 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -437,8 +437,8 @@ def add_config(self, config: Config, **kwargs) -> None: """ assert isinstance(self._init_kwargs, dict) allow_val_change = self._init_kwargs.get('allow_val_change', False) - self._wandb.config.update( - config.to_dict(), allow_val_change=allow_val_change) + self._wandb.config.update(config.to_dict(), + allow_val_change=allow_val_change) self._wandb.run.log_code(name=self._log_code_name) @force_init_env @@ -604,7 +604,8 @@ def add_scalar(self, (int, float, torch.Tensor, np.ndarray, np.number)): self._tensorboard.add_scalar(name, value, step) else: - warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, ' + warnings.warn(f'Got type {type(value)} with name {name}, ' + 'but numpy array, torch tensor, ' f'int or float are expected. skip it!') @force_init_env @@ -938,8 +939,10 @@ def add_image(self, should be RGB. step (int): Global step value to record. Defaults to 0. """ - self._logger.report_image( - title=name, series=name, iteration=step, image=image) + self._logger.report_image(title=name, + series=name, + iteration=step, + image=image) @force_init_env def add_scalar(self, @@ -954,8 +957,10 @@ def add_scalar(self, value (int, float, torch.Tensor, np.ndarray): Value to save. step (int): Global step value to record. Defaults to 0. """ - self._logger.report_scalar( - title=name, series=name, value=value, iteration=step) + self._logger.report_scalar(title=name, + series=name, + value=value, + iteration=step) @force_init_env def add_scalars(self, @@ -975,8 +980,10 @@ def add_scalars(self, assert 'step' not in scalar_dict, 'Please set it directly ' \ 'through the step parameter' for key, value in scalar_dict.items(): - self._logger.report_scalar( - title=key, series=key, value=value, iteration=step) + self._logger.report_scalar(title=key, + series=key, + value=value, + iteration=step) def close(self) -> None: """Close the clearml.""" @@ -1092,8 +1099,9 @@ def add_image(self, # values in the array need to be in the [0, 1] range img = image.astype(np.float32) / 255.0 - self._neptune['images'].append( - File.as_image(img), name=name, step=step) + self._neptune['images'].append(File.as_image(img), + name=name, + step=step) @force_init_env def add_scalar(self, diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index 6979395aca..e1525a86e3 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -271,9 +271,8 @@ def show(self, # will be updated with `win_name`. cv2.namedWindow(winname=f'{id(self)}') cv2.setWindowTitle(f'{id(self)}', win_name) - cv2.imshow( - str(id(self)), - self.get_image() if drawn_img is None else drawn_img) + cv2.imshow(str(id(self)), + self.get_image() if drawn_img is None else drawn_img) cv2.waitKey(int(np.ceil(wait_time * 1000))) else: raise ValueError('backend should be "matplotlib" or "cv2", ' @@ -300,10 +299,9 @@ def set_image(self, image: np.ndarray) -> None: # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) self.ax_save.cla() self.ax_save.axis(False) - self.ax_save.imshow( - image, - extent=(0, self.width, self.height, 0), - interpolation='none') + self.ax_save.imshow(image, + extent=(0, self.width, self.height, 0), + interpolation='none') @master_only def get_image(self) -> np.ndarray: @@ -344,14 +342,16 @@ def _init_manager(self, win_name: str) -> None: from matplotlib.figure import Figure from matplotlib.pyplot import new_figure_manager if getattr(self, 'manager', None) is None: - self.manager = new_figure_manager( - num=1, FigureClass=Figure, **self.fig_show_cfg) + self.manager = new_figure_manager(num=1, + FigureClass=Figure, + **self.fig_show_cfg) try: self.manager.set_window_title(win_name) except Exception: - self.manager = new_figure_manager( - num=1, FigureClass=Figure, **self.fig_show_cfg) + self.manager = new_figure_manager(num=1, + FigureClass=Figure, + **self.fig_show_cfg) self.manager.set_window_title(win_name) @master_only @@ -413,8 +413,11 @@ def draw_points(self, 'The shape of `positions` should be (N, 2), ' f'but got {positions.shape}') colors = color_val_matplotlib(colors) # type: ignore - self.ax_save.scatter( - positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) + self.ax_save.scatter(positions[:, 0], + positions[:, 1], + c=colors, + s=sizes, + marker=marker) return self @master_only @@ -616,11 +619,10 @@ def draw_lines( warnings.warn( 'Warning: The line is out of bounds,' ' the drawn line may not be in the image', UserWarning) - line_collect = LineCollection( - lines.tolist(), - colors=colors, - linestyles=line_styles, - linewidths=line_widths) + line_collect = LineCollection(lines.tolist(), + colors=colors, + linestyles=line_styles, + linewidths=line_widths) self.ax_save.add_collection(line_collect) return self @@ -676,10 +678,9 @@ def draw_circles( assert center.shape == (radius.shape[0], 2), ( 'The shape of `center` should be (radius.shape, 2), ' f'but got {center.shape}') - if not (self._is_posion_valid(center - - np.tile(radius.reshape((-1, 1)), (1, 2))) - and self._is_posion_valid( - center + np.tile(radius.reshape((-1, 1)), (1, 2)))): + if not (self._is_posion_valid(center - np.tile(radius.reshape( + (-1, 1)), (1, 2))) and self._is_posion_valid( + center + np.tile(radius.reshape((-1, 1)), (1, 2)))): warnings.warn( 'Warning: The circle is out of bounds,' ' the drawn circle may not be in the image', UserWarning) @@ -698,13 +699,12 @@ def draw_circles( min(max(linewidth, 1), self._default_font_size / 4) for linewidth in line_widths ] - p = PatchCollection( - circles, - alpha=alpha, - facecolors=face_colors, - edgecolors=edge_colors, - linewidths=line_widths, - linestyles=line_styles) + p = PatchCollection(circles, + alpha=alpha, + facecolors=face_colors, + edgecolors=edge_colors, + linewidths=line_widths, + linestyles=line_styles) self.ax_save.add_collection(p) return self @@ -754,8 +754,9 @@ def draw_bboxes( assert bboxes.shape[-1] == 4, ( f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}') - assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= - bboxes[:, 3]).all() + assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] + <= bboxes[:, + 3]).all() if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): warnings.warn( 'Warning: The bbox is out of bounds,' @@ -765,13 +766,12 @@ def draw_bboxes( bboxes[:, 2], bboxes[:, 3], bboxes[:, 0], bboxes[:, 3]), axis=-1).reshape(-1, 4, 2) poly = [p for p in poly] - return self.draw_polygons( - poly, - alpha=alpha, - edge_colors=edge_colors, - line_styles=line_styles, - line_widths=line_widths, - face_colors=face_colors) + return self.draw_polygons(poly, + alpha=alpha, + edge_colors=edge_colors, + line_styles=line_styles, + line_widths=line_widths, + face_colors=face_colors) @master_only def draw_polygons( @@ -837,13 +837,12 @@ def draw_polygons( min(max(linewidth, 1), self._default_font_size / 4) for linewidth in line_widths ] - polygon_collection = PolyCollection( - polygons, - alpha=alpha, - facecolor=face_colors, - linestyles=line_styles, - edgecolors=edge_colors, - linewidths=line_widths) + polygon_collection = PolyCollection(polygons, + alpha=alpha, + facecolor=face_colors, + linestyles=line_styles, + edgecolors=edge_colors, + linewidths=line_widths) self.ax_save.add_collection(polygon_collection) return self @@ -903,14 +902,14 @@ def draw_binary_masks( rgb = np.zeros_like(img) rgb[...] = color rgb = cv2.bitwise_and(rgb, rgb, mask=binary_mask) - img_complement = cv2.bitwise_and( - img, img, mask=binary_mask_complement) + img_complement = cv2.bitwise_and(img, + img, + mask=binary_mask_complement) rgb = rgb + img_complement img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0) - self.ax_save.imshow( - img, - extent=(0, self.width, self.height, 0), - interpolation='nearest') + self.ax_save.imshow(img, + extent=(0, self.width, self.height, 0), + interpolation='nearest') return self @staticmethod @@ -991,18 +990,16 @@ def draw_featmap(featmap: torch.Tensor, f'the feature map will be interpolated. ' f'This may cause mismatch problems !') if resize_shape is None: - featmap = F.interpolate( - featmap[None], - overlaid_image.shape[:2], - mode='bilinear', - align_corners=False)[0] + featmap = F.interpolate(featmap[None], + overlaid_image.shape[:2], + mode='bilinear', + align_corners=False)[0] if resize_shape is not None: - featmap = F.interpolate( - featmap[None], - resize_shape, - mode='bilinear', - align_corners=False)[0] + featmap = F.interpolate(featmap[None], + resize_shape, + mode='bilinear', + align_corners=False)[0] if overlaid_image is not None: overlaid_image = cv2.resize(overlaid_image, resize_shape[::-1]) @@ -1044,8 +1041,12 @@ def draw_featmap(featmap: torch.Tensor, fig = plt.figure(frameon=False) # Set the window layout - fig.subplots_adjust( - left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) + fig.subplots_adjust(left=0, + right=1, + bottom=0, + top=1, + wspace=0, + hspace=0) dpi = fig.get_dpi() fig.set_size_inches((width * col + 1e-2) / dpi, (height * row + 1e-2) / dpi) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..5895ecbc3e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,97 @@ +[build-system] +requires = ["setuptools>=72", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mmengine" +# Version is dynamically set from mmengine/version.py +dynamic = ["version"] +description = "Engine of OpenMMLab projects" +readme = "README.md" +license = { text = "Apache License 2.0" } +authors = [ + { name = "MMEngine Authors", email = "openmmlab@gmail.com" }, + { name = "MGAM", email = "312065559@qq.com" } +] +requires-python = ">=3.7" +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Utilities", +] +keywords = ["OpenMMLab", "Engine"] +# Core dependencies from requirements/runtime.txt +dependencies = [ + "addict", + "matplotlib", + "numpy", + "pyyaml", + "regex; sys_platform=='win32'", + "rich", + "termcolor", + "yapf", +] + +# Optional dependency groups +[project.optional-dependencies] +# All dependencies (runtime + tests) +all = [ + # Runtime dependencies are already included in base + # Test dependencies from requirements/tests.txt + # Note: aim excluded due to dependency issues (aimrocks not available) + # "aim<=3.17.5; sys_platform!='win32'", + "bitsandbytes", + "clearml", + "coverage", + "dadaptation", + "dvclive", + "lion-pytorch", + "lmdb", + "mlflow", + "parameterized", + "pydantic==1.10.9", + "pytest", + "transformers", +] +# Test dependencies only +tests = [ + "bitsandbytes", + "clearml", + "coverage", + "dadaptation", + "dvclive", + "lion-pytorch", + "lmdb", + "mlflow", + "parameterized", + "pydantic==1.10.9", + "pytest", + "transformers", +] + +[project.urls] +Homepage = "https://github.com/open-mmlab/mmengine" +Repository = "https://github.com/open-mmlab/mmengine" +Documentation = "https://mmengine.readthedocs.io" + +# Setuptools configuration +[tool.setuptools] +# Include package data files (similar to include_package_data=True) +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] +include = ["mmengine*"] +exclude = ["tests*", "docs*", "examples*"] + +# Dynamic version from mmengine/version.py +[tool.setuptools.dynamic] +version = {attr = "mmengine.version.__version__"} diff --git a/setup.cfg b/setup.cfg.bak similarity index 100% rename from setup.cfg rename to setup.cfg.bak diff --git a/setup.py b/setup.py deleted file mode 100644 index 5b1f7fc803..0000000000 --- a/setup.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -import re -from setuptools import find_packages, setup # type: ignore - -from pkg_resources import DistributionNotFound, get_distribution - - -def readme(): - with open('README.md', encoding='utf-8') as f: - content = f.read() - return content - - -version_file = 'mmengine/version.py' - - -def choose_requirement(primary, secondary): - """If some version of primary requirement installed, return primary, else - return secondary.""" - try: - name = re.split(r'[!<>=]', primary)[0] - get_distribution(name) - except DistributionNotFound: - return secondary - - return str(primary) - - -def get_version(): - with open(version_file) as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] - - -def parse_requirements(fname='requirements/runtime.txt', with_version=True): - """Parse the package dependencies listed in a requirements file but strips - specific versioning information. - - Args: - fname (str): path to requirements file - with_version (bool, default=False): if True include version specs - - Returns: - List[str]: list of requirements items - - CommandLine: - python -c "import setup; print(setup.parse_requirements())" - """ - import re - import sys - from os.path import exists - require_fpath = fname - - def parse_line(line): - """Parse information from a line in a requirements text file.""" - if line.startswith('-r '): - # Allow specifying requirements in other files - target = line.split(' ')[1] - for info in parse_require_file(target): - yield info - else: - info = {'line': line} - if line.startswith('-e '): - info['package'] = line.split('#egg=')[1] - else: - # Remove versioning from the package - pat = '(' + '|'.join(['>=', '==', '>']) + ')' - parts = re.split(pat, line, maxsplit=1) - parts = [p.strip() for p in parts] - - info['package'] = parts[0] - if len(parts) > 1: - op, rest = parts[1:] - if ';' in rest: - # Handle platform specific dependencies - # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies - version, platform_deps = map(str.strip, - rest.split(';')) - info['platform_deps'] = platform_deps - else: - version = rest # NOQA - info['version'] = (op, version) - yield info - - def parse_require_file(fpath): - with open(fpath) as f: - for line in f.readlines(): - line = line.strip() - if line and not line.startswith('#'): - yield from parse_line(line) - - def gen_packages_items(): - if exists(require_fpath): - for info in parse_require_file(require_fpath): - parts = [info['package']] - if with_version and 'version' in info: - parts.extend(info['version']) - if not sys.version.startswith('3.4'): - # apparently package_deps are broken in 3.4 - platform_deps = info.get('platform_deps') - if platform_deps is not None: - parts.append(';' + platform_deps) - item = ''.join(parts) - yield item - - packages = list(gen_packages_items()) - return packages - - -if int(os.getenv('MMENGINE_LITE', '0')) == 1: - install_requires = parse_requirements('requirements/runtime_lite.txt') -else: - install_requires = parse_requirements() - -setup( - name='mmengine' - if os.getenv('MMENGINE_LITE', '0') == '0' else 'mmengine-lite', - version=get_version(), - description='Engine of OpenMMLab projects', - long_description=readme(), - long_description_content_type='text/markdown', - url='https://github.com/open-mmlab/mmengine', - author='MMEngine Authors', - author_email='openmmlab@gmail.com', - packages=find_packages(), - include_package_data=True, - classifiers=[ - 'Development Status :: 4 - Beta', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Utilities', - ], - python_requires='>=3.7', - install_requires=install_requires, - extras_require={ - 'all': parse_requirements('requirements.txt'), - 'tests': parse_requirements('requirements/tests.txt'), - }, -) diff --git a/tests/data/config/lazy_module_config/test_ast_transform.py b/tests/data/config/lazy_module_config/test_ast_transform.py index a8803dde24..6f0ada1736 100644 --- a/tests/data/config/lazy_module_config/test_ast_transform.py +++ b/tests/data/config/lazy_module_config/test_ast_transform.py @@ -3,7 +3,7 @@ from importlib.util import find_spec as find_module import numpy -import numpy.compat +import numpy.fft import numpy.linalg as linalg from mmengine.config import Config diff --git a/tests/test_analysis/test_flop_count.py b/tests/test_analysis/test_flop_count.py index 20749a0bab..99d096cbf8 100644 --- a/tests/test_analysis/test_flop_count.py +++ b/tests/test_analysis/test_flop_count.py @@ -243,8 +243,8 @@ def addmm_dummy_flop_jit( custom_ops2: Dict[str, Handle] = { f'aten::{self.lin_op}': addmm_dummy_flop_jit } - flop_dict2, _ = flop_count( - custom_net, (x, ), supported_ops=custom_ops2) + flop_dict2, _ = flop_count(custom_net, (x, ), + supported_ops=custom_ops2) flop = 400000 / 1e9 self.assertEqual( flop_dict2[self.lin_op], @@ -365,9 +365,9 @@ def _test_conv( else: spatial_size = ( (spatial_dim + 2 * padding) - kernel_size) // stride + 1 - gt_flop = ( - batch_size * input_dim * output_dim * (kernel_size**conv_dim) * - (spatial_size**conv_dim) / group_size / 1e9) + gt_flop = (batch_size * input_dim * output_dim * + (kernel_size**conv_dim) * (spatial_size**conv_dim) / + group_size / 1e9) gt_dict = defaultdict(float) gt_dict['conv'] = gt_flop self.assertDictEqual( @@ -849,8 +849,8 @@ def _count_function(self, func, inputs, name) -> Tuple[Any, Any]: def f(*args): return func(*inputs) - graph = torch.jit.trace( - f, tuple(tensor_inputs), check_trace=False).graph + graph = torch.jit.trace(f, tuple(tensor_inputs), + check_trace=False).graph nodes = [k for k in graph.nodes() if k.kind() == name] self.assertEqual(len(nodes), 1) node = nodes[0] diff --git a/tests/test_analysis/test_jit_analysis.py b/tests/test_analysis/test_jit_analysis.py index be10309d0f..b66dcb7f85 100644 --- a/tests/test_analysis/test_jit_analysis.py +++ b/tests/test_analysis/test_jit_analysis.py @@ -44,8 +44,8 @@ def __init__(self, lin_op: str = 'addmm') -> None: fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) - spatial_pos = (conv_input_size[1] + 2 * padding) - 2 * ( - kernel_size // 2) + spatial_pos = (conv_input_size[1] + + 2 * padding) - 2 * (kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out conv_flops = Counter({'conv': conv_flops_}) model_flops = conv_flops + fc_flops @@ -95,8 +95,8 @@ def __init__(self, lin_op: str = 'addmm') -> None: fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) - spatial_pos = (self.input_size[1] + 2 * padding) - 2 * ( - kernel_size // 2) + spatial_pos = (self.input_size[1] + + 2 * padding) - 2 * (kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out conv_flops = Counter({'conv': conv_flops_}) @@ -428,8 +428,8 @@ def test_non_forward_func_call(self) -> None: model = NonForwardNet() inputs = (torch.randn((1, 10)), ) - analyzer = FlopAnalyzer( - model=model, inputs=inputs).ancestor_mode('caller') + analyzer = FlopAnalyzer(model=model, + inputs=inputs).ancestor_mode('caller') inner_fc_count = model.submod.fc_flops total_count = model.fc_flops + inner_fc_count @@ -441,8 +441,8 @@ def test_non_forward_func_call(self) -> None: # The mod not directly called is registered as such self.assertEqual(analyzer.uncalled_modules(), {'submod'}) - analyzer = FlopAnalyzer( - model=model, inputs=inputs).ancestor_mode('owner') + analyzer = FlopAnalyzer(model=model, + inputs=inputs).ancestor_mode('owner') self.assertEqual(analyzer.total('submod'), inner_fc_count) self.assertEqual(analyzer.total('submod.fc'), inner_fc_count) self.assertEqual(analyzer.total(''), total_count) @@ -455,9 +455,9 @@ def test_shared_module(self) -> None: model = SharedModuleNet() inputs = (torch.randn((1, *model.input_size)), ) - analyzer = ( - FlopAnalyzer(model=model, inputs=inputs).unsupported_ops_warnings( - enabled=False).ancestor_mode('caller')) + analyzer = (FlopAnalyzer(model=model, + inputs=inputs).unsupported_ops_warnings( + enabled=False).ancestor_mode('caller')) # The names `submod2.submod` and `multiname2` are not included, # since only the first name of a module is made the canonical one. @@ -487,14 +487,14 @@ def test_shared_module(self) -> None: ) # Test getting canonical name - self.assertEqual( - analyzer.canonical_module_name('multiname2'), 'multiname1') - self.assertEqual( - analyzer.canonical_module_name('multiname1'), 'multiname1') - self.assertEqual( - analyzer.canonical_module_name('submod2.submod'), 'submod1.submod') - self.assertEqual( - analyzer.canonical_module_name('submod1.submod'), 'submod1.submod') + self.assertEqual(analyzer.canonical_module_name('multiname2'), + 'multiname1') + self.assertEqual(analyzer.canonical_module_name('multiname1'), + 'multiname1') + self.assertEqual(analyzer.canonical_module_name('submod2.submod'), + 'submod1.submod') + self.assertEqual(analyzer.canonical_module_name('submod1.submod'), + 'submod1.submod') # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set()) @@ -561,13 +561,12 @@ def test_unsupported_ops(self) -> None: model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) - analyzer = JitModelAnalysis( - model=model, inputs=inputs).set_op_handle( - 'aten::addmm', - addmm_flop_jit, - 'aten::linear', - linear_flop_jit, - ) + analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle( + 'aten::addmm', + addmm_flop_jit, + 'aten::linear', + linear_flop_jit, + ) analyzer.total() skipped_inner_conv = Counter({'aten::_convolution': 1}) @@ -606,8 +605,8 @@ def test_changing_handles(self) -> None: 'aten::linear': linear_flop_jit, } - analyzer = JitModelAnalysis( - model=model, inputs=inputs).set_op_handle(**op_handles) + analyzer = JitModelAnalysis(model=model, + inputs=inputs).set_op_handle(**op_handles) analyzer.unsupported_ops_warnings(enabled=False) # Request a result once to cache flop counts @@ -634,9 +633,10 @@ def dummy_ops_handle(inputs: List[Any], dummy_flops = {} for name, counts in model.flops.items(): - dummy_flops[name] = Counter( - {op: flop - for op, flop in counts.items() if op != self.lin_op}) + dummy_flops[name] = Counter({ + op: flop + for op, flop in counts.items() if op != self.lin_op + }) dummy_flops[''][dummy_name] = 2 * dummy_out dummy_flops['fc'][dummy_name] = dummy_out dummy_flops['submod'][dummy_name] = dummy_out @@ -657,14 +657,12 @@ def test_copy(self) -> None: model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) - analyzer = ( - JitModelAnalysis(model=model, inputs=inputs).set_op_handle( - 'aten::addmm', - addmm_flop_jit, - 'aten::linear', - linear_flop_jit, - ).unsupported_ops_warnings(enabled=False).tracer_warnings( - mode='none')) + analyzer = (JitModelAnalysis(model=model, inputs=inputs).set_op_handle( + 'aten::addmm', + addmm_flop_jit, + 'aten::linear', + linear_flop_jit, + ).unsupported_ops_warnings(enabled=False).tracer_warnings(mode='none')) repeated_net_flops = model.fc1_num * model.fc1_flops repeated_net_flops += model.fc2_num * model.fc2_flops @@ -699,8 +697,8 @@ def test_copy(self) -> None: new_model = NonForwardNet() bs = 5 new_inputs = (torch.randn((bs, *new_model.input_size)), ) - analyzer_new = analyzer.copy( - new_model=new_model, new_inputs=new_inputs) + analyzer_new = analyzer.copy(new_model=new_model, + new_inputs=new_inputs) non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops diff --git a/tests/test_analysis/test_print_helper.py b/tests/test_analysis/test_print_helper.py index 14366583d5..3abd0a0bd9 100644 --- a/tests/test_analysis/test_print_helper.py +++ b/tests/test_analysis/test_print_helper.py @@ -60,23 +60,24 @@ def test_get_model_complexity_info(): assert complexity_info['flops'] == flops assert complexity_info['params'] == params - complexity_info = get_model_complexity_info( - model=model, input_shape=input_shape1) - flops = FlopAnalyzer( - model=model, inputs=(torch.randn(1, *input_shape1), )).total() + complexity_info = get_model_complexity_info(model=model, + input_shape=input_shape1) + flops = FlopAnalyzer(model=model, + inputs=(torch.randn(1, *input_shape1), )).total() assert complexity_info['flops'] == flops # test a network that accepts two tensors as input model = NetAcceptTwoTensors() - complexity_info = get_model_complexity_info( - model=model, inputs=(input1, input2)) + complexity_info = get_model_complexity_info(model=model, + inputs=(input1, input2)) flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total() params = parameter_count(model=model)[''] assert complexity_info['flops'] == flops assert complexity_info['params'] == params - complexity_info = get_model_complexity_info( - model=model, input_shape=(input_shape1, input_shape2)) + complexity_info = get_model_complexity_info(model=model, + input_shape=(input_shape1, + input_shape2)) inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2)) flops = FlopAnalyzer(model=model, inputs=inputs).total() assert complexity_info['flops'] == flops @@ -88,8 +89,8 @@ def test_get_model_complexity_info(): scalar = torch.tensor([ scalar ]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar - complexity_info = get_model_complexity_info( - model=model, inputs=(input1, scalar)) + complexity_info = get_model_complexity_info(model=model, + inputs=(input1, scalar)) flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total() params = parameter_count(model=model)[''] assert complexity_info['flops'] == flops @@ -104,5 +105,6 @@ def test_get_model_complexity_info(): # when both `inputs` and `input_shape` are specified model = NetAcceptOneTensor() with pytest.raises(ValueError, match='cannot be both set'): - get_model_complexity_info( - model, inputs=input1, input_shape=input_shape1) + get_model_complexity_info(model, + inputs=input1, + input_shape=input_shape1) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index e783431441..8d6c2bef5f 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -40,8 +40,10 @@ def test_init(self, file_format): Config([0, 1]) # test `filename` parameter - cfg_dict = dict( - item1=[1, 2], item2=dict(a=0), item3=True, item4='test') + cfg_dict = dict(item1=[1, 2], + item2=dict(a=0), + item3=True, + item4='test') cfg_file = osp.join( self.data_path, f'config/{file_format}_config/simple_config.{file_format}') @@ -54,9 +56,9 @@ def test_init(self, file_format): self.data_path, f'config/{file_format}_config/test_reserved_key.{file_format}') # reserved keys cannot be set in config - with pytest.raises( - KeyError, match='filename is reserved for config ' - 'file'): + with pytest.raises(KeyError, + match='filename is reserved for config ' + 'file'): Config.fromfile(cfg_file) def test_fromfile(self): @@ -74,8 +76,8 @@ def test_fromfile(self): Config.fromfile(cfg_file, import_custom_modules=False) assert 'TEST_VALUE' not in os.environ sys.modules.pop('test_custom_import_module') - with pytest.raises( - ImportError, match='Failed to import custom modules from'): + with pytest.raises(ImportError, + match='Failed to import custom modules from'): Config.fromfile(cfg_file, import_custom_modules=True) @pytest.mark.parametrize('file_format', ['py', 'json', 'yaml']) @@ -100,8 +102,10 @@ def test_fromstring(self, file_format): Config.fromstring(cfg_str, '.xml') def test_magic_methods(self): - cfg_dict = dict( - item1=[1, 2], item2=dict(a=0), item3=True, item4='test') + cfg_dict = dict(item1=[1, 2], + item2=dict(a=0), + item3=True, + item4='test') filename = 'py_config/simple_config.py' cfg_file = osp.join(self.data_path, 'config', filename) cfg = Config.fromfile(cfg_file) @@ -218,8 +222,9 @@ def test_auto_argparser(self): sys.argv.extend(tmp) def test_dict_to_config_dict(self): - cfg_dict = dict( - a=1, b=dict(c=dict()), d=[dict(e=dict(f=(dict(g=1), [])))]) + cfg_dict = dict(a=1, + b=dict(c=dict()), + d=[dict(e=dict(f=(dict(g=1), [])))]) cfg_dict = Config._dict_to_config_dict(cfg_dict) assert isinstance(cfg_dict, ConfigDict) assert isinstance(cfg_dict.a, int) @@ -316,8 +321,10 @@ def test_repr(self, tmp_path): def test_dict_action(self): parser = argparse.ArgumentParser(description='Train a detector') - parser.add_argument( - '--options', nargs='+', action=DictAction, help='custom options') + parser.add_argument('--options', + nargs='+', + action=DictAction, + help='custom options') # Nested brackets args = parser.parse_args( ['--options', 'item2.a=a,b', 'item2.b=[(a,b), [1,2], false]']) @@ -471,10 +478,9 @@ def test_pre_substitute_base_vars(self, tmp_path): assert cfg_module_dict['item10'].startswith('_item7') def test_substitute_base_vars(self): - cfg = dict( - item4='_item1.12345', - item5=dict(item3='1', item2='_item2_.fswf'), - item0=('_item0_.12ed21wq', 1)) + cfg = dict(item4='_item1.12345', + item5=dict(item3='1', item2='_item2_.fswf'), + item0=('_item0_.12ed21wq', 1)) cfg_base = dict(item1=0, item2=[1, 2, 3], item0=(1, 2, 3)) base_var_dict = { '_item1.12345': 'item1', @@ -517,9 +523,8 @@ def test_get_cfg_path_local(self): assert scope is None osp.isfile(cfg_path) - @pytest.mark.skipif( - not is_installed('mmdet') or not is_installed('mmcls'), - reason='mmdet and mmcls should be installed') + @pytest.mark.skipif(not is_installed('mmdet') or not is_installed('mmcls'), + reason='mmdet and mmcls should be installed') def test_get_cfg_path_external(self): filename = 'py_config/simple_config.py' filename = osp.join(self.data_path, 'config', filename) @@ -559,20 +564,18 @@ def _predefined_vars(self): path = osp.join(self.data_path, 'config/py_config') path = Path(path).as_posix() - cfg_dict_dst = dict( - item1='test_predefined_var.py', - item2=path, - item3='abc_test_predefined_var') + cfg_dict_dst = dict(item1='test_predefined_var.py', + item2=path, + item3='abc_test_predefined_var') assert Config._file2dict(cfg_file)[0]['item1'] == cfg_dict_dst['item1'] assert Config._file2dict(cfg_file)[0]['item2'] == cfg_dict_dst['item2'] assert Config._file2dict(cfg_file)[0]['item3'] == cfg_dict_dst['item3'] # test `use_predefined_variable=False` - cfg_dict_ori = dict( - item1='{{fileBasename}}', - item2='{{ fileDirname}}', - item3='abc_{{ fileBasenameNoExtension }}') + cfg_dict_ori = dict(item1='{{fileBasename}}', + item2='{{ fileDirname}}', + item3='abc_{{ fileBasenameNoExtension }}') assert Config._file2dict(cfg_file, False)[0]['item1'] == cfg_dict_ori['item1'] @@ -652,8 +655,8 @@ def _merge_from_multiple_bases(self): assert cfg_dict['item4'] == 'test' assert cfg_dict['item5'] == dict(a=0, b=1) assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict['item7'] == dict(a=[0, 1, 2], + b=dict(c=[3.1, 4.2, 5.3])) # Redefine key with pytest.raises(KeyError): Config.fromfile( @@ -674,8 +677,8 @@ def _base_variables(self): assert cfg_dict['item4'] == 'test' assert cfg_dict['item5'] == dict(a=0, b=1) assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict['item7'] == dict(a=[0, 1, 2], + b=dict(c=[3.1, 4.2, 5.3])) assert cfg_dict['item8'] == file.split('/')[-1] assert cfg_dict['item9'] == dict(a=0) assert cfg_dict['item10'] == [3.1, 4.2, 5.3] @@ -696,8 +699,8 @@ def _base_variables(self): assert cfg_dict['item4'] == 'test' assert cfg_dict['item5'] == dict(a=0, b=1) assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict['item7'] == dict(a=[0, 1, 2], + b=dict(c=[3.1, 4.2, 5.3])) assert cfg_dict['item8'] == 'test_base_variables.py' assert cfg_dict['item9'] == dict(a=0) assert cfg_dict['item10'] == [3.1, 4.2, 5.3] @@ -705,18 +708,17 @@ def _base_variables(self): assert cfg_dict['item12'] == dict(a=0) assert cfg_dict['item13'] == [3.1, 4.2, 5.3] assert cfg_dict['item14'] == [1, 2] - assert cfg_dict['item15'] == dict( - a=dict(b=dict(a=0)), - b=[False], - c=['test'], - d=[[{ - 'e': 0 - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e=[1, 2]) + assert cfg_dict['item15'] == dict(a=dict(b=dict(a=0)), + b=[False], + c=['test'], + d=[[{ + 'e': 0 + }], [{ + 'a': 0 + }, { + 'b': 1 + }]], + e=[1, 2]) # test reference assignment for py cfg_file = osp.join( @@ -728,17 +730,16 @@ def _base_variables(self): assert cfg_dict['item22'] == 'test_base_variables.py' assert cfg_dict['item23'] == [3.1, 4.2, 5.3] assert cfg_dict['item24'] == [3.1, 4.2, 5.3] - assert cfg_dict['item25'] == dict( - a=dict(b=[3.1, 4.2, 5.3]), - b=[[3.1, 4.2, 5.3]], - c=[[{ - 'e': 'test_base_variables.py' - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e='test_base_variables.py') + assert cfg_dict['item25'] == dict(a=dict(b=[3.1, 4.2, 5.3]), + b=[[3.1, 4.2, 5.3]], + c=[[{ + 'e': 'test_base_variables.py' + }], [{ + 'a': 0 + }, { + 'b': 1 + }]], + e='test_base_variables.py') cfg_file = osp.join(self.data_path, 'config/py_config/test_py_base.py') cfg = Config.fromfile(cfg_file) @@ -780,18 +781,17 @@ def _base_variables(self): assert cfg.item12 == 'test_py_base.py' assert cfg.item13 == 3.1 assert cfg.item14 == [1, 2] - assert cfg.item15 == dict( - a=dict(b=dict(a=0, b=[5, 6])), - b=[False], - c=['test'], - d=[[{ - 'e': 0 - }], [{ - 'c': 0 - }, { - 'b': 1 - }]], - e=[1, 2]) + assert cfg.item15 == dict(a=dict(b=dict(a=0, b=[5, 6])), + b=[False], + c=['test'], + d=[[{ + 'e': 0 + }], [{ + 'c': 0 + }, { + 'b': 1 + }]], + e=[1, 2]) # Test use global variable in config function cfg_file = osp.join(self.data_path, @@ -913,8 +913,8 @@ def test_copy(self): assert new_cfg._filename == cfg._filename assert new_cfg._text == cfg._text - @pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet should be installed') + @pytest.mark.skipif(not is_installed('mmdet'), + reason='mmdet should be installed') def test_get_external_cfg(self): ext_cfg_path = osp.join(self.data_path, 'config/py_config/test_get_external_cfg.py') @@ -927,8 +927,8 @@ def test_get_external_cfg(self): ) assert '_scope_' in ext_cfg._cfg_dict.model - @pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet should be installed') + @pytest.mark.skipif(not is_installed('mmdet'), + reason='mmdet should be installed') def test_build_external_package(self): # Test load base config. ext_cfg_path = osp.join(self.data_path, @@ -1062,10 +1062,9 @@ def _compare_dict(a, b): 'config/lazy_module_config/error_mix_using1.py')) # Force to import in non-lazy-import mode - Config.fromfile( - osp.join(self.data_path, - 'config/lazy_module_config/error_mix_using1.py'), - lazy_import=False) + Config.fromfile(osp.join( + self.data_path, 'config/lazy_module_config/error_mix_using1.py'), + lazy_import=False) # current lazy-import config, base text config with pytest.raises(RuntimeError, match='_base_ ='): @@ -1131,15 +1130,12 @@ def test_build_lazy(self): self.assertDictEqual(cfg_dict, raw) # Check `items` and `values` will only return the build object - raw = dict( - a=LazyObject('mmengine'), - b=dict( - c=2, - e=[ - dict( - f=dict(h=LazyObject('mmengine')), - g=LazyObject('mmengine')) - ])) + raw = dict(a=LazyObject('mmengine'), + b=dict(c=2, + e=[ + dict(f=dict(h=LazyObject('mmengine')), + g=LazyObject('mmengine')) + ])) cfg_dict = ConfigDict(raw) # check `items` and values self.assertDictEqual(cfg_dict._to_lazy_dict(), raw) diff --git a/tests/test_config/test_lazy.py b/tests/test_config/test_lazy.py index d69822814b..1dda04fdaa 100644 --- a/tests/test_config/test_lazy.py +++ b/tests/test_config/test_lazy.py @@ -8,7 +8,7 @@ from unittest import TestCase import numpy -import numpy.compat +import numpy.fft import numpy.linalg as linalg from rich.progress import Progress @@ -56,17 +56,17 @@ def test_lazy_module(self): # 1.2 getattr as LazyAttr self.assertIsInstance(lazy_numpy.linalg, LazyAttr) - self.assertIsInstance(lazy_numpy.compat, LazyAttr) + self.assertIsInstance(lazy_numpy.fft, LazyAttr) - # 1.3 Build module from LazyObject. amp and functional can be accessed + # 1.3 Build module from LazyObject. linalg and fft can be accessed imported_numpy = lazy_numpy.build() self.assertIs(imported_numpy.linalg, linalg) - self.assertIs(imported_numpy.compat, numpy.compat) + self.assertIs(imported_numpy.fft, numpy.fft) # 1.4.1 Build module from LazyAttr imported_linalg = lazy_numpy.linalg.build() - imported_compat = lazy_numpy.compat.build() - self.assertIs(imported_compat, numpy.compat) + imported_fft = lazy_numpy.fft.build() + self.assertIs(imported_fft, numpy.fft) self.assertIs(imported_linalg, linalg) # 1.4.2 build class method from LazyAttr diff --git a/tests/test_data/test_data_utils.py b/tests/test_data/test_data_utils.py index 76e30e8642..255a849bd2 100644 --- a/tests/test_data/test_data_utils.py +++ b/tests/test_data/test_data_utils.py @@ -49,30 +49,26 @@ def test_pseudo_collate(self): self.assertIs(batch_data_sample[1], data_sample2) # Test with list of tuple, each tuple is a nested dict instance - data_batch = [(dict( - inputs=input1, - data_sample=data_sample1, - value=1, - name='1', - nested=dict(data_sample=data_sample1)), - dict( - inputs=input2, - data_sample=data_sample2, - value=2, - name='2', - nested=dict(data_sample=data_sample2))), - (dict( - inputs=input1, - data_sample=data_sample1, - value=1, - name='1', - nested=dict(data_sample=data_sample1)), - dict( - inputs=input2, - data_sample=data_sample2, - value=2, - name='2', - nested=dict(data_sample=data_sample2)))] + data_batch = [(dict(inputs=input1, + data_sample=data_sample1, + value=1, + name='1', + nested=dict(data_sample=data_sample1)), + dict(inputs=input2, + data_sample=data_sample2, + value=2, + name='2', + nested=dict(data_sample=data_sample2))), + (dict(inputs=input1, + data_sample=data_sample1, + value=1, + name='1', + nested=dict(data_sample=data_sample1)), + dict(inputs=input2, + data_sample=data_sample2, + value=2, + name='2', + nested=dict(data_sample=data_sample2)))] data_batch = pseudo_collate(data_batch) batch_inputs_0 = data_batch[0]['inputs'] batch_inputs_1 = data_batch[1]['inputs'] diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index f4ec815ec2..24a1091a85 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -36,8 +36,10 @@ class CustomDataset(BaseDataset): class TestBaseDataset: def setup_method(self): - self.data_info = dict( - filename='test_img.jpg', height=604, width=640, sample_idx=0) + self.data_info = dict(filename='test_img.jpg', + height=604, + width=640, + sample_idx=0) self.imgs = torch.rand((2, 3, 32, 32)) self.ori_meta = BaseDataset.METAINFO self.ori_parse_data_info = BaseDataset.parse_data_info @@ -50,28 +52,28 @@ def teardown_method(self): def test_init(self): # test the instantiation of self.base_dataset - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') # test the instantiation of self.base_dataset with # `serialize_data=False` - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - serialize_data=False) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + serialize_data=False) assert dataset._fully_initialized assert hasattr(dataset, 'data_list') assert not hasattr(dataset, 'data_address') @@ -79,54 +81,49 @@ def test_init(self): assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset with lazy init - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=True) assert not dataset._fully_initialized assert not dataset.data_list # test the instantiation of self.base_dataset if ann_file is not # existed. with pytest.raises(FileNotFoundError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/not_existed_annotation.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/not_existed_annotation.json') # Use the default value of ann_file, i.e., '' with pytest.raises(TypeError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs')) + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs')) # test the instantiation of self.base_dataset when the ann_file is # wrong with pytest.raises(ValueError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/annotation_wrong_keys.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/annotation_wrong_keys.json') with pytest.raises(TypeError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/annotation_wrong_format.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/annotation_wrong_format.json') with pytest.raises(TypeError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=['img']), - ann_file='annotations/annotation_wrong_format.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path=['img']), + ann_file='annotations/annotation_wrong_format.json') # test the instantiation of self.base_dataset when `parse_data_info` # return `list[dict]` BaseDataset.parse_data_info = MagicMock( return_value=[self.data_info, self.data_info.copy()]) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') dataset.pipeline = self.pipeline assert dataset._fully_initialized assert hasattr(dataset, 'data_list') @@ -139,25 +136,24 @@ def test_init(self): # return unsupported data. with pytest.raises(TypeError): BaseDataset.parse_data_info = MagicMock(return_value='xxx') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') with pytest.raises(TypeError): BaseDataset.parse_data_info = MagicMock( return_value=[self.data_info, 'xxx']) - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') # test the instantiation of self.base_dataset without `ann_file` BaseDataset.parse_data_info = self.ori_parse_data_info - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='', - serialize_data=False, - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='', + serialize_data=False, + lazy_init=True) assert not dataset.ann_file # Test `ann_file` and `data_root` could be None. @@ -166,125 +162,119 @@ def test_init(self): def test_meta(self): # test dataset.metainfo with setting the metainfo from annotation file # as the metainfo of self.base_dataset. - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') - assert dataset.metainfo == dict( - dataset_type='test_dataset', task_name='test_task', empty_list=[]) + assert dataset.metainfo == dict(dataset_type='test_dataset', + task_name='test_task', + empty_list=[]) # test dataset.metainfo with setting METAINFO in self.base_dataset dataset_type = 'new_dataset' - BaseDataset.METAINFO = dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='test_task', - classes=('dog', 'cat'), - empty_list=[]) + BaseDataset.METAINFO = dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='test_task', + classes=('dog', 'cat'), + empty_list=[]) # test dataset.metainfo with passing metainfo into self.base_dataset metainfo = dict(classes=('dog', ), task_name='new_task') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # test dataset.metainfo with passing metainfo as Config into # self.base_dataset metainfo = Config(dict(classes=('dog', ), task_name='new_task')) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # test dataset.metainfo with passing metainfo as ConfigDict (Mapping) # into self.base_dataset metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task')) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not # change BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish') - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat', 'fish')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # test dataset.metainfo with passing metainfo containing a file into # self.base_dataset - metainfo = dict( - classes=osp.join( - osp.dirname(__file__), '../data/meta/classes.txt')) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='test_task', - classes=['dog'], - empty_list=[]) + metainfo = dict(classes=osp.join(osp.dirname(__file__), + '../data/meta/classes.txt')) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='test_task', + classes=['dog'], + empty_list=[]) # test dataset.metainfo with passing unsupported metainfo into # self.base_dataset with pytest.raises(TypeError): metainfo = 'dog' - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) # test dataset.metainfo with passing metainfo into self.base_dataset # and lazy_init is True metainfo = dict(classes=('dog', )) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo, - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo, + lazy_init=True) # 'task_name' and 'empty_list' not in dataset.metainfo - assert dataset.metainfo == dict( - dataset_type=dataset_type, classes=('dog', )) + assert dataset.metainfo == dict(dataset_type=dataset_type, + classes=('dog', )) # test whether self.base_dataset.METAINFO is changed when a customize # dataset inherit self.base_dataset @@ -293,26 +283,26 @@ class ToyDataset(BaseDataset): METAINFO = dict(xxx='xxx') assert ToyDataset.METAINFO == dict(xxx='xxx') - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat', 'fish')) # test update METAINFO in ToyDataset. class ToyDataset(BaseDataset): METAINFO = copy.deepcopy(BaseDataset.METAINFO) METAINFO['classes'] = ('bird', ) - assert ToyDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('bird', )) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + assert ToyDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('bird', )) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat', 'fish')) @pytest.mark.parametrize('lazy_init', [True, False]) def test_length(self, lazy_init): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init) if not lazy_init: assert dataset._fully_initialized assert hasattr(dataset, 'data_list') @@ -364,11 +354,11 @@ def test_compose(self): @pytest.mark.parametrize('lazy_init', [True, False]) def test_getitem(self, lazy_init): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init) dataset.pipeline = self.pipeline if not lazy_init: assert dataset._fully_initialized @@ -406,11 +396,11 @@ def fake_prepare_data(idx): @pytest.mark.parametrize('lazy_init', [True, False]) def test_get_data_info(self, lazy_init): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init) if not lazy_init: assert dataset._fully_initialized @@ -427,10 +417,10 @@ def test_get_data_info(self, lazy_init): # Test parse_data_info with `data_prefix` BaseDataset.parse_data_info = self.ori_parse_data_info data_root = osp.join(osp.dirname(__file__), '../data/') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') data_info = dataset.get_data_info(0) assert data_info['img_path'] == osp.join(data_root, 'imgs', 'test_img.jpg') @@ -448,11 +438,11 @@ def foo(self): class_without_full_init.foo() def test_full_init(self): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=True) dataset.pipeline = self.pipeline # test `full_init()` when lazy_init is True assert not dataset._fully_initialized @@ -465,11 +455,11 @@ def test_full_init(self): assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=False) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=False) dataset.pipeline = self.pipeline assert dataset._fully_initialized @@ -479,10 +469,10 @@ def test_full_init(self): assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset when passing indices - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json') dataset_sliced = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img_path=''), @@ -497,12 +487,12 @@ def test_full_init(self): def test_get_subset_(self, lazy_init, serialize_data): # Test positive int indices. indices = 2 - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init, - serialize_data=serialize_data) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init, + serialize_data=serialize_data) dataset_copy = copy.deepcopy(dataset) dataset_copy.get_subset_(indices) @@ -575,12 +565,12 @@ def test_get_subset_(self, lazy_init, serialize_data): def test_get_subset(self, lazy_init, serialize_data): # Test positive indices. indices = 2 - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init, - serialize_data=serialize_data) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init, + serialize_data=serialize_data) dataset_sliced = dataset.get_subset(indices) assert len(dataset_sliced) == 2 assert dataset_sliced[0] == dataset[0] @@ -621,11 +611,11 @@ def test_get_subset(self, lazy_init, serialize_data): def test_rand_another(self): # test the instantiation of self.base_dataset when passing num_samples - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - indices=1) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json', + indices=1) assert dataset._rand_another() >= 0 assert dataset._rand_another() < len(dataset) @@ -640,20 +630,20 @@ def setup_method(self): dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - self.dataset_a = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset_a = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) # create dataset_b data_info = dict(filename='gray.jpg', height=288, width=512) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - self.dataset_b = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset_b = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) # test init self.cat_datasets = ConcatDataset( @@ -661,11 +651,11 @@ def setup_method(self): def test_init(self): # Test build dataset from cfg. - dataset_cfg_b = dict( - type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset_cfg_b = dict(type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b]) cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline assert len(cat_datasets) == len(self.cat_datasets) @@ -678,8 +668,8 @@ def test_init(self): ConcatDataset(datasets=[0]) with pytest.raises(TypeError): - ConcatDataset( - datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1) + ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b], + ignore_keys=1) def test_full_init(self): # test init with lazy_init=True @@ -696,11 +686,11 @@ def test_full_init(self): with pytest.raises(NotImplementedError): self.cat_datasets.get_subset(1) - dataset_b = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=dict(classes=('cat'))) + dataset_b = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=dict(classes=('cat'))) # Regardless of order, different meta information without # `ignore_keys` will raise error. with pytest.raises(ValueError): @@ -710,11 +700,11 @@ def test_full_init(self): # `ignore_keys` does not contain different meta information keys will # raise error. with pytest.raises(ValueError): - ConcatDataset( - datasets=[self.dataset_a, dataset_b], ignore_keys=['a']) + ConcatDataset(datasets=[self.dataset_a, dataset_b], + ignore_keys=['a']) # Different meta information with `ignore_keys` will not raise error. - cat_datasets = ConcatDataset( - datasets=[self.dataset_a, dataset_b], ignore_keys='classes') + cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_b], + ignore_keys='classes') cat_datasets.full_init() assert len(cat_datasets) == 6 cat_datasets.full_init() @@ -727,19 +717,19 @@ def test_metainfo(self): assert self.cat_datasets.metainfo == self.dataset_a.metainfo def test_length(self): - assert len(self.cat_datasets) == ( - len(self.dataset_a) + len(self.dataset_b)) + assert len(self.cat_datasets) == (len(self.dataset_a) + + len(self.dataset_b)) def test_getitem(self): assert ( self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() - assert (self.cat_datasets[0]['imgs'] != - self.dataset_b[0]['imgs']).all() + assert (self.cat_datasets[0]['imgs'] + != self.dataset_b[0]['imgs']).all() assert ( self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all() - assert (self.cat_datasets[-1]['imgs'] != - self.dataset_a[-1]['imgs']).all() + assert (self.cat_datasets[-1]['imgs'] + != self.dataset_a[-1]['imgs']).all() def test_get_data_info(self): assert self.cat_datasets.get_data_info( @@ -768,26 +758,26 @@ def setup_method(self): data_info = dict(filename='test_img.jpg', height=604, width=640) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - self.dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_times = 5 # test init - self.repeat_datasets = RepeatDataset( - dataset=self.dataset, times=self.repeat_times) + self.repeat_datasets = RepeatDataset(dataset=self.dataset, + times=self.repeat_times) def test_init(self): # Test build dataset from cfg. - dataset_cfg = dict( - type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - repeat_dataset = RepeatDataset( - dataset=dataset_cfg, times=self.repeat_times) + dataset_cfg = dict(type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') + repeat_dataset = RepeatDataset(dataset=dataset_cfg, + times=self.repeat_times) repeat_dataset.dataset.pipeline = self.dataset.pipeline assert len(repeat_dataset) == len(self.repeat_datasets) for i in range(len(repeat_dataset)): @@ -840,10 +830,10 @@ def setup_method(self): dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) dataset.get_cat_ids = MagicMock(return_value=[0]) - self.dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_indices = [0, 0, 1, 1, 1] @@ -854,13 +844,13 @@ def setup_method(self): def test_init(self): # Test build dataset from cfg. - dataset_cfg = dict( - type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - cls_banlanced_datasets = ClassBalancedDataset( - dataset=dataset_cfg, oversample_thr=1e-3) + dataset_cfg = dict(type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') + cls_banlanced_datasets = ClassBalancedDataset(dataset=dataset_cfg, + oversample_thr=1e-3) cls_banlanced_datasets.repeat_indices = self.repeat_indices cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets) diff --git a/tests/test_dataset/test_sampler.py b/tests/test_dataset/test_sampler.py index 31582a8679..70d510159c 100644 --- a/tests/test_dataset/test_sampler.py +++ b/tests/test_dataset/test_sampler.py @@ -44,9 +44,8 @@ def test_dist(self, mock): self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3)) self.assertEqual(sampler.total_size, sampler.num_samples * 3) self.assertEqual(len(sampler), sampler.num_samples) - self.assertEqual( - list(sampler), - list(range(self.data_length))[2::3] + [1]) + self.assertEqual(list(sampler), + list(range(self.data_length))[2::3] + [1]) # test round_up=False sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False) diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index a2ef07b713..95db0f8bd7 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -126,8 +126,9 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29505' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) def setUp(self): super().setUp() @@ -193,9 +194,8 @@ def test_broadcast_dist(self): def test_sync_random_seed(self): self._init_dist_env(self.rank, self.world_size) - with patch.object( - torch, 'tensor', - return_value=torch.tensor(1024)) as mock_tensor: + with patch.object(torch, 'tensor', + return_value=torch.tensor(1024)) as mock_tensor: output = dist.sync_random_seed() assert output == 1024 mock_tensor.assert_called() @@ -333,20 +333,17 @@ def test_all_reduce_params(self): torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) ] else: - data = ( - torch.tensor([2, 3], dtype=tensor_type) - for _ in range(100)) + data = (torch.tensor([2, 3], dtype=tensor_type) + for _ in range(100)) data_gen = (item for item in data) if reduce_op == 'sum': - expected = ( - torch.tensor([2, 4], dtype=tensor_type) - for _ in range(100)) + expected = (torch.tensor([2, 4], dtype=tensor_type) + for _ in range(100)) else: - expected = ( - torch.tensor([1, 2], dtype=tensor_type) - for _ in range(100)) + expected = (torch.tensor([1, 2], dtype=tensor_type) + for _ in range(100)) dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op) @@ -354,8 +351,8 @@ def test_all_reduce_params(self): self.assertTrue(torch.allclose(item1, item2)) -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +@unittest.skipIf(torch.cuda.device_count() < 2, + reason='need 2 gpu to test nccl') class TestDistWithNCCLBackend(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): @@ -366,8 +363,9 @@ def _init_dist_env(self, rank, world_size): num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='nccl', + rank=rank, + world_size=world_size) def setUp(self): super().setUp() @@ -431,9 +429,8 @@ def test_broadcast_dist(self): def test_sync_random_seed(self): self._init_dist_env(self.rank, self.world_size) - with patch.object( - torch, 'tensor', - return_value=torch.tensor(1024)) as mock_tensor: + with patch.object(torch, 'tensor', + return_value=torch.tensor(1024)) as mock_tensor: output = dist.sync_random_seed() assert output == 1024 mock_tensor.assert_called() @@ -580,8 +577,10 @@ def test_collect_results(self): # broadcast tmpdir to all ranks to make it consistent object_list = [tmpdir] dist.broadcast_object_list(object_list) - output = dist.collect_results( - data, size, device='cpu', tmpdir=object_list[0]) + output = dist.collect_results(data, + size, + device='cpu', + tmpdir=object_list[0]) if dist.get_rank() == 0: self.assertEqual(output, expected) else: @@ -646,13 +645,13 @@ def test_all_reduce_params(self): dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op) if reduce_op == 'sum': - expected = ( - torch.tensor([2, 4], dtype=tensor_type).to(device_type) - for _ in range(100)) + expected = (torch.tensor([2, 4], + dtype=tensor_type).to(device_type) + for _ in range(100)) else: - expected = ( - torch.tensor([1, 2], dtype=tensor_type).to(device_type) - for _ in range(100)) + expected = (torch.tensor([1, 2], + dtype=tensor_type).to(device_type) + for _ in range(100)) for item1, item2 in zip(data_gen, expected): self.assertTrue(torch.allclose(item1, item2)) diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py index d9af72f964..4cffb385fa 100644 --- a/tests/test_dist/test_utils.py +++ b/tests/test_dist/test_utils.py @@ -101,8 +101,8 @@ def test_get_data_device(self): 'data should be a Tensor, sequence of tensor or dict'): dist.get_data_device('123') - @unittest.skipIf( - torch.cuda.device_count() == 0, reason='at lest need 1 gpu to test') + @unittest.skipIf(torch.cuda.device_count() == 0, + reason='at lest need 1 gpu to test') def test_cast_data_device(self): expected_device = torch.device('cuda', torch.cuda.current_device()) # data is a Tensor @@ -181,8 +181,8 @@ def test_cast_data_device(self): self.assertEqual(output['key1'].device, expected_device) self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1'])) self.assertEqual(output['key2'][0].device, expected_device) - self.assertTrue( - torch.allclose(output['key2'][0].cpu(), out['key2'][0])) + self.assertTrue(torch.allclose(output['key2'][0].cpu(), + out['key2'][0])) # data is not a valid type with self.assertRaisesRegex( @@ -218,8 +218,9 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_PORT'] = '29505' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) dist.init_local_group(0, world_size) def setUp(self): @@ -247,8 +248,8 @@ def test_local_size(self): def test_local_rank(self): self._init_dist_env(self.rank, self.world_size) - self.assertEqual( - torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + self.assertEqual(torch_dist.get_rank(dist.get_local_group()), + dist.get_local_rank()) def test_get_dist_info(self): self._init_dist_env(self.rank, self.world_size) @@ -337,8 +338,8 @@ def test_get_comm_device(self): assert dist.get_comm_device(group) == torch.device('cpu') -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +@unittest.skipIf(torch.cuda.device_count() < 2, + reason='need 2 gpu to test nccl') class TestUtilsWithNCCLBackend(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): @@ -349,8 +350,9 @@ def _init_dist_env(self, rank, world_size): num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='nccl', + rank=rank, + world_size=world_size) dist.init_local_group(0, world_size) def setUp(self): @@ -378,8 +380,8 @@ def test_local_size(self): def test_local_rank(self): self._init_dist_env(self.rank, self.world_size) - self.assertEqual( - torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + self.assertEqual(torch_dist.get_rank(dist.get_local_group()), + dist.get_local_rank()) def test_get_dist_info(self): self._init_dist_env(self.rank, self.world_size) @@ -579,8 +581,8 @@ def test_cast_data_device(self): self.assertEqual(output['key1'].device, expected_device) self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1'])) self.assertEqual(output['key2'][0].device, expected_device) - self.assertTrue( - torch.allclose(output['key2'][0].cpu(), out['key2'][0])) + self.assertTrue(torch.allclose(output['key2'][0].cpu(), + out['key2'][0])) # data is not a valid type with self.assertRaisesRegex( diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index 58b7e1e6fe..c9f4100b40 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -100,8 +100,10 @@ def test_single_metric(self): size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, + batch_size, + pred=1, + label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) @@ -126,8 +128,10 @@ def test_composed_metrics(self): size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, + batch_size, + pred=1, + label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) @@ -147,8 +151,10 @@ def test_ambiguous_metric(self): size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, + batch_size, + pred=1, + label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) with self.assertRaisesRegex( @@ -175,10 +181,9 @@ def test_dataset_meta(self): def test_collect_device(self): cfg = [ dict(type='ToyMetric', collect_device='cpu'), - dict( - type='ToyMetric', - collect_device='gpu', - dummy_metrics=dict(mAP=0.0)) + dict(type='ToyMetric', + collect_device='gpu', + dummy_metrics=dict(mAP=0.0)) ] evaluator = Evaluator(cfg) @@ -262,16 +267,15 @@ def test_evaluate_cast_cpu(self): size = 10 all_data = [ - dict( - inputs=torch.zeros((3, 10, 10), device='cuda'), - data_sample=BaseDataElement( - label=torch.ones((1, ), device='cuda'))) + dict(inputs=torch.zeros((3, 10, 10), device='cuda'), + data_sample=BaseDataElement( + label=torch.ones((1, ), device='cuda'))) for _ in range(size) ] all_predictions = [ - BaseDataElement( - pred=torch.zeros((1, ), device='cuda'), - label=torch.ones((1, ), device='cuda')) for _ in range(size) + BaseDataElement(pred=torch.zeros((1, ), device='cuda'), + label=torch.ones((1, ), device='cuda')) + for _ in range(size) ] for data, pred in zip(all_data, all_predictions): evaluator.process([pred], [data]) diff --git a/tests/test_evaluator/test_metric.py b/tests/test_evaluator/test_metric.py index 055bd73ca1..d1a5608ef4 100644 --- a/tests/test_evaluator/test_metric.py +++ b/tests/test_evaluator/test_metric.py @@ -19,10 +19,9 @@ def test_init(self): # collect_dir could only be configured when collect_device='cpu' with self.assertRaises(ValueError): - DumpResults( - out_file_path='./results.json', - collect_device='gpu', - collect_dir='./tmp') + DumpResults(out_file_path='./results.json', + collect_device='gpu', + collect_dir='./tmp') def test_process(self): metric = DumpResults(out_file_path='./results.pkl') diff --git a/tests/test_fileio/test_backends/test_backend_utils.py b/tests/test_fileio/test_backends/test_backend_utils.py index 7903f5574e..9ed38ff701 100644 --- a/tests/test_fileio/test_backends/test_backend_utils.py +++ b/tests/test_fileio/test_backends/test_backend_utils.py @@ -57,8 +57,8 @@ def get(self, filepath): def get_text(self, filepath): return filepath - with pytest.raises( - TypeError, match='not a subclass of BaseStorageBackend'): + with pytest.raises(TypeError, + match='not a subclass of BaseStorageBackend'): register_backend('example3', ExampleBackend2) # 4. test `force` parameter @@ -85,8 +85,9 @@ def get_text(self, filepath): assert 'prefix1' in prefix_to_backends # 5.2 prefixes is a list (tuple) of strings - register_backend( - 'example4', ExampleBackend3, prefixes=['prefix2', 'prefix3']) + register_backend('example4', + ExampleBackend3, + prefixes=['prefix2', 'prefix3']) assert 'example4' in backends assert 'prefix2' in prefix_to_backends assert 'prefix3' in prefix_to_backends @@ -108,7 +109,9 @@ def get(self, filepath): def get_text(self, filepath): return filepath - register_backend( - 'example6', ExampleBackend4, prefixes='prefix2', force=True) + register_backend('example6', + ExampleBackend4, + prefixes='prefix2', + force=True) assert 'example6' in backends assert 'prefix2' in prefix_to_backends diff --git a/tests/test_fileio/test_backends/test_local_backend.py b/tests/test_fileio/test_backends/test_local_backend.py index 427ebf789a..71b2423504 100644 --- a/tests/test_fileio/test_backends/test_local_backend.py +++ b/tests/test_fileio/test_backends/test_local_backend.py @@ -146,15 +146,15 @@ def test_isfile(self, path_type): @parameterized.expand([[Path], [str]]) def test_join_path(self, path_type): backend = LocalBackend() - filepath = backend.join_path( - path_type(self.test_data_dir), path_type('file')) + filepath = backend.join_path(path_type(self.test_data_dir), + path_type('file')) expected = osp.join(path_type(self.test_data_dir), path_type('file')) self.assertEqual(filepath, expected) - filepath = backend.join_path( - path_type(self.test_data_dir), path_type('dir'), path_type('file')) - expected = osp.join( - path_type(self.test_data_dir), path_type('dir'), path_type('file')) + filepath = backend.join_path(path_type(self.test_data_dir), + path_type('dir'), path_type('file')) + expected = osp.join(path_type(self.test_data_dir), path_type('dir'), + path_type('file')) self.assertEqual(filepath, expected) @parameterized.expand([[Path], [str]]) @@ -170,17 +170,15 @@ def test_copyfile(self, path_type): src = Path(tmp_dir) / 'test.txt' backend.put_text('disk', src) dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) self.assertEqual(backend.get_text(dst), 'disk') # dst is a directory dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) self.assertEqual( backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') @@ -195,17 +193,16 @@ def test_copytree(self, path_type): # src and dst are Path objects src = Path(tmp_dir) / 'dir1' dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) self.assertTrue(backend.isdir(dst)) self.assertTrue(backend.isfile(dst / 'text3.txt')) self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), + path_type(Path(tmp_dir) / 'dir2')) @parameterized.expand([[Path], [str]]) def test_copyfile_from_local(self, path_type): @@ -214,16 +211,14 @@ def test_copyfile_from_local(self, path_type): src = Path(tmp_dir) / 'test.txt' backend.put_text('disk', src) dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) self.assertEqual(backend.get_text(dst), 'disk') dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) self.assertEqual( backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') @@ -238,17 +233,16 @@ def test_copytree_from_local(self, path_type): # src and dst are Path objects src = Path(tmp_dir) / 'dir1' dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) self.assertTrue(backend.isdir(dst)) self.assertTrue(backend.isfile(dst / 'text3.txt')) self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), + path_type(Path(tmp_dir) / 'dir2')) @parameterized.expand([[Path], [str]]) def test_copyfile_to_local(self, path_type): @@ -257,16 +251,14 @@ def test_copyfile_to_local(self, path_type): src = Path(tmp_dir) / 'test.txt' backend.put_text('disk', src) dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) self.assertEqual(backend.get_text(dst), 'disk') dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) self.assertEqual( backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') @@ -281,17 +273,16 @@ def test_copytree_to_local(self, path_type): # src and dst are Path objects src = Path(tmp_dir) / 'dir1' dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) self.assertTrue(backend.isdir(dst)) self.assertTrue(backend.isfile(dst / 'text3.txt')) self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), + path_type(Path(tmp_dir) / 'dir2')) @parameterized.expand([[Path], [str]]) def test_remove(self, path_type): @@ -361,8 +352,8 @@ def symlink(src, dst): with patch.object(os, 'symlink', side_effect=symlink): src = Path(tmp_dir) / 'test.txt' dst = Path(tmp_dir) / 'test_link1.txt' - res = backend.copy_if_symlink_fails( - path_type(src), path_type(dst)) + res = backend.copy_if_symlink_fails(path_type(src), + path_type(dst)) self.assertFalse(res) self.assertFalse(osp.islink(dst)) self.assertTrue(backend.exists(dst)) @@ -371,8 +362,8 @@ def symlink(src, dst): with patch.object(os, 'symlink', side_effect=symlink): src = Path(tmp_dir) / 'dir' dst = Path(tmp_dir) / 'dir_link1' - res = backend.copy_if_symlink_fails( - path_type(src), path_type(dst)) + res = backend.copy_if_symlink_fails(path_type(src), + path_type(dst)) self.assertFalse(res) self.assertFalse(osp.islink(dst)) self.assertTrue(backend.exists(dst)) @@ -382,15 +373,14 @@ def test_list_dir_or_file(self, path_type): backend = LocalBackend() with build_temporary_directory() as tmp_dir: # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(path_type(tmp_dir))), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(path_type(tmp_dir))), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) # list directories and files recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + recursive=True)), { 'dir1', osp.join('dir1', 'text3.txt'), 'dir2', @@ -402,35 +392,38 @@ def test_list_dir_or_file(self, path_type): # only list directories self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False)), + backend.list_dir_or_file(path_type(tmp_dir), + list_file=False)), {'dir1', 'dir2'}) with self.assertRaisesRegex( TypeError, '`suffix` should be None when `list_dir` is True'): - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False, suffix='.txt') + backend.list_dir_or_file(path_type(tmp_dir), + list_file=False, + suffix='.txt') # only list directories recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False, recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + list_file=False, + recursive=True)), {'dir1', 'dir2', osp.join('dir2', 'dir3')}) # only list files self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False)), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False)), {'text1.txt', 'text2.txt'}) # only list files recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False, recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + recursive=True)), { osp.join('dir1', 'text3.txt'), osp.join('dir2', 'dir3', 'text4.txt'), @@ -440,45 +433,44 @@ def test_list_dir_or_file(self, path_type): # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False, suffix='.txt')), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix='.txt')), {'text1.txt', 'text2.txt'}) self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'))), + {'text1.txt', 'text2.txt'}) with self.assertRaisesRegex( TypeError, '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=['.txt', '.jpg']) + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix=['.txt', '.jpg']) # only list files ending with suffix recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix='.txt', - recursive=True)), { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', - 'text2.txt' - }) + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix='.txt', + recursive=True)), + { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + }) # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), { osp.join('dir1', 'text3.txt'), osp.join('dir2', 'dir3', 'text4.txt'), diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py index 6f379c3f23..3af60276b5 100644 --- a/tests/test_fileio/test_backends/test_petrel_backend.py +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -124,13 +124,13 @@ def test_name(self): def test_map_path(self): backend = PetrelBackend(path_mapping=None) - self.assertEqual( - backend._map_path(self.petrel_path), self.petrel_path) + self.assertEqual(backend._map_path(self.petrel_path), + self.petrel_path) backend = PetrelBackend( path_mapping={'data/': 'petrel://user/data/'}) - self.assertEqual( - backend._map_path('data/test.jpg'), self.petrel_path) + self.assertEqual(backend._map_path('data/test.jpg'), + self.petrel_path) def test_format_path(self): backend = PetrelBackend() @@ -140,37 +140,31 @@ def test_format_path(self): def test_replace_prefix(self): backend = PetrelBackend() - self.assertEqual( - backend._replace_prefix(self.petrel_path), self.expected_path) + self.assertEqual(backend._replace_prefix(self.petrel_path), + self.expected_path) def test_join_path(self): backend = PetrelBackend() - self.assertEqual( - backend.join_path(self.petrel_dir, 'file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(f'{self.petrel_dir}/', 'file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(f'{self.petrel_dir}/', '/file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(self.petrel_dir, 'dir', 'file'), - f'{self.petrel_dir}/dir/file') + self.assertEqual(backend.join_path(self.petrel_dir, 'file'), + f'{self.petrel_dir}/file') + self.assertEqual(backend.join_path(f'{self.petrel_dir}/', 'file'), + f'{self.petrel_dir}/file') + self.assertEqual(backend.join_path(f'{self.petrel_dir}/', '/file'), + f'{self.petrel_dir}/file') + self.assertEqual(backend.join_path(self.petrel_dir, 'dir', 'file'), + f'{self.petrel_dir}/dir/file') def test_get(self): backend = PetrelBackend() - with patch.object( - backend._client, 'Get', - return_value=b'petrel') as patched_get: + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get: self.assertEqual(backend.get(self.petrel_path), b'petrel') patched_get.assert_called_once_with(self.expected_path) def test_get_text(self): backend = PetrelBackend() - with patch.object( - backend._client, 'Get', - return_value=b'petrel') as patched_get: + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get: self.assertEqual(backend.get_text(self.petrel_path), 'petrel') patched_get.assert_called_once_with(self.expected_path) @@ -201,9 +195,8 @@ def test_exists(self): with self.assertRaises(NotImplementedError): backend.exists(self.petrel_path) - with patch.object( - backend._client, 'contains', - return_value=True) as patched_contains: + with patch.object(backend._client, 'contains', + return_value=True) as patched_contains: self.assertTrue(backend.exists(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) @@ -216,9 +209,8 @@ def test_isdir(self): with self.assertRaises(NotImplementedError): backend.isdir(self.petrel_path) - with patch.object( - backend._client, 'isdir', - return_value=True) as patched_contains: + with patch.object(backend._client, 'isdir', + return_value=True) as patched_contains: self.assertTrue(backend.isdir(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) @@ -231,9 +223,8 @@ def test_isfile(self): with self.assertRaises(NotImplementedError): backend.isfile(self.petrel_path) - with patch.object( - backend._client, 'contains', - return_value=True) as patched_contains: + with patch.object(backend._client, 'contains', + return_value=True) as patched_contains: self.assertTrue(backend.isfile(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) @@ -335,8 +326,8 @@ def test_copyfile_from_local(self): src = self.img_path dst = f'{self.petrel_dir}/dir' expected_dst = f'{self.expected_dir}/dir/color.jpg' - self.assertEqual( - backend.copyfile_from_local(src, dst), f'{dst}/color.jpg') + self.assertEqual(backend.copyfile_from_local(src, dst), + f'{dst}/color.jpg') patched_put.assert_called_once_with(expected_dst, src.open('rb').read()) patched_isdir.assert_called_once_with( @@ -380,8 +371,8 @@ def test_copyfile_to_local(self): src = self.petrel_path dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile_to_local(src, dst), dst / 'test.jpg') + self.assertEqual(backend.copyfile_to_local(src, dst), + dst / 'test.jpg') patched_get.assert_called_once_with(self.expected_path) self.assertEqual((dst / 'test.jpg').open('rb').read(), b'petrel') @@ -468,9 +459,8 @@ def test_list_dir_or_file(self): with build_temporary_directory() as tmp_dir: # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(tmp_dir)), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(tmp_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) # list directories and files recursively self.assertEqual( @@ -489,14 +479,16 @@ def test_list_dir_or_file(self): TypeError, '`list_dir` should be False when `suffix` is not None' ): - backend.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + backend.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt') # only list directories recursively self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)), + backend.list_dir_or_file(tmp_dir, + list_file=False, + recursive=True)), {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) # only list files @@ -507,8 +499,9 @@ def test_list_dir_or_file(self): # only list files recursively self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, recursive=True)), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -518,41 +511,43 @@ def test_list_dir_or_file(self): # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt')), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt')), {'text1.txt', 'text2.txt'}) self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix=('.txt', '.jpg'))), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) with self.assertRaisesRegex( TypeError, '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # only list files ending with suffix recursively self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix='.txt', - recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), - 'text1.txt', 'text2.txt' - }) + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', + 'text2.txt' + }) # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -673,9 +668,8 @@ def test_copytree(self): dst = f'{self.petrel_dir}/dir3' self.assertFalse(backend.exists(dst)) self.assertEqual(backend.copytree(src, dst), dst) - self.assertEqual( - list(backend.list_dir_or_file(src)), - list(backend.list_dir_or_file(dst))) + self.assertEqual(list(backend.list_dir_or_file(src)), + list(backend.list_dir_or_file(dst))) # dst should not exist with self.assertRaises(FileExistsError): @@ -696,8 +690,8 @@ def test_copyfile_from_local(self): dst = f'{self.petrel_dir}/dir1' expected_dst = f'{self.petrel_dir}/dir1/color.jpg' self.assertFalse(backend.exists(expected_dst)) - self.assertEqual( - backend.copyfile_from_local(src, dst), expected_dst) + self.assertEqual(backend.copyfile_from_local(src, dst), + expected_dst) self.assertTrue(backend.isfile(expected_dst)) def test_copytree_from_local(self): @@ -705,8 +699,8 @@ def test_copytree_from_local(self): backend.rmtree(self.petrel_dir) with build_temporary_directory() as tmp_dir: backend.copytree_from_local(tmp_dir, self.petrel_dir) - files = backend.list_dir_or_file( - self.petrel_dir, recursive=True) + files = backend.list_dir_or_file(self.petrel_dir, + recursive=True) self.assertEqual(len(list(files)), 8) def test_copyfile_to_local(self): @@ -721,8 +715,8 @@ def test_copyfile_to_local(self): # dst is a directory dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile_to_local(src, dst), dst / 'img.jpg') + self.assertEqual(backend.copyfile_to_local(src, dst), + dst / 'img.jpg') self.assertEqual((dst / 'img.jpg').open('rb').read(), b'img') def test_copytree_to_local(self): @@ -767,9 +761,8 @@ def test_list_dir_or_file(self): backend = PetrelBackend() # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(self.petrel_dir)), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(self.petrel_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) # list directories and files recursively self.assertEqual( @@ -783,21 +776,22 @@ def test_list_dir_or_file(self): # only list directories self.assertEqual( - set( - backend.list_dir_or_file(self.petrel_dir, + set(backend.list_dir_or_file(self.petrel_dir, list_file=False)), {'dir1', 'dir2'}) with self.assertRaisesRegex( TypeError, '`list_dir` should be False when `suffix` is not None'): - backend.list_dir_or_file( - self.petrel_dir, list_file=False, suffix='.txt') + backend.list_dir_or_file(self.petrel_dir, + list_file=False, + suffix='.txt') # only list directories recursively self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, list_file=False, recursive=True)), + backend.list_dir_or_file(self.petrel_dir, + list_file=False, + recursive=True)), {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) # only list files @@ -808,8 +802,9 @@ def test_list_dir_or_file(self): # only list files recursively self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, recursive=True)), + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -819,42 +814,43 @@ def test_list_dir_or_file(self): # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, suffix='.txt')), + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix='.txt')), {'text1.txt', 'text2.txt'}) self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'))), + {'text1.txt', 'text2.txt'}) with self.assertRaisesRegex( TypeError, '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, suffix=['.txt', '.jpg']) + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # only list files ending with suffix recursively self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix='.txt', - recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), 'text1.txt', - 'text2.txt' - }) + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix='.txt', + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', + 'text2.txt' + }) # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( diff --git a/tests/test_fileio/test_fileclient.py b/tests/test_fileio/test_fileclient.py index 345832a026..72eea97b88 100644 --- a/tests/test_fileio/test_fileclient.py +++ b/tests/test_fileio/test_fileclient.py @@ -16,8 +16,6 @@ from mmengine.utils import has_method sys.modules['ceph'] = MagicMock() -sys.modules['petrel_client'] = MagicMock() -sys.modules['petrel_client.client'] = MagicMock() sys.modules['mc'] = MagicMock() @@ -226,23 +224,24 @@ def test_disk_backend(self): osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' } # 3. only list directories - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_file=False)) == {'dir1', 'dir2'} + assert set(disk_backend.list_dir_or_file( + tmp_dir, list_file=False)) == {'dir1', 'dir2'} with pytest.raises( TypeError, match='`suffix` should be None when `list_dir` is True'): # Exception is raised among the `list_dir_or_file` of client, # so we need to invode the client to trigger the exception - disk_backend.client.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + disk_backend.client.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt') # 4. only list directories recursively assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)) == { - 'dir1', 'dir2', - osp.join('dir2', 'dir3') - } + disk_backend.list_dir_or_file(tmp_dir, + list_file=False, + recursive=True)) == { + 'dir1', 'dir2', + osp.join('dir2', 'dir3') + } # 5. only list files assert set(disk_backend.list_dir_or_file( tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} @@ -256,18 +255,23 @@ def test_disk_backend(self): } # 7. only list files ending with suffix assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix='.txt')) == {'text1.txt', 'text2.txt'} + disk_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt')) == { + 'text1.txt', 'text2.txt' + } assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + disk_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))) == { + 'text1.txt', 'text2.txt' + } with pytest.raises( TypeError, match='`suffix` must be a string or tuple of strings'): - disk_backend.client.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + disk_backend.client.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # 8. only list files ending with suffix recursively assert set( disk_backend.list_dir_or_file( @@ -289,7 +293,9 @@ def test_disk_backend(self): osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' } - @patch('petrel_client.client.Client', MockPetrelClient) + @patch.dict( + sys.modules, + {'petrel_client': MagicMock(**{'client.Client': MockPetrelClient})}) @pytest.mark.parametrize('backend,prefix', [('petrel', None), (None, 's3')]) def test_petrel_backend(self, backend, prefix): @@ -326,16 +332,16 @@ def test_petrel_backend(self, backend, prefix): == petrel_path # test `get` - with patch.object( - petrel_backend.client._client, 'Get', - return_value=b'petrel') as mock_get: + with patch.object(petrel_backend.client._client, + 'Get', + return_value=b'petrel') as mock_get: assert petrel_backend.get(petrel_path) == b'petrel' mock_get.assert_called_once_with(petrel_path) # test `get_text` - with patch.object( - petrel_backend.client._client, 'Get', - return_value=b'petrel') as mock_get: + with patch.object(petrel_backend.client._client, + 'Get', + return_value=b'petrel') as mock_get: assert petrel_backend.get_text(petrel_path) == 'petrel' mock_get.assert_called_once_with(petrel_path) @@ -381,9 +387,9 @@ def test_petrel_backend(self, backend, prefix): with pytest.raises(NotImplementedError): petrel_backend.exists(petrel_path) - with patch.object( - petrel_backend.client._client, 'contains', - return_value=True) as mock_contains: + with patch.object(petrel_backend.client._client, + 'contains', + return_value=True) as mock_contains: assert petrel_backend.exists(petrel_path) mock_contains.assert_called_once_with(petrel_path) @@ -394,9 +400,9 @@ def test_petrel_backend(self, backend, prefix): with pytest.raises(NotImplementedError): petrel_backend.isdir(petrel_path) - with patch.object( - petrel_backend.client._client, 'isdir', - return_value=True) as mock_isdir: + with patch.object(petrel_backend.client._client, + 'isdir', + return_value=True) as mock_isdir: assert petrel_backend.isdir(petrel_dir) mock_isdir.assert_called_once_with(petrel_dir) @@ -408,9 +414,9 @@ def test_petrel_backend(self, backend, prefix): with pytest.raises(NotImplementedError): petrel_backend.isfile(petrel_path) - with patch.object( - petrel_backend.client._client, 'contains', - return_value=True) as mock_contains: + with patch.object(petrel_backend.client._client, + 'contains', + return_value=True) as mock_contains: assert petrel_backend.isfile(petrel_path) mock_contains.assert_called_once_with(petrel_path) @@ -447,8 +453,8 @@ def test_petrel_backend(self, backend, prefix): 'dir1', 'dir2', 'text1.txt', 'text2.txt' } # 2. list directories and files recursively - assert set( - petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == { + assert set(petrel_backend.list_dir_or_file( + tmp_dir, recursive=True)) == { 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join( ('dir2', 'dir3')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -464,18 +470,20 @@ def test_petrel_backend(self, backend, prefix): 'None')): # Exception is raised among the `list_dir_or_file` of client, # so we need to invode the client to trigger the exception - petrel_backend.client.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + petrel_backend.client.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt') # 4. only list directories recursively assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)) == { - 'dir1', 'dir2', '/'.join(('dir2', 'dir3')) - } + petrel_backend.list_dir_or_file(tmp_dir, + list_file=False, + recursive=True)) == { + 'dir1', 'dir2', '/'.join( + ('dir2', 'dir3')) + } # 5. only list files - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} + assert set(petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} # 6. only list files recursively assert set( petrel_backend.list_dir_or_file( @@ -486,27 +494,35 @@ def test_petrel_backend(self, backend, prefix): } # 7. only list files ending with suffix assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix='.txt')) == {'text1.txt', 'text2.txt'} + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt')) == { + 'text1.txt', 'text2.txt' + } assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))) == { + 'text1.txt', 'text2.txt' + } with pytest.raises( TypeError, match='`suffix` must be a string or tuple of strings'): - petrel_backend.client.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + petrel_backend.client.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # 8. only list files ending with suffix recursively assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt', - recursive=True)) == { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), 'text1.txt', - 'text2.txt' - } + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)) == { + '/'.join( + ('dir1', 'text3.txt')), + '/'.join(('dir2', 'dir3', + 'text4.txt')), + 'text1.txt', 'text2.txt' + } # 7. only list files ending with suffix assert set( petrel_backend.list_dir_or_file( @@ -782,11 +798,10 @@ def get(self, filepath): def get_text(self, filepath, encoding='utf-8'): return 'text6' - FileClient.register_backend( - 'example4', - Example6Backend, - force=True, - prefixes='example4_prefix') + FileClient.register_backend('example4', + Example6Backend, + force=True, + prefixes='example4_prefix') example_backend = FileClient('example4') assert example_backend.get(self.img_path) == b'bytes6' assert example_backend.get_text(self.text_path) == 'text6' @@ -830,11 +845,10 @@ def get(self, filepath): def get_text(self, filepath, encoding='utf-8'): return 'text8' - FileClient.register_backend( - 'example6', - Example8Backend, - force=True, - prefixes='example6_prefix') + FileClient.register_backend('example6', + Example8Backend, + force=True, + prefixes='example6_prefix') example_backend = FileClient('example6') assert example_backend.get(self.img_path) == b'bytes8' assert example_backend.get_text(self.text_path) == 'text8' diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index 33a0956fed..13fba651dc 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -152,35 +152,37 @@ def test_list_from_file(): # get list from http filename = 'http://path/of/your/file' - with patch.object( - HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + with patch.object(HTTPBackend, + 'get_text', + return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, file_client_args={'prefix': 'http'}) + filelist = mmengine.list_from_file(filename, + file_client_args={'prefix': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, backend_args={'backend': 'http'}) + filelist = mmengine.list_from_file(filename, + backend_args={'backend': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] # get list from petrel filename = 's3://path/of/your/file' - with patch.object( - PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + with patch.object(PetrelBackend, + 'get_text', + return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'petrel'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, file_client_args={'prefix': 's3'}) + filelist = mmengine.list_from_file(filename, + file_client_args={'prefix': 's3'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, backend_args={'backend': 'petrel'}) + filelist = mmengine.list_from_file(filename, + backend_args={'backend': 'petrel'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] @@ -194,35 +196,36 @@ def test_dict_from_file(): # get dict from http filename = 'http://path/of/your/file' - with patch.object( - HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): - mapping = mmengine.dict_from_file( - filename, file_client_args={'backend': 'http'}) + with patch.object(HTTPBackend, + 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): + mapping = mmengine.dict_from_file(filename, + file_client_args={'backend': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, file_client_args={'prefix': 'http'}) + mapping = mmengine.dict_from_file(filename, + file_client_args={'prefix': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, backend_args={'backend': 'http'}) + mapping = mmengine.dict_from_file(filename, + backend_args={'backend': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} # get dict from petrel filename = 's3://path/of/your/file' - with patch.object( - PetrelBackend, 'get_text', - return_value='1 cat\n2 dog cow\n3 panda'): + with patch.object(PetrelBackend, + 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): mapping = mmengine.dict_from_file( filename, file_client_args={'backend': 'petrel'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, file_client_args={'prefix': 's3'}) + mapping = mmengine.dict_from_file(filename, + file_client_args={'prefix': 's3'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, backend_args={'backend': 'petrel'}) + mapping = mmengine.dict_from_file(filename, + backend_args={'backend': 'petrel'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py index c34af47e0b..5fb4c9b596 100644 --- a/tests/test_fileio/test_io.py +++ b/tests/test_fileio/test_io.py @@ -139,8 +139,9 @@ def test_get_file_backend(): backend_args = {'path_mapping': {'src': 'dst'}, 'enable_mc': True} uri = 'petrel://your_bucket/img.png' - backend4 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend4 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=True) assert isinstance(backend4, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 2 unique_key = 'petrel:{"path_mapping": {"src": "dst"}, "enable_mc": true}' @@ -148,16 +149,18 @@ def test_get_file_backend(): assert backend4 is not backend2 uri = 'petrel://your_bucket/img1.png' - backend5 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend5 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=True) assert isinstance(backend5, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 2 assert backend5 is backend4 assert backend5 is not backend2 backend_args = {'path_mapping': {'src1': 'dst1'}, 'enable_mc': True} - backend6 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend6 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=True) assert isinstance(backend6, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 3 unique_key = 'petrel:{"path_mapping": {"src1": "dst1"}, "enable_mc": true}' @@ -165,8 +168,9 @@ def test_get_file_backend(): assert backend6 is not backend4 assert backend6 is not backend5 - backend7 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=False) + backend7 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=False) assert isinstance(backend7, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 3 assert backend7 is not backend6 @@ -472,8 +476,9 @@ def test_list_dir_or_file(): TypeError, match='`suffix` should be None when `list_dir` is True'): list( - fileio.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt')) + fileio.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt')) # only list directories recursively assert set( @@ -502,34 +507,39 @@ def test_list_dir_or_file(): tmp_dir, list_dir=False, suffix='.txt')) == {'text1.txt', 'text2.txt'} assert set( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))) == { + 'text1.txt', 'text2.txt' + } with pytest.raises( TypeError, match='`suffix` must be a string or tuple of strings'): list( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])) + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg'])) # only list files ending with suffix recursively assert set( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt', recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', - 'text2.txt' - } + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + 'text1.txt', 'text2.txt' + } # only list files ending with suffix assert set( - fileio.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - } + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), + 'text1.txt', 'text2.txt' + } diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d731a42b76..fa95d0b5ce 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -57,9 +57,8 @@ def test_init(self): ValueError, '"file_client_args" and "backend_args" cannot be set ' 'at the same time'): - CheckpointHook( - file_client_args={'backend': 'disk'}, - backend_args={'backend': 'local'}) + CheckpointHook(file_client_args={'backend': 'disk'}, + backend_args={'backend': 'local'}) # Test save best CheckpointHook(save_best='acc') @@ -88,8 +87,9 @@ def test_init(self): hook = CheckpointHook(greater_keys=['acc']) self.assertEqual(hook.greater_keys, ['acc']) - hook = CheckpointHook( - interval=2, by_epoch=False, save_best=['acc', 'mIoU']) + hook = CheckpointHook(interval=2, + by_epoch=False, + save_best=['acc', 'mIoU']) self.assertEqual(hook.key_indicators, ['acc', 'mIoU']) self.assertEqual(hook.rules, ['greater', 'greater']) @@ -123,8 +123,9 @@ def test_before_train(self): self.assertEqual(checkpoint_hook.out_dir, runner.work_dir) # the out_dir of the checkpoint hook is not None - checkpoint_hook = CheckpointHook( - interval=1, by_epoch=True, out_dir='test_dir') + checkpoint_hook = CheckpointHook(interval=1, + by_epoch=True, + out_dir='test_dir') checkpoint_hook.before_train(runner) self.assertEqual(checkpoint_hook.out_dir, osp.join('test_dir', osp.basename(cfg.work_dir))) @@ -162,13 +163,15 @@ def test_after_val_epoch(self): # if metrics is an empty dict, print a warning information with self.assertLogs(runner.logger, level='WARNING'): - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='auto') checkpoint_hook.after_val_epoch(runner, {}) # if save_best is None,no best_ckpt meta should be stored - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best=None) + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best=None) checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, {}) self.assertNotIn('best_score', runner.message_hub.runtime_info) @@ -176,8 +179,9 @@ def test_after_val_epoch(self): # when `save_best` is set to `auto`, first metric will be used. metrics = {'acc': 0.5, 'map': 0.3} - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='auto') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_epoch_9.pth' @@ -186,20 +190,22 @@ def test_after_val_epoch(self): self.assertEqual(checkpoint_hook.key_indicators, ['acc']) self.assertEqual(checkpoint_hook.rules, ['greater']) self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt'), + best_ckpt_path) # # when `save_best` is set to `acc`, it should update greater value - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='acc') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='acc') checkpoint_hook.before_train(runner) metrics['acc'] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 0.8) # # when `save_best` is set to `loss`, it should update less value - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='loss') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='loss') checkpoint_hook.before_train(runner) metrics['loss'] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) @@ -209,8 +215,10 @@ def test_after_val_epoch(self): # when `rule` is set to `less`,then it should update less value # no matter what `save_best` is - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='acc', rule='less') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='acc', + rule='less') checkpoint_hook.before_train(runner) metrics['acc'] = 0.3 checkpoint_hook.after_val_epoch(runner, metrics) @@ -218,22 +226,26 @@ def test_after_val_epoch(self): # # when `rule` is set to `greater`,then it should update greater value # # no matter what `save_best` is - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='loss', rule='greater') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='loss', + rule='greater') checkpoint_hook.before_train(runner) metrics['loss'] = 1.0 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 1.0) # test multi `save_best` with one rule - checkpoint_hook = CheckpointHook( - interval=2, save_best=['acc', 'mIoU'], rule='greater') + checkpoint_hook = CheckpointHook(interval=2, + save_best=['acc', 'mIoU'], + rule='greater') self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU']) self.assertEqual(checkpoint_hook.rules, ['greater', 'greater']) # test multi `save_best` with multi rules - checkpoint_hook = CheckpointHook( - interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) + checkpoint_hook = CheckpointHook(interval=2, + save_best=['FID', 'IS'], + rule=['less', 'greater']) self.assertEqual(checkpoint_hook.key_indicators, ['FID', 'IS']) self.assertEqual(checkpoint_hook.rules, ['less', 'greater']) @@ -254,10 +266,10 @@ def test_after_val_epoch(self): checkpoint_hook.out_dir, best_mIoU_name) self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_acc'), + best_acc_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_mIoU'), + best_mIoU_path) # test behavior when by_epoch is False cfg = copy.deepcopy(self.iter_based_cfg) @@ -266,8 +278,10 @@ def test_after_val_epoch(self): # check best ckpt name and best score metrics = {'acc': 0.5, 'map': 0.3} - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, save_best='acc', rule='greater') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=False, + save_best='acc', + rule='greater') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(checkpoint_hook.key_indicators, ['acc']) @@ -276,8 +290,8 @@ def test_after_val_epoch(self): best_ckpt_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_ckpt_name) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt'), + best_ckpt_path) self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) # check best score updating @@ -286,13 +300,14 @@ def test_after_val_epoch(self): best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_ckpt_name) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt'), + best_ckpt_path) self.assertEqual(runner.message_hub.get_info('best_score'), 0.666) # check best checkpoint name with `by_epoch` is False - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, save_best=['acc', 'mIoU']) + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=False, + save_best=['acc', 'mIoU']) checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) checkpoint_hook.after_val_epoch(runner, metrics) @@ -305,10 +320,10 @@ def test_after_val_epoch(self): self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_acc'), + best_acc_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_mIoU'), + best_mIoU_path) # after_val_epoch should not save last_checkpoint self.assertFalse( @@ -321,8 +336,9 @@ def test_after_val_epoch(self): self.clear_work_dir() cfg = copy.deepcopy(cfg) runner = self.build_runner(cfg) - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=by_epoch, save_best='acc') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=by_epoch, + save_best='acc') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) all_files = os.listdir(runner.work_dir) @@ -373,9 +389,8 @@ def test_after_train_epoch(self): checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) self.assertEqual((runner.epoch + 1) % 2, 0) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'epoch_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'epoch_10.pth')) last_ckpt_path = osp.join(cfg.work_dir, 'last_checkpoint') self.assertTrue(osp.isfile(last_ckpt_path)) @@ -387,9 +402,8 @@ def test_after_train_epoch(self): # epoch can not be evenly divided by 2 runner.train_loop._epoch = 10 checkpoint_hook.after_train_epoch(runner) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'epoch_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'epoch_10.pth')) runner.message_hub.runtime_info.clear() # by epoch is False @@ -416,25 +430,22 @@ def test_after_train_iter(self): checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=9) self.assertIn('last_ckpt', runner.message_hub.runtime_info) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'iter_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'iter_10.pth')) # epoch can not be evenly divided by 2 runner.train_loop._iter = 10 checkpoint_hook.after_train_epoch(runner) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'iter_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'iter_10.pth')) @parameterized.expand([['iter'], ['epoch']]) def test_with_runner(self, training_type): common_cfg = getattr(self, f'{training_type}_based_cfg') setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) - checkpoint_cfg = dict( - type='CheckpointHook', - interval=1, - by_epoch=training_type == 'epoch') + checkpoint_cfg = dict(type='CheckpointHook', + interval=1, + by_epoch=training_type == 'epoch') common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg) # Test interval in epoch based training @@ -458,34 +469,41 @@ def test_with_runner(self, training_type): cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('optimizer', ckpt) cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('optimizer', ckpt) # Test save_param_scheduler=False cfg = copy.deepcopy(common_cfg) cfg.param_scheduler = [ - dict( - type='LinearLR', - start_factor=0.1, - begin=0, - end=500, - by_epoch=training_type == 'epoch') + dict(type='LinearLR', + start_factor=0.1, + begin=0, + end=500, + by_epoch=training_type == 'epoch') ] runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('param_schedulers', ckpt) cfg.default_hooks.checkpoint.save_param_scheduler = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('param_schedulers', ckpt) self.clear_work_dir() @@ -533,7 +551,9 @@ def test_with_runner(self, training_type): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'], [9, 10, 11]) @@ -574,9 +594,11 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) @@ -603,11 +625,13 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) # if the current ckpt is the best, the interval will be ignored the # the ckpt will also be saved - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) diff --git a/tests/test_hooks/test_early_stopping_hook.py b/tests/test_hooks/test_early_stopping_hook.py index 16f8fd981c..08fe4cbac5 100644 --- a/tests/test_hooks/test_early_stopping_hook.py +++ b/tests/test_hooks/test_early_stopping_hook.py @@ -149,8 +149,9 @@ def test_after_val_epoch(self): # if `monitor` does not match and strict=True, crash the training. with self.assertRaises(RuntimeError): metrics = {'accuracy/top1': 0.5, 'loss': 0.23} - hook = EarlyStoppingHook( - monitor='acc', rule='greater', strict=True) + hook = EarlyStoppingHook(monitor='acc', + rule='greater', + strict=True) hook.after_val_epoch(runner, metrics) # Check largest value @@ -176,8 +177,9 @@ def test_after_val_epoch(self): # Check stop training runner = get_mock_runner() metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + min_delta=1) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -187,8 +189,9 @@ def test_after_val_epoch(self): # Check finite runner = get_mock_runner() metrics = [{'accuracy/top1': math.inf} for i in range(5)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + min_delta=1) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -198,8 +201,10 @@ def test_after_val_epoch(self): # Check patience runner = get_mock_runner() metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1, patience=10) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + min_delta=1, + patience=10) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -209,11 +214,10 @@ def test_after_val_epoch(self): # Check stopping_threshold runner = get_mock_runner() metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', - rule='greater', - stopping_threshold=98.5, - patience=0) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + stopping_threshold=98.5, + patience=0) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -230,26 +234,27 @@ def test_with_runner(self): min_delta=1, patience=3, ) - runner = Runner( - model=ToyModel(), - work_dir=work_dir, - train_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - val_evaluator=dict(type=DummyMetric, length=max_epoch), - optim_wrapper=OptimWrapper( - torch.optim.Adam(ToyModel().parameters())), - train_cfg=dict( - by_epoch=True, max_epochs=max_epoch, val_interval=1), - val_cfg=dict(), - custom_hooks=[early_stop_cfg], - experiment_name='earlystop_test') + runner = Runner(model=ToyModel(), + work_dir=work_dir, + train_dataloader=dict(dataset=DummyDataset(), + sampler=dict( + type='DefaultSampler', + shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict(dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=dict(type=DummyMetric, length=max_epoch), + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict(by_epoch=True, + max_epochs=max_epoch, + val_interval=1), + val_cfg=dict(), + custom_hooks=[early_stop_cfg], + experiment_name='earlystop_test') runner.train() self.assertEqual(runner.epoch, 6) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 6dad7ba4f0..3c54d40ae2 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -208,9 +208,9 @@ def test_after_load_checkpoint(self): # Check the weight of state_dict and ema_state_dict have been swapped. # when runner._resume is True runner._resume = True - checkpoint = dict( - state_dict=ToyModel().state_dict(), - ema_state_dict=ExponentialMovingAverage(ToyModel()).state_dict()) + checkpoint = dict(state_dict=ToyModel().state_dict(), + ema_state_dict=ExponentialMovingAverage( + ToyModel()).state_dict()) ori_checkpoint = copy.deepcopy(checkpoint) ema_hook.after_load_checkpoint(runner, checkpoint) for key in ori_checkpoint['state_dict'].keys(): @@ -230,7 +230,8 @@ def test_with_runner(self): self.assertTrue( isinstance(ema_hook.ema_model, ExponentialMovingAverage)) - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), + weights_only=False) self.assertTrue('ema_state_dict' in checkpoint) self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) @@ -245,7 +246,8 @@ def test_with_runner(self): runner.test() # Test load checkpoint without ema_state_dict - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), + weights_only=False) checkpoint.pop('ema_state_dict') torch.save(checkpoint, osp.join(self.temp_dir.name, 'without_ema_state_dict.pth')) @@ -273,8 +275,9 @@ def test_with_runner(self): cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_epoch=5)] runner = self.build_runner(cfg) runner.train() - state_dict = torch.load( - osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') + state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) @@ -286,13 +289,15 @@ def test_with_runner(self): cfg.default_hooks.checkpoint.interval = 1 runner = self.build_runner(cfg) runner.train() - state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') + state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) - state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') + state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) def _test_swap_parameters(self, func_name, *args, **kwargs): diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index d30972d360..02b0e9970e 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy from unittest.mock import patch import pytest @@ -9,11 +10,11 @@ class TestEmptyCacheHook(RunnerTestCase): - @pytest.mark.skipif( - not is_cuda_available(), reason='cuda should be available') + @pytest.mark.skipif(not is_cuda_available(), + reason='cuda should be available') def test_with_runner(self): with patch('torch.cuda.empty_cache') as mock_empty_cache: - cfg = self.epoch_based_cfg + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [dict(type='EmptyCacheHook')] cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501 runner = self.build_runner(cfg) @@ -24,12 +25,14 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after each epoch: # runner.train: `max_epochs` times. + # runner.val: last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501 # runner.val: `1` time. # runner.test: `1` time. - target_called_times = runner.max_epochs + 2 + target_called_times = runner.max_epochs + 3 self.assertEqual(mock_empty_cache.call_count, target_called_times) - + # with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)] runner = self.build_runner(cfg) @@ -39,16 +42,17 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after/before each epoch: # runner.train: `max_epochs*2` times. - # runner.val: `1*2` times. + # runner.val: (max_epochs + 1)*2 times, last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501 # runner.test: `1*2` times. - target_called_times = runner.max_epochs * 2 + 4 + target_called_times = runner.max_epochs * 2 + (runner.max_epochs + + 1) * 2 + 1 * 2 self.assertEqual(mock_empty_cache.call_count, target_called_times) with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [ - dict( - type='EmptyCacheHook', after_iter=True, before_epoch=True) + dict(type='EmptyCacheHook', after_iter=True, before_epoch=True) ] runner = self.build_runner(cfg) @@ -58,13 +62,13 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after/before each epoch, # after each iteration: - # runner.train: `max_epochs*2 + len(dataloader)*max_epochs` times. # noqa: E501 - # runner.val: `1*2 + len(val_dataloader)` times. - # runner.test: `1*2 + len(val_dataloader)` times. + # runner.train: max_epochs * (2 + len(train_dataloader)) times. + # runner.val: (max_epochs + 1(interval) + 1(last)) * (2 + len(val_dataloader)) times # noqa: E501 + # runner.test: 1 * (2 + len(test_dataloader)) times target_called_times = \ - runner.max_epochs * 2 + 4 + \ - len(runner.train_dataloader) * runner.max_epochs + \ - len(runner.val_dataloader) + \ - len(runner.test_dataloader) + runner.max_epochs * (2 + len(runner.train_dataloader)) + \ + (runner.max_epochs + 1) * (2 + len(runner.val_dataloader)) + \ + 1 * (2 + len(runner.test_dataloader)) + self.assertEqual(mock_empty_cache.call_count, target_called_times) diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index 52b8bc1fa3..925226b98a 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -49,17 +49,15 @@ def test_init(self): # test deprecated warning raised by `file_client_args` logger = MMLogger.get_current_instance() with self.assertLogs(logger, level='WARNING'): - LoggerHook( - out_dir=self.temp_dir.name, - file_client_args=dict(backend='disk')) + LoggerHook(out_dir=self.temp_dir.name, + file_client_args=dict(backend='disk')) with self.assertRaisesRegex( ValueError, '"file_client_args" and "backend_args" cannot be '): - LoggerHook( - out_dir=self.temp_dir.name, - file_client_args=dict(enable_mc=True), - backend_args=dict(enable_mc=True)) + LoggerHook(out_dir=self.temp_dir.name, + file_client_args=dict(enable_mc=True), + backend_args=dict(enable_mc=True)) def test_after_train_iter(self): # Test LoggerHook by iter. @@ -138,8 +136,8 @@ def test_after_val_epoch(self): 'acc': 0.8 }, **args), ] - self.assertEqual( - len(calls), len(runner.visualizer.add_scalars.mock_calls)) + self.assertEqual(len(calls), + len(runner.visualizer.add_scalars.mock_calls)) runner.visualizer.add_scalars.assert_has_calls(calls) # Test when `log_metric_by_epoch` is False @@ -165,8 +163,8 @@ def test_after_val_epoch(self): 'acc': 0.5 }, **args), ] - self.assertEqual( - len(calls), len(runner.visualizer.add_scalars.mock_calls)) + self.assertEqual(len(calls), + len(runner.visualizer.add_scalars.mock_calls)) runner.visualizer.add_scalars.assert_has_calls(calls) def test_after_test_epoch(self): @@ -174,10 +172,9 @@ def test_after_test_epoch(self): runner = MagicMock() runner.log_dir = self.temp_dir.name runner.timestamp = 'test_after_test_epoch' - runner.log_processor.get_log_after_epoch = MagicMock( - return_value=( - dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])), - 'log_str')) + runner.log_processor.get_log_after_epoch = MagicMock(return_value=( + dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])), + 'log_str')) logger_hook.before_run(runner) logger_hook.after_test_epoch(runner) runner.log_processor.get_log_after_epoch.assert_called() @@ -232,8 +229,9 @@ def test_with_runner(self): shutil.rmtree(osp.join(out_dir, filename)) # Test out_suffix - cfg.default_hooks.logger = dict( - type='LoggerHook', out_dir=out_dir, out_suffix='.log') + cfg.default_hooks.logger = dict(type='LoggerHook', + out_dir=out_dir, + out_suffix='.log') runner = self.build_runner(cfg) runner.train() filenames = scandir(out_dir, recursive=True) @@ -241,8 +239,9 @@ def test_with_runner(self): all(filename.endswith('.log') for filename in filenames)) # Test keep_local=False - cfg.default_hooks.logger = dict( - type='LoggerHook', out_dir=out_dir, keep_local=False) + cfg.default_hooks.logger = dict(type='LoggerHook', + out_dir=out_dir, + keep_local=False) runner = self.build_runner(cfg) runner.train() filenames = scandir(runner._log_dir, recursive=True) diff --git a/tests/test_hooks/test_naive_visualization_hook.py b/tests/test_hooks/test_naive_visualization_hook.py index 2e39e94527..2345dcbc0d 100644 --- a/tests/test_hooks/test_naive_visualization_hook.py +++ b/tests/test_hooks/test_naive_visualization_hook.py @@ -16,47 +16,40 @@ def test_after_train_iter(self): inputs = torch.randn(1, 3, 15, 15) batch_idx = 10 # test with normalize, resize, pad - gt_datasamples = BaseDataElement( - metainfo=dict( - img_norm_cfg=dict( - mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(img_norm_cfg=dict( + mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with resize, pad - gt_datasamples = BaseDataElement( - metainfo=dict( - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only resize - gt_datasamples = BaseDataElement( - metainfo=dict( - scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict( + scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only pad - gt_datasamples = BaseDataElement( - metainfo=dict( - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, diff --git a/tests/test_hooks/test_prepare_tta_hook.py b/tests/test_hooks/test_prepare_tta_hook.py index a356164ef6..0de30d788a 100644 --- a/tests/test_hooks/test_prepare_tta_hook.py +++ b/tests/test_hooks/test_prepare_tta_hook.py @@ -82,9 +82,8 @@ def test_before_test(self): # Test with epoch based runner. cfg = copy.deepcopy(self.epoch_based_cfg) cfg.custom_hooks.append( - dict( - type='PrepareTTAHook', - tta_cfg=dict(type='ToyTestTimeAugModel'))) + dict(type='PrepareTTAHook', + tta_cfg=dict(type='ToyTestTimeAugModel'))) cfg.model = dict(type='ToyModel') cfg.test_dataloader.dataset = dict( type='ToyDatasetTTA', pipeline=dict(type='ToyTTAPipeline')) @@ -96,9 +95,8 @@ def test_before_test(self): # Test with iteration based runner cfg = copy.deepcopy(self.iter_based_cfg) cfg.custom_hooks.append( - dict( - type='PrepareTTAHook', - tta_cfg=dict(type='ToyTestTimeAugModel'))) + dict(type='PrepareTTAHook', + tta_cfg=dict(type='ToyTestTimeAugModel'))) cfg.model = dict(type='ToyModel') cfg.test_dataloader.dataset = dict( type='ToyDatasetTTA', pipeline=dict(type='ToyTTAPipeline')) diff --git a/tests/test_hooks/test_profiler_hook.py b/tests/test_hooks/test_profiler_hook.py index 2db6df01b6..8021664bbd 100644 --- a/tests/test_hooks/test_profiler_hook.py +++ b/tests/test_hooks/test_profiler_hook.py @@ -52,13 +52,13 @@ def deal_profile(_profile): hook.on_trace_ready = dict(type='unknown') hook._parse_trace_config(runner) - hook.on_trace_ready = dict( - type='log_trace', sort_by='self_cpu_time_total', row_limit=10) + hook.on_trace_ready = dict(type='log_trace', + sort_by='self_cpu_time_total', + row_limit=10) hook._parse_trace_config(runner) - @unittest.skipIf( - not is_installed('torch-tb-profiler'), - reason='required torch-tb-profiler') + @unittest.skipIf(not is_installed('torch-tb-profiler'), + reason='required torch-tb-profiler') def test_parse_trace_config_tensorboard(self): # Test on_trace_ready_args runner = MagicMock() @@ -76,16 +76,15 @@ def test_parse_trace_config_tensorboard(self): hook._parse_trace_config(runner) # with self.assertWarns(DeprecationWarning): - hook = ProfilerHook( - on_trace_ready=dict(type='tb_trace'), - json_trace_path=ops.join(self.temp_dir.name, 'demo.json')) + hook = ProfilerHook(on_trace_ready=dict(type='tb_trace'), + json_trace_path=ops.join(self.temp_dir.name, + 'demo.json')) hook._parse_trace_config(runner) self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - on_trace_ready=dict( - type='tb_trace', dir_name=self.temp_dir.name)) + dict(type='ProfilerHook', + on_trace_ready=dict(type='tb_trace', + dir_name=self.temp_dir.name)) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -148,19 +147,18 @@ def test_after_train_iter(self): hook.profiler.__exit__.assert_called_once() hook.profiler.step.assert_called_once() - hook = ProfilerHook( - by_epoch=False, - schedule=dict(wait=1, warmup=1, active=3, repeat=1)) + hook = ProfilerHook(by_epoch=False, + schedule=dict(wait=1, warmup=1, active=3, + repeat=1)) hook.profiler = MagicMock() hook.after_train_iter(runner, 1, 1, 1) hook.profiler.step.assert_called_once() def test_with_runner(self): self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - activity_with_cpu=False, - activity_with_cuda=False) + dict(type='ProfilerHook', + activity_with_cpu=False, + activity_with_cuda=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -171,16 +169,14 @@ def test_with_runner(self): ] runner = self.build_runner(self.epoch_based_cfg) runner.train() - self.assertTrue( - ops.exists(json_path), 'ERROR::json file is not generated!') + self.assertTrue(ops.exists(json_path), + 'ERROR::json file is not generated!') self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - on_trace_ready=dict( - type='log_trace', - sort_by='self_cpu_time_total', - row_limit=10)) + dict(type='ProfilerHook', + on_trace_ready=dict(type='log_trace', + sort_by='self_cpu_time_total', + row_limit=10)) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -200,8 +196,8 @@ def test_with_runner(self): runner.train() -@unittest.skipIf( - not is_npu_available(), reason='Ascend PyTorch and npu devices not exist') +@unittest.skipIf(not is_npu_available(), + reason='Ascend PyTorch and npu devices not exist') class TestNPUProfilerHook(RunnerTestCase): def test_init(self): @@ -243,27 +239,25 @@ def test_after_train_iter(self): def test_with_runner(self): result_path = ops.join(self.temp_dir.name, 'test/cann_profiling') self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='NPUProfilerHook', - begin=0, - result_path=result_path, - exit_after_profiling=False) + dict(type='NPUProfilerHook', + begin=0, + result_path=result_path, + exit_after_profiling=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='NPUProfilerHook', - result_path=result_path, - ge_profiling_to_std_out=True, - exit_after_profiling=False) + dict(type='NPUProfilerHook', + result_path=result_path, + ge_profiling_to_std_out=True, + exit_after_profiling=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() - self.assertTrue( - ops.exists(result_path), 'profiler result path is not generated!') + self.assertTrue(ops.exists(result_path), + 'profiler result path is not generated!') self.assertTrue( os.getenv('GE_PROFILING_TO_STD_OUT', '0') == '1', diff --git a/tests/test_hooks/test_runtime_info_hook.py b/tests/test_hooks/test_runtime_info_hook.py index c7e7a3c339..5f15f7ddd8 100644 --- a/tests/test_hooks/test_runtime_info_hook.py +++ b/tests/test_hooks/test_runtime_info_hook.py @@ -95,8 +95,8 @@ def test_before_train_iter(self): optim2 = SGD(model.layer2.parameters(), lr=0.02) optim_wrapper1 = OptimWrapper(optim1) optim_wrapper2 = OptimWrapper(optim2) - optim_wrapper_dict = OptimWrapperDict( - key1=optim_wrapper1, key2=optim_wrapper2) + optim_wrapper_dict = OptimWrapperDict(key1=optim_wrapper1, + key2=optim_wrapper2) runner.optim_wrapper = optim_wrapper_dict hook.before_train_iter(runner, batch_idx=2, data_batch=None) self.assertEqual( @@ -108,8 +108,10 @@ def test_after_train_iter(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) - hook.after_train_iter( - runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111}) + hook.after_train_iter(runner, + batch_idx=2, + data_batch=None, + outputs={'loss_cls': 1.111}) self.assertEqual( runner.message_hub.get_scalar('train/loss_cls').current(), 1.111) @@ -167,14 +169,13 @@ def test_scalar_check(self): # check other scalar dtypes val = np.mean([5]) # this is not ndarray but dtype is np.float64. - hook.after_val_epoch( - runner, - metrics={ - 'acc_f32': val.astype(np.float32), - 'acc_i32': val.astype(np.int32), - 'acc_u8': val.astype(np.uint8), - 'acc_ndarray': np.array([5]), - }) + hook.after_val_epoch(runner, + metrics={ + 'acc_f32': val.astype(np.float32), + 'acc_i32': val.astype(np.int32), + 'acc_u8': val.astype(np.uint8), + 'acc_ndarray': np.array([5]), + }) self.assertEqual( runner.message_hub.get_scalar('val/acc_f32').current(), 5) self.assertEqual( @@ -185,13 +186,12 @@ def test_scalar_check(self): runner.message_hub.get_scalar('val/acc_ndarray').current(), 5) val = torch.tensor([5.0]).mean() - hook.after_val_epoch( - runner, - metrics={ - 'acc_f32': val.float(), - 'acc_i64': val.long(), - 'acc_tensor': torch.tensor([5]), - }) + hook.after_val_epoch(runner, + metrics={ + 'acc_f32': val.float(), + 'acc_i64': val.long(), + 'acc_tensor': torch.tensor([5]), + }) self.assertEqual( runner.message_hub.get_scalar('val/acc_f32').current(), 5) self.assertEqual( diff --git a/tests/test_hooks/test_sync_buffers_hook.py b/tests/test_hooks/test_sync_buffers_hook.py index 6d4019dc58..71db44e38a 100644 --- a/tests/test_hooks/test_sync_buffers_hook.py +++ b/tests/test_hooks/test_sync_buffers_hook.py @@ -1,15 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os from unittest.mock import MagicMock import torch import torch.distributed as torch_dist import torch.nn as nn +from torch.testing._internal.common_distributed import DistributedTestBase from mmengine.dist import all_gather from mmengine.hooks import SyncBuffersHook from mmengine.registry import MODELS -from mmengine.testing._internal import MultiProcessTestCase from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel @@ -23,22 +22,14 @@ def __init__(self, data_preprocessor=None): def init_weights(self): for buffer in self.buffers(): buffer.fill_( - torch.tensor(int(os.environ['RANK']), dtype=torch.float32)) + torch.tensor(torch_dist.get_rank(), dtype=torch.float32)) return super().init_weights() -class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase): - - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - def prepare_subprocess(self): - MODELS.register_module(module=ToyModuleWithNorm, force=True) - super(MultiProcessTestCase, self).setUp() +class TestSyncBuffersHook(DistributedTestBase, RunnerTestCase): def test_sync_buffers_hook(self): - self.setup_dist_env() + self.create_pg('cuda') runner = MagicMock() runner.model = ToyModuleWithNorm() runner.model.init_weights() @@ -53,9 +44,12 @@ def test_sync_buffers_hook(self): for buffer in runner.model.buffers(): buffer1, buffer2 = all_gather(buffer) self.assertTrue(torch.allclose(buffer1, buffer2)) + torch_dist.destroy_process_group() def test_with_runner(self): - self.setup_dist_env() + MODELS.register_module(module=ToyModuleWithNorm, force=True) + self.create_pg('cuda') + RunnerTestCase.setUp(self) cfg = self.epoch_based_cfg cfg.model = dict(type='ToyModuleWithNorm') cfg.launch = 'pytorch' @@ -67,8 +61,6 @@ def test_with_runner(self): buffer1, buffer2 = all_gather(buffer) self.assertTrue(torch.allclose(buffer1, buffer2)) - def setup_dist_env(self): - super().setup_dist_env() - os.environ['RANK'] = str(self.rank) - torch_dist.init_process_group( - backend='gloo', rank=self.rank, world_size=self.world_size) + @property + def world_size(self) -> int: + return 2 diff --git a/tests/test_hub/test_hub.py b/tests/test_hub/test_hub.py index ae21d3dab4..5dd951e478 100644 --- a/tests/test_hub/test_hub.py +++ b/tests/test_hub/test_hub.py @@ -12,9 +12,8 @@ # mmdet has a more typical config structure, while mmpose has a complex # config structure -@pytest.mark.skipif( - not (is_installed('mmdet') and is_installed('mmpose')), - reason='mmdet and mmpose should be installed') +@pytest.mark.skipif(not (is_installed('mmdet') and is_installed('mmpose')), + reason='mmdet and mmpose should be installed') def test_get_config(): # Test load base config. base_cfg = get_config('mmdet::_base_/models/faster-rcnn_r50_fpn.py') @@ -32,8 +31,8 @@ def test_get_config(): assert cfg._cfg_dict == test_cfg._cfg_dict # Test pretrained - cfg = get_config( - 'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) + cfg = get_config('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', + pretrained=True) assert cfg.model_path == 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # noqa E301 # Test load mmpose @@ -42,8 +41,8 @@ def test_get_config(): ) -@pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet and mmpose should be installed') +@pytest.mark.skipif(not is_installed('mmdet'), + reason='mmdet and mmpose should be installed') def test_get_model(): # TODO compatible with downstream codebase. DefaultScope.get_instance('test_get_model', scope_name='test_scope') diff --git a/tests/test_infer/test_infer.py b/tests/test_infer/test_infer.py index 2d020b6300..c0142c98a5 100644 --- a/tests/test_infer/test_infer.py +++ b/tests/test_infer/test_infer.py @@ -133,8 +133,8 @@ def test_call(self): inferencer(imgs) inferencer(img_paths) - @pytest.mark.skipif( - not is_imported('mmdet'), reason='mmdet is not installed') + @pytest.mark.skipif(not is_imported('mmdet'), + reason='mmdet is not installed') def test_load_model_from_meta(self): from mmdet.utils import register_all_modules @@ -154,8 +154,8 @@ def test_get_chunk_data(self): inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) data = list(range(1, 11)) chunk_data = inferencer._get_chunk_data(data, 3) - self.assertEqual( - list(chunk_data), [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]) + self.assertEqual(list(chunk_data), + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]) def test_init_visualizer(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -173,11 +173,10 @@ def test_init_visualizer(self): def test_dispatch_kwargs(self): inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) - kwargs = dict( - pre_arg=dict(a=1), - for_arg=dict(c=2), - vis_arg=dict(b=3), - pos_arg=dict(d=4)) + kwargs = dict(pre_arg=dict(a=1), + for_arg=dict(c=2), + vis_arg=dict(b=3), + pos_arg=dict(d=4)) pre_arg, for_arg, vis_arg, pos_arg = inferencer._dispatch_kwargs( **kwargs) self.assertEqual(pre_arg, dict(pre_arg=dict(a=1))) @@ -217,8 +216,8 @@ def test_preprocess(self): for data in dataloader: self.assertTrue(is_list_of(data, torch.Tensor)) - @pytest.mark.skipif( - not is_imported('mmdet'), reason='mmdet is not installed') + @pytest.mark.skipif(not is_imported('mmdet'), + reason='mmdet is not installed') def test_list_models(self): model_list = BaseInferencer.list_models('mmdet') self.assertTrue(len(model_list) > 0) diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py index 2ac2b3548e..2826c349e1 100644 --- a/tests/test_logging/test_logger.py +++ b/tests/test_logging/test_logger.py @@ -34,16 +34,18 @@ def test_init_rank0(self, tmp_path): # If `rank=0`, the `log_level` of stream_handler and file_handler # depends on the given arguments. tmp_file = tmp_path / 'tmp_file.log' - logger = MMLogger.get_instance( - 'rank0.pkg2', log_level='INFO', log_file=str(tmp_file)) + logger = MMLogger.get_instance('rank0.pkg2', + log_level='INFO', + log_file=str(tmp_file)) assert isinstance(logger, logging.Logger) assert len(logger.handlers) == 2 assert isinstance(logger.handlers[0], logging.StreamHandler) assert isinstance(logger.handlers[1], logging.FileHandler) logger_pkg3 = MMLogger.get_instance('rank0.pkg2') assert id(logger_pkg3) == id(logger) - logger = MMLogger.get_instance( - 'rank0.pkg3', logger_name='logger_test', log_level='INFO') + logger = MMLogger.get_instance('rank0.pkg3', + logger_name='logger_test', + log_level='INFO') assert logger.name == 'logger_test' assert logger.instance_name == 'rank0.pkg3' # `FileHandler` should be closed in Windows, otherwise we cannot @@ -59,14 +61,14 @@ def test_init_rank1(self, tmp_path): # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. tmp_file = tmp_path / 'tmp_file.log' log_path = tmp_path / 'tmp_file_test_device1_rank1.log' - logger = MMLogger.get_instance( - 'rank1.pkg2', log_level='INFO', log_file=str(tmp_file)) + logger = MMLogger.get_instance('rank1.pkg2', + log_level='INFO', + log_file=str(tmp_file)) assert len(logger.handlers) == 1 - logger = MMLogger.get_instance( - 'rank1.pkg3', - log_level='INFO', - log_file=str(tmp_file), - distributed=True) + logger = MMLogger.get_instance('rank1.pkg3', + log_level='INFO', + log_file=str(tmp_file), + distributed=True) assert logger.handlers[0].level == logging.ERROR assert logger.handlers[1].level == logging.INFO assert len(logger.handlers) == 2 @@ -94,8 +96,9 @@ def test_handler(self, capsys, tmp_path, log_level): # test file_handler output plain text without color. tmp_file = tmp_path / 'tmp_file.log' instance_name = f'test_file_{log_level}' - logger = MMLogger.get_instance( - instance_name, log_level=log_level, log_file=tmp_file) + logger = MMLogger.get_instance(instance_name, + log_level=log_level, + log_file=tmp_file) logger.log(level=log_level, msg='welcome') with open(tmp_file) as f: @@ -209,27 +212,32 @@ def test_filter(self, capsys): def test_file_handlers(self, tmp_path): tmp_file = tmp_path / 'tmp_file.log' fh = None - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.FileHandler) fh = dict(type='BaseRotatingHandler', mode='a') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.BaseRotatingHandler) fh = dict(type='RotatingFileHandler', maxBytes=1024) - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.RotatingFileHandler) fh = dict(type='TimedRotatingFileHandler', when='MIDNIGHT') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.TimedRotatingFileHandler) fh = dict(type='WatchedFileHandler') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.WatchedFileHandler) # `FileHandler` should be closed in Windows, otherwise we cannot diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index 3dc5cef748..b82211ea2d 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -27,10 +27,9 @@ def test_init(self): MessageHub('hello', log_scalars=OrderedDict(a=1)) # `Resumed_keys` with pytest.raises(AssertionError): - MessageHub( - 'hello', - runtime_info=OrderedDict(iter=1), - resumed_keys=OrderedDict(iters=False)) + MessageHub('hello', + runtime_info=OrderedDict(iter=1), + resumed_keys=OrderedDict(iters=False)) def test_update_scalar(self): message_hub = MessageHub.get_instance('mmengine') @@ -99,11 +98,10 @@ def test_get_runtime(self): def test_get_scalars(self): import torch message_hub = MessageHub.get_instance('mmengine') - log_dict = dict( - loss=1, - loss_cls=torch.tensor(2), - loss_bbox=np.array(3), - loss_iou=dict(value=1, count=2)) + log_dict = dict(loss=1, + loss_cls=torch.tensor(2), + loss_bbox=np.array(3), + loss_iou=dict(value=1, count=2)) message_hub.update_scalars(log_dict) loss = message_hub.get_scalar('loss') loss_cls = message_hub.get_scalar('loss_cls') @@ -169,8 +167,11 @@ def test_load_state_dict(self, capsys): state_dict = OrderedDict() state_dict['log_scalars'] = dict(a=1, b=HistoryBuffer()) state_dict['runtime_info'] = dict(c=1, d=NoDeepCopy(), e=1) - state_dict['resumed_keys'] = dict( - a=True, b=True, c=True, e=False, f=True) + state_dict['resumed_keys'] = dict(a=True, + b=True, + c=True, + e=False, + f=True) message_hub4 = MessageHub.get_instance('test_load_state_dict4') message_hub4.load_state_dict(state_dict) @@ -179,8 +180,9 @@ def test_load_state_dict(self, capsys): assert 'c' in message_hub4.runtime_info and \ state_dict['runtime_info']['d'] is \ message_hub4.runtime_info['d'] - assert message_hub4._resumed_keys == OrderedDict( - b=True, c=True, e=False) + assert message_hub4._resumed_keys == OrderedDict(b=True, + c=True, + e=False) def test_getstate(self): message_hub = MessageHub.get_instance('name') diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py index 6438b8bde5..f9d3d38ca0 100644 --- a/tests/test_model/test_averaged_model.py +++ b/tests/test_model/test_averaged_model.py @@ -18,9 +18,8 @@ class TestAveragedModel(TestCase): """ # noqa: E501 def _test_swa_model(self, net_device, avg_device): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.Linear(5, 10)).to(net_device) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)).to(net_device) averaged_model = StochasticWeightAverage(model, device=avg_device) averaged_params = [ @@ -52,8 +51,8 @@ def test_averaged_model_all_devices(self): def test_swa_mixed_device(self): if not torch.cuda.is_available(): return - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) model[0].cuda() model[1].cpu() averaged_model = StochasticWeightAverage(model) @@ -73,8 +72,8 @@ def test_swa_mixed_device(self): self.assertTrue(p_avg.device == p_swa.device) def test_swa_state_dict(self): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) averaged_model = StochasticWeightAverage(model) averaged_model2 = StochasticWeightAverage(model) n_updates = 10 @@ -92,19 +91,19 @@ def test_ema(self): # test invalid momentum with self.assertRaisesRegex(AssertionError, 'momentum must be in range'): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) ExponentialMovingAverage(model, momentum=3) # Warning should be raised if the value of momentum in EMA is # a large number with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) ExponentialMovingAverage(model, momentum=0.9) # test EMA - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) momentum = 0.1 ema_model = ExponentialMovingAverage(model, momentum=momentum) @@ -129,13 +128,14 @@ def test_ema(self): def test_ema_update_buffers(self): # Test EMA and update_buffers as True. - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10)) momentum = 0.1 - ema_model = ExponentialMovingAverage( - model, momentum=momentum, update_buffers=True) + ema_model = ExponentialMovingAverage(model, + momentum=momentum, + update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) @@ -168,9 +168,9 @@ def test_ema_update_buffers(self): assert_allclose(p_target, p_ema) def test_momentum_annealing_ema(self): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10)) # Test invalid gamma with self.assertRaisesRegex(AssertionError, 'gamma must be greater than 0'): @@ -180,8 +180,10 @@ def test_momentum_annealing_ema(self): momentum = 0.1 gamma = 4 - ema_model = MomentumAnnealingEMA( - model, gamma=gamma, momentum=momentum, update_buffers=True) + ema_model = MomentumAnnealingEMA(model, + gamma=gamma, + momentum=momentum, + update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) @@ -216,19 +218,18 @@ def test_momentum_annealing_ema(self): def test_momentum_annealing_ema_with_interval(self): # Test EMA with momentum annealing and interval - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10)) momentum = 0.1 gamma = 4 interval = 3 - ema_model = MomentumAnnealingEMA( - model, - gamma=gamma, - momentum=momentum, - interval=interval, - update_buffers=True) + ema_model = MomentumAnnealingEMA(model, + gamma=gamma, + momentum=momentum, + interval=interval, + update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 8dc23eec86..484c95ec71 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -94,10 +94,9 @@ def test_parse_losses(self): ] losses = dict(loss_cls=loss_cls, loss_list=loss_list) target_parsed_losses = torch.tensor(6, dtype=torch.float32) - targe_log_vars = dict( - loss=torch.tensor(6, dtype=torch.float32), - loss_cls=torch.tensor(1, dtype=torch.float32), - loss_list=torch.tensor(5, dtype=torch.float32)) + targe_log_vars = dict(loss=torch.tensor(6, dtype=torch.float32), + loss_cls=torch.tensor(1, dtype=torch.float32), + loss_list=torch.tensor(5, dtype=torch.float32)) parse_losses, log_vars = model.parse_losses(losses) assert_allclose(parse_losses, target_parsed_losses) for key in log_vars: diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index c409260a50..e429db032c 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -97,12 +97,11 @@ def test_init(self): assert_allclose(data_processor.pad_value, torch.tensor(0)) # Initiate model with bgr2rgb, mean, std .etc.. - data_processor = ImgDataPreprocessor( - bgr_to_rgb=True, - mean=[0, 0, 0], - std=[255, 255, 255], - pad_size_divisor=16, - pad_value=10) + data_processor = ImgDataPreprocessor(bgr_to_rgb=True, + mean=[0, 0, 0], + std=[255, 255, 255], + pad_size_divisor=16, + pad_value=10) self.assertTrue(data_processor._enable_normalize) self.assertTrue(data_processor._channel_conversion, True) assert_allclose(data_processor.mean, @@ -122,15 +121,15 @@ def test_init(self): ImgDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True) with self.assertRaisesRegex(AssertionError, 'mean and std should be'): - ImgDataPreprocessor( - bgr_to_rgb=True, - mean=None, - std=[255, 255, 255], - pad_size_divisor=16, - pad_value=10) - - data_processor = ImgDataPreprocessor( - bgr_to_rgb=True, pad_size_divisor=16, pad_value=10) + ImgDataPreprocessor(bgr_to_rgb=True, + mean=None, + std=[255, 255, 255], + pad_size_divisor=16, + pad_value=10) + + data_processor = ImgDataPreprocessor(bgr_to_rgb=True, + pad_size_divisor=16, + pad_value=10) self.assertFalse(data_processor._enable_normalize) def test_forward(self): @@ -147,10 +146,9 @@ def test_forward(self): data_sample1 = InstanceData(bboxes=torch.randn(5, 4)) data_sample2 = InstanceData(bboxes=torch.randn(5, 4)) - data = dict( - inputs=[inputs1.clone(), inputs2.clone()], - data_sample=[data_sample1.clone(), - data_sample2.clone()]) + data = dict(inputs=[inputs1.clone(), inputs2.clone()], + data_sample=[data_sample1.clone(), + data_sample2.clone()]) std = torch.tensor([1, 2, 3]).view(-1, 1, 1) target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std @@ -193,26 +191,27 @@ def test_forward(self): assert_allclose(data_sample.bboxes, target_data_sample.bboxes) # Test gray image with 3 dim mean will raise error - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) - data = dict( - inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5)) + data = dict(inputs=[torch.ones(10, 10), + torch.ones(10, 10)], + data_sample=None) with self.assertRaisesRegex(AssertionError, 'If the mean has 3 values'): data_preprocessor(data) - data = dict( - inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) + data = dict(inputs=[torch.ones(10, 10), + torch.ones(10, 10)], + data_sample=None) with self.assertRaisesRegex(AssertionError, 'If the mean has 3 values'): data_preprocessor(data) # Test stacked batch inputs and batch data samples - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), - std=(127.5, 127.5, 127.5), - rgb_to_bgr=True, - pad_size_divisor=16) + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5), + rgb_to_bgr=True, + pad_size_divisor=16) _batch_inputs = torch.randn(2, 3, 10, 10) _batch_labels = [torch.randn(1), torch.randn(1)] data = dict(inputs=_batch_inputs, data_sample=_batch_labels) @@ -226,8 +225,8 @@ def test_forward(self): assert_allclose(target_batch_inputs, inputs) # Test batch inputs without convert channel order and pad - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5)) _batch_inputs = torch.randn(2, 3, 10, 10) _batch_labels = [torch.randn(1), torch.randn(1)] data = dict(inputs=_batch_inputs, data_sample=_batch_labels) @@ -239,8 +238,8 @@ def test_forward(self): assert_allclose(target_batch_inputs, inputs) # Test empty `data_sample` - data = dict( - inputs=[inputs1.clone(), inputs2.clone()], data_sample=None) + data = dict(inputs=[inputs1.clone(), inputs2.clone()], + data_sample=None) output = data_preprocessor(data, True) inputs, data_samples = output['inputs'], output['data_sample'] self.assertIsNone(data_samples) diff --git a/tests/test_model/test_base_module.py b/tests/test_model/test_base_module.py index 1401eed298..bf9489aa76 100644 --- a/tests/test_model/test_base_module.py +++ b/tests/test_model/test_base_module.py @@ -97,20 +97,27 @@ class TestBaseModule(TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.BaseModule = BaseModule() - self.model_cfg = dict( - type='FooModel', - init_cfg=[ - dict(type='Constant', val=1, bias=2, layer='Linear'), - dict(type='Constant', val=3, bias=4, layer='Conv1d'), - dict(type='Constant', val=5, bias=6, layer='Conv2d') - ], - component1=dict(type='FooConv1d'), - component2=dict(type='FooConv2d'), - component3=dict(type='FooLinear'), - component4=dict( - type='FooLinearConv1d', - linear=dict(type='FooLinear'), - conv1d=dict(type='FooConv1d'))) + self.model_cfg = dict(type='FooModel', + init_cfg=[ + dict(type='Constant', + val=1, + bias=2, + layer='Linear'), + dict(type='Constant', + val=3, + bias=4, + layer='Conv1d'), + dict(type='Constant', + val=5, + bias=6, + layer='Conv2d') + ], + component1=dict(type='FooConv1d'), + component2=dict(type='FooConv2d'), + component3=dict(type='FooLinear'), + component4=dict(type='FooLinearConv1d', + linear=dict(type='FooLinear'), + conv1d=dict(type='FooConv1d'))) self.model = build_from_cfg(self.model_cfg, FOOMODELS) self.logger = MMLogger.get_instance(self._testMethodName) @@ -212,8 +219,8 @@ def __init__(self, torch.save(self.model.state_dict(), checkpoint_path) model_cfg = copy.deepcopy(self.model_cfg) model_cfg['type'] = 'PratrainedModel' - model_cfg['init_cfg'] = dict( - type='Pretrained', checkpoint=checkpoint_path) + model_cfg['init_cfg'] = dict(type='Pretrained', + checkpoint=checkpoint_path) model = FOOMODELS.build(model_cfg) ori_layer_weight = model.linear.linear.weight.clone() ori_layer_bias = model.linear.linear.bias.clone() @@ -280,8 +287,8 @@ def test_dump_init_info(self): model1.init_weights() assert len(os.listdir(dump_dir)) == 0 log_path = os.path.join(dump_dir, 'out.log') - MMLogger.get_instance( - 'logger2', log_file=log_path) # add logger with FileHandler + MMLogger.get_instance('logger2', + log_file=log_path) # add logger with FileHandler model2 = build_from_cfg(self.model_cfg, FOOMODELS) model2.init_weights() assert len(os.listdir(dump_dir)) == 1 @@ -297,14 +304,16 @@ class TestModuleList(TestCase): def test_modulelist_weight_init(self): models_cfg = [ - dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + dict(type='FooConv1d', + init_cfg=dict(type='Constant', + layer='Conv1d', + val=0., + bias=1.)), + dict(type='FooConv2d', + init_cfg=dict(type='Constant', + layer='Conv2d', + val=2., + bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList(layers) @@ -323,10 +332,11 @@ def test_modulelist_weight_init(self): torch.full(modellist[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] - modellist = ModuleList( - layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modellist = ModuleList(layers, + init_cfg=dict(type='Constant', + layer=['Conv1d', 'Conv2d'], + val=4., + bias=5.)) modellist.init_weights() self.assertTrue( torch.equal(modellist[0].conv1d.weight, @@ -346,14 +356,16 @@ class TestModuleDict(TestCase): def test_moduledict_weight_init(self): models_cfg = dict( - foo_conv_1d=dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - foo_conv_2d=dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + foo_conv_1d=dict(type='FooConv1d', + init_cfg=dict(type='Constant', + layer='Conv1d', + val=0., + bias=1.)), + foo_conv_2d=dict(type='FooConv2d', + init_cfg=dict(type='Constant', + layer='Conv2d', + val=2., + bias=3.)), ) layers = { name: build_from_cfg(cfg, COMPONENTS) @@ -382,10 +394,11 @@ def test_moduledict_weight_init(self): name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items() } - modeldict = ModuleDict( - layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modeldict = ModuleDict(layers, + init_cfg=dict(type='Constant', + layer=['Conv1d', 'Conv2d'], + val=4., + bias=5.)) modeldict.init_weights() self.assertTrue( torch.equal( @@ -409,14 +422,16 @@ class TestSequential(TestCase): def test_sequential_model_weight_init(self): seq_model_cfg = [ - dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + dict(type='FooConv1d', + init_cfg=dict(type='Constant', + layer='Conv1d', + val=0., + bias=1.)), + dict(type='FooConv2d', + init_cfg=dict(type='Constant', + layer='Conv2d', + val=2., + bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential(*layers) @@ -435,10 +450,11 @@ def test_sequential_model_weight_init(self): torch.full(seq_model[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] - seq_model = Sequential( - *layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + seq_model = Sequential(*layers, + init_cfg=dict(type='Constant', + layer=['Conv1d', 'Conv2d'], + val=4., + bias=5.)) seq_model.init_weights() self.assertTrue( torch.equal(seq_model[0].conv1d.weight, diff --git a/tests/test_model/test_efficient_conv_bn_eval.py b/tests/test_model/test_efficient_conv_bn_eval.py index eb91a6d090..e8cee21f89 100644 --- a/tests/test_model/test_efficient_conv_bn_eval.py +++ b/tests/test_model/test_efficient_conv_bn_eval.py @@ -46,9 +46,8 @@ def forward(self, x): return x -@unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('1.8'), - reason='torch.fx needs Pytorch 1.8 or higher') +@unittest.skipIf(digit_version(TORCH_VERSION) < digit_version('1.8'), + reason='torch.fx needs Pytorch 1.8 or higher') class TestEfficientConvBNEval(TestCase): """Test the turn_on_efficient_conv_bn_eval function.""" diff --git a/tests/test_model/test_model_utils.py b/tests/test_model/test_model_utils.py index a08ff67d77..203e6000e4 100644 --- a/tests/test_model/test_model_utils.py +++ b/tests/test_model/test_model_utils.py @@ -25,8 +25,8 @@ def add_module(self, name, module): raise ValueError() -@pytest.mark.skipif( - torch.__version__ == 'parrots', reason='not supported in parrots now') +@pytest.mark.skipif(torch.__version__ == 'parrots', + reason='not supported in parrots now') def test_revert_syncbn(): # conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.SyncBatchNorm(8)) @@ -40,8 +40,8 @@ def test_revert_syncbn(): revert_sync_batchnorm(conv) -@pytest.mark.skipif( - torch.__version__ == 'parrots', reason='not supported in parrots now') +@pytest.mark.skipif(torch.__version__ == 'parrots', + reason='not supported in parrots now') def test_convert_syncbn(): # conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.BatchNorm2d(8)) diff --git a/tests/test_model/test_test_aug_time.py b/tests/test_model/test_test_aug_time.py index d2b8c97190..62f44bb1cc 100644 --- a/tests/test_model/test_test_aug_time.py +++ b/tests/test_model/test_test_aug_time.py @@ -79,10 +79,12 @@ def test_test_step(self): ] tuple_dataset = [([1, 2], [3, 4]) for _ in range(10)] - dict_dataloader = DataLoader( - dict_dataset, batch_size=2, collate_fn=pseudo_collate) - tuple_dataloader = DataLoader( - tuple_dataset, batch_size=2, collate_fn=pseudo_collate) + dict_dataloader = DataLoader(dict_dataset, + batch_size=2, + collate_fn=pseudo_collate) + tuple_dataloader = DataLoader(tuple_dataset, + batch_size=2, + collate_fn=pseudo_collate) for data in dict_dataloader: result = tta_model.test_step(data) @@ -103,8 +105,8 @@ def test_init(self): def test_with_runner(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.model = dict( - type='ToyTestTimeAugModel', module=dict(type='ToyModel')) + cfg.model = dict(type='ToyTestTimeAugModel', + module=dict(type='ToyModel')) cfg.test_dataloader.dataset = dict(type='ToyDatasetTTA') cfg.test_dataloader.dataset['pipeline'] = dict(type='ToyTTAPipeline') runner = self.build_runner(cfg) diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index ea657acac1..31f7beeea0 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -79,8 +79,8 @@ def setUp(self): super().setUp() self._spawn_processes() - @unittest.skipIf( - not torch.cuda.is_available(), reason='cuda should be available') + @unittest.skipIf(not torch.cuda.is_available(), + reason='cuda should be available') def test_train_step(self): self._init_dist_env(self.rank, self.world_size) # Mixed precision training and gradient asynchronous should be valid at @@ -88,8 +88,8 @@ def test_train_step(self): model = ToyModel().cuda() ddp_model = MMDistributedDataParallel(module=model) optimizer = SGD(ddp_model.parameters(), lr=0) - optim_wrapper = AmpOptimWrapper( - optimizer=optimizer, accumulative_counts=3) + optim_wrapper = AmpOptimWrapper(optimizer=optimizer, + accumulative_counts=3) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] @@ -113,11 +113,11 @@ def test_train_step(self): self.assertIsNone(grad) # Test enable detect_anomalous_params. - ddp_model = MMDistributedDataParallel( - module=model, detect_anomalous_params=True) + ddp_model = MMDistributedDataParallel(module=model, + detect_anomalous_params=True) optimizer = SGD(ddp_model.parameters(), lr=0) - optim_wrapper = AmpOptimWrapper( - optimizer=optimizer, accumulative_counts=3) + optim_wrapper = AmpOptimWrapper(optimizer=optimizer, + accumulative_counts=3) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] @@ -148,12 +148,13 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29510' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) -@unittest.skipIf( - not torch.cuda.is_available(), reason='cuda should be available') +@unittest.skipIf(not torch.cuda.is_available(), + reason='cuda should be available') class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): def test_init(self): @@ -178,8 +179,8 @@ def test_train_step(self): optimizer2 = SGD(model.conv1.parameters(), lr=0.2) optim_wrapper1 = OptimWrapper(optimizer1, 1) optim_wrapper2 = OptimWrapper(optimizer2, 1) - optim_wrapper_dict = OptimWrapperDict( - optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2) + optim_wrapper_dict = OptimWrapperDict(optim_wrapper1=optim_wrapper1, + optim_wrapper2=optim_wrapper2) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) # Automatically sync grads of `optim_wrapper1` since @@ -215,15 +216,15 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29515' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp') -@unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('2.0.0'), - reason='fsdp needs Pytorch 2.0.0 or higher') +@unittest.skipIf(torch.cuda.device_count() < 2, + reason='need 2 gpu to test fsdp') +@unittest.skipIf(digit_version(TORCH_VERSION) < digit_version('2.0.0'), + reason='fsdp needs Pytorch 2.0.0 or higher') class TestMMFullyShardedDataParallel(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): @@ -234,8 +235,9 @@ def _init_dist_env(self, rank, world_size): num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='nccl', + rank=rank, + world_size=world_size) def setUp(self) -> None: super().setUp() @@ -266,8 +268,8 @@ def wrap_policy(module, recurse=True, *args, **kwargs): return True return isinstance(module, nn.Conv2d) - fsdp_model = MMFullyShardedDataParallel( - module=model.cuda(), auto_wrap_policy=wrap_policy) + fsdp_model = MMFullyShardedDataParallel(module=model.cuda(), + auto_wrap_policy=wrap_policy) optimizer = SGD(fsdp_model.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1) inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 113aacd6c8..cbebdd9b49 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -26,8 +26,8 @@ MMCV_FULL_AVAILABLE = mmcv_full_available() if not MMCV_FULL_AVAILABLE: - sys.modules['mmcv.ops'] = MagicMock( - DeformConv2d=dict, ModulatedDeformConv2d=dict) + sys.modules['mmcv.ops'] = MagicMock(DeformConv2d=dict, + ModulatedDeformConv2d=dict) def has_dadaptation() -> bool: @@ -73,8 +73,10 @@ def __init__(self): self.sub = SubModel() if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack - self.dcn = DeformConv2dPack( - 3, 4, kernel_size=3, deformable_groups=1) + self.dcn = DeformConv2dPack(3, + 4, + kernel_size=3, + deformable_groups=1) class ExampleDuplicateModel(nn.Module): @@ -90,8 +92,10 @@ def __init__(self): self.conv3[0] = self.conv1[0] if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack - self.dcn = DeformConv2dPack( - 3, 4, kernel_size=3, deformable_groups=1) + self.dcn = DeformConv2dPack(3, + 4, + kernel_size=3, + deformable_groups=1) def forward(self, x): return x @@ -271,23 +275,19 @@ def test_transformers_optimizers(self): def test_build_optimizer(self): # test build function without ``constructor`` and ``paramwise_cfg`` - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # test build optimizer without type in optim_wrapper_cfg - optim_wrapper_cfg = dict( - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapper) self._check_default_optimizer(optim_wrapper.optimizer, self.model) @@ -310,24 +310,20 @@ def test_build_optimizer(self): lambda: build_optim_wrapper(self.model, optim_wrapper_cfg)) def test_build_default_optimizer_constructor(self): - optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) - optim_constructor_cfg = dict( - type='DefaultOptimWrapperConstructor', - optim_wrapper_cfg=optim_wrapper, - paramwise_cfg=paramwise_cfg) + optim_wrapper = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) + optim_constructor_cfg = dict(type='DefaultOptimWrapperConstructor', + optim_wrapper_cfg=optim_wrapper, + paramwise_cfg=paramwise_cfg) optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( optim_constructor_cfg) optim_wrapper = optim_constructor(self.model) @@ -335,13 +331,11 @@ def test_build_default_optimizer_constructor(self): **paramwise_cfg) def test_build_custom_optimizer_constructor(self): - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) @OPTIM_WRAPPER_CONSTRUCTORS.register_module() class MyOptimizerConstructor(DefaultOptimWrapperConstructor): @@ -363,10 +357,9 @@ def __call__(self, model): return build_from_cfg(self.optimizer_cfg, OPTIMIZERS) paramwise_cfg = dict(conv1_lr_mult=5) - optim_constructor_cfg = dict( - type='MyOptimizerConstructor', - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg) + optim_constructor_cfg = dict(type='MyOptimizerConstructor', + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg) optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( optim_constructor_cfg) optimizer = optim_constructor(self.model) @@ -394,9 +387,9 @@ def test_default_optimizer_constructor(self): with self.assertRaises(TypeError): # paramwise_cfg must be a dict or None - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict(lr=0.0001, weight_decay=None)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(lr=0.0001, + weight_decay=None)) paramwise_cfg = ['error'] optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -405,29 +398,28 @@ def test_default_optimizer_constructor(self): with self.assertRaises(ValueError): # bias_decay_mult/norm_decay_mult is specified but weight_decay # is None - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict(lr=0.0001, weight_decay=None)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(lr=0.0001, + weight_decay=None)) paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) # basic config with ExampleModel - optimizer_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optimizer_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) optim_wrapper = optim_constructor(self.model) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # Support building custom optimizers - CUSTOM_OPTIMIZERS = Registry( - 'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS) + CUSTOM_OPTIMIZERS = Registry('custom optimizer', + scope='custom optimizer', + parent=OPTIMIZERS) class CustomOptimizer(torch.optim.SGD): @@ -444,93 +436,84 @@ def __init__(self, model_params, *args, **kwargs): def test_default_optimizer_constructor_with_model_wrapper(self): # basic config with pseudo data parallel model = PseudoDataParallel() - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = None optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) optim_wrapper = optim_constructor(model) - self._check_default_optimizer( - optim_wrapper.optimizer, model, prefix='module.') + self._check_default_optimizer(optim_wrapper.optimizer, + model, + prefix='module.') # paramwise_cfg with pseudo data parallel model = PseudoDataParallel() - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) - self._check_sgd_optimizer( - optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, + model, + prefix='module.', + **paramwise_cfg) # basic config with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(ExampleModel()) - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = None optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg) optim_wrapper = optim_constructor(model) - self._check_default_optimizer( - optim_wrapper.optimizer, model, prefix='module.') + self._check_default_optimizer(optim_wrapper.optimizer, + model, + prefix='module.') # paramwise_cfg with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(self.model) - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) - self._check_sgd_optimizer( - optim_wrapper.optimizer, - model, - prefix='module.', - **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, + model, + prefix='module.', + **paramwise_cfg) def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): # Empty paramwise_cfg with ExampleModel - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict() optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -541,13 +524,11 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): model = ExampleModel() for param in model.parameters(): param.requires_grad = False - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict() optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -556,20 +537,17 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): def test_default_optimizer_constructor_with_paramwise_cfg(self): # paramwise_cfg with ExampleModel - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) @@ -578,19 +556,16 @@ def test_default_optimizer_constructor_with_paramwise_cfg(self): def test_default_optimizer_constructor_no_grad(self): # paramwise_cfg with ExampleModel and no grad - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1) self.model.conv1.requires_grad_(False) optim_constructor = DefaultOptimWrapperConstructor( @@ -606,18 +581,15 @@ def test_default_optimizer_constructor_no_grad(self): def test_default_optimizer_constructor_bypass_duplicate(self): # paramwise_cfg with bypass_duplicate option model = ExampleDuplicateModel() - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1) with self.assertRaisesRegex( ValueError, @@ -626,14 +598,13 @@ def test_default_optimizer_constructor_bypass_duplicate(self): optim_wrapper_cfg, paramwise_cfg) optim_constructor(model) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3, - bypass_duplicate=True) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3, + bypass_duplicate=True) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -663,21 +634,18 @@ def test_default_optimizer_constructor_bypass_duplicate(self): def test_default_optimizer_constructor_custom_key(self): # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - custom_keys={ - 'param1': dict(lr_mult=10), - 'sub': dict(lr_mult=0.1, decay_mult=0), - 'sub.gn': dict(lr_mult=0.01), - 'non_exist_key': dict(lr_mult=0.0) - }, - norm_decay_mult=0.5) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(custom_keys={ + 'param1': dict(lr_mult=10), + 'sub': dict(lr_mult=0.1, decay_mult=0), + 'sub.gn': dict(lr_mult=0.01), + 'non_exist_key': dict(lr_mult=0.0) + }, + norm_decay_mult=0.5) with self.assertRaises(TypeError): # custom_keys should be a dict @@ -689,8 +657,8 @@ def test_default_optimizer_constructor_custom_key(self): with self.assertRaises(ValueError): # if 'decay_mult' is specified in custom_keys, weight_decay # should be specified - optim_wrapper_cfg_ = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)) + optim_wrapper_cfg_ = dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01)) paramwise_cfg_ = dict( custom_keys={'.backbone': dict(decay_mult=0.5)}) optim_constructor = DefaultOptimWrapperConstructor( @@ -760,10 +728,10 @@ def test_default_optimizer_constructor_custom_key(self): # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel 2 - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', lr=self.base_lr, momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + momentum=self.momentum)) paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) optim_constructor = DefaultOptimWrapperConstructor( @@ -849,24 +817,21 @@ def test_zero_redundancy_optimizer(self): self.base_wd = 0.9 # test build function - optim_wrapper_cfg = dict( - optimizer=dict( - type='ZeroRedundancyOptimizer', - optimizer_type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(optimizer=dict(type='ZeroRedundancyOptimizer', + optimizer_type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, model) # test build optimizer without ``optimizer_type`` with self.assertRaises(TypeError): optim_wrapper_cfg = dict( - optimizer=dict( - type='ZeroRedundancyOptimizer', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optimizer=dict(type='ZeroRedundancyOptimizer', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) @unittest.skipIf( @@ -885,14 +850,12 @@ def test_zero_redundancy_optimizer_with_paramwise_cfg(self): 'conv1': dict(lr_mult=0.0, decay_mult=0.0), 'conv2': dict(lr_mult=1.0, decay_mult=2.0) }) - optim_wrapper_cfg = dict( - optimizer=dict( - type='ZeroRedundancyOptimizer', - optimizer_type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum), - paramwise_cfg=paramwise_cfg) + optim_wrapper_cfg = dict(optimizer=dict(type='ZeroRedundancyOptimizer', + optimizer_type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum), + paramwise_cfg=paramwise_cfg) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, model) @@ -901,5 +864,6 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29510' os.environ['RANK'] = str(rank) - torch.distributed.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch.distributed.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index ef1db241dd..12723686d3 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -176,8 +176,8 @@ def test_ger_lr(self): optim_wrapper = OptimWrapper(optim) self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1])) model = ToyModel() - optimizer_cfg = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.1)) + optimizer_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.1)) paramwise_cfg = dict(custom_keys={'conv1.weight': dict(lr_mult=0.1)}) optim_constructor = DefaultOptimWrapperConstructor( optimizer_cfg, paramwise_cfg) @@ -226,8 +226,8 @@ def test_step(self): @unittest.skipIf(True, reason='Solved in the future') def test_clip_grads(self): # Test `clip_grad` with `clip_norm_` - optim_wrapper = OptimWrapper( - self.optimizer, clip_grad=dict(max_norm=35)) + optim_wrapper = OptimWrapper(self.optimizer, + clip_grad=dict(max_norm=35)) loss = self.model(torch.Tensor(1, 1, 1, 1)) loss.backward() optim_wrapper._clip_grad() @@ -236,8 +236,9 @@ def test_clip_grads(self): self.message_hub._log_scalars.clear() # Test `clip_grad` with `clip_value_` - optim_wrapper = OptimWrapper( - self.optimizer, clip_grad=dict(type='value', clip_value=0.5)) + optim_wrapper = OptimWrapper(self.optimizer, + clip_grad=dict(type='value', + clip_value=0.5)) loss = self.model(torch.Tensor(1, 1, 1, 1)) loss.backward() optim_wrapper._clip_grad() @@ -300,8 +301,9 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29515' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) # TODO Test the real interface after add testing tool function which can # test the function or method is read called. @@ -328,8 +330,9 @@ def setUp(self) -> None: reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_init(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): pass @@ -339,8 +342,9 @@ def test_init(self): 'https://www.github.com/nvidia/apex') def test_step(self): optimizer = MagicMock(spec=Optimizer) - apex_optim_wrapper = ApexOptimWrapper( - optimizer=optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.backward(loss) @@ -351,8 +355,9 @@ def test_step(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_backward(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.backward(loss) @@ -362,8 +367,9 @@ def test_backward(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_state_dict(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.update_params(loss) @@ -380,8 +386,9 @@ def test_state_dict(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_load_state_dict(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): # Test load from optimizer optimizer = SGD(self.model.parameters(), lr=0.1) @@ -403,8 +410,9 @@ def test_load_state_dict(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_optim_context(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) @@ -426,24 +434,25 @@ def test_init(self): self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) # Test with dynamic. - amp_optim_wrapper = AmpOptimWrapper( - 'dynamic', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper('dynamic', + optimizer=self.optimizer) self.assertIsNone(amp_optim_wrapper._scale_update_param) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) # Test with dtype float16 - amp_optim_wrapper = AmpOptimWrapper( - dtype='float16', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dtype='float16', + optimizer=self.optimizer) self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16) # Test with dtype bfloat16 - amp_optim_wrapper = AmpOptimWrapper( - dtype='bfloat16', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dtype='bfloat16', + optimizer=self.optimizer) self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16) # Test with dict loss_scale. - amp_optim_wrapper = AmpOptimWrapper( - dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dict(init_scale=1, + growth_factor=2), + optimizer=self.optimizer) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) self.assertIsNone(amp_optim_wrapper._scale_update_param) with self.assertRaisesRegex(TypeError, @@ -455,8 +464,8 @@ def test_init(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_step(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -478,14 +487,14 @@ def test_step(self, dtype): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_backward(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): raise unittest.SkipTest('bfloat16 not supported by device') - amp_optim_wrapper = AmpOptimWrapper( - optimizer=self.optimizer, dtype=dtype) + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer, + dtype=dtype) loss_scaler = MagicMock() scale_return = MagicMock() scale_fn = MagicMock(return_value=scale_return) @@ -539,14 +548,14 @@ def test_load_state_dict(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_optim_context(self, dtype, target_dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): raise unittest.SkipTest('bfloat16 not supported by device') - amp_optim_wrapper = AmpOptimWrapper( - optimizer=self.optimizer, dtype=dtype) + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer, + dtype=dtype) with amp_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py index 3925a33ac9..990cbad757 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py @@ -18,8 +18,8 @@ def setUp(self) -> None: self.optim2 = SGD(self.model2.parameters(), lr=0.2, momentum=0.9) self.optim_wrapper1 = OptimWrapper(self.optim1) self.optim_wrapper2 = OptimWrapper(self.optim2) - self.optimizers_wrappers = dict( - optim1=self.optim_wrapper1, optim2=self.optim_wrapper2) + self.optimizers_wrappers = dict(optim1=self.optim_wrapper1, + optim2=self.optim_wrapper2) def test_init(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) @@ -111,8 +111,8 @@ def test_load_state_dict(self): optim_wrapper_load2 = OptimWrapper(optim2) optim_wrapper_dict_save = OptimWrapperDict(**self.optimizers_wrappers) - optim_wrapper_dict_load = OptimWrapperDict( - optim1=optim_wrapper_load1, optim2=optim_wrapper_load2) + optim_wrapper_dict_load = OptimWrapperDict(optim1=optim_wrapper_load1, + optim2=optim_wrapper_load2) state_dict = optim_wrapper_dict_save.state_dict() optim_wrapper_dict_load.load_state_dict(state_dict) @@ -121,21 +121,18 @@ def test_load_state_dict(self): def test_items(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.items()), - list(self.optimizers_wrappers.items())) + self.assertListEqual(list(optim_wrapper_dict.items()), + list(self.optimizers_wrappers.items())) def test_values(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.values()), - list(self.optimizers_wrappers.values())) + self.assertListEqual(list(optim_wrapper_dict.values()), + list(self.optimizers_wrappers.values())) def test_keys(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.keys()), - list(self.optimizers_wrappers.keys())) + self.assertListEqual(list(optim_wrapper_dict.keys()), + list(self.optimizers_wrappers.keys())) def test_getitem(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 22787e4709..4d3380b3cf 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -118,8 +118,10 @@ def call_sch_before_optim(): group['initial_lr'] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepLR( - self.optimizer, gamma=0.1, step_size=3, last_step=10) + scheduler = StepLR(self.optimizer, + gamma=0.1, + step_size=3, + last_step=10) scheduler.step() self.optimizer.step() @@ -179,17 +181,16 @@ def test_effective_interval(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) + single_targets = [0.05] * begin + [ + x * 0.05 for x in interpolation + ] + [0.05] * (epochs - iters - begin) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearLR( - self.optimizer, - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + scheduler = LinearLR(self.optimizer, + start_factor=start_factor, + begin=begin, + end=begin + iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def _test_scheduler_value(self, @@ -233,8 +234,10 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepLR( - self.optimizer, gamma=0.1, step_size=3, verbose=True) + scheduler = StepLR(self.optimizer, + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(scheduler, targets, epochs) def test_multi_step_scheduler(self): @@ -248,8 +251,9 @@ def test_multi_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepLR(self.optimizer, + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(scheduler, targets, epochs) def test_constant_scheduler(self): @@ -287,13 +291,14 @@ def test_linear_scheduler(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) + single_targets = [x * 0.05 + for x in interpolation] + [0.05] * (epochs - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearLR( - self.optimizer, start_factor=start_factor, end=iters + 1) + scheduler = LinearLR(self.optimizer, + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler(self): @@ -320,8 +325,10 @@ def test_cos_anneal_scheduler(self): self._test_scheduler_value(scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingLR( - self.optimizer, begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingLR(self.optimizer, + begin=5, + end=100, + eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -332,32 +339,30 @@ def test_poly_scheduler(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyLR( - self.optimizer, power=power, eta_min=min_lr, end=iters + 1) + scheduler = PolyLR(self.optimizer, + power=power, + eta_min=min_lr, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): - CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) with self.assertRaises(AssertionError): - CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) single_targets = [ 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, 0.0086372, 0.0023872, 0.0023872 @@ -365,11 +370,10 @@ def test_cosine_restart_scheduler(self): targets = [ single_targets, [t * self.layer2_mult for t in single_targets] ] - scheduler = CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) + scheduler = CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) self._test_scheduler_value(scheduler, targets, epochs=10) def test_reduce_on_plateau_scheduler(self): @@ -429,8 +433,10 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + self._test_scheduler_value(scheduler, + targets, + epochs=epochs, + step_kwargs=metrics_list) # reset the state of optimizers self.optimizer = optim.SGD([{ @@ -559,9 +565,8 @@ def test_step_scheduler_state_dict(self): def test_multi_step_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: MultiStepLR( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]), - lambda: MultiStepLR( - self.optimizer, gamma=0.01, milestones=[1, 4, 6])) + self.optimizer, gamma=0.1, milestones=[2, 5, 9]), lambda: + MultiStepLR(self.optimizer, gamma=0.01, milestones=[1, 4, 6])) def test_exp_scheduler_state_dict(self): self._check_scheduler_state_dict( @@ -593,52 +598,50 @@ def test_poly_scheduler_state_dict(self): def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartLR( - self.optimizer, - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), + lambda: CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartLR(self.optimizer, + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), epochs=10) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] self._check_scheduler_state_dict( - lambda: ReduceOnPlateauLR( - self.optimizer, - monitor='loss', - rule='less', - factor=0.01, - patience=5, - threshold=1e-4, - threshold_rule='rel', - cooldown=0, - min_value=0.0, - eps=1e-8), - lambda: ReduceOnPlateauLR( - self.optimizer, - monitor='loss_foo', - rule='greater', - factor=0.05, - patience=10, - threshold=1e-5, - threshold_rule='abs', - cooldown=5, - min_value=0.1, - eps=1e-9), + lambda: ReduceOnPlateauLR(self.optimizer, + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauLR(self.optimizer, + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), epochs=epochs, step_kwargs=metrics_list) def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError): - scheduler = StepLR.build_iter_from_epoch( - self.optimizer, gamma=0.1, step_size=2, epoch_length=-1) + scheduler = StepLR.build_iter_from_epoch(self.optimizer, + gamma=0.1, + step_size=2, + epoch_length=-1) # lr = 0.05 if epoch < 2 # lr = 0.005 if 2 <= epoch < 4 @@ -648,10 +651,14 @@ def test_step_scheduler_convert_iterbased(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepLR.build_iter_from_epoch( - self.optimizer, gamma=0.1, step_size=2, epoch_length=epoch_length) - self._test_scheduler_value( - scheduler, targets, epochs * epoch_length, param_name='lr') + scheduler = StepLR.build_iter_from_epoch(self.optimizer, + gamma=0.1, + step_size=2, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, + targets, + epochs * epoch_length, + param_name='lr') def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.05 if epoch < 2 @@ -684,8 +691,10 @@ def test_constant_scheduler_convert_iterbased(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = ConstantLR.build_iter_from_epoch( - self.optimizer, factor=1.0 / 2, end=5, epoch_length=epoch_length) + scheduler = ConstantLR.build_iter_from_epoch(self.optimizer, + factor=1.0 / 2, + end=5, + epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_linear_scheduler_convert_iterbased(self): @@ -698,16 +707,15 @@ def test_linear_scheduler_convert_iterbased(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs * epoch_length - iters) + single_targets = [x * 0.05 for x in interpolation + ] + [0.05] * (epochs * epoch_length - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearLR.build_iter_from_epoch( - self.optimizer, - start_factor=start_factor, - end=end, - epoch_length=epoch_length) + scheduler = LinearLR.build_iter_from_epoch(self.optimizer, + start_factor=start_factor, + end=end, + epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler_convert_iterbased(self): @@ -755,20 +763,17 @@ def test_poly_scheduler_convert_iterbased(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyLR.build_iter_from_epoch( - self.optimizer, - power=power, - eta_min=min_lr, - end=end, - epoch_length=epoch_length) + scheduler = PolyLR.build_iter_from_epoch(self.optimizer, + power=power, + eta_min=min_lr, + end=end, + epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs=10) def test_multi_scheduler_without_overlap_linear_multi_step(self): @@ -779,10 +784,15 @@ def test_multi_scheduler_without_overlap_linear_multi_step(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearLR( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) + scheduler1 = LinearLR(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepLR(self.optimizer, + gamma=0.1, + milestones=[3, 6], + begin=5, + end=12) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_without_overlap_exp_cosine(self): @@ -800,8 +810,11 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingLR( - self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) + scheduler2 = CosineAnnealingLR(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=5, + end=10) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) @@ -813,10 +826,13 @@ def test_multi_scheduler_with_overlap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearLR( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[3, 6, 9]) + scheduler1 = LinearLR(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepLR(self.optimizer, + gamma=0.1, + milestones=[3, 6, 9]) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_gap(self): @@ -836,32 +852,33 @@ def test_multi_scheduler_with_gap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingLR( - self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) + scheduler2 = CosineAnnealingLR(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=10, + end=15) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_onecycle_lr(self): # test linear annealing target = [1., 13., 25., 21.5, 18., 14.5, 11., 7.5, 4., 0.5] - scheduler = OneCycleLR( - self.optimizer, - eta_max=25, - final_div_factor=2, - total_steps=10, - anneal_strategy='linear') + scheduler = OneCycleLR(self.optimizer, + eta_max=25, + final_div_factor=2, + total_steps=10, + anneal_strategy='linear') self._test_scheduler_value(scheduler, [target], 10) # test linear annealing three phase target = [1., 9., 17., 25., 17., 9., 1., 0.75, 0.5, 0.25] - scheduler = OneCycleLR( - self.optimizer, - eta_max=25, - div_factor=25, - total_steps=10, - anneal_strategy='linear', - pct_start=0.4, - final_div_factor=4, - three_phase=True) + scheduler = OneCycleLR(self.optimizer, + eta_max=25, + div_factor=25, + total_steps=10, + anneal_strategy='linear', + pct_start=0.4, + final_div_factor=4, + three_phase=True) self._test_scheduler_value(scheduler, [target], 10) # test cosine annealing @@ -878,6 +895,8 @@ def annealing_cos(start, end, pct): annealing_cos(25, 0.5, 5 / 7.0), annealing_cos(25, 0.5, 6 / 7.0), 0.5 ] - scheduler = OneCycleLR( - self.optimizer, eta_max=25, final_div_factor=2, total_steps=10) + scheduler = OneCycleLR(self.optimizer, + eta_max=25, + final_div_factor=2, + total_steps=10) self._test_scheduler_value(scheduler, [target], 10) diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index 60a9713ee2..171d8f0977 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -104,8 +104,9 @@ def test_resume(self): if epoch == 4: break scheduler.step() - scheduler2 = ExponentialMomentum( - self.optimizer, gamma=0.9, last_step=4) + scheduler2 = ExponentialMomentum(self.optimizer, + gamma=0.9, + last_step=4) for epoch in range(6): results.append(self.optimizer.param_groups[0]['momentum']) scheduler2.step() @@ -136,8 +137,10 @@ def call_sch_before_optim(): group['initial_momentum'] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepMomentum( - self.optimizer, gamma=0.1, step_size=3, last_step=10) + scheduler = StepMomentum(self.optimizer, + gamma=0.1, + step_size=3, + last_step=10) scheduler.step() self.optimizer.step() @@ -182,8 +185,11 @@ def test_effective_interval(self): # check invalid begin end with self.assertRaisesRegex(ValueError, 'end should be larger than begin'): - StepMomentum( - self.optimizer, gamma=0.1, step_size=3, begin=10, end=5) + StepMomentum(self.optimizer, + gamma=0.1, + step_size=3, + begin=10, + end=5) # momentum = 0.05 if epoch == 0 # momentum = 0.025 if epoch == 1 @@ -198,17 +204,16 @@ def test_effective_interval(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) + single_targets = [0.05] * begin + [ + x * 0.05 for x in interpolation + ] + [0.05] * (epochs - iters - begin) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearMomentum( - self.optimizer, - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + scheduler = LinearMomentum(self.optimizer, + start_factor=start_factor, + begin=begin, + end=begin + iters + 1) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) def _test_scheduler_value(self, @@ -261,12 +266,16 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepMomentum( - self.optimizer, gamma=0.1, step_size=3, verbose=True) + scheduler = StepMomentum(self.optimizer, + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = StepMomentum( - self.optimizer_with_betas, gamma=0.1, step_size=3, verbose=True) + scheduler = StepMomentum(self.optimizer_with_betas, + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -281,12 +290,14 @@ def test_multi_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepMomentum(self.optimizer, + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = MultiStepMomentum( - self.optimizer_with_betas, gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepMomentum(self.optimizer_with_betas, + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -305,8 +316,9 @@ def test_constant_scheduler(self): scheduler = ConstantMomentum(self.optimizer, factor=1.0 / 2, end=5) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = ConstantMomentum( - self.optimizer_with_betas, factor=1.0 / 2, end=5) + scheduler = ConstantMomentum(self.optimizer_with_betas, + factor=1.0 / 2, + end=5) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -330,19 +342,19 @@ def test_linear_scheduler(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) + single_targets = [x * 0.05 + for x in interpolation] + [0.05] * (epochs - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearMomentum( - self.optimizer, start_factor=start_factor, end=iters + 1) + scheduler = LinearMomentum(self.optimizer, + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = LinearMomentum( - self.optimizer_with_betas, - start_factor=start_factor, - end=iters + 1) + scheduler = LinearMomentum(self.optimizer_with_betas, + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -370,18 +382,22 @@ def test_cos_anneal_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = CosineAnnealingMomentum( - self.optimizer, T_max=t, eta_min=eta_min) + scheduler = CosineAnnealingMomentum(self.optimizer, + T_max=t, + eta_min=eta_min) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = CosineAnnealingMomentum( - self.optimizer_with_betas, T_max=t, eta_min=eta_min) + scheduler = CosineAnnealingMomentum(self.optimizer_with_betas, + T_max=t, + eta_min=eta_min) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingMomentum( - self.optimizer, begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingMomentum(self.optimizer, + begin=5, + end=100, + eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -392,41 +408,42 @@ def test_poly_scheduler(self): layer1_targets = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) layer2_targets = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [layer1_targets, layer2_targets] - scheduler = PolyMomentum( - self.optimizer, power=power, eta_min=min_lr, end=iters + 1) - self._test_scheduler_value( - self.optimizer, scheduler, targets, epochs=10) - - scheduler = PolyMomentum( - self.optimizer_with_betas, - power=power, - eta_min=min_lr, - end=iters + 1) - self._test_scheduler_value( - self.optimizer_with_betas, scheduler, targets, epochs=10) + scheduler = PolyMomentum(self.optimizer, + power=power, + eta_min=min_lr, + end=iters + 1) + self._test_scheduler_value(self.optimizer, + scheduler, + targets, + epochs=10) + + scheduler = PolyMomentum(self.optimizer_with_betas, + power=power, + eta_min=min_lr, + end=iters + 1) + self._test_scheduler_value(self.optimizer_with_betas, + scheduler, + targets, + epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): - CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) with self.assertRaises(AssertionError): - CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) single_targets = [ 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, 0.0086372, 0.0023872, 0.0023872 @@ -434,21 +451,23 @@ def test_cosine_restart_scheduler(self): targets = [ single_targets, [t * self.layer2_mult for t in single_targets] ] - scheduler = CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) - self._test_scheduler_value( - self.optimizer, scheduler, targets, epochs=10) - - scheduler = CosineRestartMomentum( - self.optimizer_with_betas, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) - self._test_scheduler_value( - self.optimizer_with_betas, scheduler, targets, epochs=10) + scheduler = CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value(self.optimizer, + scheduler, + targets, + epochs=10) + + scheduler = CosineRestartMomentum(self.optimizer_with_betas, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value(self.optimizer_with_betas, + scheduler, + targets, + epochs=10) def test_reduce_on_plateau_scheduler(self): # inherit _ParamScheduler but not call super().__init__(), @@ -474,8 +493,8 @@ def test_reduce_on_plateau_scheduler(self): ReduceOnPlateauMomentum(self.optimizer, factor=2.0) ReduceOnPlateauMomentum(self.optimizer, min_value=[0.1, 0.1]) with self.assertRaises(ValueError): - ReduceOnPlateauMomentum( - self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1]) + ReduceOnPlateauMomentum(self.optimizer, + min_value=[0.1, 0.1, 0.1, 0.1]) with self.assertRaises(ValueError): ReduceOnPlateauMomentum(self.optimizer, threshold=-1.0) with self.assertRaises(ValueError): @@ -512,12 +531,11 @@ def _test_value(epochs, targets, metrics_list, optimizer, monitor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - optimizer, - scheduler, - targets, - epochs=epochs, - step_kwargs=metrics_list) + self._test_scheduler_value(optimizer, + scheduler, + targets, + epochs=epochs, + step_kwargs=metrics_list) # reset the state of optimizers self.optimizer = optim.SGD([{ @@ -700,44 +718,40 @@ def test_poly_scheduler_state_dict(self): def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartMomentum( - self.optimizer, - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), + lambda: CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartMomentum(self.optimizer, + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), epochs=10) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] self._check_scheduler_state_dict( - lambda: ReduceOnPlateauMomentum( - self.optimizer, - monitor='loss', - rule='less', - factor=0.01, - patience=5, - threshold=1e-4, - threshold_rule='rel', - cooldown=0, - min_value=0.0, - eps=1e-8), - lambda: ReduceOnPlateauMomentum( - self.optimizer, - monitor='loss_foo', - rule='greater', - factor=0.05, - patience=10, - threshold=1e-5, - threshold_rule='abs', - cooldown=5, - min_value=0.1, - eps=1e-9), + lambda: ReduceOnPlateauMomentum(self.optimizer, + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauMomentum(self.optimizer, + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), epochs=epochs, step_kwargs=metrics_list) @@ -749,10 +763,15 @@ def test_multi_scheduler_without_overlap_linear_multi_step(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearMomentum( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) + scheduler1 = LinearMomentum(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepMomentum(self.optimizer, + gamma=0.1, + milestones=[3, 6], + begin=5, + end=12) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) @@ -760,8 +779,10 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): # use Exp in the first 5 epochs and then use Cosine epochs = 10 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialMomentum( - self.optimizer, gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialMomentum(self.optimizer, + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -772,8 +793,11 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingMomentum( - self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) + scheduler2 = CosineAnnealingMomentum(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=5, + end=10) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) @@ -786,10 +810,13 @@ def test_multi_scheduler_with_overlap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearMomentum( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[3, 6, 9]) + scheduler1 = LinearMomentum(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepMomentum(self.optimizer, + gamma=0.1, + milestones=[3, 6, 9]) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) @@ -798,8 +825,10 @@ def test_multi_scheduler_with_gap(self): # no scheduler in the middle 5 epochs epochs = 15 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialMomentum( - self.optimizer, gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialMomentum(self.optimizer, + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -811,8 +840,11 @@ def test_multi_scheduler_with_gap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingMomentum( - self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) + scheduler2 = CosineAnnealingMomentum(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=10, + end=15) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index a13072dc6e..f9e39596b6 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -68,13 +68,15 @@ def test_base_scheduler_step(self): def test_invalid_optimizer(self): with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): - StepParamScheduler( - 'invalid_optimizer', step_size=1, param_name='lr') + StepParamScheduler('invalid_optimizer', + step_size=1, + param_name='lr') def test_overwrite_optimzer_step(self): # raise warning if the counter in optimizer.step() is overwritten - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9) def overwrite_fun(): pass @@ -88,18 +90,18 @@ def test_resume(self): # test invalid case: optimizer and scheduler are not both resumed with self.assertRaisesRegex(KeyError, "param 'initial_lr' is not specified"): - StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - last_step=10) + StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + last_step=10) # test manually resume with ``last_step`` instead of load_state_dict epochs = 10 targets = [0.05 * (0.9**x) for x in range(epochs)] - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9) results = [] for epoch in range(5): @@ -111,8 +113,10 @@ def test_resume(self): if epoch == 4: break scheduler.step() - scheduler2 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, last_step=4) + scheduler2 = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9, + last_step=4) for epoch in range(6): results.append(self.optimizer.param_groups[0]['lr']) scheduler2.step() @@ -130,8 +134,10 @@ def test_scheduler_before_optim_warning(self): """Warns if scheduler is used before optimizer.""" def call_sch_before_optim(): - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, step_size=3) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3) scheduler.step() self.optimizer.step() @@ -144,12 +150,11 @@ def call_sch_before_optim(): group['initial_lr'] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - last_step=10) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + last_step=10) scheduler.step() self.optimizer.step() @@ -163,8 +168,10 @@ def test_get_last_value(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', step_size=3, gamma=0.1) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + step_size=3, + gamma=0.1) for epoch in range(epochs): result = scheduler.get_last_value() if isinstance(scheduler.optimizer, OptimWrapper) \ @@ -184,8 +191,10 @@ def test_get_last_value(self): def test_scheduler_step_count(self): iteration = 10 - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, step_size=3) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3) self.assertEqual(scheduler.last_step, 0) target = [i + 1 for i in range(iteration)] step_counts = [] @@ -199,13 +208,12 @@ def test_effective_interval(self): # check invalid begin end with self.assertRaisesRegex(ValueError, 'end should be larger than begin'): - StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - begin=10, - end=5) + StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + begin=10, + end=5) # lr = 0.05 if epoch == 0 # lr = 0.025 if epoch == 1 @@ -220,24 +228,24 @@ def test_effective_interval(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) + single_targets = [0.05] * begin + [ + x * 0.05 for x in interpolation + ] + [0.05] * (epochs - iters - begin) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + scheduler = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=start_factor, + begin=begin, + end=begin + iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_param_name(self): with self.assertRaises(KeyError): - StepParamScheduler( - self.optimizer, param_name='invalid_name', step_size=10) + StepParamScheduler(self.optimizer, + param_name='invalid_name', + step_size=10) def _test_scheduler_value(self, schedulers, @@ -280,12 +288,11 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - verbose=True) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(scheduler, targets, epochs) # momentum = 0.01 if epoch < 2 @@ -295,10 +302,14 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepParamScheduler( - self.optimizer, param_name='momentum', gamma=0.1, step_size=2) - self._test_scheduler_value( - scheduler, targets, epochs, param_name='momentum') + scheduler = StepParamScheduler(self.optimizer, + param_name='momentum', + gamma=0.1, + step_size=2) + self._test_scheduler_value(scheduler, + targets, + epochs, + param_name='momentum') def test_multi_step_scheduler(self): # lr = 0.05 if epoch < 2 @@ -311,8 +322,10 @@ def test_multi_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = MultiStepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(scheduler, targets, epochs) def test_constant_scheduler(self): @@ -327,23 +340,33 @@ def test_constant_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = ConstantParamScheduler( - self.optimizer, param_name='lr', factor=1.0 / 2, end=5) + scheduler = ConstantParamScheduler(self.optimizer, + param_name='lr', + factor=1.0 / 2, + end=5) self._test_scheduler_value(scheduler, targets, epochs) def test_linear_scheduler(self): with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', start_factor=10, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=10, + end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', start_factor=-1, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=-1, + end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', end_factor=1.001, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + end_factor=1.001, + end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', end_factor=-0.00001, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + end_factor=-0.00001, + end=900) # lr = 0.025 if epoch == 0 # lr = 0.03125 if epoch == 1 # lr = 0.0375 if epoch == 2 @@ -355,16 +378,15 @@ def test_linear_scheduler(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) + single_targets = [x * 0.05 + for x in interpolation] + [0.05] * (epochs - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=start_factor, - end=iters + 1) + scheduler = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler(self): @@ -373,18 +395,18 @@ def test_exp_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9) self._test_scheduler_value(scheduler, targets, epochs) def test_cos_anneal_scheduler(self): with self.assertRaises(AssertionError): - CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=10, - eta_min=0, - eta_min_ratio=0.1) + CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=10, + eta_min=0, + eta_min_ratio=0.1) epochs = 12 t = 10 eta_min = 5e-3 @@ -397,8 +419,10 @@ def test_cos_anneal_scheduler(self): for x in range(epochs) ] targets = [targets1, targets2] - scheduler = CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', T_max=t, eta_min=eta_min) + scheduler = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=t, + eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) # Test `eta_min_ratio` @@ -413,16 +437,18 @@ def test_cos_anneal_scheduler(self): (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs) ] targets = [targets1, targets2] - scheduler = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=t, - eta_min_ratio=eta_min_ratio) + scheduler = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=t, + eta_min_ratio=eta_min_ratio) self._test_scheduler_value(scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + begin=5, + end=100, + eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -433,38 +459,33 @@ def test_poly_scheduler(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyParamScheduler( - self.optimizer, - param_name='lr', - power=power, - eta_min=min_lr, - end=iters + 1) + scheduler = PolyParamScheduler(self.optimizer, + param_name='lr', + power=power, + eta_min=min_lr, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): - CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) with self.assertRaises(AssertionError): - CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) single_targets = [ 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, 0.0086372, 0.0023872, 0.0023872 @@ -474,12 +495,11 @@ def test_cosine_restart_scheduler(self): ] # Test with non-zero eta-min. - scheduler = CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) + scheduler = CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) self._test_scheduler_value(scheduler, targets, epochs=10) epochs = 10 @@ -494,12 +514,11 @@ def test_cosine_restart_scheduler(self): for x in range(epochs) ] targets = [targets1, targets2] - scheduler = CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[t], - restart_weights=[1], - eta_min=eta_min) + scheduler = CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[t], + restart_weights=[1], + eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs=10) def test_reduce_on_plateau_scheduler(self): @@ -510,34 +529,41 @@ def test_reduce_on_plateau_scheduler(self): with self.assertRaises(TypeError): ReduceOnPlateauParamScheduler('invalid_optimizer', param_name='lr') with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', begin=10, end=5) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + begin=10, + end=5) with self.assertRaises(AssertionError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', by_epoch=False) for last_step in (1.5, -2): with self.assertRaises(AssertionError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', last_step=last_step) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + last_step=last_step) with self.assertRaises(ValueError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', factor=2.0) - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', min_value=[0.1, 0.1]) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + min_value=[0.1, 0.1]) with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', min_value=[0.1, 0.1, 0.1, 0.1]) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + min_value=[0.1, 0.1, 0.1, 0.1]) with self.assertRaises(ValueError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', threshold=-1.0) with self.assertRaises(ValueError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', rule='foo') with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', threshold_rule='foo') + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + threshold_rule='foo') # Test error in step method - scheduler = ReduceOnPlateauParamScheduler( - self.optimizer, param_name='lr', monitor='loss') + scheduler = ReduceOnPlateauParamScheduler(self.optimizer, + param_name='lr', + monitor='loss') assert scheduler.step() is None with self.assertRaises(TypeError): @@ -566,8 +592,10 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + self._test_scheduler_value(scheduler, + targets, + epochs=epochs, + step_kwargs=metrics_list) # reset the state of optimizers self.optimizer = optim.SGD( @@ -703,15 +731,14 @@ def test_step_scheduler_state_dict(self): def test_multi_step_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - milestones=[2, 5, 9]), lambda: MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.01, - milestones=[1, 4, 6])) + lambda: MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[2, 5, 9]), + lambda: MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.01, + milestones=[1, 4, 6])) def test_exp_scheduler_state_dict(self): self._check_scheduler_state_dict( @@ -723,27 +750,24 @@ def test_exp_scheduler_state_dict(self): def test_cosine_scheduler_state_dict(self): epochs = 10 eta_min = 1e-10 - self._check_scheduler_state_dict( - lambda: CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', T_max=epochs, eta_min=eta_min - ), - lambda: CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=epochs // 2, - eta_min=eta_min / 2), - epochs=epochs) + self._check_scheduler_state_dict(lambda: CosineAnnealingParamScheduler( + self.optimizer, param_name='lr', T_max=epochs, eta_min=eta_min), + lambda: CosineAnnealingParamScheduler( + self.optimizer, + param_name='lr', + T_max=epochs // 2, + eta_min=eta_min / 2), + epochs=epochs) def test_linear_scheduler_state_dict(self): epochs = 10 self._check_scheduler_state_dict( lambda: LinearParamScheduler( self.optimizer, param_name='lr', start_factor=1 / 3), - lambda: LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=0, - end_factor=0.3), + lambda: LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=0, + end_factor=0.3), epochs=epochs) def test_poly_scheduler_state_dict(self): @@ -756,48 +780,44 @@ def test_poly_scheduler_state_dict(self): def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), + lambda: CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), epochs=10) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] self._check_scheduler_state_dict( - lambda: ReduceOnPlateauParamScheduler( - self.optimizer, - param_name='lr', - monitor='loss', - rule='less', - factor=0.01, - patience=5, - threshold=1e-4, - threshold_rule='rel', - cooldown=0, - min_value=0.0, - eps=1e-8), - lambda: ReduceOnPlateauParamScheduler( - self.optimizer, - param_name='lr', - monitor='loss_foo', - rule='greater', - factor=0.05, - patience=10, - threshold=1e-5, - threshold_rule='abs', - cooldown=5, - min_value=0.1, - eps=1e-9), + lambda: ReduceOnPlateauParamScheduler(self.optimizer, + param_name='lr', + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauParamScheduler(self.optimizer, + param_name='lr', + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), epochs=epochs, step_kwargs=metrics_list) @@ -825,8 +845,10 @@ def test_step_scheduler_convert_iterbased(self): gamma=0.1, step_size=2, epoch_length=epoch_length) - self._test_scheduler_value( - scheduler, targets, epochs * epoch_length, param_name='momentum') + self._test_scheduler_value(scheduler, + targets, + epochs * epoch_length, + param_name='momentum') def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.05 if epoch < 2 @@ -878,8 +900,8 @@ def test_linear_scheduler_convert_iterbased(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs * epoch_length - iters) + single_targets = [x * 0.05 for x in interpolation + ] + [0.05] * (epochs * epoch_length - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] @@ -940,13 +962,11 @@ def test_poly_scheduler_convert_iterbased(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] scheduler = PolyParamScheduler.build_iter_from_epoch( self.optimizer, @@ -965,27 +985,28 @@ def test_multi_scheduler_without_overlap_linear_multi_step(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=1 / 2, - begin=0, - end=5) - scheduler2 = MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - milestones=[3, 6], - begin=5, - end=12) + scheduler1 = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[3, 6], + begin=5, + end=12) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_without_overlap_exp_cosine(self): # use Exp in the first 5 epochs and then use Cosine epochs = 10 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -996,13 +1017,12 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=5, - eta_min=eta_min, - begin=5, - end=10) + scheduler2 = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=5, + eta_min=eta_min, + begin=5, + end=10) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) @@ -1014,14 +1034,15 @@ def test_multi_scheduler_with_overlap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=1 / 2, - begin=0, - end=5) - scheduler2 = MultiStepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, milestones=[3, 6, 9]) + scheduler1 = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[3, 6, 9]) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_gap(self): @@ -1029,8 +1050,11 @@ def test_multi_scheduler_with_gap(self): # no scheduler in the middle 5 epochs epochs = 15 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -1042,32 +1066,33 @@ def test_multi_scheduler_with_gap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=5, - eta_min=eta_min, - begin=10, - end=15) + scheduler2 = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=5, + eta_min=eta_min, + begin=10, + end=15) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_onecycle_scheduler(self): # test invalid total steps with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, param_name='lr', total_steps=-1) + OneCycleParamScheduler(self.optimizer, + param_name='lr', + total_steps=-1) # test invalid pct_start with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, param_name='lr', total_steps=10, pct_start=-1) + OneCycleParamScheduler(self.optimizer, + param_name='lr', + total_steps=10, + pct_start=-1) # test invalid anneal_strategy with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, - param_name='lr', - total_steps=10, - anneal_strategy='a') + OneCycleParamScheduler(self.optimizer, + param_name='lr', + total_steps=10, + anneal_strategy='a') class TestParameterSchedulerOptimWrapper(TestParameterScheduler): diff --git a/tests/test_registry/test_build_functions.py b/tests/test_registry/test_build_functions.py index 80094ae107..b570b89fa8 100644 --- a/tests/test_registry/test_build_functions.py +++ b/tests/test_registry/test_build_functions.py @@ -90,19 +90,22 @@ def __init__(self, depth, stages=4): # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type='ResNet'))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 @@ -197,24 +200,22 @@ def test_build_scheduler_from_cfg(): from torch.optim import SGD model = nn.Conv2d(1, 1, 1) optimizer = SGD(model.parameters(), lr=0.1) - cfg = dict( - type='LinearParamScheduler', - optimizer=optimizer, - param_name='lr', - begin=0, - end=100) + cfg = dict(type='LinearParamScheduler', + optimizer=optimizer, + param_name='lr', + begin=0, + end=100) scheduler = PARAM_SCHEDULERS.build(cfg) assert scheduler.begin == 0 assert scheduler.end == 100 - cfg = dict( - type='LinearParamScheduler', - convert_to_iter_based=True, - optimizer=optimizer, - param_name='lr', - begin=0, - end=100, - epoch_length=10) + cfg = dict(type='LinearParamScheduler', + convert_to_iter_based=True, + optimizer=optimizer, + param_name='lr', + begin=0, + end=100, + epoch_length=10) scheduler = PARAM_SCHEDULERS.build(cfg) assert scheduler.begin == 0 diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index eb99b3dc8e..f339d37420 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -134,10 +134,9 @@ class BritishShorthair: # test `module` parameter, which is either None or a class # when the `register_module`` is called as a method rather than a # decorator, which must be a class - with pytest.raises( - TypeError, - match='module must be Callable,' - " but got "): + with pytest.raises(TypeError, + match='module must be Callable,' + " but got "): CATS.register_module(module='string') class SphynxCat: @@ -183,15 +182,17 @@ def _build_registry(self): registries.append(DOGS) HOUNDS = Registry('hounds', parent=DOGS, scope='hound') registries.append(HOUNDS) - LITTLE_HOUNDS = Registry( - 'little hounds', parent=HOUNDS, scope='little_hound') + LITTLE_HOUNDS = Registry('little hounds', + parent=HOUNDS, + scope='little_hound') registries.append(LITTLE_HOUNDS) MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound') registries.append(MID_HOUNDS) SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed') registries.append(SAMOYEDS) - LITTLE_SAMOYEDS = Registry( - 'little samoyeds', parent=SAMOYEDS, scope='little_samoyed') + LITTLE_SAMOYEDS = Registry('little samoyeds', + parent=SAMOYEDS, + scope='little_samoyed') registries.append(LITTLE_SAMOYEDS) return registries @@ -408,14 +409,14 @@ class Beagle: # test `default_scope` # switch the current registry to another registry - DefaultScope.get_instance( - f'test-{time.time()}', scope_name='mid_hound') + DefaultScope.get_instance(f'test-{time.time()}', + scope_name='mid_hound') dog = LITTLE_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) # `default_scope` can not be found - DefaultScope.get_instance( - f'test2-{time.time()}', scope_name='scope-not-found') + DefaultScope.get_instance(f'test2-{time.time()}', + scope_name='scope-not-found') dog = MID_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) @@ -431,20 +432,18 @@ class YourSamoyed: pass s_cfg = cfg_type( - dict( - _scope_='samoyed', - type='MySamoyed', - friend=dict(type='hound.BloodHound'))) + dict(_scope_='samoyed', + type='MySamoyed', + friend=dict(type='hound.BloodHound'))) dog = DOGS.build(s_cfg) assert isinstance(dog, MySamoyed) assert isinstance(dog.friend, BloodHound) assert DefaultScope.get_current_instance().scope_name != 'samoyed' s_cfg = cfg_type( - dict( - _scope_='samoyed', - type='MySamoyed', - friend=dict(type='YourSamoyed'))) + dict(_scope_='samoyed', + type='MySamoyed', + friend=dict(type='YourSamoyed'))) dog = DOGS.build(s_cfg) assert isinstance(dog, MySamoyed) assert isinstance(dog.friend, YourSamoyed) @@ -456,9 +455,9 @@ class YourSamoyed: lambda_cfg = cfg_type(dict(type='lambda_dog', name='unknown')) assert DOGS.build(lambda_cfg) == 'unknown' - DOGS.register_module( - name='patial dog', - module=functools.partial(lambda_dog, name='patial')) + DOGS.register_module(name='patial dog', + module=functools.partial(lambda_dog, + name='patial')) unknown_cfg = cfg_type(dict(type='patial dog')) assert DOGS.build(unknown_cfg) == 'patial' @@ -474,8 +473,8 @@ def test_switch_scope_and_registry(self): # | | | # HOUNDS (hound) SAMOYEDS (samoyed) CHIHUAHUA (chihuahua) - DefaultScope.get_instance( - f'scope_{time.time()}', scope_name='chihuahua') + DefaultScope.get_instance(f'scope_{time.time()}', + scope_name='chihuahua') assert DefaultScope.get_current_instance().scope_name == 'chihuahua' # Test switch scope and get target registry. @@ -597,19 +596,22 @@ def __init__(self, depth, stages=4): # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type='ResNet'))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index 4655a4c5da..51efc7db2f 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -211,14 +211,16 @@ def __init__(self): # add prefix torch.save(model.state_dict(), checkpoint_path) - state_dict = load_checkpoint( - pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')]) + state_dict = load_checkpoint(pmodel, + checkpoint_path, + revise_keys=[(r'^', 'backbone.')]) for key in pmodel.backbone.state_dict().keys(): assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key]) # strip prefix torch.save(pmodel.state_dict(), checkpoint_path) - state_dict = load_checkpoint( - model, checkpoint_path, revise_keys=[(r'^backbone\.', '')]) + state_dict = load_checkpoint(model, + checkpoint_path, + revise_keys=[(r'^backbone\.', '')]) for key in state_dict.keys(): key_stripped = re.sub(r'^backbone\.', '', key) @@ -354,6 +356,7 @@ def load_from_abc(filename, map_location): assert loader.__name__ == 'load_from_abc' +@patch.dict(sys.modules, {'petrel_client': MagicMock()}) def test_save_checkpoint(tmp_path): model = Model() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) @@ -366,17 +369,19 @@ def test_save_checkpoint(tmp_path): save_checkpoint(model.state_dict(), filename) filename = str(tmp_path / 'checkpoint2.pth') - checkpoint = dict( - model=model.state_dict(), optimizer=optimizer.state_dict()) + checkpoint = dict(model=model.state_dict(), + optimizer=optimizer.state_dict()) save_checkpoint(checkpoint, filename) filename = str(tmp_path / 'checkpoint3.pth') - save_checkpoint( - model.state_dict(), filename, backend_args={'backend': 'local'}) + save_checkpoint(model.state_dict(), + filename, + backend_args={'backend': 'local'}) filename = str(tmp_path / 'checkpoint4.pth') - save_checkpoint( - model.state_dict(), filename, file_client_args={'backend': 'disk'}) + save_checkpoint(model.state_dict(), + filename, + file_client_args={'backend': 'disk'}) # 2. save to petrel oss with patch.object(PetrelBackend, 'put') as mock_method: @@ -386,10 +391,9 @@ def test_save_checkpoint(tmp_path): with patch.object(PetrelBackend, 'put') as mock_method: filename = 's3://path//of/your/checkpoint2.pth' - save_checkpoint( - model.state_dict(), - filename, - file_client_args={'backend': 'petrel'}) + save_checkpoint(model.state_dict(), + filename, + file_client_args={'backend': 'petrel'}) mock_method.assert_called() diff --git a/tests/test_runner/test_log_processor.py b/tests/test_runner/test_log_processor.py index d7fae5722a..b48b218c9e 100644 --- a/tests/test_runner/test_log_processor.py +++ b/tests/test_runner/test_log_processor.py @@ -16,8 +16,9 @@ class TestLogProcessor(RunnerTestCase): def test_init(self): - log_processor = LogProcessor( - window_size=10, by_epoch=True, custom_cfg=None) + log_processor = LogProcessor(window_size=10, + by_epoch=True, + custom_cfg=None) assert log_processor.by_epoch assert log_processor.window_size == 10 assert log_processor.custom_cfg == [] @@ -81,8 +82,8 @@ def test_parse_windows_size(self): # yapf: enable def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy): # Prepare LoggerHook - log_processor = LogProcessor( - by_epoch=by_epoch, log_with_hierarchy=log_with_hierarchy) + log_processor = LogProcessor(by_epoch=by_epoch, + log_with_hierarchy=log_with_hierarchy) log_processor._get_max_memory = MagicMock(return_value='100') eta = 40 self.runner.message_hub.update_info('eta', eta) @@ -157,15 +158,15 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy): [False, 'val', False], [True, 'test', True], [False, 'test', False])) def test_log_val(self, by_epoch, mode, log_with_hierarchy): # Prepare LoggerHook - log_processor = LogProcessor( - by_epoch=by_epoch, log_with_hierarchy=log_with_hierarchy) + log_processor = LogProcessor(by_epoch=by_epoch, + log_with_hierarchy=log_with_hierarchy) # Prepare validation information. scalar_logs = dict(accuracy=0.9, data_time=1.0) - non_scalar_logs = dict( - recall={ - 'cat': 1, - 'dog': 0 - }, cm=torch.tensor([1, 2, 3])) + non_scalar_logs = dict(recall={ + 'cat': 1, + 'dog': 0 + }, + cm=torch.tensor([1, 2, 3])) log_processor._collect_scalars = MagicMock(return_value=scalar_logs) log_processor._collect_non_scalars = MagicMock( return_value=non_scalar_logs) @@ -207,8 +208,9 @@ def test_collect_scalars(self): 'val/metric': history_metric_buffer } self.runner.message_hub._log_scalars = log_scalars - tag = log_processor._collect_scalars( - copy.deepcopy(custom_cfg), self.runner, mode='train') + tag = log_processor._collect_scalars(copy.deepcopy(custom_cfg), + self.runner, + mode='train') # Training key in tag. assert list(tag.keys()) == ['time', 'loss_cls', 'time_max'] # Test statistics lr with `current`, loss and time with 'mean' @@ -217,17 +219,17 @@ def test_collect_scalars(self): assert tag['loss_cls'] == loss_cls_scalars[-10:].mean() # Validation key in tag - tag = log_processor._collect_scalars( - copy.deepcopy(custom_cfg), self.runner, mode='val') + tag = log_processor._collect_scalars(copy.deepcopy(custom_cfg), + self.runner, + mode='val') assert list(tag.keys()) == ['metric'] assert tag['metric'] == metric_scalars[-1] # reserve_prefix=True - tag = log_processor._collect_scalars( - copy.deepcopy(custom_cfg), - self.runner, - mode='train', - reserve_prefix=True) + tag = log_processor._collect_scalars(copy.deepcopy(custom_cfg), + self.runner, + mode='train', + reserve_prefix=True) assert list( tag.keys()) == ['train/time', 'train/loss_cls', 'train/time_max'] # Test statistics lr with `current`, loss and time with 'mean' @@ -315,31 +317,27 @@ def setUp(self): def test_with_runner(self): cfg = self.epoch_based_cfg.copy() - cfg.log_processor = dict( - custom_cfg=[ - dict( - data_src='time', - window_size='epoch', - log_name='iter_time', - method_name='mean') - ], - log_with_hierarchy=True) + cfg.log_processor = dict(custom_cfg=[ + dict(data_src='time', + window_size='epoch', + log_name='iter_time', + method_name='mean') + ], + log_with_hierarchy=True) runner = self.build_runner(cfg) runner.train() runner.val() runner.test() cfg = self.iter_based_cfg.copy() - cfg.log_processor = dict( - by_epoch=False, - custom_cfg=[ - dict( - data_src='time', - window_size=100, - log_name='iter_time', - method_name='mean') - ], - log_with_hierarchy=True) + cfg.log_processor = dict(by_epoch=False, + custom_cfg=[ + dict(data_src='time', + window_size=100, + log_name='iter_time', + method_name='mean') + ], + log_with_hierarchy=True) runner = self.build_runner(cfg) runner.train() runner.val() diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e7668054bb..6d0f3da0f1 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -389,40 +389,42 @@ def setUp(self): epoch_based_cfg = dict( model=dict(type='ToyModel'), work_dir=self.temp_dir, - train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), + train_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), + test_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), auto_scale_lr=dict(base_batch_size=16, enable=False), - optim_wrapper=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), + optim_wrapper=dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01)), param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), val_evaluator=dict(type='ToyMetric1'), test_evaluator=dict(type='ToyMetric1'), - train_cfg=dict( - by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), + train_cfg=dict(by_epoch=True, + max_epochs=3, + val_interval=1, + val_begin=1), val_cfg=dict(), test_cfg=dict(), custom_hooks=[], - default_hooks=dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict( - type='CheckpointHook', interval=1, by_epoch=True), - sampler_seed=dict(type='DistSamplerSeedHook')), + default_hooks=dict(runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', + interval=1, + by_epoch=True), + sampler_seed=dict(type='DistSamplerSeedHook')), data_preprocessor=None, launcher='none', env_cfg=dict(dist_cfg=dict(backend='nccl')), @@ -620,23 +622,25 @@ def test_init(self): train_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn) val_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn) test_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn) - runner = Runner( - model=model, - work_dir=self.temp_dir, - train_cfg=dict( - by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), - train_dataloader=train_dataloader, - optim_wrapper=optim_wrapper, - param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), - val_cfg=dict(), - val_dataloader=val_dataloader, - val_evaluator=[ToyMetric1()], - test_cfg=dict(), - test_dataloader=test_dataloader, - test_evaluator=[ToyMetric1()], - default_hooks=dict(param_scheduler=toy_hook), - custom_hooks=[toy_hook2], - experiment_name='test_init14') + runner = Runner(model=model, + work_dir=self.temp_dir, + train_cfg=dict(by_epoch=True, + max_epochs=3, + val_interval=1, + val_begin=1), + train_dataloader=train_dataloader, + optim_wrapper=optim_wrapper, + param_scheduler=MultiStepLR(optim_wrapper, + milestones=[1, 2]), + val_cfg=dict(), + val_dataloader=val_dataloader, + val_evaluator=[ToyMetric1()], + test_cfg=dict(), + test_dataloader=test_dataloader, + test_evaluator=[ToyMetric1()], + default_hooks=dict(param_scheduler=toy_hook), + custom_hooks=[toy_hook2], + experiment_name='test_init14') runner.train() runner.test() @@ -693,8 +697,8 @@ def test_init(self): # 6.6 Test initializing with `_ParameterScheduler`. optimizer = SGD(nn.Linear(1, 1).parameters(), lr=0.1) - cfg.param_scheduler = MultiStepLR( - milestones=[1, 2], optimizer=optimizer) + cfg.param_scheduler = MultiStepLR(milestones=[1, 2], + optimizer=optimizer) cfg.experiment_name = 'test_init22' Runner(**cfg) @@ -706,9 +710,10 @@ def test_init(self): Runner(**cfg) # 6.8 Test initializing with 2 `_ParameterScheduler` for 2 optimizers. - cfg.param_scheduler = dict( - linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer), - linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer)) + cfg.param_scheduler = dict(linear1=MultiStepLR(milestones=[1, 2], + optimizer=optimizer), + linear2=MultiStepLR(milestones=[1, 2], + optimizer=optimizer)) cfg.experiment_name = 'test_init24' Runner(**cfg) @@ -747,9 +752,8 @@ def test_dump_config(self): temp_config_file = tempfile.NamedTemporaryFile( dir=temp_config_dir, suffix='.py', delete=False) temp_config_file.close() - file_cfg = Config( - self.epoch_based_cfg._cfg_dict, - filename=temp_config_file.name) + file_cfg = Config(self.epoch_based_cfg._cfg_dict, + filename=temp_config_file.name) file_cfg.experiment_name = f'test_dump2{idx}' runner = Runner.from_cfg(cfg=file_cfg) assert osp.exists( @@ -813,9 +817,8 @@ def test_build_visualizer(self): runner.visualizer.instance_name) # input is a Visualizer object - self.assertEqual( - id(runner.build_visualizer(runner.visualizer)), - id(runner.visualizer)) + self.assertEqual(id(runner.build_visualizer(runner.visualizer)), + id(runner.visualizer)) # input is a dict visualizer_cfg = dict(type='Visualizer', name='test_build_visualizer2') @@ -835,8 +838,9 @@ def test_build_visualizer(self): runner.build_visualizer('invalid-type') def test_default_scope(self): - TOY_SCHEDULERS = Registry( - 'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy') + TOY_SCHEDULERS = Registry('parameter scheduler', + parent=PARAM_SCHEDULERS, + scope='toy') @TOY_SCHEDULERS.register_module(force=True) class ToyScheduler(MultiStepLR): @@ -844,8 +848,8 @@ class ToyScheduler(MultiStepLR): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.epoch_based_cfg.param_scheduler = dict( - type='ToyScheduler', milestones=[1, 2]) + self.epoch_based_cfg.param_scheduler = dict(type='ToyScheduler', + milestones=[1, 2]) self.epoch_based_cfg.default_scope = 'toy' cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1019,20 +1023,21 @@ def test_build_optim_wrapper(self): # "constructor" are not in optimizer optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01) optim_wrapper1 = OptimWrapper(optimizer1) - optim_wrapper2 = dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.01)) + optim_wrapper2 = dict(type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.01)) optim_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) with self.assertRaisesRegex(ValueError, 'each item mush be an optimizer object'): runner.build_optim_wrapper(optim_cfg) # 2.3 input is a dict which contains multiple configs - optim_wrapper_cfg = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + optim_wrapper_cfg = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapperDict) self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD) @@ -1049,8 +1054,9 @@ def test_build_optim_wrapper(self): # Specify the type of optimizer wrapper model = nn.Linear(1, 1) optimizer = SGD(model.parameters(), lr=0.1) - optim_wrapper_cfg = dict( - optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2) + optim_wrapper_cfg = dict(optimizer=optimizer, + type='ToyOptimWrapper', + accumulative_counts=2) optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, ToyOptimWrapper) self.assertIs(optim_wrapper.optimizer, optimizer) @@ -1065,10 +1071,10 @@ def test_build_param_scheduler(self): # `build_param_scheduler` cfg = dict(type='MultiStepLR', milestones=[1, 2]) runner.optim_wrapper = dict( - key1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - key2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), + key1=dict(type='OptimWrapper', optimizer=dict(type='SGD', + lr=0.01)), + key2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.02)), ) with self.assertRaisesRegex(AssertionError, 'should be called before'): runner.build_param_scheduler(cfg) @@ -1129,12 +1135,11 @@ def test_build_param_scheduler(self): self.assertEqual(len(param_schedulers['key2']), 2) # 4. test multiple optimizers and multiple parameter shceduers - cfg = dict( - key1=dict(type='MultiStepLR', milestones=[1, 2]), - key2=[ - dict(type='MultiStepLR', milestones=[1, 2]), - dict(type='StepLR', step_size=1) - ]) + cfg = dict(key1=dict(type='MultiStepLR', milestones=[1, 2]), + key2=[ + dict(type='MultiStepLR', milestones=[1, 2]), + dict(type='StepLR', step_size=1) + ]) param_schedulers = runner.build_param_scheduler(cfg) self.assertIsInstance(param_schedulers, dict) self.assertEqual(len(param_schedulers), 2) @@ -1146,16 +1151,16 @@ def test_build_param_scheduler(self): dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))) # 5.1 train loop should be built before converting scheduler - cfg = dict( - type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True) + cfg = dict(type='MultiStepLR', + milestones=[1, 2], + convert_to_iter_based=True) # 5.2 convert epoch-based to iter-based scheduler - cfg = dict( - type='MultiStepLR', - milestones=[1, 2], - begin=1, - end=7, - convert_to_iter_based=True) + cfg = dict(type='MultiStepLR', + milestones=[1, 2], + begin=1, + end=7, + convert_to_iter_based=True) runner._train_loop = runner.build_train_loop(runner.train_loop) param_schedulers = runner.build_param_scheduler(cfg) self.assertFalse(param_schedulers[0].by_epoch) @@ -1170,11 +1175,10 @@ def test_build_param_scheduler(self): # runner.max_epochs = 3 self.assertEqual(param_schedulers[0].end, 3) - cfg = dict( - type='MultiStepLR', - milestones=[1, 2], - begin=1, - convert_to_iter_based=True) + cfg = dict(type='MultiStepLR', + milestones=[1, 2], + begin=1, + convert_to_iter_based=True) param_schedulers = runner.build_param_scheduler(cfg) self.assertFalse(param_schedulers[0].by_epoch) self.assertEqual(param_schedulers[0].begin, 4) @@ -1217,12 +1221,11 @@ def test_build_evaluator(self): self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu') # test build a customize evaluator - evaluator = dict( - type='ToyEvaluator', - metrics=[ - dict(type='ToyMetric1', collect_device='cpu'), - dict(type='ToyMetric2', collect_device='gpu') - ]) + evaluator = dict(type='ToyEvaluator', + metrics=[ + dict(type='ToyMetric1', collect_device='cpu'), + dict(type='ToyMetric2', collect_device='gpu') + ]) _evaluator = runner.build_evaluator(evaluator) self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator) self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu') @@ -1237,11 +1240,10 @@ def test_build_dataloader(self): cfg.experiment_name = 'test_build_dataloader' runner = Runner.from_cfg(cfg) - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=1, - num_workers=0) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=1, + num_workers=0) seed = np.random.randint(2**31) dataloader = runner.build_dataloader(cfg, seed=seed) self.assertIsInstance(dataloader, DataLoader) @@ -1250,28 +1252,27 @@ def test_build_dataloader(self): self.assertEqual(dataloader.sampler.seed, seed) # diff_rank_seed is True - dataloader = runner.build_dataloader( - cfg, seed=seed, diff_rank_seed=True) + dataloader = runner.build_dataloader(cfg, + seed=seed, + diff_rank_seed=True) self.assertNotEqual(dataloader.sampler.seed, seed) # custom worker_init_fn - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2) dataloader = runner.build_dataloader(cfg) self.assertIs(dataloader.worker_init_fn.func, custom_worker_init) # collate_fn is a dict - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2, - collate_fn=dict(type='pseudo_collate')) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn=dict(type='pseudo_collate')) dataloader = runner.build_dataloader(cfg) self.assertIsInstance(dataloader.collate_fn, partial) @@ -1279,36 +1280,33 @@ def test_build_dataloader(self): def custom_collate(data_batch): return data_batch - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2, - collate_fn=custom_collate) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn=custom_collate) dataloader = runner.build_dataloader(cfg) self.assertIs(dataloader.collate_fn, custom_collate) # collate_fn is a invalid value with self.assertRaisesRegex( TypeError, 'collate_fn should be a dict or callable object'): - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2, - collate_fn='collate_fn') + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn='collate_fn') dataloader = runner.build_dataloader(cfg) self.assertIsInstance(dataloader.collate_fn, partial) # num_batch_per_epoch is not None - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate'), - batch_size=3, - num_workers=2, - num_batch_per_epoch=2) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + batch_size=3, + num_workers=2, + num_batch_per_epoch=2) dataloader = runner.build_dataloader(cfg) self.assertEqual(len(dataloader.dataset), 6) @@ -1432,8 +1430,8 @@ def test_build_log_processor(self): self.assertIsInstance(log_processor, LogProcessor) # input is a LogProcessor object - self.assertEqual( - id(runner.build_log_processor(log_processor)), id(log_processor)) + self.assertEqual(id(runner.build_log_processor(log_processor)), + id(log_processor)) # test custom validation log_processor cfg = dict(type='CustomLogProcessor') @@ -1525,8 +1523,10 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train3' cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.train_cfg = dict( - by_epoch=False, max_iters=12, val_interval=4, val_begin=4) + cfg.train_cfg = dict(by_epoch=False, + max_iters=12, + val_interval=4, + val_begin=4) runner = Runner.from_cfg(cfg) runner.train() @@ -1562,11 +1562,13 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train4' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) + cfg.train_dataloader.sampler = dict(type='DefaultSampler', + shuffle=True) cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.train_cfg = dict( - by_epoch=False, max_iters=12, val_interval=4, val_begin=4) + cfg.train_cfg = dict(by_epoch=False, + max_iters=12, + val_interval=4, + val_begin=4) runner = Runner.from_cfg(cfg) # Warning should be raised since the sampler is not InfiniteSampler. with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): @@ -1610,16 +1612,15 @@ def before_train_iter(self, runner, batch_idx, data_batch=None): cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train5' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) + cfg.train_dataloader.sampler = dict(type='DefaultSampler', + shuffle=True) cfg.custom_hooks = [ dict(type='TestIterDynamicIntervalHook', priority=50) ] - cfg.train_cfg = dict( - by_epoch=False, - max_iters=max_iters, - val_interval=interval, - dynamic_intervals=dynamic_intervals) + cfg.train_cfg = dict(by_epoch=False, + max_iters=max_iters, + val_interval=interval, + dynamic_intervals=dynamic_intervals) runner = Runner.from_cfg(cfg) runner.train() for result, target, in zip(iter_results, iter_targets): @@ -1647,16 +1648,15 @@ def before_train_epoch(self, runner): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_train6' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) + cfg.train_dataloader.sampler = dict(type='DefaultSampler', + shuffle=True) cfg.custom_hooks = [ dict(type='TestEpochDynamicIntervalHook', priority=50) ] - cfg.train_cfg = dict( - by_epoch=True, - max_epochs=max_epochs, - val_interval=interval, - dynamic_intervals=dynamic_intervals) + cfg.train_cfg = dict(by_epoch=True, + max_epochs=max_epochs, + val_interval=interval, + dynamic_intervals=dynamic_intervals) runner = Runner.from_cfg(cfg) runner.train() for result, target, in zip(epoch_results, epoch_targets): @@ -1687,12 +1687,13 @@ def init_weights(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_train8' cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2]) - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') runner = runner.from_cfg(cfg) runner.train() @@ -1701,12 +1702,13 @@ def init_weights(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_train8.1.1' cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2]) - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') runner = runner.from_cfg(cfg) runner.train() @@ -1759,8 +1761,8 @@ def init_weights(self): # 10.3 Test build dataloader with custom worker_init function cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train10.3' - cfg.train_dataloader.update( - worker_init_fn=dict(type='custom_worker_init')) + cfg.train_dataloader.update(worker_init_fn=dict( + type='custom_worker_init')) runner = Runner.from_cfg(cfg) runner.train() @@ -1835,9 +1837,8 @@ def train_step(self, *args, **kwargs): runner.train() self.assertEqual(runner.iter, 3 * 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, + reason='torch.compile is not valid, please install PyTorch>=2.0.0') def test_train_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1947,9 +1948,8 @@ def after_val_iter(self, runner.val() self.assertEqual(val_result, 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, + reason='torch.compile is not valid, please install PyTorch>=2.0.0') def test_val_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) @@ -2052,9 +2052,8 @@ def after_test_iter(self, runner.test() self.assertEqual(test_result, 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, + reason='torch.compile is not valid, please install PyTorch>=2.0.0') def test_test_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) @@ -2088,8 +2087,8 @@ def test_register_hook(self): self.assertEqual(len(runner._hooks), 1) self.assertTrue(isinstance(runner._hooks[0], IterTimerHook)) # default priority of `IterTimerHook` is 'NORMAL' - self.assertEqual( - get_priority(runner._hooks[0].priority), get_priority('NORMAL')) + self.assertEqual(get_priority(runner._hooks[0].priority), + get_priority('NORMAL')) runner._hooks = [] # 1.2.1 `hook` is a dict and contains `priority` field @@ -2098,9 +2097,8 @@ def test_register_hook(self): runner.register_hook(timer_cfg) self.assertEqual(len(runner._hooks), 1) self.assertTrue(isinstance(runner._hooks[0], IterTimerHook)) - self.assertEqual( - get_priority(runner._hooks[0].priority), - get_priority('BELOW_NORMAL')) + self.assertEqual(get_priority(runner._hooks[0].priority), + get_priority('BELOW_NORMAL')) # 1.3 `hook` is a hook object runtime_info_hook = RuntimeInfoHook() @@ -2110,8 +2108,8 @@ def test_register_hook(self): # `IterTimerHook`, so the first item of `_hooks` should be # `runtime_info_hook` self.assertTrue(isinstance(runner._hooks[0], RuntimeInfoHook)) - self.assertEqual( - get_priority(runner._hooks[0].priority), get_priority('VERY_HIGH')) + self.assertEqual(get_priority(runner._hooks[0].priority), + get_priority('VERY_HIGH')) # 2. test `priority` parameter # `priority` argument is not None and it will be set as priority of @@ -2120,16 +2118,16 @@ def test_register_hook(self): runner.register_hook(param_scheduler_cfg, priority='VERY_LOW') self.assertEqual(len(runner._hooks), 3) self.assertTrue(isinstance(runner._hooks[2], ParamSchedulerHook)) - self.assertEqual( - get_priority(runner._hooks[2].priority), get_priority('VERY_LOW')) + self.assertEqual(get_priority(runner._hooks[2].priority), + get_priority('VERY_LOW')) # `priority` is Priority logger_cfg = dict(type='LoggerHook', priority='BELOW_NORMAL') runner.register_hook(logger_cfg, priority=Priority.VERY_LOW) self.assertEqual(len(runner._hooks), 4) self.assertTrue(isinstance(runner._hooks[3], LoggerHook)) - self.assertEqual( - get_priority(runner._hooks[3].priority), get_priority('VERY_LOW')) + self.assertEqual(get_priority(runner._hooks[3].priority), + get_priority('VERY_LOW')) def test_default_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -2189,8 +2187,9 @@ class CustomTrainLoop2(IterBasedTrainLoop): def __init__(self, runner, dataloader, max_iters, warmup_loader, max_warmup_iters): - super().__init__( - runner=runner, dataloader=dataloader, max_iters=max_iters) + super().__init__(runner=runner, + dataloader=dataloader, + max_iters=max_iters) self.warmup_loader = self.runner.build_dataloader( warmup_loader) self.max_warmup_iters = max_warmup_iters @@ -2213,13 +2212,13 @@ def run(self): self.runner.call_hook('after_train') def warmup_iter(self, data_batch): - self.runner.call_hook( - 'before_warmup_iter', data_batch=data_batch) + self.runner.call_hook('before_warmup_iter', + data_batch=data_batch) train_logs = self.runner.model.train_step( data_batch, self.runner.optim_wrapper) self.runner.message_hub.update_info('train_logs', train_logs) - self.runner.call_hook( - 'after_warmup_iter', data_batch=data_batch) + self.runner.call_hook('after_warmup_iter', + data_batch=data_batch) before_warmup_iter_results = [] after_warmup_iter_results = [] @@ -2237,11 +2236,11 @@ def after_warmup_iter(self, runner, data_batch=None, outputs=None): self.iter_based_cfg.train_cfg = dict( type='CustomTrainLoop2', max_iters=10, - warmup_loader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='InfiniteSampler', shuffle=True), - batch_size=1, - num_workers=0), + warmup_loader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='InfiniteSampler', + shuffle=True), + batch_size=1, + num_workers=0), max_warmup_iters=5) self.iter_based_cfg.custom_hooks = [ dict(type='TestWarmupHook', priority=50) @@ -2272,7 +2271,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 3) self.assertEqual(ckpt['meta']['iter'], 12) self.assertEqual(ckpt['meta']['experiment_name'], @@ -2304,8 +2303,8 @@ def test_checkpoint(self): # 1.3.1 test `resume` cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint3' - cfg.optim_wrapper = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)) + cfg.optim_wrapper = dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.2)) cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) runner.resume(path) @@ -2380,12 +2379,13 @@ def test_checkpoint(self): # 1.6 multiple optimizers cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint6' - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') # disable OptimizerHook because it only works with one optimizer runner = Runner.from_cfg(cfg) @@ -2400,12 +2400,13 @@ def test_checkpoint(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint7' - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.2)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.03)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) @@ -2444,7 +2445,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 0) self.assertEqual(ckpt['meta']['iter'], 12) assert isinstance(ckpt['optimizer'], dict) @@ -2455,7 +2456,7 @@ def test_checkpoint(self): self.assertEqual(message_hub.get_info('iter'), 11) # 2.1.2 check class attribute _statistic_methods can be saved HistoryBuffer._statistics_methods.clear() - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertIn('min', HistoryBuffer._statistics_methods) # 2.2 test `load_checkpoint` @@ -2518,12 +2519,11 @@ def test_checkpoint(self): # 2.7.1 test `resume` 2 optimizers and 1 scheduler list. path = osp.join(self.temp_dir, 'epoch_3.pth') - optim_cfg = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + optim_cfg = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint14' cfg.optim_wrapper = optim_cfg @@ -2546,9 +2546,11 @@ def test_checkpoint(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint16' cfg.optim_wrapper = optim_cfg - cfg.param_scheduler = dict( - linear1=dict(type='MultiStepLR', milestones=[1, 2, 3]), - linear2=dict(type='StepLR', gamma=0.1, step_size=3)) + cfg.param_scheduler = dict(linear1=dict(type='MultiStepLR', + milestones=[1, 2, 3]), + linear2=dict(type='StepLR', + gamma=0.1, + step_size=3)) cfg.model = dict(type='ToyGANModel') resumed_cfg = copy.deepcopy(cfg) runner = Runner.from_cfg(cfg) diff --git a/tests/test_strategies/test_fsdp.py b/tests/test_strategies/test_fsdp.py index 64b900d2f8..545651b5da 100644 --- a/tests/test_strategies/test_fsdp.py +++ b/tests/test_strategies/test_fsdp.py @@ -59,33 +59,29 @@ def test_init(self): strategy = FSDPStrategy(state_dict_cfg='full') self._assert_full(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.LOCAL_STATE_DICT)) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.LOCAL_STATE_DICT)) self._assert_local(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig(), - optim_state_dict_config=FullOptimStateDictConfig(), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(), + optim_state_dict_config=FullOptimStateDictConfig(), + )) self._assert_full(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type='FULL_STATE_DICT', - state_dict_config=dict(type='FullStateDictConfig'), - optim_state_dict_config=dict(type='FullOptimStateDictConfig'), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type='FULL_STATE_DICT', + state_dict_config=dict(type='FullStateDictConfig'), + optim_state_dict_config=dict(type='FullOptimStateDictConfig'), + )) self._assert_full(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=dict(type=FullStateDictConfig), - optim_state_dict_config=dict(type=FullOptimStateDictConfig), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + )) self._assert_full(strategy) with self.assertRaises(ValueError): @@ -97,33 +93,28 @@ def test_init(self): # state_dict_type must be a str or a enumerate of StateDictType with self.assertRaises(TypeError): - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=[], - state_dict_config=dict(type=FullStateDictConfig), - optim_state_dict_config=dict( - type=FullOptimStateDictConfig), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=[], + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + )) # state_dict_config should be a dict or a subclass of StateDictConfig with self.assertRaises(TypeError): - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=[], - optim_state_dict_config=dict( - type=FullOptimStateDictConfig), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=[], + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + )) # optim_state_dict_config should be a dict or a subclass of # OptimStateDictConfig with self.assertRaises(TypeError): - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=dict(type=FullStateDictConfig), - optim_state_dict_config=[], - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=[], + )) def run_strategy(self): # Strategy can run with the built model, optimizer and schedulers. @@ -168,8 +159,8 @@ def run_strategy(self): # optimizer with multiple param_groups can be reconstructed. model = ToyModel() - strategy = FSDPStrategy( - model_wrapper=dict(auto_wrap_policy=linear_wrap_policy)) + strategy = FSDPStrategy(model_wrapper=dict( + auto_wrap_policy=linear_wrap_policy)) param_groups = [] for param in model.parameters(): param_groups.append(dict(params=[param], lr=0.1)) @@ -204,10 +195,9 @@ def _worker(cls, rank, func): self.tearDown() def test_run_strategy(self): - start_processes( - TestStrategy._worker, - args=('run_strategy', ), - nprocs=self.world_size) + start_processes(TestStrategy._worker, + args=('run_strategy', ), + nprocs=self.world_size) def test_build_model(self): ... diff --git a/tests/test_structures/test_data_element.py b/tests/test_structures/test_data_element.py index 1cb7cd1745..d1c1e9afb9 100644 --- a/tests/test_structures/test_data_element.py +++ b/tests/test_structures/test_data_element.py @@ -29,8 +29,9 @@ def gt_instances(self): @gt_instances.setter def gt_instances(self, value): - self.set_field( - value=value, name='_gt_instances', dtype=BaseDataElement) + self.set_field(value=value, + name='_gt_instances', + dtype=BaseDataElement) @gt_instances.deleter def gt_instances(self): @@ -42,8 +43,9 @@ def pred_instances(self): @pred_instances.setter def pred_instances(self, value): - self.set_field( - value=value, name='_pred_instances', dtype=BaseDataElement) + self.set_field(value=value, + name='_pred_instances', + dtype=BaseDataElement) @pred_instances.deleter def pred_instances(self): @@ -53,13 +55,13 @@ def pred_instances(self): class TestBaseDataElement(TestCase): def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) - gt_instances = BaseDataElement( - bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) - pred_instances = BaseDataElement( - bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + metainfo = dict(img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), + random.randint(400, 600))) + gt_instances = BaseDataElement(bboxes=torch.rand((5, 4)), + labels=torch.rand((5, ))) + pred_instances = BaseDataElement(bboxes=torch.rand((5, 4)), + scores=torch.rand((5, ))) data = dict(gt_instances=gt_instances, pred_instances=pred_instances) return metainfo, data @@ -232,8 +234,8 @@ def test_set_data(self): def test_update(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) - proposals = BaseDataElement( - bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + proposals = BaseDataElement(bboxes=torch.rand((5, 4)), + scores=torch.rand((5, ))) new_instances = BaseDataElement(proposals=proposals) instances.update(new_instances) self.check_key_value(instances, metainfo, @@ -267,8 +269,8 @@ def test_delete_modify(self): del instances.gt_instances del instances.img_id - assert not self.is_equal( - instances.pop('pred_instances', None), data['pred_instances']) + assert not self.is_equal(instances.pop('pred_instances', None), + data['pred_instances']) with self.assertRaises(AttributeError): del instances.pred_instances @@ -293,8 +295,8 @@ def test_delete_modify(self): with self.assertRaises(AttributeError): del instances._data_fields - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='GPU is required!') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='GPU is required!') def test_cuda(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) @@ -338,8 +340,9 @@ def test_detach(self): def test_repr(self): metainfo = dict(img_shape=(800, 1196, 3)) - gt_instances = BaseDataElement( - metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + gt_instances = BaseDataElement(metainfo=metainfo, + det_labels=torch.LongTensor( + [0, 1, 2, 3])) sample = BaseDataElement(metainfo=metainfo, gt_instances=gt_instances) address = hex(id(sample)) address_gt_instances = hex(id(sample.gt_instances)) diff --git a/tests/test_structures/test_instance_data.py b/tests/test_structures/test_instance_data.py index fe4a1b2603..20009741ef 100644 --- a/tests/test_structures/test_instance_data.py +++ b/tests/test_structures/test_instance_data.py @@ -73,9 +73,9 @@ def __repr__(self): class TestInstanceData(TestCase): def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) + metainfo = dict(img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), + random.randint(400, 600))) instances_infos = [1] * 5 bboxes = torch.rand((5, 4)) labels = np.random.rand(5) @@ -83,15 +83,14 @@ def setup_data(self): ids = (1, 2, 3, 4, 5) name_ids = '12345' polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist()) - instance_data = InstanceData( - metainfo=metainfo, - bboxes=bboxes, - labels=labels, - polygons=polygons, - kps=kps, - ids=ids, - name_ids=name_ids, - instances_infos=instances_infos) + instance_data = InstanceData(metainfo=metainfo, + bboxes=bboxes, + labels=labels, + polygons=polygons, + kps=kps, + ids=ids, + name_ids=name_ids, + instances_infos=instances_infos) return instance_data def test_set_data(self): @@ -189,8 +188,8 @@ def test_cat(self): assert len(cat_instance_data) == 10 # All inputs must be InstanceData - instance_data_2 = BaseDataElement( - bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) + instance_data_2 = BaseDataElement(bboxes=torch.rand((5, 4)), + labels=torch.rand((5, ))) with self.assertRaises(AssertionError): InstanceData.cat([instance_data_1, instance_data_2]) @@ -208,11 +207,10 @@ def test_cat(self): instance_data_1.polygons = TmpObjectWithoutCat( np.arange(25).reshape((5, -1)).tolist()) instance_data_2 = instance_data_1.clone() - with pytest.raises( - ValueError, - match=('The type of `polygons` is ' - f'`{type(instance_data_1.polygons)}` ' - 'which has no attribute of `cat`')): + with pytest.raises(ValueError, + match=('The type of `polygons` is ' + f'`{type(instance_data_1.polygons)}` ' + 'which has no attribute of `cat`')): cat_instance_data = InstanceData.cat( [instance_data_1, instance_data_2]) diff --git a/tests/test_structures/test_label_data.py b/tests/test_structures/test_label_data.py index 8c73bca767..7cb771019f 100644 --- a/tests/test_structures/test_label_data.py +++ b/tests/test_structures/test_label_data.py @@ -21,10 +21,11 @@ def test_label_to_onehot(self): # item'max bigger than num_classes with self.assertRaises(AssertionError): - LabelData.label_to_onehot( - torch.tensor([11], dtype=torch.int64), num_classes) - onehot = LabelData.label_to_onehot( - label=torch.tensor([], dtype=torch.int64), num_classes=num_classes) + LabelData.label_to_onehot(torch.tensor([11], dtype=torch.int64), + num_classes) + onehot = LabelData.label_to_onehot(label=torch.tensor( + [], dtype=torch.int64), + num_classes=num_classes) assert (onehot == torch.zeros((num_classes, ), dtype=torch.int64)).all() @@ -50,8 +51,8 @@ def test_onehot_to_label(self): assert label == item assert label.device == item.device - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='GPU is required!') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='GPU is required!') def test_cuda(self): item = torch.arange(0, 9).cuda() onehot = LabelData.label_to_onehot(item, num_classes=10) diff --git a/tests/test_structures/test_pixel_data.py b/tests/test_structures/test_pixel_data.py index 1ca80373af..34fcc249b8 100644 --- a/tests/test_structures/test_pixel_data.py +++ b/tests/test_structures/test_pixel_data.py @@ -12,9 +12,9 @@ class TestPixelData(TestCase): def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) + metainfo = dict(img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), + random.randint(400, 600))) image = np.random.randint(0, 255, (4, 20, 40)) featmap = torch.randint(0, 255, (10, 20, 40)) pixel_data = PixelData(metainfo=metainfo, image=image, featmap=featmap) diff --git a/tests/test_testing/test_runner_test_case.py b/tests/test_testing/test_runner_test_case.py index 5d41c03531..be93e74ee6 100644 --- a/tests/test_testing/test_runner_test_case.py +++ b/tests/test_testing/test_runner_test_case.py @@ -46,8 +46,8 @@ def test_experiment_name(self): def test_init_dist(self): self.setup_dist_env() - self.assertEqual( - str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT']) + self.assertEqual(str(self.dist_cfg['MASTER_PORT']), + os.environ['MASTER_PORT']) self.assertEqual(self.dist_cfg['MASTER_ADDR'], os.environ['MASTER_ADDR']) self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK']) diff --git a/tests/test_utils/test_dl_utils/test_setup_env.py b/tests/test_utils/test_dl_utils/test_setup_env.py index 9ca98b4311..74c6881233 100644 --- a/tests/test_utils/test_dl_utils/test_setup_env.py +++ b/tests/test_utils/test_dl_utils/test_setup_env.py @@ -38,8 +38,9 @@ def test_setup_multi_processes(): assert os.getenv('OMP_NUM_THREADS') == '4' # test manually set opencv threads and mp start method - config = dict( - mp_start_method='spawn', opencv_num_threads=4, distributed=True) + config = dict(mp_start_method='spawn', + opencv_num_threads=4, + distributed=True) set_multi_processing(**config) assert cv2.getNumThreads() == 4 assert mp.get_start_method() == 'spawn' diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 7c43d04853..71af2e7c6a 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -158,8 +158,8 @@ def test_import_modules_from_strings(): with pytest.raises(ImportError): import_modules_from_strings('_not_implemented_module') with pytest.warns(UserWarning): - imported = import_modules_from_strings( - '_not_implemented_module', allow_failed_imports=True) + imported = import_modules_from_strings('_not_implemented_module', + allow_failed_imports=True) assert imported is None with pytest.warns(UserWarning): imported = import_modules_from_strings(['os.path', '_not_implemented'], diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py index bed91b6c18..e271e9d314 100644 --- a/tests/test_utils/test_package_utils.py +++ b/tests/test_utils/test_package_utils.py @@ -2,9 +2,13 @@ import os.path as osp import sys -import pkg_resources import pytest +try: + from importlib.metadata import PackageNotFoundError +except ImportError: + from importlib_metadata import PackageNotFoundError # type: ignore[import-untyped, no-redef, import-not-found] # noqa: E501 + from mmengine.utils import get_installed_path, is_installed @@ -33,5 +37,5 @@ def test_get_install_path(): assert get_installed_path('optim') == osp.join(PYTHONPATH, 'optim') sys.path.pop() - with pytest.raises(pkg_resources.DistributionNotFound): + with pytest.raises(PackageNotFoundError): get_installed_path('unknown') diff --git a/tests/test_utils/test_progressbar.py b/tests/test_utils/test_progressbar.py index 0636e25e1d..c2635f2d6c 100644 --- a/tests/test_utils/test_progressbar.py +++ b/tests/test_utils/test_progressbar.py @@ -23,8 +23,9 @@ def test_start(self): prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out) assert out.getvalue() == 'completed: 0, elapsed: 0s' reset_string_io(out) - prog_bar = mmengine.ProgressBar( - bar_width=bar_width, start=False, file=out) + prog_bar = mmengine.ProgressBar(bar_width=bar_width, + start=False, + file=out) assert out.getvalue() == '' reset_string_io(out) prog_bar.start() @@ -34,16 +35,17 @@ def test_start(self): prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' reset_string_io(out) - prog_bar = mmengine.ProgressBar( - 10, bar_width=bar_width, start=False, file=out) + prog_bar = mmengine.ProgressBar(10, + bar_width=bar_width, + start=False, + file=out) assert out.getvalue() == '' reset_string_io(out) prog_bar.start() assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' - @skipIf( - platform.system() != 'Linux', - reason='Only test `TestProgressBar.test_update` in Linux') + @skipIf(platform.system() != 'Linux', + reason='Only test `TestProgressBar.test_update` in Linux') def test_update(self): out = StringIO() bar_width = 20 @@ -62,9 +64,8 @@ def test_update(self): assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \ 'task/s, elapsed: 1s, ETA: 9s' - @skipIf( - platform.system() != 'Linux', - reason='Only test `TestProgressBar.test_adaptive_length` in Linux') + @skipIf(platform.system() != 'Linux', + reason='Only test `TestProgressBar.test_adaptive_length` in Linux') def test_adaptive_length(self): with patch.dict('os.environ', {'COLUMNS': '80'}): out = StringIO() @@ -108,13 +109,16 @@ def test_track_progress(): assert ret == [1, 2, 3] # tasks is an iterable object - ret = mmengine.track_progress( - return_itself, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) + ret = mmengine.track_progress(return_itself, ((i for i in [1, 2, 3]), 3), + bar_width=3, + file=out) assert ret == [1, 2, 3] # tasks is a range object - ret = mmengine.track_progress( - return_itself, range(1, 4), bar_width=3, file=out) + ret = mmengine.track_progress(return_itself, + range(1, 4), + bar_width=3, + file=out) assert ret == [1, 2, 3] @@ -143,19 +147,24 @@ def test_track_iter_progress(): def test_track_parallel_progress(): # tasks is a list out = StringIO() - ret = mmengine.track_parallel_progress( - return_itself, [1, 2, 3, 4], 2, bar_width=4, file=out) + ret = mmengine.track_parallel_progress(return_itself, [1, 2, 3, 4], + 2, + bar_width=4, + file=out) assert ret == [1, 2, 3, 4] # tasks is an iterable object - ret = mmengine.track_parallel_progress( - return_itself, ((i for i in [1, 2, 3, 4]), 4), - 2, - bar_width=4, - file=out) + ret = mmengine.track_parallel_progress(return_itself, + ((i for i in [1, 2, 3, 4]), 4), + 2, + bar_width=4, + file=out) assert ret == [1, 2, 3, 4] # tasks is a range object - ret = mmengine.track_parallel_progress( - return_itself, range(1, 5), 2, bar_width=4, file=out) + ret = mmengine.track_parallel_progress(return_itself, + range(1, 5), + 2, + bar_width=4, + file=out) assert ret == [1, 2, 3, 4] diff --git a/tests/test_utils/test_timer.py b/tests/test_utils/test_timer.py index 570f7ea380..de83d17527 100644 --- a/tests/test_utils/test_timer.py +++ b/tests/test_utils/test_timer.py @@ -7,8 +7,8 @@ import mmengine -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != 'Linux', + reason='Only test `Timer` in linux!') def test_timer_init(): timer = mmengine.Timer(start=False) assert not timer.is_running @@ -18,8 +18,8 @@ def test_timer_init(): assert timer.is_running -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != 'Linux', + reason='Only test `Timer` in linux!') def test_timer_run(): timer = mmengine.Timer() time.sleep(1) @@ -36,8 +36,8 @@ def test_timer_run(): timer.since_last_check() -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != 'Linux', + reason='Only test `Timer` in linux!') def test_timer_context(capsys): with mmengine.Timer(): time.sleep(1) diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py index c991462ef9..b04e24a7fd 100644 --- a/tests/test_visualizer/test_vis_backend.py +++ b/tests/test_visualizer/test_vis_backend.py @@ -156,8 +156,9 @@ def test_add_scalar(self): tensorboard_vis_backend.add_scalar('map', np.array(9), step=0) tensorboard_vis_backend.add_scalar('map', np.array(95), step=1) tensorboard_vis_backend.add_scalar('map', np.array([9])[0], step=0) - tensorboard_vis_backend.add_scalar( - 'map', np.array([95])[0], step=1) + tensorboard_vis_backend.add_scalar('map', + np.array([95])[0], + step=1) assert len(record) == 0 # test with tensor tensorboard_vis_backend.add_scalar('map', torch.tensor(0.9), step=0) @@ -266,8 +267,8 @@ def test_define_metric_cfg(self): wandb_vis_backend = WandbVisBackend( 'temp_dir', define_metric_cfg=define_metric_cfg) wandb_vis_backend._init_env() - wandb_vis_backend._wandb.define_metric.assert_any_call( - 'test3', summary='max') + wandb_vis_backend._wandb.define_metric.assert_any_call('test3', + summary='max') shutil.rmtree('temp_dir') @@ -284,11 +285,11 @@ def test_experiment(self): def test_create_experiment(self): with patch('mlflow.create_experiment') as mock_create_experiment: - MLflowVisBackend( - 'temp_dir', exp_name='test', - artifact_location='foo')._init_env() - mock_create_experiment.assert_any_call( - 'test', artifact_location='foo') + MLflowVisBackend('temp_dir', + exp_name='test', + artifact_location='foo')._init_env() + mock_create_experiment.assert_any_call('test', + artifact_location='foo') def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) @@ -366,8 +367,8 @@ def test_close(self): clearml_vis_backend.close() -@pytest.mark.skipif( - not is_installed('neptune'), reason='Neptune is not installed.') +@pytest.mark.skipif(not is_installed('neptune'), + reason='Neptune is not installed.') class TestNeptuneVisBackend: def test_init(self): @@ -457,9 +458,8 @@ def test_close(self): shutil.rmtree('temp_dir') -@pytest.mark.skipif( - platform.system() == 'Windows', - reason='Aim does not support Windows for now.') +@pytest.mark.skipif(platform.system() == 'Windows', + reason='Aim does not support Windows for now.') class TestAimVisBackend: def test_init(self): diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index e4ababc637..f7d9a06f1d 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -57,8 +57,8 @@ def setUp(self): TestCase calls functions in this order: setUp() -> testMethod() -> tearDown() -> cleanUp() """ - self.image = np.random.randint( - 0, 256, size=(10, 10, 3)).astype('uint8') + self.image = np.random.randint(0, 256, + size=(10, 10, 3)).astype('uint8') self.vis_backend_cfg = [ dict(type='MockVisBackend', name='mock1'), dict(type='MockVisBackend', name='mock2') @@ -72,35 +72,33 @@ def test_init(self): visualizer = Visualizer( vis_backends=copy.deepcopy(self.vis_backend_cfg)) - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') assert isinstance(visualizer.get_backend('mock1'), MockVisBackend) assert len(visualizer._vis_backends) == 2 # The name fields cannot be the same with pytest.raises(RuntimeError): - Visualizer( - vis_backends=[ - dict(type='MockVisBackend'), - dict(type='MockVisBackend') - ], - save_dir='temp_dir') + Visualizer(vis_backends=[ + dict(type='MockVisBackend'), + dict(type='MockVisBackend') + ], + save_dir='temp_dir') with pytest.raises(RuntimeError): - Visualizer( - vis_backends=[ - dict(type='MockVisBackend', name='mock1'), - dict(type='MockVisBackend', name='mock1') - ], - save_dir='temp_dir') + Visualizer(vis_backends=[ + dict(type='MockVisBackend', name='mock1'), + dict(type='MockVisBackend', name='mock1') + ], + save_dir='temp_dir') # test global init instance_name = 'visualizer' + str(time.time()) - visualizer = Visualizer.get_instance( - instance_name, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer.get_instance(instance_name, + vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') assert len(visualizer._vis_backends) == 2 visualizer_any = Visualizer.get_instance(instance_name) assert visualizer_any == visualizer @@ -120,9 +118,10 @@ def __init__(self, save_dir: str) -> None: VISBACKENDS.module_dict.pop('CustomLocalVisBackend') - visualizer = Visualizer.get_instance( - 'test_save_dir', - vis_backends=dict(type='CustomLocalVisBackend', save_dir='tmp')) + visualizer = Visualizer.get_instance('test_save_dir', + vis_backends=dict( + type='CustomLocalVisBackend', + save_dir='tmp')) visualizer = Visualizer.get_instance( 'test_save_dir', vis_backends=[CustomLocalVisBackend('tmp')]) @@ -148,8 +147,10 @@ def test_draw_bboxes(self): # valid bbox visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2])) bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]]) - visualizer.draw_bboxes( - bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles='-') + visualizer.draw_bboxes(bboxes, + alpha=0.5, + edge_colors=(255, 0, 0), + line_styles='-') bboxes = bboxes.numpy() visualizer.draw_bboxes(bboxes) @@ -159,10 +160,9 @@ def test_draw_bboxes(self): visualizer.draw_bboxes(torch.tensor([5, 1, 2, 2])) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The bbox is out of bounds,' - ' the drawn bbox may not be in the image'): + with pytest.warns(UserWarning, + match='Warning: The bbox is out of bounds,' + ' the drawn bbox may not be in the image'): visualizer.draw_bboxes(torch.tensor([1, 1, 20, 2])) # test incorrect bbox format @@ -170,10 +170,10 @@ def test_draw_bboxes(self): visualizer.draw_bboxes([1, 1, 2, 2]) def test_close(self): - visualizer = Visualizer( - image=self.image, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(image=self.image, + vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._close is False @@ -189,36 +189,33 @@ def test_draw_points(self): with pytest.raises(AssertionError): visualizer.draw_points(positions=np.array([1, 2, 3], dtype=object)) # test color - visualizer.draw_points( - positions=torch.tensor([[1, 1], [3, 3]]), - colors=['g', (255, 255, 0)]) - visualizer.draw_points( - positions=torch.tensor([[1, 1], [3, 3]]), - colors=['g', (255, 255, 0)], - marker='.', - sizes=[1, 5]) + visualizer.draw_points(positions=torch.tensor([[1, 1], [3, 3]]), + colors=['g', (255, 255, 0)]) + visualizer.draw_points(positions=torch.tensor([[1, 1], [3, 3]]), + colors=['g', (255, 255, 0)], + marker='.', + sizes=[1, 5]) def test_draw_texts(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy - visualizer.draw_texts( - 'text1', positions=torch.tensor([5, 5]), colors=(0, 255, 0)) + visualizer.draw_texts('text1', + positions=torch.tensor([5, 5]), + colors=(0, 255, 0)) visualizer.draw_texts(['text1', 'text2'], positions=torch.tensor([[5, 5], [3, 3]]), colors=[(255, 0, 0), (255, 0, 0)]) visualizer.draw_texts('text1', positions=np.array([5, 5])) visualizer.draw_texts(['text1', 'text2'], positions=np.array([[5, 5], [3, 3]])) - visualizer.draw_texts( - 'text1', - positions=torch.tensor([5, 5]), - bboxes=dict(facecolor='r', alpha=0.6)) + visualizer.draw_texts('text1', + positions=torch.tensor([5, 5]), + bboxes=dict(facecolor='r', alpha=0.6)) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The text is out of bounds,' - ' the drawn text may not be in the image'): + with pytest.warns(UserWarning, + match='Warning: The text is out of bounds,' + ' the drawn text may not be in the image'): visualizer.draw_texts('text1', positions=torch.tensor([15, 5])) # test incorrect format @@ -230,8 +227,8 @@ def test_draw_texts(self): visualizer.draw_texts(['text1', 'text2'], positions=torch.tensor([5, 5])) with pytest.raises(AssertionError): - visualizer.draw_texts( - 'text1', positions=torch.tensor([[5, 5], [3, 3]])) + visualizer.draw_texts('text1', + positions=torch.tensor([[5, 5], [3, 3]])) with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), @@ -259,24 +256,21 @@ def test_draw_lines(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy - visualizer.draw_lines( - x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6])) - visualizer.draw_lines( - x_datas=np.array([[1, 5], [2, 4]]), - y_datas=np.array([[2, 6], [4, 7]])) - visualizer.draw_lines( - x_datas=np.array([[1, 5], [2, 4]]), - y_datas=np.array([[2, 6], [4, 7]]), - colors='r', - line_styles=['-', '-.'], - line_widths=[1, 2]) + visualizer.draw_lines(x_datas=torch.tensor([1, 5]), + y_datas=torch.tensor([2, 6])) + visualizer.draw_lines(x_datas=np.array([[1, 5], [2, 4]]), + y_datas=np.array([[2, 6], [4, 7]])) + visualizer.draw_lines(x_datas=np.array([[1, 5], [2, 4]]), + y_datas=np.array([[2, 6], [4, 7]]), + colors='r', + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The line is out of bounds,' - ' the drawn line may not be in the image'): - visualizer.draw_lines( - x_datas=torch.tensor([12, 5]), y_datas=torch.tensor([2, 6])) + with pytest.warns(UserWarning, + match='Warning: The line is out of bounds,' + ' the drawn line may not be in the image'): + visualizer.draw_lines(x_datas=torch.tensor([12, 5]), + y_datas=torch.tensor([2, 6])) # test incorrect format with pytest.raises(TypeError): @@ -286,9 +280,8 @@ def test_draw_lines(self): # test length mismatch with pytest.raises(AssertionError): - visualizer.draw_lines( - x_datas=torch.tensor([1, 5]), - y_datas=torch.tensor([[2, 6], [4, 7]])) + visualizer.draw_lines(x_datas=torch.tensor([1, 5]), + y_datas=torch.tensor([[2, 6], [4, 7]])) def test_draw_circles(self): visualizer = Visualizer(image=self.image) @@ -296,33 +289,30 @@ def test_draw_circles(self): # only support tensor and numpy visualizer.draw_circles(torch.tensor([1, 5]), torch.tensor([1])) visualizer.draw_circles(np.array([1, 5]), np.array([1])) - visualizer.draw_circles( - torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2])) + visualizer.draw_circles(torch.tensor([[1, 5], [2, 6]]), + radius=torch.tensor([1, 2])) # test face_colors - visualizer.draw_circles( - torch.tensor([[1, 5], [2, 6]]), - radius=torch.tensor([1, 2]), - face_colors=(255, 0, 0), - edge_colors=(255, 0, 0)) + visualizer.draw_circles(torch.tensor([[1, 5], [2, 6]]), + radius=torch.tensor([1, 2]), + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) # test config - visualizer.draw_circles( - torch.tensor([[1, 5], [2, 6]]), - radius=torch.tensor([1, 2]), - edge_colors=['g', 'r'], - line_styles=['-', '-.'], - line_widths=[1, 2]) + visualizer.draw_circles(torch.tensor([[1, 5], [2, 6]]), + radius=torch.tensor([1, 2]), + edge_colors=['g', 'r'], + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The circle is out of bounds,' - ' the drawn circle may not be in the image'): - visualizer.draw_circles( - torch.tensor([12, 5]), radius=torch.tensor([1])) - visualizer.draw_circles( - torch.tensor([1, 5]), radius=torch.tensor([10])) + with pytest.warns(UserWarning, + match='Warning: The circle is out of bounds,' + ' the drawn circle may not be in the image'): + visualizer.draw_circles(torch.tensor([12, 5]), + radius=torch.tensor([1])) + visualizer.draw_circles(torch.tensor([1, 5]), + radius=torch.tensor([10])) # test incorrect format with pytest.raises(TypeError): @@ -332,8 +322,8 @@ def test_draw_circles(self): # test length mismatch with pytest.raises(AssertionError): - visualizer.draw_circles( - torch.tensor([[1, 5]]), radius=torch.tensor([1, 2])) + visualizer.draw_circles(torch.tensor([[1, 5]]), + radius=torch.tensor([1, 2])) def test_draw_polygons(self): visualizer = Visualizer(image=self.image) @@ -344,27 +334,24 @@ def test_draw_polygons(self): np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]]) ]) - visualizer.draw_polygons( - polygons=[ - np.array([[1, 1], [2, 2], [3, 4]]), - torch.tensor([[1, 1], [2, 2], [3, 4]]) - ], - face_colors=(255, 0, 0), - edge_colors=(255, 0, 0)) - visualizer.draw_polygons( - polygons=[ - np.array([[1, 1], [2, 2], [3, 4]]), - torch.tensor([[1, 1], [2, 2], [3, 4]]) - ], - edge_colors=['r', 'g'], - line_styles='-', - line_widths=[2, 1]) + visualizer.draw_polygons(polygons=[ + np.array([[1, 1], [2, 2], [3, 4]]), + torch.tensor([[1, 1], [2, 2], [3, 4]]) + ], + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) + visualizer.draw_polygons(polygons=[ + np.array([[1, 1], [2, 2], [3, 4]]), + torch.tensor([[1, 1], [2, 2], [3, 4]]) + ], + edge_colors=['r', 'g'], + line_styles='-', + line_widths=[2, 1]) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The polygon is out of bounds,' - ' the drawn polygon may not be in the image'): + with pytest.warns(UserWarning, + match='Warning: The polygon is out of bounds,' + ' the drawn polygon may not be in the image'): visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [16, 4]])) def test_draw_binary_masks(self): @@ -388,8 +375,8 @@ def test_draw_binary_masks(self): # test color dim with pytest.raises(AssertionError): - visualizer.draw_binary_masks( - binary_mask, colors=np.array([1, 22, 4, 45])) + visualizer.draw_binary_masks(binary_mask, + colors=np.array([1, 22, 4, 45])) binary_mask = np.random.randint(0, 2, size=(10, 10)) with pytest.raises(AssertionError): visualizer.draw_binary_masks(binary_mask) @@ -399,15 +386,14 @@ def test_draw_featmap(self): image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8') # must be Tensor - with pytest.raises( - AssertionError, - match='`featmap` should be torch.Tensor, but got ' - ""): + with pytest.raises(AssertionError, + match='`featmap` should be torch.Tensor, but got ' + ""): visualizer.draw_featmap(np.ones((3, 3, 3))) # test tensor format - with pytest.raises( - AssertionError, match='Input dimension must be 3, but got 4'): + with pytest.raises(AssertionError, + match='Input dimension must be 3, but got 4'): visualizer.draw_featmap(torch.randn(1, 1, 3, 3)) # test overlaid_image shape @@ -415,29 +401,29 @@ def test_draw_featmap(self): visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image) # test resize_shape - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), resize_shape=(6, 7)) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), + resize_shape=(6, 7)) assert featmap.shape[:2] == (6, 7) - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7)) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), + overlaid_image=image, + resize_shape=(6, 7)) assert featmap.shape[:2] == (6, 7) # test channel_reduction parameter # mode only supports 'squeeze_mean' and 'select_max' with pytest.raises(AssertionError): - visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='xx') + visualizer.draw_featmap(torch.randn(2, 3, 3), + channel_reduction='xx') - featmap = visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='squeeze_mean') + featmap = visualizer.draw_featmap(torch.randn(2, 3, 3), + channel_reduction='squeeze_mean') assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='select_max') + featmap = visualizer.draw_featmap(torch.randn(2, 3, 3), + channel_reduction='select_max') assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(2, 4, 3), - overlaid_image=image, - channel_reduction='select_max') + featmap = visualizer.draw_featmap(torch.randn(2, 4, 3), + overlaid_image=image, + channel_reduction='select_max') assert featmap.shape[:2] == (3, 3) # test topk parameter @@ -448,53 +434,54 @@ def test_draw_featmap(self): 'dimension you input is 6, you can use the ' 'channel_reduction parameter or set topk ' 'greater than 0 to solve the error'): - visualizer.draw_featmap( - torch.randn(6, 3, 3), channel_reduction=None, topk=0) + visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=0) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), channel_reduction='select_max', topk=10) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction='select_max', + topk=10) assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), channel_reduction=None, topk=-1) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), + channel_reduction=None, + topk=-1) assert featmap.shape[:2] == (4, 3) - featmap = visualizer.draw_featmap( - torch.randn(3, 4, 3), - overlaid_image=image, - channel_reduction=None, - topk=-1) + featmap = visualizer.draw_featmap(torch.randn(3, 4, 3), + overlaid_image=image, + channel_reduction=None, + topk=-1) assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(2, 2)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(2, 2)) assert featmap.shape[:2] == (6, 6) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(1, 4)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(1, 4)) assert featmap.shape[:2] == (3, 12) with pytest.raises( AssertionError, match='The product of row and col in the `arrangement` ' 'is less than topk, please set ' 'the `arrangement` correctly'): - visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(1, 2)) + visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(1, 2)) # test gray - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - overlaid_image=np.random.randint( - 0, 256, size=(3, 3), dtype='uint8'), - channel_reduction=None, - topk=4, - arrangement=(2, 2)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + overlaid_image=np.random.randint( + 0, + 256, + size=(3, 3), + dtype='uint8'), + channel_reduction=None, + topk=4, + arrangement=(2, 2)) assert featmap.shape[:2] == (6, 6) def test_chain_call(self): @@ -509,17 +496,17 @@ def test_chain_call(self): draw_binary_masks(binary_mask) def test_get_backend(self): - visualizer = Visualizer( - image=self.image, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(image=self.image, + vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') for name in ['mock1', 'mock2']: assert isinstance(visualizer.get_backend(name), MockVisBackend) def test_add_config(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) visualizer.add_config(cfg) @@ -527,9 +514,9 @@ def test_add_config(self): assert visualizer.get_backend(name)._add_config is True def test_add_graph(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') class Model(nn.Module): @@ -546,26 +533,26 @@ def forward(self, x, y=None): def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') visualizer.add_image('img', image) for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._add_image is True def test_add_scalar(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') visualizer.add_scalar('map', 0.9, step=0) for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._add_scalar is True def test_add_scalars(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') input_dict = {'map': 0.7, 'acc': 0.9} visualizer.add_scalars(input_dict) for name in ['mock1', 'mock2']: @@ -597,51 +584,44 @@ def test_show(self): patch('mmengine.visualization.visualizer.wait_continue', wait_continue): # test default backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='matplotlib') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='matplotlib') assert hasattr(visualizer, 'manager') calls = [ - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' ') + call(visualizer.manager.canvas.figure, + timeout=0, + continue_key=' ') ] wait_continue.assert_has_calls(calls) # matplotlib backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='matplotlib') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='matplotlib') assert hasattr(visualizer, 'manager') calls = [ - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' '), - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' ') + call(visualizer.manager.canvas.figure, + timeout=0, + continue_key=' '), + call(visualizer.manager.canvas.figure, + timeout=0, + continue_key=' ') ] wait_continue.assert_has_calls(calls) # cv2 backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='cv2') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='cv2') cv2.imshow.assert_called_once_with(str(id(visualizer)), img) # unknown backend with pytest.raises(ValueError): - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='unknown') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='unknown')