From d2ccc44a2c8e5d49bb26187aff42f2abc90aee28 Mon Sep 17 00:00:00 2001 From: LALBJ <40877073+LALBJ@users.noreply.github.com> Date: Wed, 23 Aug 2023 10:45:18 +0800 Subject: [PATCH 1/9] [CodeCamp2023-584]Support DINO self-supervised learning in project (#1756) * feat: impelemt DINO * chore: delete debug code * chore: impplement pre-commit * fix: fix imported package * chore: pre-commit check --- projects/dino/README.md | 26 +++++ ..._vit-base-p16_8xb64-amp-coslr-100e_in1k.py | 104 ++++++++++++++++++ projects/dino/dataset/__init__.py | 1 + projects/dino/dataset/transform/__init__.py | 3 + projects/dino/dataset/transform/processing.py | 91 +++++++++++++++ projects/dino/engine/__init__.py | 1 + projects/dino/engine/hooks/__init__.py | 3 + .../hooks/dino_teacher_temp_warmup_hook.py | 33 ++++++ projects/dino/models/__init__.py | 3 + projects/dino/models/algorithm/__init__.py | 3 + projects/dino/models/algorithm/dino.py | 82 ++++++++++++++ projects/dino/models/head/__init__.py | 3 + projects/dino/models/head/dino_head.py | 69 ++++++++++++ projects/dino/models/neck/__init__.py | 3 + projects/dino/models/neck/dino_neck.py | 41 +++++++ projects/dino/tools/dist_train.sh | 19 ++++ projects/dino/tools/slurm_train.sh | 23 ++++ projects/dino/tools/train.py | 104 ++++++++++++++++++ 18 files changed, 612 insertions(+) create mode 100644 projects/dino/README.md create mode 100644 projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py create mode 100644 projects/dino/dataset/__init__.py create mode 100644 projects/dino/dataset/transform/__init__.py create mode 100644 projects/dino/dataset/transform/processing.py create mode 100644 projects/dino/engine/__init__.py create mode 100644 projects/dino/engine/hooks/__init__.py create mode 100644 projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py create mode 100644 projects/dino/models/__init__.py create mode 100644 projects/dino/models/algorithm/__init__.py create mode 100644 projects/dino/models/algorithm/dino.py create mode 100644 projects/dino/models/head/__init__.py create mode 100644 projects/dino/models/head/dino_head.py create mode 100644 projects/dino/models/neck/__init__.py create mode 100644 projects/dino/models/neck/dino_neck.py create mode 100644 projects/dino/tools/dist_train.sh create mode 100644 projects/dino/tools/slurm_train.sh create mode 100644 projects/dino/tools/train.py diff --git a/projects/dino/README.md b/projects/dino/README.md new file mode 100644 index 00000000000..3458fa4cdb3 --- /dev/null +++ b/projects/dino/README.md @@ -0,0 +1,26 @@ +# Implementation for DINO + +**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation. + +First, ensure you are in the root directory of MMPretrain, then you have two choices +to play with DINO in MMPretrain: + +## Slurm + +If you are using a cluster managed by Slurm, you can use the following command to +start your job: + +```shell +GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp +``` + +The above command will pre-train the model on a single node with 8 GPUs. + +## PyTorch + +If you are using a single machine, without any cluster management software, you can use the following command + +```shell +NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8 +--amp +``` diff --git a/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py b/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py new file mode 100644 index 00000000000..d4a1c240218 --- /dev/null +++ b/projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py @@ -0,0 +1,104 @@ +model = dict( + type='DINO', + data_preprocessor=dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='mmpretrain.VisionTransformer', arch='b', patch_size=16), + neck=dict( + type='DINONeck', + in_channels=768, + out_channels=65536, + hidden_channels=2048, + bottleneck_channels=256), + head=dict( + type='DINOHead', + out_channels=65536, + num_crops=10, + student_temp=0.1, + center_momentum=0.9)) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='DINOMultiCrop', + global_crops_scale=(0.4, 1.0), + local_crops_scale=(0.05, 0.4), + local_crops_number=8), + dict(type='PackInputs') +] +train_dataloader = dict( + batch_size=32, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type='mmpretrain.ImageNet', + data_root='/data/imagenet/', + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline, + )) +optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict( + type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05), + paramwise_cfg=dict( + custom_keys=dict( + ln=dict(decay_mult=0.0), + bias=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0), + mask_token=dict(decay_mult=0.0), + cls_token=dict(decay_mult=0.0))), + loss_scale='dynamic') +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-09, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=90, + by_epoch=True, + begin=10, + end=100, + convert_to_iter_based=True) +] +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100) +default_scope = 'mmpretrain' +default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=100), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1), + sampler_seed=dict(type='DistSamplerSeedHook')) +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) +log_processor = dict( + window_size=10, + custom_cfg=[dict(data_src='', method='mean', window_size='global')]) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='UniversalVisualizer', + vis_backends=[dict(type='LocalVisBackend')], + name='visualizer') +log_level = 'INFO' +load_from = None +resume = True +randomness = dict(seed=2, diff_rank_seed=True) +custom_hooks = [ + dict( + type='DINOTeacherTempWarmupHook', + warmup_teacher_temp=0.04, + teacher_temp=0.04, + teacher_temp_warmup_epochs=0, + max_epochs=100) +] diff --git a/projects/dino/dataset/__init__.py b/projects/dino/dataset/__init__.py new file mode 100644 index 00000000000..da65f2853ad --- /dev/null +++ b/projects/dino/dataset/__init__.py @@ -0,0 +1 @@ +from .transform import * # noqa: F401,F403 diff --git a/projects/dino/dataset/transform/__init__.py b/projects/dino/dataset/transform/__init__.py new file mode 100644 index 00000000000..00dacb3f3c9 --- /dev/null +++ b/projects/dino/dataset/transform/__init__.py @@ -0,0 +1,3 @@ +from .processing import DINOMultiCrop + +__all__ = ['DINOMultiCrop'] diff --git a/projects/dino/dataset/transform/processing.py b/projects/dino/dataset/transform/processing.py new file mode 100644 index 00000000000..df4bf0be9dd --- /dev/null +++ b/projects/dino/dataset/transform/processing.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +from mmcv.transforms import RandomApply # noqa: E501 +from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale + +from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur, + RandomResizedCrop, Solarize) +from mmpretrain.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class DINOMultiCrop(BaseTransform): + """Multi-crop transform for DINO. + + This module applies the multi-crop transform for DINO. + + Args: + global_crops_scale (int): Scale of global crops. + local_crops_scale (int): Scale of local crops. + local_crops_number (int): Number of local crops. + """ + + def __init__(self, global_crops_scale: int, local_crops_scale: int, + local_crops_number: int) -> None: + super().__init__() + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + + flip_and_color_jitter = Compose([ + RandomFlip(prob=0.5, direction='horizontal'), + RandomApply([ + ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1) + ], + prob=0.8), + RandomGrayscale( + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989), + ) + ]) + + self.global_transform_1 = Compose([ + RandomResizedCrop( + 224, + crop_ratio_range=global_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + ]) + + self.global_transform_2 = Compose([ + RandomResizedCrop( + 224, + crop_ratio_range=global_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + Solarize(thr=128, prob=0.2), + ]) + + self.local_crops_number = local_crops_number + self.local_transform = Compose([ + RandomResizedCrop( + 96, + crop_ratio_range=local_crops_scale, + interpolation='bicubic'), + flip_and_color_jitter, + GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)), + ]) + + def transform(self, results: dict) -> dict: + ori_img = results['img'] + crops = [] + results['img'] = ori_img + crops.append(self.global_transform_1(results)['img']) + results['img'] = ori_img + crops.append(self.global_transform_2(results)['img']) + for _ in range(self.local_crops_number): + results['img'] = ori_img + crops.append(self.local_transform(results)['img']) + results['img'] = crops + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(global_crops_scale = {self.global_crops_scale}, ' + repr_str += f'local_crops_scale = {self.local_crops_scale}, ' + repr_str += f'local_crop_number = {self.local_crops_number})' + return repr_str diff --git a/projects/dino/engine/__init__.py b/projects/dino/engine/__init__.py new file mode 100644 index 00000000000..41422545e61 --- /dev/null +++ b/projects/dino/engine/__init__.py @@ -0,0 +1 @@ +from .hooks import * # noqa diff --git a/projects/dino/engine/hooks/__init__.py b/projects/dino/engine/hooks/__init__.py new file mode 100644 index 00000000000..df43c492e52 --- /dev/null +++ b/projects/dino/engine/hooks/__init__.py @@ -0,0 +1,3 @@ +from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook + +__all__ = ['DINOTeacherTempWarmupHook'] diff --git a/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py b/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py new file mode 100644 index 00000000000..d66b0250e72 --- /dev/null +++ b/projects/dino/engine/hooks/dino_teacher_temp_warmup_hook.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from mmengine.hooks import Hook + +from mmpretrain.registry import HOOKS + + +@HOOKS.register_module() +class DINOTeacherTempWarmupHook(Hook): + """Warmup teacher temperature for DINO. + + This hook warmups the temperature for teacher to stabilize the training + process. + + Args: + warmup_teacher_temp (float): Warmup temperature for teacher. + teacher_temp (float): Temperature for teacher. + teacher_temp_warmup_epochs (int): Warmup epochs for teacher + temperature. + max_epochs (int): Maximum epochs for training. + """ + + def __init__(self, warmup_teacher_temp: float, teacher_temp: float, + teacher_temp_warmup_epochs: int, max_epochs: int) -> None: + super().__init__() + self.teacher_temps = np.concatenate( + (np.linspace(warmup_teacher_temp, teacher_temp, + teacher_temp_warmup_epochs), + np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp)) + + def before_train_epoch(self, runner) -> None: + runner.model.module.head.teacher_temp = self.teacher_temps[ + runner.epoch] diff --git a/projects/dino/models/__init__.py b/projects/dino/models/__init__.py new file mode 100644 index 00000000000..49d014874ad --- /dev/null +++ b/projects/dino/models/__init__.py @@ -0,0 +1,3 @@ +from .algorithm import * # noqa +from .head import * # noqa +from .neck import * # noqa diff --git a/projects/dino/models/algorithm/__init__.py b/projects/dino/models/algorithm/__init__.py new file mode 100644 index 00000000000..1125b63f851 --- /dev/null +++ b/projects/dino/models/algorithm/__init__.py @@ -0,0 +1,3 @@ +from .dino import DINO + +__all__ = ['DINO'] diff --git a/projects/dino/models/algorithm/dino.py b/projects/dino/models/algorithm/dino.py new file mode 100644 index 00000000000..2d78922f1f6 --- /dev/null +++ b/projects/dino/models/algorithm/dino.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from torch import nn + +from mmpretrain.models import BaseSelfSupervisor, CosineEMA +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class DINO(BaseSelfSupervisor): + """Implementation for DINO. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + backbone (dict): Config for backbone. + neck (dict): Config for neck. + head (dict): Config for head. + pretrained (str, optional): Path for pretrained model. + Defaults to None. + base_momentum (float, optional): Base momentum for momentum update. + Defaults to 0.99. + data_preprocessor (dict, optional): Config for data preprocessor. + Defaults to None. + init_cfg (list[dict] | dict, optional): Config for initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: dict, + head: dict, + pretrained: Optional[str] = None, + base_momentum: float = 0.99, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # create momentum model + self.teacher = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) + # weight normalization layer + self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer) + self.neck.last_layer.weight_g.data.fill_(1) + self.neck.last_layer.weight_g.requires_grad = False + self.teacher.module[1].last_layer = nn.utils.weight_norm( + self.teacher.module[1].last_layer) + self.teacher.module[1].last_layer.weight_g.data.fill_(1) + self.teacher.module[1].last_layer.weight_g.requires_grad = False + + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + global_crops = torch.cat(inputs[:2]) + local_crops = torch.cat(inputs[2:]) + # teacher forward + teacher_output = self.teacher(global_crops) + + # student forward global + student_output_global = self.backbone(global_crops) + student_output_global = self.neck(student_output_global) + + # student forward local + student_output_local = self.backbone(local_crops) + student_output_local = self.neck(student_output_local) + + student_output = torch.cat( + (student_output_global, student_output_local)) + + # compute loss + loss = self.head(student_output, teacher_output) + + return dict(loss=loss) diff --git a/projects/dino/models/head/__init__.py b/projects/dino/models/head/__init__.py new file mode 100644 index 00000000000..fe31e084cd3 --- /dev/null +++ b/projects/dino/models/head/__init__.py @@ -0,0 +1,3 @@ +from .dino_head import DINOHead + +__all__ = ['DINOHead'] diff --git a/projects/dino/models/head/dino_head.py b/projects/dino/models/head/dino_head.py new file mode 100644 index 00000000000..e817bfade38 --- /dev/null +++ b/projects/dino/models/head/dino_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmengine.dist import all_reduce, get_world_size +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DINOHead(BaseModule): + """Implementation for DINO head. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + out_channels (int): Output channels of the head. + num_crops (int): Number of crops. + student_temp (float): Temperature for student output. + center_momentum (float): Momentum for center update. + """ + + def __init__(self, out_channels: int, num_crops: int, student_temp: float, + center_momentum: float) -> None: + super().__init__() + self.student_temp = student_temp + self.teacher_temp = 0 + self.center_momentum = center_momentum + self.num_crops = num_crops + self.register_buffer('center', torch.zeros(1, out_channels)) + + def forward(self, student_output: torch.Tensor, + teacher_output: torch.Tensor) -> torch.Tensor: + + current_teacher_output = teacher_output + student_output = student_output / self.student_temp + student_output = student_output.chunk(self.num_crops, dim=0) + + # teacher centering and sharpening + teacher_output = F.softmax( + (teacher_output - self.center) / self.teacher_temp, dim=-1) + teacher_output = teacher_output.detach().chunk(2, dim=0) + + total_loss = 0 + n_loss_terms = 0 + + for i in range(len(teacher_output)): + for j in range(len(student_output)): + if i == j: + continue + total_loss += (-teacher_output[i] * + student_output[j].log_softmax(dim=-1)).sum( + dim=-1).mean() + n_loss_terms += 1 + total_loss /= n_loss_terms + self.update_center(current_teacher_output) + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output: torch.Tensor) -> None: + + batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + all_reduce(batch_center) + batch_center = batch_center / (len(teacher_output) * get_world_size()) + + # ema update batch center + self.center = self.center * self.center_momentum + batch_center * ( + 1 - self.center_momentum) diff --git a/projects/dino/models/neck/__init__.py b/projects/dino/models/neck/__init__.py new file mode 100644 index 00000000000..e5f4aadb09d --- /dev/null +++ b/projects/dino/models/neck/__init__.py @@ -0,0 +1,3 @@ +from .dino_neck import DINONeck + +__all__ = ['DINONeck'] diff --git a/projects/dino/models/neck/dino_neck.py b/projects/dino/models/neck/dino_neck.py new file mode 100644 index 00000000000..8d8881ea24a --- /dev/null +++ b/projects/dino/models/neck/dino_neck.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class DINONeck(BaseModule): + """Implementation for DINO neck. + + This module is proposed in `DINO: Emerging Properties in Self-Supervised + Vision Transformers `_. + + Args: + in_channels (int): Input channels. + hidden_channels (int): Hidden channels. + out_channels (int): Output channels. + bottleneck_channels (int): Bottleneck channels. + """ + + def __init__(self, in_channels: int, hidden_channels: int, + out_channels: int, bottleneck_channels: int) -> None: + super().__init__() + self.mlp = nn.Sequential(*[ + nn.Linear(in_channels, hidden_channels), + nn.GELU(), + nn.Linear(hidden_channels, hidden_channels), + nn.GELU(), + nn.Linear(hidden_channels, bottleneck_channels), + ]) + + self.last_layer = nn.Linear( + bottleneck_channels, out_channels, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(x[0]) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/projects/dino/tools/dist_train.sh b/projects/dino/tools/dist_train.sh new file mode 100644 index 00000000000..3fca7641dec --- /dev/null +++ b/projects/dino/tools/dist_train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} diff --git a/projects/dino/tools/slurm_train.sh b/projects/dino/tools/slurm_train.sh new file mode 100644 index 00000000000..7e2ad297d84 --- /dev/null +++ b/projects/dino/tools/slurm_train.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:4} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u projects/dino/tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/projects/dino/tools/train.py b/projects/dino/tools/train.py new file mode 100644 index 00000000000..b9482c3b75a --- /dev/null +++ b/projects/dino/tools/train.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from dataset import * # noqa: F401,F403 +from engine import * # noqa: F401,F403 +from mmengine.config import Config, DictAction +from mmengine.runner import Runner +from models.algorithm import * # noqa: F401,F403 +from models.head import * # noqa: F401,F403 +from models.neck import * # noqa: F401,F403 + +from mmpretrain.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') + parser.add_argument( + '--amp', + action='store_true', + help='enable automatic-mixed-precision training') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + # register all modules in mmpretrain into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + work_type = args.config.split('/')[1] + cfg.work_dir = osp.join('./work_dirs', work_type, + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') + assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ + '`--amp` is not supported custom optimizer wrapper type ' \ + f'`{optim_wrapper}.' + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.setdefault('loss_scale', 'dynamic') + + # resume training + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main() From 845b462190b5b4d5b785ff16251c28e6713d8550 Mon Sep 17 00:00:00 2001 From: DE009 <57087096+DE009@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:54:18 +0800 Subject: [PATCH 2/9] [CodeCamp2023-340] New Version of config Adapting MobileNet Algorithm (#1774) * add new config adapting MobileNetV2,V3 * add base model config for mobile net v3, modified all training configs of mobile net v3 inherit from the base model config * removed directory _base_/models/mobilenet_v3 --- .../configs/_base_/datasets/cifar10_bs16.py | 52 ++++++++++++ .../_base_/datasets/imagenet_bs128_mbv3.py | 75 ++++++++++++++++ .../datasets/imagenet_bs32_pil_resize.py | 60 +++++++++++++ .../configs/_base_/models/mobilenet_v2_1x.py | 17 ++++ .../_base_/models/mobilenet_v3_small.py | 25 ++++++ .../configs/_base_/schedules/cifar10_bs128.py | 20 +++++ .../schedules/imagenet_bs256_epochstep.py | 20 +++++ .../mobilenet_v2/mobilenet_v2_8xb32_in1k.py | 9 ++ .../mobilenet_v3_large_8xb128_in1k.py | 40 +++++++++ .../mobilenet_v3_small_050_8xb128_in1k.py | 85 +++++++++++++++++++ .../mobilenet_v3_small_075_8xb128_in1k.py | 83 ++++++++++++++++++ .../mobilenet_v3_small_8xb128_in1k.py | 34 ++++++++ .../mobilenet_v3_small_8xb16_cifar10.py | 34 ++++++++ 13 files changed, 554 insertions(+) create mode 100644 mmpretrain/configs/_base_/datasets/cifar10_bs16.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py create mode 100644 mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py create mode 100644 mmpretrain/configs/_base_/models/mobilenet_v2_1x.py create mode 100644 mmpretrain/configs/_base_/models/mobilenet_v3_small.py create mode 100644 mmpretrain/configs/_base_/schedules/cifar10_bs128.py create mode 100644 mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py create mode 100644 mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py create mode 100644 mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py diff --git a/mmpretrain/configs/_base_/datasets/cifar10_bs16.py b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py new file mode 100644 index 00000000000..3737dbee9a6 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/cifar10_bs16.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = CIFAR10 +data_preprocessor = dict( + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False) + +train_pipeline = [ + dict(type=RandomCrop, crop_size=32, padding=4), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=16, + num_workers=2, + dataset=dict( + type=dataset_type, + data_root='data/cifar10/', + split='test', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, )) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py new file mode 100644 index 00000000000..cf0aa629d72 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs128_mbv3.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet, + LoadImageFromFile, PackInputs, RandomErasing, + RandomFlip, RandomResizedCrop, ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in bgr_mean])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=128, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py new file mode 100644 index 00000000000..f911bc20ff6 --- /dev/null +++ b/mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.dataset import DefaultSampler + +from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, + PackInputs, RandomFlip, RandomResizedCrop, + ResizeEdge) +from mmpretrain.evaluation import Accuracy + +# dataset settings +dataset_type = ImageNet +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=RandomResizedCrop, scale=224, backend='pillow'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='train', + pipeline=train_pipeline), + sampler=dict(type=DefaultSampler, shuffle=True), +) + +val_dataloader = dict( + batch_size=32, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type=DefaultSampler, shuffle=False), +) +val_evaluator = dict(type=Accuracy, topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py new file mode 100644 index 00000000000..17dbb9fdd88 --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v2_1x.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, LinearClsHead, MobileNetV2) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV2, widen_factor=1.0), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=LinearClsHead, + num_classes=1000, + in_channels=1280, + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmpretrain/configs/_base_/models/mobilenet_v3_small.py b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py new file mode 100644 index 00000000000..83edab59206 --- /dev/null +++ b/mmpretrain/configs/_base_/models/mobilenet_v3_small.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.model.weight_init import NormalInit +from torch.nn.modules.activation import Hardswish + +from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, + ImageClassifier, MobileNetV3, + StackedLinearClsHead) + +# model settings +model = dict( + type=ImageClassifier, + backbone=dict(type=MobileNetV3, arch='small'), + neck=dict(type=GlobalAveragePooling), + head=dict( + type=StackedLinearClsHead, + num_classes=1000, + in_channels=576, + mid_channels=[1024], + dropout_rate=0.2, + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + init_cfg=dict( + type=NormalInit, layer='Linear', mean=0., std=0.01, bias=0.), + topk=(1, 5))) diff --git a/mmpretrain/configs/_base_/schedules/cifar10_bs128.py b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py new file mode 100644 index 00000000000..8ab749e8b64 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/cifar10_bs128.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import MultiStepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[100, 150], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=128) diff --git a/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py new file mode 100644 index 00000000000..9d245ebb9c3 --- /dev/null +++ b/mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.optim import StepLR +from torch.optim import SGD + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.045, momentum=0.9, weight_decay=0.00004)) + +# learning policy +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=1, gamma=0.98) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=256) diff --git a/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py new file mode 100644 index 00000000000..79eec635501 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v2/mobilenet_v2_8xb32_in1k.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.imagenet_bs32_pil_resize import * + from .._base_.default_runtime import * + from .._base_.models.mobilenet_v2_1x import * + from .._base_.schedules.imagenet_bs256_epochstep import * diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py new file mode 100644 index 00000000000..3f1bee1c132 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. + +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict(arch='large'), + head=dict(in_channels=960, mid_channels=[1280]), + )) +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py new file mode 100644 index 00000000000..50e1ffc6709 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_050_8xb128_in1k.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_050', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=288), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) + +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py new file mode 100644 index 00000000000..c8c640cd8a0 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_075_8xb128_in1k.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.optim import RMSprop + +# model settings +model.merge( + dict( + backbone=dict( + arch='small_075', + norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)), + head=dict(in_channels=432), + )) + +train_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=RandomResizedCrop, + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type=RandomFlip, prob=0.5, direction='horizontal'), + dict( + type=AutoAugment, + policies='imagenet', + hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])), + dict( + type=RandomErasing, + erase_prob=0.2, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type=PackInputs), +] + +test_pipeline = [ + dict(type=LoadImageFromFile), + dict( + type=ResizeEdge, + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type=CenterCrop, crop_size=224), + dict(type=PackInputs), +] + +train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline))) +val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline))) +test_dataloader = val_dataloader + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py new file mode 100644 index 00000000000..0c220a01d09 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb128_in1k.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification + +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.imagenet_bs128_mbv3 import * + from .._base_.default_runtime import * + +from mmengine.optim import StepLR +from torch.optim import RMSprop + +# schedule settings +optim_wrapper = dict( + optimizer=dict( + type=RMSprop, + lr=0.064, + alpha=0.9, + momentum=0.9, + eps=0.0316, + weight_decay=1e-5)) + +param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) + +train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +# base_batch_size = (8 GPUs) x (128 samples per GPU) +auto_scale_lr = dict(base_batch_size=1024) diff --git a/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py new file mode 100644 index 00000000000..0f91ee38243 --- /dev/null +++ b/mmpretrain/configs/mobilenet_v3/mobilenet_v3_small_8xb16_cifar10.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This is a BETA new format config file, and the usage may change recently. +from mmengine.config import read_base + +with read_base(): + from .._base_.models.mobilenet_v3_small import * + from .._base_.datasets.cifar10_bs16 import * + from .._base_.schedules.cifar10_bs128 import * + from .._base_.default_runtime import * + +from mmengine.optim import MultiStepLR + +# model settings +model.merge( + dict( + head=dict( + _delete_=True, + type=StackedLinearClsHead, + num_classes=10, + in_channels=576, + mid_channels=[1280], + act_cfg=dict(type=Hardswish), + loss=dict(type=CrossEntropyLoss, loss_weight=1.0), + topk=(1, 5)))) +# schedule settings +param_scheduler.merge( + dict( + type=MultiStepLR, + by_epoch=True, + milestones=[120, 170], + gamma=0.1, + )) + +train_cfg.merge(dict(by_epoch=True, max_epochs=200)) From bb59c9ad825debef83cd2c6d02f06a5914d2b139 Mon Sep 17 00:00:00 2001 From: Coobiw <48615375+Coobiw@users.noreply.github.com> Date: Mon, 4 Sep 2023 10:30:28 +0800 Subject: [PATCH 3/9] [Feature] Implement of Zero-Shot CLIP Classifier (#1737) * zero-shot CLIP * modify zero-shot clip config * add in1k_sub_prompt(8 prompts) for improvement * add some annotations doc * clip base class & clip_zs sub-class * some modifications of details after review * convert into and use mmpretrain-vit * modify names of some files and directories --- ...clip_vit-base-p16_zeroshot-cls_cifar100.py | 68 ++++ .../clip_vit-base-p16_zeroshot-cls_in1k.py | 69 ++++ ...lip_vit-large-p14_zeroshot-cls_cifar100.py | 68 ++++ .../clip_vit-large-p14_zeroshot-cls_in1k.py | 69 ++++ mmpretrain/datasets/categories.py | 221 +++++++++++ mmpretrain/models/multimodal/__init__.py | 4 +- mmpretrain/models/multimodal/clip/__init__.py | 5 + mmpretrain/models/multimodal/clip/clip.py | 364 ++++++++++++++++++ .../multimodal/clip/clip_transformer.py | 99 +++++ mmpretrain/models/multimodal/clip/utils.py | 115 ++++++ .../openai-clip_to_mmpretrain-clip.py | 77 ++++ 11 files changed, 1158 insertions(+), 1 deletion(-) create mode 100644 configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py create mode 100644 configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py create mode 100644 configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py create mode 100644 configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py create mode 100644 mmpretrain/models/multimodal/clip/__init__.py create mode 100644 mmpretrain/models/multimodal/clip/clip.py create mode 100644 mmpretrain/models/multimodal/clip/clip_transformer.py create mode 100644 mmpretrain/models/multimodal/clip/utils.py create mode 100644 tools/model_converters/openai-clip_to_mmpretrain-clip.py diff --git a/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py b/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py new file mode 100644 index 00000000000..dd684a50a31 --- /dev/null +++ b/configs/clip/clip_vit-base-p16_zeroshot-cls_cifar100.py @@ -0,0 +1,68 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, +) + +test_pipeline = [ + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='CIFAR100', + data_root='data/cifar100', + split='test', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + text_prototype='cifar100', + text_prompt='openai_cifar100', + context_length=77, +) diff --git a/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py b/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py new file mode 100644 index 00000000000..80c4fde82f5 --- /dev/null +++ b/configs/clip/clip_vit-base-p16_zeroshot-cls_in1k.py @@ -0,0 +1,69 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='ImageNet', + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=768, out_channels=512), + text_backbone=dict( + type='CLIPTransformer', + width=512, + layers=12, + heads=8, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-base-patch16', + use_fast=False), + vocab_size=49408, + transformer_width=512, + proj_dim=512, + text_prototype='imagenet', + text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub + context_length=77, +) diff --git a/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py b/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py new file mode 100644 index 00000000000..a6dd7c11412 --- /dev/null +++ b/configs/clip/clip_vit-large-p14_zeroshot-cls_cifar100.py @@ -0,0 +1,68 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=False, +) + +test_pipeline = [ + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='CIFAR100', + data_root='data/cifar100', + split='test', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='large', + img_size=224, + patch_size=14, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768), + text_backbone=dict( + type='CLIPTransformer', + width=768, + layers=12, + heads=12, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-large-patch14', + use_fast=False), + vocab_size=49408, + transformer_width=768, + proj_dim=768, + text_prototype='cifar100', + text_prompt='openai_cifar100', + context_length=77, +) diff --git a/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py b/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py new file mode 100644 index 00000000000..10500017a93 --- /dev/null +++ b/configs/clip/clip_vit-large-p14_zeroshot-cls_in1k.py @@ -0,0 +1,69 @@ +_base_ = '../_base_/default_runtime.py' + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(224, 224), interpolation='bicubic'), + dict( + type='PackInputs', + algorithm_keys=['text'], + meta_keys=['image_id', 'scale_factor'], + ), +] + +train_dataloader = None +test_dataloader = dict( + batch_size=32, + num_workers=8, + dataset=dict( + type='ImageNet', + data_root='data/imagenet', + split='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +test_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# schedule settings +train_cfg = None +val_cfg = None +test_cfg = dict() + +# model settings +model = dict( + type='CLIPZeroShot', + vision_backbone=dict( + type='VisionTransformer', + arch='large', + img_size=224, + patch_size=14, + drop_rate=0., + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + pre_norm=True, + ), + projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768), + text_backbone=dict( + type='CLIPTransformer', + width=768, + layers=12, + heads=12, + attn_mask=True, + ), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='openai/clip-vit-large-patch14', + use_fast=False), + vocab_size=49408, + transformer_width=768, + proj_dim=768, + text_prototype='imagenet', + text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub + context_length=77, +) diff --git a/mmpretrain/datasets/categories.py b/mmpretrain/datasets/categories.py index 011ee5c1609..9e75f7953b8 100644 --- a/mmpretrain/datasets/categories.py +++ b/mmpretrain/datasets/categories.py @@ -1438,3 +1438,224 @@ '海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒', '桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼', '柳树', '狼', '女人', '蠕虫') + +IMAGENET_SIMPLE_CATEGORIES = ( + 'tench', 'goldfish', 'great white shark', 'tiger shark', + 'hammerhead shark', 'electric ray', 'stingray', 'rooster', 'hen', + 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', + 'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee', + 'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture', + 'great grey owl', 'fire salamander', 'smooth newt', 'newt', + 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', + 'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', + 'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'green iguana', + 'Carolina anole', 'desert grassland whiptail lizard', 'agama', + 'frilled-necked lizard', 'alligator lizard', 'Gila monster', + 'European green lizard', 'chameleon', 'Komodo dragon', 'Nile crocodile', + 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', + 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', + 'garter snake', 'water snake', 'vine snake', 'night snake', + 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', + 'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake', + 'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion', + 'yellow garden spider', 'barn spider', 'European garden spider', + 'southern black widow', 'tarantula', 'wolf spider', 'tick', 'centipede', + 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl', + 'quail', 'partridge', 'african grey parrot', 'macaw', + 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', + 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', + 'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', + 'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm', + 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', + 'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab', + 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', + 'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill', + 'flamingo', 'little blue heron', 'great egret', 'bittern bird', + 'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard', + 'ruddy turnstone', 'dunlin', 'common redshank', 'dowitcher', + 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', + 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', + 'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon', + 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', + 'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', + 'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound', + 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', + 'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki', + 'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier', + 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', + 'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier', + 'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier', + 'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier', + 'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier', + 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', + 'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier', + 'Australian Silky Terrier', 'Soft-coated Wheaten Terrier', + 'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever', + 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', + 'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla', + 'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany dog', + 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel', + 'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz', + 'Schipperke', 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie', + 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', + 'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler', + 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher', + 'Greater Swiss Mountain Dog', 'Bernese Mountain Dog', + 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff', + 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', + 'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', 'Affenpinscher', + 'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', 'Great Pyrenees dog', + 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon', + 'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle', + 'Miniature Poodle', 'Standard Poodle', + 'Mexican hairless dog (xoloitzcuintli)', 'grey wolf', + 'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote', 'dingo', + 'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox', + 'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat', + 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', + 'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear', + 'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle', + 'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle', + 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant', + 'grasshopper', 'cricket insect', 'stick insect', 'cockroach', + 'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly', + 'damselfly', 'red admiral butterfly', 'ringlet butterfly', + 'monarch butterfly', 'small white butterfly', 'sulphur butterfly', + 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', + 'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine', + 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', + 'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle', + 'arabian camel', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', + 'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon', + 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', + 'white-headed capuchin', 'howler monkey', 'titi monkey', + "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', + 'indri', 'Asian elephant', 'African bush elephant', 'red panda', + 'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish', + 'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus', + 'abaya', 'academic gown', 'accordion', 'acoustic guitar', + 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', + 'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can', + 'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon', + 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell', + 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow', + 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', + 'military hat (bearskin or shako)', 'beer bottle', 'beer glass', + 'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder', + 'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie', + 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow', + 'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate', + 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', + 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', + 'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit', + 'cardboard box / carton', 'car wheel', 'automated teller machine', + 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello', + 'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw', + 'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet', + 'Christmas stocking', 'church', 'movie theater', 'cleaver', + 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug', + 'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard', + 'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet', + 'cowboy boot', 'cowboy hat', 'cradle', 'construction crane', + 'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball', + 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer', + 'rotary dial telephone', 'diaper', 'digital clock', 'digital watch', + 'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', + 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick', + 'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar', + 'electric locomotive', 'entertainment center', 'envelope', + 'espresso machine', 'face powder', 'feather boa', 'filing cabinet', + 'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute', + 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', + 'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat', + 'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart', + 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano', + 'greenhouse', 'radiator grille', 'grocery store', 'guillotine', + 'hair clip', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer', + 'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica', + 'harp', 'combine harvester', 'hatchet', 'holster', 'home theater', + 'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar', + 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', + 'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw', + 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', + 'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library', + 'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick', + 'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass', + 'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights', + 'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask', + 'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet', + 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', + 'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl', + 'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped', + 'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa', + 'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van', + 'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier', + 'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer', + 'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart', + 'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel', + 'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel', + 'parachute', 'parallel bars', 'park bench', 'parking meter', + 'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case', + 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum', + 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', + 'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate ship', + 'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack', + 'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', + 'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill', + 'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck', + 'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', + 'radiator', 'radio', 'radio telescope', 'rain barrel', + 'recreational vehicle', 'fishing casting reel', 'reflex camera', + 'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle', + 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', + 'ruler measuring stick', 'sneaker', 'safe', 'safety pin', 'salt shaker', + 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', + 'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw', + 'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe store', + 'shoji screen / room divider', 'shopping basket', 'shopping cart', + 'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask', + 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', + 'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock', + 'solar thermal collector', 'sombrero', 'soup bowl', 'keyboard space bar', + 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', + 'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive', + 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall', + 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', + 'submarine', 'suit', 'sundial', 'sunglasses', 'sunglasses', 'sunscreen', + 'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing', + 'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player', + 'teapot', 'teddy bear', 'television', 'tennis ball', 'thatched roof', + 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', + 'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole', + 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray', + 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', + 'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard', + 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', + 'vaulted or arched ceiling', 'velvet fabric', 'vending machine', + 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', + 'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine', + 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', + 'hair wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', + 'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence', + 'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword', + 'traffic or street sign', 'traffic light', 'dust jacket', 'menu', 'plate', + 'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'popsicle', + 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', + 'mashed potatoes', 'cabbage', 'broccoli', 'cauliflower', 'zucchini', + 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber', + 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple', + 'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', + 'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara', + 'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito', + 'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff', + 'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach', + 'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver', + 'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip', + 'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra', + 'stinkhorn mushroom', 'earth star fungus', 'hen of the woods mushroom', + 'bolete', 'corn cob', 'toilet paper') diff --git a/mmpretrain/models/multimodal/__init__.py b/mmpretrain/models/multimodal/__init__.py index 072c0f84f72..73645f0f5e6 100644 --- a/mmpretrain/models/multimodal/__init__.py +++ b/mmpretrain/models/multimodal/__init__.py @@ -5,6 +5,7 @@ from .blip import * # noqa: F401,F403 from .blip2 import * # noqa: F401,F403 from .chinese_clip import * # noqa: F401, F403 + from .clip import * # noqa: F401, F403 from .flamingo import * # noqa: F401, F403 from .llava import * # noqa: F401, F403 from .minigpt4 import * # noqa: F401, F403 @@ -17,5 +18,6 @@ register_multimodal_placeholder([ 'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption', 'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo', - 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter' + 'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP', + 'CLIPZeroShot' ], MODELS) diff --git a/mmpretrain/models/multimodal/clip/__init__.py b/mmpretrain/models/multimodal/clip/__init__.py new file mode 100644 index 00000000000..f7a117ea7ca --- /dev/null +++ b/mmpretrain/models/multimodal/clip/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..clip.clip import CLIP, CLIPZeroShot +from ..clip.clip_transformer import CLIPProjection, CLIPTransformer + +__all__ = ['CLIP', 'CLIPZeroShot', 'CLIPTransformer', 'CLIPProjection'] diff --git a/mmpretrain/models/multimodal/clip/clip.py b/mmpretrain/models/multimodal/clip/clip.py new file mode 100644 index 00000000000..b509a63b3be --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES, + IMAGENET_SIMPLE_CATEGORIES) +from mmpretrain.registry import MODELS, TOKENIZER +from mmpretrain.structures import DataSample +from mmpretrain.utils import track_on_main_process +from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT, + OPENAI_IMAGENET_PROMPT_SUB) + +CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES] +PROTOTYPE_MAP = { + 'imagenet': IMAGENET_SIMPLE_CATEGORIES, + 'cifar100': CIFAR100_CATEGORIES, +} +PROMPT_MAP = { + 'openai_imagenet': OPENAI_IMAGENET_PROMPT, + 'openai_cifar100': OPENAI_CIFAR100_PROMPT, + 'vanilla': [lambda c: f'a photo of a {c}'], + 'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB +} + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class CLIP(BaseModel): + """The implementation of `CLIP `_. + + Args: + vision_backbone (dict): Config dict for vision backbone. + text_backbone (dict): Config dict for text backbone. + tokenizer (dict): Config dict for text tokenizer. + proj_dim (int): Projection dimension for similarity computation. + text_prototype (str): Text prototype, which can be a key in + `PROTOTYPE_MAP` or list of text. + text_prompt (str): The prompt for text prototype. + Defaults to 'vanilla',which refers to "a photo of {cls}". + context_length (int): The context length to use. Defaults to 77. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "MultiModalDataPreprocessor" as type. + See :class:`MultiModalDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): The config to control the initialization. + Defaults to None. + """ + + def __init__(self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if data_preprocessor is None: + data_preprocessor = {} + data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.context_length = context_length + + # build the vision transformer + self.visual = MODELS.build(vision_backbone) + + # build the visual projection + self.visual_proj = MODELS.build(projection) + + # build attn_mask for casual-attn + text_backbone['attn_mask'] = self.build_attention_mask() + + # build the text transformer + self.transformer = MODELS.build(text_backbone) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, proj_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + self.tokenizer = TOKENIZER.build(tokenizer) + + self.tokenizer.vocab = self.tokenizer.get_vocab( + ) # CLIPTokenizer has no attribute named 'vocab', so manually + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, + # with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward( + self, + images: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'predict', + **kwargs, + ): + """The unified entry for a forward process in both training and test. + The method accepts the following modes: + + - "predict": Forward and return a list of data samples contain the + predict results. + + Args: + images (torch.Tensor): the preprocessed image tensor of shape + ``(N, C, H, W)``. + data_samples (List[DataSample], optional): The annotation data + of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to 'predict'. + """ + if mode == 'predict': + return self.predict(images, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: + """The function to extract image latent features.""" + return self.visual_proj(self.visual(images))[0] + + def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: + """The function to extract text latent features.""" + x = self.token_embedding(texts) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x)[0] + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + texts.argmax(dim=-1)] @ self.text_projection + + return x + + def extract_feat( + self, images: torch.Tensor, + texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """The function to extract image and text latent features, the input + image or text can not both be None.""" + + assert images is not None or texts is not None, \ + 'text and image cannot both be None!' + if images is None: + return self.extract_text_feat(texts) + elif texts is None: + return self.extract_image_feat(images) + + image_features = self.extract_image_feat(images) + text_features = self.extract_text_feat(texts) + + image_features = image_features / image_features.norm( + dim=-1, keepdim=True) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + return image_features, text_features + + def compute_similarity(self, images, texts): + """Extract images and texts features and compute cosine similarity.""" + image_features, text_features = self.extract_feat( + images=images, texts=texts) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape (N, N) + return logits_per_image, logits_per_text + + @abstractmethod + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + raise NotImplementedError + + def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Args: + texts (Union[str, List[str]]): An input string or a list of input + strings to tokenize + context_length (int): The context length to use. Defaults to 52. + + Returns: + torch.Tensor: Resulting tokens. + """ + if isinstance(texts, str): + texts = [texts] + + all_tokens = [] + for text in texts: + # adapt the text to Chinese BERT vocab + # text = text.lower().replace('“', "\"").replace('”', "\"") + + # add special tokens + all_tokens.append( + [self.tokenizer.vocab['<|startoftext|>'] + ] + # <|startoftext|>代表[CLS] token + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text))[:self.context_length - 2] + + [self.tokenizer.vocab['<|endoftext|>']]) + + result = torch.zeros( + len(all_tokens), self.context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + assert len(tokens) <= self.context_length + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +@MODELS.register_module() +class CLIPZeroShot(CLIP): + + def __init__( + self, + vision_backbone: dict, + projection: dict, + text_backbone: dict, + tokenizer: dict, + vocab_size: int, + transformer_width: int, + proj_dim: int, + context_length: int = 77, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + text_prototype: Union[str, List[str]] = 'imagenet', + text_prompt: str = 'vanilla', + ): + super(CLIPZeroShot, + self).__init__(vision_backbone, projection, text_backbone, + tokenizer, vocab_size, transformer_width, + proj_dim, context_length, data_preprocessor, + init_cfg) + + # for zero-shot classification + if isinstance(text_prototype, + str) and text_prototype in PROTOTYPE_MAP.keys(): + self.prototype = PROTOTYPE_MAP[text_prototype] + else: + self.prototype = text_prototype + self.text_prototype_embeds = None + + self.prompt = PROMPT_MAP[text_prompt] + + def predict(self, + images: torch.Tensor, + data_samples: DataSample = None) -> DataSample: + """Predict the classes of the input images. + + The prediction is for zero-shot classification and the text prototypes + will be prepared in thisfunction. + + Args: + images (torch.Tensor): The input images. + data_samples (DataSample): The data samples with information from + dataset. + + Returns: + DataSample: The results of prediction. + """ + + if self.text_prototype_embeds is None: + self.prepare_text_prototype(device=images.device) + + image_features = self.extract_image_feat(images=images) + image_features /= image_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = image_features @ self.text_prototype_embeds.to( + image_features.device) * self.logit_scale.exp() + + pred_scores = F.softmax(logits_per_image, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = DataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples + + def prepare_text_prototype(self, device) -> None: + """The function to prepare text prototypes with prompt.""" + class_embeddings = [] + for classname in track_on_main_process(self.prototype, + 'Prepare text prototype...'): + # format with class + texts = [prompt(classname) for prompt in self.prompt] + tokenized_texts = self.tokenize(texts) + class_features = self.extract_text_feat(tokenized_texts.to(device)) + class_features /= class_features.norm(dim=-1, keepdim=True) + class_feature = class_features.mean(dim=0) + class_feature /= class_feature.norm() + class_embeddings.append(class_feature) + self.text_prototype_embeds = torch.stack( + class_embeddings, dim=1).to(device) diff --git a/mmpretrain/models/multimodal/clip/clip_transformer.py b/mmpretrain/models/multimodal/clip/clip_transformer.py new file mode 100644 index 00000000000..4b5f76661cb --- /dev/null +++ b/mmpretrain/models/multimodal/clip/clip_transformer.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from typing import Optional, Tuple + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmpretrain.models.utils.clip_generator_helper import \ + ResidualAttentionBlock +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CLIPTransformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +@MODELS.register_module() +class CLIPProjection(BaseModule): + """Neck with CLIP Projection. + + Args: + in_channels (int): Number of channels in the input. + out_channels (int): Number of channels in the output. + init_cfg (dict | list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + init_cfg: Optional[dict] = None): + super(CLIPProjection, self).__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + scale = in_channels**-0.5 + self.proj = nn.Parameter(scale * + torch.randn(in_channels, out_channels)) + + def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]: + """forward function. + + Args: + inputs (Tuple): The features extracted from + the backbone. Multiple stage inputs are acceptable but only + the last stage will be used. + Returns: + Tuple(torch.Tensor)): A tuple of reducted features. + """ + if isinstance(inputs, tuple): + inputs = inputs[-1] + out = inputs @ self.proj + elif isinstance(inputs, torch.Tensor): + out = inputs @ self.proj + else: + raise TypeError( + '`CLIPProjection` neck inputs should be tuple or torch.tensor') + return (out, ) diff --git a/mmpretrain/models/multimodal/clip/utils.py b/mmpretrain/models/multimodal/clip/utils.py new file mode 100644 index 00000000000..65239bc37d6 --- /dev/null +++ b/mmpretrain/models/multimodal/clip/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +OPENAI_CIFAR100_PROMPT = [ + lambda c: f'a photo of a {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'a low contrast photo of a {c}.', + lambda c: f'a high contrast photo of a {c}.', + lambda c: f'a bad photo of a {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a photo of a big {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a low contrast photo of the {c}.', + lambda c: f'a high contrast photo of the {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the big {c}.', +] + +OPENAI_IMAGENET_PROMPT_SUB = [ + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +] + +OPENAI_IMAGENET_PROMPT = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] diff --git a/tools/model_converters/openai-clip_to_mmpretrain-clip.py b/tools/model_converters/openai-clip_to_mmpretrain-clip.py new file mode 100644 index 00000000000..72725502551 --- /dev/null +++ b/tools/model_converters/openai-clip_to_mmpretrain-clip.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_clip(ckpt): + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('visual.conv1'): + new_k = k.replace('conv1', 'patch_embed.projection') + elif k.startswith('visual.positional_embedding'): + new_k = k.replace('positional_embedding', 'pos_embed') + new_v = v.unsqueeze(dim=0) + elif k.startswith('visual.class_embedding'): + new_k = k.replace('class_embedding', 'cls_token') + new_v = v.unsqueeze(dim=0).unsqueeze(dim=0) + elif k.startswith('visual.ln_pre'): + new_k = k.replace('ln_pre', 'pre_norm') + elif k.startswith('visual.transformer.resblocks'): + new_k = k.replace('transformer.resblocks', 'layers') + if 'ln_1' in k: + new_k = new_k.replace('ln_1', 'ln1') + elif 'ln_2' in k: + new_k = new_k.replace('ln_2', 'ln2') + elif 'mlp.c_fc' in k: + new_k = new_k.replace('mlp.c_fc', 'ffn.layers.0.0') + elif 'mlp.c_proj' in k: + new_k = new_k.replace('mlp.c_proj', 'ffn.layers.1') + elif 'attn.in_proj_weight' in k: + new_k = new_k.replace('in_proj_weight', 'qkv.weight') + elif 'attn.in_proj_bias' in k: + new_k = new_k.replace('in_proj_bias', 'qkv.bias') + elif 'attn.out_proj' in k: + new_k = new_k.replace('out_proj', 'proj') + elif k.startswith('visual.ln_post'): + new_k = k.replace('ln_post', 'ln1') + elif k.startswith('visual.proj'): + new_k = k.replace('visual.proj', 'visual_proj.proj') + else: + new_k = k + + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained clip ' + 'models to mmpretrain style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + weight = convert_clip(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main() From 06bb586eb715626f19e97dfa8b632f104ba47d2b Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Sun, 8 Oct 2023 15:44:37 +0800 Subject: [PATCH 4/9] [Fix] Fix pipeline bug in image retrieval inferencer --- mmpretrain/apis/image_retrieval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py index deae1de7975..27919b20f58 100644 --- a/mmpretrain/apis/image_retrieval.py +++ b/mmpretrain/apis/image_retrieval.py @@ -108,6 +108,7 @@ def build_dataloader(dataset): # A config of dataset from mmpretrain.registry import DATASETS test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline] + prototype.setdefault('pipeline', test_pipeline) dataset = DATASETS.build(prototype) dataloader = build_dataloader(dataset) elif isinstance(prototype, DataLoader): From 3bcf7e2d6ed1d4c215dcf5e404dd6da52e8f0e3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A3=9E=E9=A3=9E?= <102729089+ASHORE1225@users.noreply.github.com> Date: Sun, 8 Oct 2023 15:46:47 +0800 Subject: [PATCH 5/9] =?UTF-8?q?[CodeCamp2023-341]=20=E5=A4=9A=E6=A8=A1?= =?UTF-8?q?=E6=80=81=E6=95=B0=E6=8D=AE=E9=9B=86=E6=96=87=E6=A1=A3=E8=A1=A5?= =?UTF-8?q?=E5=85=85-COCO=20Retrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mmpretrain/datasets/coco_retrieval.py | 75 ++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/mmpretrain/datasets/coco_retrieval.py b/mmpretrain/datasets/coco_retrieval.py index 60d1586ad86..be8a0bcb864 100644 --- a/mmpretrain/datasets/coco_retrieval.py +++ b/mmpretrain/datasets/coco_retrieval.py @@ -1,18 +1,45 @@ # Copyright (c) OpenMMLab. All rights reserved. import json +import os.path as osp from collections import OrderedDict -from typing import List +from os import PathLike +from typing import List, Sequence, Union from mmengine import get_file_backend -from mmpretrain.registry import DATASETS +from mmpretrain.registry import DATASETS, TRANSFORMS from .base_dataset import BaseDataset +def expanduser(data_prefix): + if isinstance(data_prefix, (str, PathLike)): + return osp.expanduser(data_prefix) + else: + return data_prefix + + @DATASETS.register_module() class COCORetrieval(BaseDataset): """COCO Retrieval dataset. + COCO (Common Objects in Context): The COCO dataset contains more than + 330K images,each of which has approximately 5 descriptive annotations. + This dataset was releasedin collaboration between Microsoft and Carnegie + Mellon University + + COCO_2014 dataset directory: :: + + COCO_2014 + ├── val2014 + ├── train2014 + ├── annotations + ├── instances_train2014.json + ├── instances_val2014.json + ├── person_keypoints_train2014.json + ├── person_keypoints_val2014.json + ├── captions_train2014.json + ├── captions_val2014.json + Args: ann_file (str): Annotation file path. test_mode (bool): Whether dataset is used for evaluation. This will @@ -23,8 +50,52 @@ class COCORetrieval(BaseDataset): data_prefix (str | dict): Prefix for training data. Defaults to ''. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmpretrain.datasets import COCORetrieval + >>> train_dataset=COCORetrieval(data_root='coco2014/') + >>> train_dataset + Dataset COCORetrieval + Number of samples: 414113 + Annotation file: /coco2014/annotations/captions_train2014.json + Prefix of images: /coco2014/ + >>> from mmpretrain.datasets import COCORetrieval + >>> val_dataset = COCORetrieval(data_root='coco2014/') + >>> val_dataset + Dataset COCORetrieval + Number of samples: 202654 + Annotation file: /coco2014/annotations/captions_val2014.json + Prefix of images: /coco2014/ """ + def __init__(self, + ann_file: str, + test_mode: bool = False, + data_prefix: Union[str, dict] = '', + data_root: str = '', + pipeline: Sequence = (), + **kwargs): + + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + transforms = [] + for transform in pipeline: + if isinstance(transform, dict): + transforms.append(TRANSFORMS.build(transform)) + else: + transforms.append(transform) + + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + test_mode=test_mode, + pipeline=transforms, + ann_file=ann_file, + **kwargs, + ) + def load_data_list(self) -> List[dict]: """Load data list.""" # get file backend From f4efe8b78765185f38de32359953f3870dc72010 Mon Sep 17 00:00:00 2001 From: joaquincabezas <1833086+joaquincabezas@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:24:28 +0200 Subject: [PATCH 6/9] fix torchserve version --- docker/serve/Dockerfile | 5 +- tools/torchserve/mmpretrain2torchserve.py | 88 +++++++++++++---------- 2 files changed, 53 insertions(+), 40 deletions(-) diff --git a/docker/serve/Dockerfile b/docker/serve/Dockerfile index bff871b722d..c8080d0d423 100644 --- a/docker/serve/Dockerfile +++ b/docker/serve/Dockerfile @@ -1,6 +1,7 @@ ARG PYTORCH="1.12.1" ARG CUDA="11.3" ARG CUDNN="8" +ARG TORCHSERVE="0.7.1" FROM pytorch/torchserve:latest-gpu ARG MMPRE="1.0.2" @@ -11,8 +12,8 @@ ENV HOME="/home/model-server" ENV PATH="/opt/conda/bin:$HOME/.local/bin:$PATH" RUN export FORCE_CUDA=1 -# TORCHSEVER -RUN pip install torchserve torch-model-archiver +# TORCHSERVE +RUN pip install torchserve==${TORCHSERVE} torch-model-archiver==${TORCHSERVE} RUN pip install nvgpu # OPEN-MMLAB diff --git a/tools/torchserve/mmpretrain2torchserve.py b/tools/torchserve/mmpretrain2torchserve.py index 8d53bf3f924..008833cc5b9 100644 --- a/tools/torchserve/mmpretrain2torchserve.py +++ b/tools/torchserve/mmpretrain2torchserve.py @@ -10,8 +10,9 @@ from model_archiver.model_packaging_utils import ModelExportUtils except ImportError: raise ImportError( - 'Please run `pip install torchserve torch-model-archiver"` to ' - 'install required third-party libraries.') + 'Please run `pip install torchserve==0.7.1 torch-model-archiver==0.7.1"` to ' + "install required third-party libraries." + ) def mmpretrain2torchserve( @@ -19,7 +20,7 @@ def mmpretrain2torchserve( checkpoint_file: str, output_folder: str, model_name: str, - model_version: str = '1.0', + model_version: str = "1.0", force: bool = False, ): """Converts mmpretrain model (config + checkpoint) to TorchServe `.mar`. @@ -49,64 +50,75 @@ def mmpretrain2torchserve( config = mmengine.Config.fromfile(config_file) with TemporaryDirectory() as tmpdir: - config.dump(f'{tmpdir}/config.py') + config.dump(f"{tmpdir}/config.py") args = Namespace( **{ - 'model_file': f'{tmpdir}/config.py', - 'serialized_file': checkpoint_file, - 'handler': f'{Path(__file__).parent}/mmpretrain_handler.py', - 'model_name': model_name or Path(checkpoint_file).stem, - 'version': model_version, - 'export_path': output_folder, - 'force': force, - 'requirements_file': None, - 'extra_files': None, - 'runtime': 'python', - 'archive_format': 'default' - }) + "model_file": f"{tmpdir}/config.py", + "serialized_file": checkpoint_file, + "handler": f"{Path(__file__).parent}/mmpretrain_handler.py", + "model_name": model_name or Path(checkpoint_file).stem, + "version": model_version, + "export_path": output_folder, + "force": force, + "requirements_file": None, + "extra_files": None, + "runtime": "python", + "archive_format": "default", + } + ) manifest = ModelExportUtils.generate_manifest_json(args) package_model(args, manifest) def parse_args(): parser = ArgumentParser( - description='Convert mmpretrain models to TorchServe `.mar` format.') - parser.add_argument('config', type=str, help='config file path') - parser.add_argument('checkpoint', type=str, help='checkpoint file path') + description="Convert mmpretrain models to TorchServe `.mar` format." + ) + parser.add_argument("config", type=str, help="config file path") + parser.add_argument("checkpoint", type=str, help="checkpoint file path") parser.add_argument( - '--output-folder', + "--output-folder", type=str, required=True, - help='Folder where `{model_name}.mar` will be created.') + help="Folder where `{model_name}.mar` will be created.", + ) parser.add_argument( - '--model-name', + "--model-name", type=str, default=None, - help='If not None, used for naming the `{model_name}.mar`' - 'file that will be created under `output_folder`.' - 'If None, `{Path(checkpoint_file).stem}` will be used.') + help="If not None, used for naming the `{model_name}.mar`" + "file that will be created under `output_folder`." + "If None, `{Path(checkpoint_file).stem}` will be used.", + ) parser.add_argument( - '--model-version', - type=str, - default='1.0', - help='Number used for versioning.') + "--model-version", type=str, default="1.0", help="Number used for versioning." + ) parser.add_argument( - '-f', - '--force', - action='store_true', - help='overwrite the existing `{model_name}.mar`') + "-f", + "--force", + action="store_true", + help="overwrite the existing `{model_name}.mar`", + ) args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() if package_model is None: - raise ImportError('`torch-model-archiver` is required.' - 'Try: pip install torch-model-archiver') + raise ImportError( + "`torch-model-archiver` is required." + "Try: pip install torch-model-archiver==0.7.1" + ) - mmpretrain2torchserve(args.config, args.checkpoint, args.output_folder, - args.model_name, args.model_version, args.force) + mmpretrain2torchserve( + args.config, + args.checkpoint, + args.output_folder, + args.model_name, + args.model_version, + args.force, + ) From aaf9e4002aacebfd25a664d9b7f681b091584bd0 Mon Sep 17 00:00:00 2001 From: joaquincabezas <1833086+joaquincabezas@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:30:01 +0200 Subject: [PATCH 7/9] undo black formatting --- tools/torchserve/mmpretrain2torchserve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/torchserve/mmpretrain2torchserve.py b/tools/torchserve/mmpretrain2torchserve.py index 008833cc5b9..02a8e177269 100644 --- a/tools/torchserve/mmpretrain2torchserve.py +++ b/tools/torchserve/mmpretrain2torchserve.py @@ -10,7 +10,7 @@ from model_archiver.model_packaging_utils import ModelExportUtils except ImportError: raise ImportError( - 'Please run `pip install torchserve==0.7.1 torch-model-archiver==0.7.1"` to ' + 'Please run `pip install torchserve==0.7.1 torch-model-archiver==0.7.1 "` to ' "install required third-party libraries." ) From dd920d40f5b19f8f180bc54e7bf819e1d48cbec2 Mon Sep 17 00:00:00 2001 From: joaquincabezas <1833086+joaquincabezas@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:31:38 +0200 Subject: [PATCH 8/9] return to mmpretrain style --- tools/torchserve/mmpretrain2torchserve.py | 86 ++++++++++------------- 1 file changed, 37 insertions(+), 49 deletions(-) diff --git a/tools/torchserve/mmpretrain2torchserve.py b/tools/torchserve/mmpretrain2torchserve.py index 02a8e177269..3aba99348d7 100644 --- a/tools/torchserve/mmpretrain2torchserve.py +++ b/tools/torchserve/mmpretrain2torchserve.py @@ -11,8 +11,7 @@ except ImportError: raise ImportError( 'Please run `pip install torchserve==0.7.1 torch-model-archiver==0.7.1 "` to ' - "install required third-party libraries." - ) + 'install required third-party libraries.') def mmpretrain2torchserve( @@ -20,7 +19,7 @@ def mmpretrain2torchserve( checkpoint_file: str, output_folder: str, model_name: str, - model_version: str = "1.0", + model_version: str = '1.0', force: bool = False, ): """Converts mmpretrain model (config + checkpoint) to TorchServe `.mar`. @@ -50,75 +49,64 @@ def mmpretrain2torchserve( config = mmengine.Config.fromfile(config_file) with TemporaryDirectory() as tmpdir: - config.dump(f"{tmpdir}/config.py") + config.dump(f'{tmpdir}/config.py') args = Namespace( **{ - "model_file": f"{tmpdir}/config.py", - "serialized_file": checkpoint_file, - "handler": f"{Path(__file__).parent}/mmpretrain_handler.py", - "model_name": model_name or Path(checkpoint_file).stem, - "version": model_version, - "export_path": output_folder, - "force": force, - "requirements_file": None, - "extra_files": None, - "runtime": "python", - "archive_format": "default", - } - ) + 'model_file': f'{tmpdir}/config.py', + 'serialized_file': checkpoint_file, + 'handler': f'{Path(__file__).parent}/mmpretrain_handler.py', + 'model_name': model_name or Path(checkpoint_file).stem, + 'version': model_version, + 'export_path': output_folder, + 'force': force, + 'requirements_file': None, + 'extra_files': None, + 'runtime': 'python', + 'archive_format': 'default' + }) manifest = ModelExportUtils.generate_manifest_json(args) package_model(args, manifest) def parse_args(): parser = ArgumentParser( - description="Convert mmpretrain models to TorchServe `.mar` format." - ) - parser.add_argument("config", type=str, help="config file path") - parser.add_argument("checkpoint", type=str, help="checkpoint file path") + description='Convert mmpretrain models to TorchServe `.mar` format.') + parser.add_argument('config', type=str, help='config file path') + parser.add_argument('checkpoint', type=str, help='checkpoint file path') parser.add_argument( - "--output-folder", + '--output-folder', type=str, required=True, - help="Folder where `{model_name}.mar` will be created.", - ) + help='Folder where `{model_name}.mar` will be created.') parser.add_argument( - "--model-name", + '--model-name', type=str, default=None, - help="If not None, used for naming the `{model_name}.mar`" - "file that will be created under `output_folder`." - "If None, `{Path(checkpoint_file).stem}` will be used.", - ) + help='If not None, used for naming the `{model_name}.mar`' + 'file that will be created under `output_folder`.' + 'If None, `{Path(checkpoint_file).stem}` will be used.') parser.add_argument( - "--model-version", type=str, default="1.0", help="Number used for versioning." - ) + '--model-version', + type=str, + default='1.0', + help='Number used for versioning.') parser.add_argument( - "-f", - "--force", - action="store_true", - help="overwrite the existing `{model_name}.mar`", - ) + '-f', + '--force', + action='store_true', + help='overwrite the existing `{model_name}.mar`') args = parser.parse_args() return args -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() if package_model is None: - raise ImportError( - "`torch-model-archiver` is required." - "Try: pip install torch-model-archiver==0.7.1" - ) + raise ImportError('`torch-model-archiver` is required.' + 'Try: pip install torch-model-archiver==0.7.1') - mmpretrain2torchserve( - args.config, - args.checkpoint, - args.output_folder, - args.model_name, - args.model_version, - args.force, - ) + mmpretrain2torchserve(args.config, args.checkpoint, args.output_folder, + args.model_name, args.model_version, args.force) From 79efd27a7ec15c84b3dd718644e900ffc31f7409 Mon Sep 17 00:00:00 2001 From: joaquincabezas <1833086+joaquincabezas@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:08:03 +0200 Subject: [PATCH 9/9] linting --- tools/torchserve/mmpretrain2torchserve.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/torchserve/mmpretrain2torchserve.py b/tools/torchserve/mmpretrain2torchserve.py index 3aba99348d7..a0e0384a9ee 100644 --- a/tools/torchserve/mmpretrain2torchserve.py +++ b/tools/torchserve/mmpretrain2torchserve.py @@ -10,7 +10,8 @@ from model_archiver.model_packaging_utils import ModelExportUtils except ImportError: raise ImportError( - 'Please run `pip install torchserve==0.7.1 torch-model-archiver==0.7.1 "` to ' + 'Please run ' + '`pip install torchserve==0.7.1 torch-model-archiver==0.7.1"` to ' 'install required third-party libraries.')