Skip to content

Commit

Permalink
[Refactor] Add self-supervised backbones and target generators. (open…
Browse files Browse the repository at this point in the history
…-mmlab#1379)

* add heads

* add losses

* fix

* remove mim head

* add modified backbones and target generators

* add unittest

* refactor caevit

* add window_size check

* fix lint

* apply new DataSample

* fix ut error

* update ut

* fix ut

* fix lint

* Update base modules.

---------

Co-authored-by: mzr1996 <[email protected]>
  • Loading branch information
fangyixiao18 and mzr1996 authored Feb 28, 2023
1 parent 63d9f27 commit e453a45
Show file tree
Hide file tree
Showing 30 changed files with 2,776 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
rev: v0.4.0
hooks:
- id: check-copyright
args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"]
args: ["mmpretrain", "tests", "demo", "tools", "--excludes", "mmpretrain/.mim/", "--ignore-file-not-found-error"]
- repo: local
hooks:
- id: metafile
Expand Down
1 change: 1 addition & 0 deletions mmpretrain/evaluation/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
1 change: 0 additions & 1 deletion mmpretrain/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .necks import * # noqa: F401,F403
from .retrievers import * # noqa: F401,F403
from .selfsup import * # noqa: F401,F403
from .target_generators import * # noqa: F401,F403
from .tta import * # noqa: F401,F403
from .utils import * # noqa: F401,F403

Expand Down
4 changes: 2 additions & 2 deletions mmpretrain/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
from .beit import BEiT
from .beit import BEiTViT
from .conformer import Conformer
from .convmixer import ConvMixer
from .convnext import ConvNeXt
Expand Down Expand Up @@ -106,7 +106,7 @@
'HorNet',
'MobileViT',
'DaViT',
'BEiT',
'BEiTViT',
'RevVisionTransformer',
'MixMIMTransformer',
'TinyViT',
Expand Down
11 changes: 7 additions & 4 deletions mmpretrain/models/backbones/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def __init__(self,
drop_path_rate=0.,
drop_rate=0.,
num_fcs=num_fcs,
qkv_bias=bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
Expand Down Expand Up @@ -214,7 +213,7 @@ def forward(self, x: torch.Tensor,


@MODELS.register_module()
class BEiT(VisionTransformer):
class BEiTViT(VisionTransformer):
"""Backbone for BEiT.
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
Expand Down Expand Up @@ -244,8 +243,10 @@ class BEiT(VisionTransformer):
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
Expand Down Expand Up @@ -285,6 +286,7 @@ def __init__(self,
out_indices=-1,
drop_rate=0,
drop_path_rate=0,
bias='qv_bias',
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=False,
with_cls_token=True,
Expand Down Expand Up @@ -395,6 +397,7 @@ def __init__(self,
use_rel_pos_bias=use_rel_pos_bias,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
bias=bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
Expand Down
25 changes: 13 additions & 12 deletions mmpretrain/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class ImageClassifier(BaseClassifier):
- augments (List[dict]): The batch augmentation methods to use.
More details can be found in
:mod:`mmpretrain.model.utils.augment`.
- probs (List[float], optional): The probability of every batch
augmentation methods. If None, choose evenly. Defaults to None.
Defaults to None.
data_preprocessor (dict, optional): The config for preprocessing input
Expand All @@ -51,14 +53,16 @@ def __init__(self,
if pretrained is not None:
init_cfg = dict(type='Pretrained', checkpoint=pretrained)

if data_preprocessor is None:
data_preprocessor = {}
# The build process is in MMEngine, so we need to add scope here.
data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor')
data_preprocessor = data_preprocessor or {}

if train_cfg is not None and 'augments' in train_cfg:
# Set batch augmentations by `train_cfg`
data_preprocessor['batch_augments'] = train_cfg
if isinstance(data_preprocessor, dict):
data_preprocessor.setdefault('type', 'ClsDataPreprocessor')
data_preprocessor.setdefault('batch_augments', train_cfg)
data_preprocessor = MODELS.build(data_preprocessor)
elif not isinstance(data_preprocessor, nn.Module):
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')

super(ImageClassifier, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
Expand All @@ -82,16 +86,13 @@ def forward(self,
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "tensor": Forward the whole network and return tensor(s) without any
post-processing, same as a common PyTorch Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
Expand Down
26 changes: 26 additions & 0 deletions mmpretrain/models/selfsup/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseSelfSupervisor
from .beit import VQKD, BEiTPretrainViT
from .cae import CAEViT, Encoder
from .mae import MAEViT
from .maskfeat import HOGGenerator, MaskFeatViT
from .milan import CLIPGenerator, MILANViT
from .mixmim import MixMIMPretrainTransformer
from .mocov3 import MoCoV3ViT
from .simmim import SimMIMSwinTransformer

__all__ = [
'BaseSelfSupervisor',
'BEiTPretrainViT',
'VQKD',
'CAEViT',
'Encoder',
'MAEViT',
'HOGGenerator',
'MaskFeatViT',
'CLIPGenerator',
'MILANViT',
'MixMIMPretrainTransformer',
'MoCoV3ViT',
'SimMIMSwinTransformer',
]
164 changes: 164 additions & 0 deletions mmpretrain/models/selfsup/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Union

import torch
from mmengine.model import BaseModel
from torch import nn

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample


class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta):
"""BaseModel for Self-Supervised Learning.
All self-supervised algorithms should inherit this module.
Args:
backbone (dict): The backbone module. See
:mod:`mmpretrain.models.backbones`.
neck (dict, optional): The neck module to process features from
backbone. See :mod:`mmpretrain.models.necks`. Defaults to None.
head (dict, optional): The head module to do prediction and calculate
loss from processed features. See :mod:`mmpretrain.models.heads`.
Notice that if the head is not set, almost all methods cannot be
used except :meth:`extract_feat`. Defaults to None.
target_generator: (dict, optional): The target_generator module to
generate targets for self-supervised learning optimization, such as
HOG, extracted features from other modules(DALL-E, CLIP), etc.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""

def __init__(self,
backbone: dict,
neck: Optional[dict] = None,
head: Optional[dict] = None,
target_generator: Optional[dict] = None,
pretrained: Optional[str] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
if pretrained is not None:
init_cfg = dict(type='Pretrained', checkpoint=pretrained)

data_preprocessor = data_preprocessor or {}
if isinstance(data_preprocessor, dict):
data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
elif not isinstance(data_preprocessor, nn.Module):
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')

super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)

if not isinstance(backbone, nn.Module):
backbone = MODELS.build(backbone)
if neck is not None and not isinstance(neck, nn.Module):
neck = MODELS.build(neck)
if head is not None and not isinstance(head, nn.Module):
head = MODELS.build(head)
if target_generator is not None and not isinstance(
target_generator, nn.Module):
target_generator = MODELS.build(target_generator)

self.backbone = backbone
self.neck = neck
self.head = head
self.target_generator = target_generator

@property
def with_neck(self) -> bool:
"""Check if the model has a neck module."""
return hasattr(self, 'neck') and self.neck is not None

@property
def with_head(self) -> bool:
"""Check if the model has a head module."""
return hasattr(self, 'head') and self.head is not None

@property
def with_target_generator(self) -> bool:
"""Check if the model has a target_generator module."""
return hasattr(
self, 'target_generator') and self.target_generator is not None

def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the backbone network and return the feature
tensor(s) tensor without any post-processing, same as a common
PyTorch Module.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample], optional): The other data of
every samples. It's required for some algorithms
if ``mode="loss"``. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
feats = self.extract_feat(inputs)
return feats
elif mode == 'loss':
return self.loss(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')

@abstractmethod
def extract_feat(self, inputs: torch.Tensor):
"""Extract features from the input tensor with shape (N, C, ...).
The sub-classes are recommended to implement this method to extract
features from backbone and neck.
Args:
inputs (Tensor): A batch of inputs. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
Returns:
tuple | Tensor: The output feature tensor(s).
"""
raise NotImplementedError

@abstractmethod
def loss(self, inputs: torch.Tensor,
data_samples: List[DataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
This is a abstract method, and subclass should overwrite this methods
if needed.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample]): The annotation data of
every samples.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
raise NotImplementedError
Loading

0 comments on commit e453a45

Please sign in to comment.