diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee7c2a0c616..fc7d59f2eec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,7 @@ repos: rev: v0.4.0 hooks: - id: check-copyright - args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"] + args: ["mmpretrain", "tests", "demo", "tools", "--excludes", "mmpretrain/.mim/", "--ignore-file-not-found-error"] - repo: local hooks: - id: metafile diff --git a/mmpretrain/evaluation/functional/__init__.py b/mmpretrain/evaluation/functional/__init__.py index e69de29bb2d..ef101fec61e 100644 --- a/mmpretrain/evaluation/functional/__init__.py +++ b/mmpretrain/evaluation/functional/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmpretrain/models/__init__.py b/mmpretrain/models/__init__.py index 5d7c95ae12f..767bbaa2a00 100644 --- a/mmpretrain/models/__init__.py +++ b/mmpretrain/models/__init__.py @@ -9,7 +9,6 @@ from .necks import * # noqa: F401,F403 from .retrievers import * # noqa: F401,F403 from .selfsup import * # noqa: F401,F403 -from .target_generators import * # noqa: F401,F403 from .tta import * # noqa: F401,F403 from .utils import * # noqa: F401,F403 diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py index ca21dba92fb..00c2118e4b3 100644 --- a/mmpretrain/models/backbones/__init__.py +++ b/mmpretrain/models/backbones/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .alexnet import AlexNet -from .beit import BEiT +from .beit import BEiTViT from .conformer import Conformer from .convmixer import ConvMixer from .convnext import ConvNeXt @@ -106,7 +106,7 @@ 'HorNet', 'MobileViT', 'DaViT', - 'BEiT', + 'BEiTViT', 'RevVisionTransformer', 'MixMIMTransformer', 'TinyViT', diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py index 0301126536b..439efa2bd9f 100644 --- a/mmpretrain/models/backbones/beit.py +++ b/mmpretrain/models/backbones/beit.py @@ -155,7 +155,6 @@ def __init__(self, drop_path_rate=0., drop_rate=0., num_fcs=num_fcs, - qkv_bias=bias, act_cfg=act_cfg, norm_cfg=norm_cfg, init_cfg=init_cfg) @@ -214,7 +213,7 @@ def forward(self, x: torch.Tensor, @MODELS.register_module() -class BEiT(VisionTransformer): +class BEiTViT(VisionTransformer): """Backbone for BEiT. A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers @@ -244,8 +243,10 @@ class BEiT(VisionTransformer): drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. - qkv_bias (bool): Whether to add bias for qkv in attention modules. - Defaults to True. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize @@ -285,6 +286,7 @@ def __init__(self, out_indices=-1, drop_rate=0, drop_path_rate=0, + bias='qv_bias', norm_cfg=dict(type='LN', eps=1e-6), final_norm=False, with_cls_token=True, @@ -395,6 +397,7 @@ def __init__(self, use_rel_pos_bias=use_rel_pos_bias, drop_rate=drop_rate, drop_path_rate=dpr[i], + bias=bias, norm_cfg=norm_cfg) _layer_cfg.update(layer_cfgs[i]) self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg)) diff --git a/mmpretrain/models/classifiers/image.py b/mmpretrain/models/classifiers/image.py index ffb49fb9158..7b70af0486a 100644 --- a/mmpretrain/models/classifiers/image.py +++ b/mmpretrain/models/classifiers/image.py @@ -30,6 +30,8 @@ class ImageClassifier(BaseClassifier): - augments (List[dict]): The batch augmentation methods to use. More details can be found in :mod:`mmpretrain.model.utils.augment`. + - probs (List[float], optional): The probability of every batch + augmentation methods. If None, choose evenly. Defaults to None. Defaults to None. data_preprocessor (dict, optional): The config for preprocessing input @@ -51,14 +53,16 @@ def __init__(self, if pretrained is not None: init_cfg = dict(type='Pretrained', checkpoint=pretrained) - if data_preprocessor is None: - data_preprocessor = {} - # The build process is in MMEngine, so we need to add scope here. - data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor') + data_preprocessor = data_preprocessor or {} - if train_cfg is not None and 'augments' in train_cfg: - # Set batch augmentations by `train_cfg` - data_preprocessor['batch_augments'] = train_cfg + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'ClsDataPreprocessor') + data_preprocessor.setdefault('batch_augments', train_cfg) + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') super(ImageClassifier, self).__init__( init_cfg=init_cfg, data_preprocessor=data_preprocessor) @@ -82,16 +86,13 @@ def forward(self, The method should accept three modes: "tensor", "predict" and "loss": - - "tensor": Forward the whole network and return tensor or tuple of - tensor without any post-processing, same as a common nn.Module. + - "tensor": Forward the whole network and return tensor(s) without any + post-processing, same as a common PyTorch Module. - "predict": Forward and return the predictions, which are fully processed to a list of :obj:`DataSample`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. - Note that this method doesn't handle neither back propagation nor - optimizer updating, which are done in the :meth:`train_step`. - Args: inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. diff --git a/mmpretrain/models/selfsup/__init__.py b/mmpretrain/models/selfsup/__init__.py index e69de29bb2d..3357c29e9f1 100644 --- a/mmpretrain/models/selfsup/__init__.py +++ b/mmpretrain/models/selfsup/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseSelfSupervisor +from .beit import VQKD, BEiTPretrainViT +from .cae import CAEViT, Encoder +from .mae import MAEViT +from .maskfeat import HOGGenerator, MaskFeatViT +from .milan import CLIPGenerator, MILANViT +from .mixmim import MixMIMPretrainTransformer +from .mocov3 import MoCoV3ViT +from .simmim import SimMIMSwinTransformer + +__all__ = [ + 'BaseSelfSupervisor', + 'BEiTPretrainViT', + 'VQKD', + 'CAEViT', + 'Encoder', + 'MAEViT', + 'HOGGenerator', + 'MaskFeatViT', + 'CLIPGenerator', + 'MILANViT', + 'MixMIMPretrainTransformer', + 'MoCoV3ViT', + 'SimMIMSwinTransformer', +] diff --git a/mmpretrain/models/selfsup/base.py b/mmpretrain/models/selfsup/base.py new file mode 100644 index 00000000000..d9a169af173 --- /dev/null +++ b/mmpretrain/models/selfsup/base.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Union + +import torch +from mmengine.model import BaseModel +from torch import nn + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta): + """BaseModel for Self-Supervised Learning. + + All self-supervised algorithms should inherit this module. + + Args: + backbone (dict): The backbone module. See + :mod:`mmpretrain.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmpretrain.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + target_generator: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + data_preprocessor = data_preprocessor or {} + if isinstance(data_preprocessor, dict): + data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor') + data_preprocessor = MODELS.build(data_preprocessor) + elif not isinstance(data_preprocessor, nn.Module): + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + if target_generator is not None and not isinstance( + target_generator, nn.Module): + target_generator = MODELS.build(target_generator) + + self.backbone = backbone + self.neck = neck + self.head = head + self.target_generator = target_generator + + @property + def with_neck(self) -> bool: + """Check if the model has a neck module.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Check if the model has a head module.""" + return hasattr(self, 'head') and self.head is not None + + @property + def with_target_generator(self) -> bool: + """Check if the model has a target_generator module.""" + return hasattr( + self, 'target_generator') and self.target_generator is not None + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[DataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the backbone network and return the feature + tensor(s) tensor without any post-processing, same as a common + PyTorch Module. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The other data of + every samples. It's required for some algorithms + if ``mode="loss"``. Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + @abstractmethod + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + + Returns: + tuple | Tensor: The output feature tensor(s). + """ + raise NotImplementedError + + @abstractmethod + def loss(self, inputs: torch.Tensor, + data_samples: List[DataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + This is a abstract method, and subclass should overwrite this methods + if needed. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + raise NotImplementedError diff --git a/mmpretrain/models/selfsup/beit.py b/mmpretrain/models/selfsup/beit.py new file mode 100644 index 00000000000..57138ece836 --- /dev/null +++ b/mmpretrain/models/selfsup/beit.py @@ -0,0 +1,280 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Tuple, Union + +import torch +from einops import rearrange +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from torch import nn + +from mmpretrain.models import BEiTViT +from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class VQKD(BaseModule): + """Vector-Quantized Knowledge Distillation. + + The module only contains encoder and VectorQuantizer part + Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py + + Args: + encoder_config (dict): The config of encoder. + decoder_config (dict, optional): The config of decoder. Currently, + VQKD only support to build encoder. Defaults to None. + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + decay (float): The decay parameter of EMA. Defaults to 0.99. + beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1. + quantize_kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + encoder_config: dict, + decoder_config: Optional[dict] = None, + num_embed: int = 8192, + embed_dims: int = 32, + decay: float = 0.99, + beta: float = 1.0, + quantize_kmeans_init: bool = True, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.encoder = BEiTViT(**encoder_config) + if decoder_config is not None: + self.decoder = BEiTViT(**decoder_config) + + self.quantize = NormEMAVectorQuantizer( + num_embed=num_embed, + embed_dims=embed_dims, + beta=beta, + decay=decay, + kmeans_init=quantize_kmeans_init, + ) + + # task layer + self.encode_task_layer = nn.Sequential( + nn.Linear(self.encoder.arch_settings['embed_dims'], + self.encoder.arch_settings['embed_dims']), nn.Tanh(), + nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims)) + + def get_tokens(self, x: torch.Tensor) -> dict: + """Get tokens for beit pre-training.""" + _, embed_ind, _ = self.encode(x) + output = {} + output['token'] = embed_ind.view(x.shape[0], -1) + output['input_img'] = x + + return output + + def encode( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode the input images and get corresponding results.""" + encoder_features = self.encoder(x)[0] + B, C, N1, N2 = encoder_features.shape + encoder_features = encoder_features.permute(0, 2, 3, + 1).reshape(B, N1 * N2, C) + + with torch.cuda.amp.autocast(enabled=False): + to_quantizer_features = self.encode_task_layer( + encoder_features.type_as(self.encode_task_layer[-1].weight)) + + N = to_quantizer_features.shape[1] + h, w = int(math.sqrt(N)), int(math.sqrt(N)) + + to_quantizer_features = rearrange( + to_quantizer_features, 'b (h w) c -> b c h w', h=h, + w=w) # reshape for quantizer + quantize, loss, embed_ind = self.quantize(to_quantizer_features) + + return quantize, embed_ind, loss + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The forward function. + + Currently, only support to get tokens. + """ + return self.get_tokens(x)['token'] + + +@MODELS.register_module() +class BEiTPretrainViT(BEiTViT): + """Vision Transformer for BEiT pre-training. + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + avg_token (bool): Whether or not to use the mean patch token for + classification. If True, the model will only take the average + of all patch tokens. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + output_cls_token (bool): Whether output the cls_token. If set True, + ``with_cls_token`` must be True. Defaults to True. + use_abs_pos_emb (bool): Whether or not use absolute position embedding. + Defaults to False. + use_rel_pos_bias (bool): Whether or not use relative position bias. + Defaults to False. + use_shared_rel_pos_bias (bool): Whether or not use shared relative + position bias. Defaults to True. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + avg_token: bool = False, + frozen_stages: int = -1, + output_cls_token: bool = True, + use_abs_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + layer_scale_init_value: int = 0.1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(padding=0), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + avg_token=avg_token, + frozen_stages=frozen_stages, + output_cls_token=output_cls_token, + use_abs_pos_emb=use_abs_pos_emb, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + use_rel_pos_bias=use_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=0.02) + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def forward(self, x: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor]: + """The BEiT style forward function. + + Args: + x (torch.Tensor): Input images, which is of shape (B x C x H x W). + mask (torch.Tensor): Mask for input, which is of shape + (B x patch_resolution[0] x patch_resolution[1]). + + Returns: + Tuple[torch.Tensor]: Hidden features. + """ + x, patch_resolution = self.patch_embed(x) + + # replace the masked visual tokens by mask_token + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + self.shared_rel_pos_bias = self.rel_pos_bias().to( + mask.device) if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias=self.shared_rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmpretrain/models/selfsup/cae.py b/mmpretrain/models/selfsup/cae.py new file mode 100644 index 00000000000..301dcdca6f0 --- /dev/null +++ b/mmpretrain/models/selfsup/cae.py @@ -0,0 +1,348 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Part of code is modified from BEiT +# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py +import math +from collections import OrderedDict +from functools import partial +from typing import List, Optional, Union + +import attr +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models import VisionTransformer +from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer +from mmpretrain.registry import MODELS +from ..utils import build_2d_sincos_position_embedding + + +@attr.s(eq=False) +class Conv2d(nn.Module): + n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) + n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) + kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) + + use_float16: bool = attr.ib(default=True) + device: torch.device = attr.ib(default=torch.device('cpu')) + requires_grad: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + super().__init__() + + w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), + dtype=torch.float32, + device=self.device, + requires_grad=self.requires_grad) + w.normal_(std=1 / math.sqrt(self.n_in * self.kw**2)) + + b = torch.zeros((self.n_out, ), + dtype=torch.float32, + device=self.device, + requires_grad=self.requires_grad) + self.w, self.b = nn.Parameter(w), nn.Parameter(b) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_float16 and 'cuda' in self.w.device.type: + if x.dtype != torch.float16: + x = x.half() + + w, b = self.w.half(), self.b.half() + else: + if x.dtype != torch.float32: + x = x.float() + + w, b = self.w, self.b + + return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) + + +@attr.s(eq=False, repr=False) +class EncoderBlock(nn.Module): + n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) + n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0) + n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) + + device: torch.device = attr.ib(default=None) + requires_grad: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + super().__init__() + self.n_hid = self.n_out // 4 + self.post_gain = 1 / (self.n_layers**2) + + make_conv = partial( + Conv2d, device=self.device, requires_grad=self.requires_grad) + self.id_path = make_conv( + self.n_in, self.n_out, + 1) if self.n_in != self.n_out else nn.Identity() + self.res_path = nn.Sequential( + OrderedDict([ + ('relu_1', nn.ReLU()), + ('conv_1', make_conv(self.n_in, self.n_hid, 3)), + ('relu_2', nn.ReLU()), + ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_3', nn.ReLU()), + ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_4', nn.ReLU()), + ('conv_4', make_conv(self.n_hid, self.n_out, 1)), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +@attr.s(eq=False, repr=False) +@MODELS.register_module(name='DALL-E') +class Encoder(BaseModule): + group_count: int = 4 + n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) + n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) + input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) + vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) + + device: torch.device = attr.ib(default=torch.device('cpu')) + requires_grad: bool = attr.ib(default=False) + use_mixed_precision: bool = attr.ib(default=True) + init_cfg: Optional[Union[dict, List[dict]]] = attr.ib(default=None) + + def __attrs_post_init__(self) -> None: + super().__init__(init_cfg=self.init_cfg) + + blk_range = range(self.n_blk_per_group) + n_layers = self.group_count * self.n_blk_per_group + make_conv = partial( + Conv2d, device=self.device, requires_grad=self.requires_grad) + make_blk = partial( + EncoderBlock, + n_layers=n_layers, + device=self.device, + requires_grad=self.requires_grad) + + self.blocks = nn.Sequential( + OrderedDict([ + ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), + ('group_1', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(1 * self.n_hid, 1 * self.n_hid)) + for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_2', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk( + 1 * self.n_hid if i == 0 else 2 * self.n_hid, + 2 * self.n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_3', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk( + 2 * self.n_hid if i == 0 else 4 * self.n_hid, + 4 * self.n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_4', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk( + 4 * self.n_hid if i == 0 else 8 * self.n_hid, + 8 * self.n_hid)) for i in blk_range], + ]))), + ('output', + nn.Sequential( + OrderedDict([ + ('relu', nn.ReLU()), + ('conv', + make_conv( + 8 * self.n_hid, + self.vocab_size, + 1, + use_float16=False)), + ]))), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.float() + if len(x.shape) != 4: + raise ValueError(f'input shape {x.shape} is not 4d') + if x.shape[1] != self.input_channels: + raise ValueError(f'input has {x.shape[1]} channels but model \ + built for {self.input_channels}') + if x.dtype != torch.float32: + raise ValueError('input must have dtype torch.float32') + + return self.blocks(x) + + +@MODELS.register_module() +class CAEViT(VisionTransformer): + """Vision Transformer for CAE pre-training. + + Rewritten version of: `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_ + + Args: + arch (str | dict): Vision Transformer architecture. Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + bias (bool | str): The option to add leanable bias for q, k, v. If bias + is True, it will add leanable bias. If bias is 'qv_bias', it will + only add leanable bias for q, v. If bias is False, it will not add + bias for q, k, v. Default to 'qv_bias'. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + layer_scale_init_value (float, optional): The init value of gamma in + BEiTTransformerEncoderLayer. + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: str = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + bias: bool = 'qv_bias', + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + output_cls_token: bool = True, + interpolate_mode: str = 'bicubic', + layer_scale_init_value: float = None, + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: dict = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + output_cls_token=output_cls_token, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + self.pos_embed.requires_grad = False + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + dpr = np.linspace(0, drop_path_rate, self.num_layers) + + # Replace original TransformerEncoderLayer with + # BEiTTransformerEncoderLayer + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + layer_scale_init_value=layer_scale_init_value, + window_size=None, + # setting `use_rel_pos_bias` to False ignores the `window_size` + use_rel_pos_bias=False, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + bias=bias, + norm_cfg=norm_cfg) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # initialize position embedding in backbone + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m) -> None: + """Initialize the weights.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate features for masked images. + + This function generates mask images and get the hidden features for + visible patches. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (torch.Tensor): Mask for input, which is of shape B x L. + + Returns: + torch.Tensor: hidden features. + """ + x, _ = self.patch_embed(img) + batch_size, _, dim = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + + # NOTE: unmasked embeddings + x_unmasked = x[~mask].reshape(batch_size, -1, dim) + x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1) + + pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1, + dim) + pos_embed_unmasked = pos_embed[:, + 1:][~mask].reshape(batch_size, -1, dim) + pos_embed_unmasked = torch.cat((pos_embed[:, :1], pos_embed_unmasked), + dim=1) + x_unmasked = x_unmasked + pos_embed_unmasked + + x_unmasked = self.drop_after_pos(x_unmasked) + + for i, layer in enumerate(self.layers): + x_unmasked = layer(x=x_unmasked, rel_pos_bias=None) + + if i == len(self.layers) - 1 and self.final_norm: + x_unmasked = self.norm1(x_unmasked) + + return x_unmasked diff --git a/mmpretrain/models/selfsup/mae.py b/mmpretrain/models/selfsup/mae.py new file mode 100644 index 00000000000..506c1a94099 --- /dev/null +++ b/mmpretrain/models/selfsup/mae.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import torch + +from mmpretrain.models import VisionTransformer +from mmpretrain.registry import MODELS +from ..utils import build_2d_sincos_position_embedding + + +@MODELS.register_module() +class MAEViT(VisionTransformer): + """Vision Transformer for MAE pre-training. + + A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_. + This module implements the patch masking in MAE and initialize the + position embedding with sine-cosine position embedding. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + output_cls_token: bool = True, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + output_cls_token=output_cls_token, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + # position embedding is not learnable during pretraining + self.pos_embed.requires_grad = False + self.mask_ratio = mask_ratio + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.pos_embed.shape[-1], + cls_token=True) + self.pos_embed.data.copy_(pos_embed.float()) + + w = self.patch_embed.projection.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + torch.nn.init.normal_(self.cls_token, std=.02) + + def random_masking( + self, + x: torch.Tensor, + mask_ratio: float = 0.75 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the mask for MAE Pre-training. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.75. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + masked image, mask and the ids to restore original image. + - x_masked (torch.Tensor): masked image. + - mask (torch.Tensor): mask used to mask image. + - ids_restore (torch.Tensor): ids to restore original image. + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + This function generates mask and masks some patches randomly and get + the hidden features for visible patches. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + Hidden features, mask and the ids to restore original image. + + - x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - mask (torch.Tensor): mask used to mask image. + - ids_restore (torch.Tensor): ids to restore original image. + """ + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, self.mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return (x, mask, ids_restore) diff --git a/mmpretrain/models/selfsup/maskfeat.py b/mmpretrain/models/selfsup/maskfeat.py new file mode 100644 index 00000000000..6c799d91359 --- /dev/null +++ b/mmpretrain/models/selfsup/maskfeat.py @@ -0,0 +1,275 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Sequence, Union + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmpretrain.models import VisionTransformer +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class HOGGenerator(BaseModule): + """Generate HOG feature for images. + + This module is used in MaskFeat to generate HOG feature. The code is + modified from file `slowfast/models/operators.py + `_. + Here is the link of `HOG wikipedia + `_. + + Args: + nbins (int): Number of bin. Defaults to 9. + pool (float): Number of cell. Defaults to 8. + gaussian_window (int): Size of gaussian kernel. Defaults to 16. + """ + + def __init__(self, + nbins: int = 9, + pool: int = 8, + gaussian_window: int = 16) -> None: + super().__init__() + self.nbins = nbins + self.pool = pool + self.pi = math.pi + weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) + weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous() + weight_y = weight_x.transpose(2, 3).contiguous() + self.register_buffer('weight_x', weight_x) + self.register_buffer('weight_y', weight_y) + + self.gaussian_window = gaussian_window + if gaussian_window: + gaussian_kernel = self.get_gaussian_kernel(gaussian_window, + gaussian_window // 2) + self.register_buffer('gaussian_kernel', gaussian_kernel) + + def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor: + """Returns a 2D Gaussian kernel array.""" + + def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor: + n = torch.arange(0, kernlen).float() + n -= n.mean() + n /= std + w = torch.exp(-0.5 * n**2) + return w + + kernel_1d = _gaussian_fn(kernlen, std) + kernel_2d = kernel_1d[:, None] * kernel_1d[None, :] + return kernel_2d / kernel_2d.sum() + + def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: + """Reshape HOG Features for output.""" + hog_feat = hog_feat.flatten(1, 2) + self.unfold_size = hog_feat.shape[-1] // 14 + hog_feat = hog_feat.permute(0, 2, 3, 1) + hog_feat = hog_feat.unfold(1, self.unfold_size, + self.unfold_size).unfold( + 2, self.unfold_size, self.unfold_size) + hog_feat = hog_feat.flatten(1, 2).flatten(2) + return hog_feat + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Generate hog feature for each batch images. + + Args: + x (torch.Tensor): Input images of shape (N, 3, H, W). + + Returns: + torch.Tensor: Hog features. + """ + # input is RGB image with shape [B 3 H W] + self.h, self.w = x.size(-2), x.size(-1) + x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect') + gx_rgb = F.conv2d( + x, self.weight_x, bias=None, stride=1, padding=0, groups=3) + gy_rgb = F.conv2d( + x, self.weight_y, bias=None, stride=1, padding=0, groups=3) + norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1) + phase = torch.atan2(gx_rgb, gy_rgb) + phase = phase / self.pi * self.nbins # [-9, 9] + + b, c, h, w = norm_rgb.shape + out = torch.zeros((b, c, self.nbins, h, w), + dtype=torch.float, + device=x.device) + phase = phase.view(b, c, 1, h, w) + norm_rgb = norm_rgb.view(b, c, 1, h, w) + if self.gaussian_window: + if h != self.gaussian_window: + assert h % self.gaussian_window == 0, 'h {} gw {}'.format( + h, self.gaussian_window) + repeat_rate = h // self.gaussian_window + temp_gaussian_kernel = self.gaussian_kernel.repeat( + [repeat_rate, repeat_rate]) + else: + temp_gaussian_kernel = self.gaussian_kernel + norm_rgb *= temp_gaussian_kernel + + out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb) + + out = out.unfold(3, self.pool, self.pool) + out = out.unfold(4, self.pool, self.pool) + out = out.sum(dim=[-1, -2]) + + self.out = F.normalize(out, p=2, dim=2) + + return self._reshape(self.out) + + def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray: + """Generate HOG image according to HOG features.""" + assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \ + 'Check the input batch size and the channcel number, only support'\ + '"batch_size = 1".' + hog_image = np.zeros([self.h, self.w]) + cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu()) + cell_width = self.pool / 2 + max_mag = np.array(cell_gradient).max() + angle_gap = 360 / self.nbins + + for x in range(cell_gradient.shape[1]): + for y in range(cell_gradient.shape[2]): + cell_grad = cell_gradient[:, x, y] + cell_grad /= max_mag + angle = 0 + for magnitude in cell_grad: + angle_radian = math.radians(angle) + x1 = int(x * self.pool + + magnitude * cell_width * math.cos(angle_radian)) + y1 = int(y * self.pool + + magnitude * cell_width * math.sin(angle_radian)) + x2 = int(x * self.pool - + magnitude * cell_width * math.cos(angle_radian)) + y2 = int(y * self.pool - + magnitude * cell_width * math.sin(angle_radian)) + magnitude = 0 if magnitude < 0 else magnitude + cv2.line(hog_image, (y1, x1), (y2, x2), + int(255 * math.sqrt(magnitude))) + angle += angle_gap + return hog_image + + +@MODELS.register_module() +class MaskFeatViT(VisionTransformer): + """Vision Transformer for MaskFeat pre-training. + + A PyTorch implement of: `Masked Feature Prediction for Self-Supervised + Visual Pre-Training `_. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + output_cls_token: bool = True, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + output_cls_token=output_cls_token, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.parameter.Parameter( + torch.zeros(1, 1, self.embed_dims), requires_grad=True) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + """Initialize position embedding, mask token and cls token.""" + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + nn.init.trunc_normal_(self.cls_token, std=.02) + nn.init.trunc_normal_(self.mask_token, std=.02) + nn.init.trunc_normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m: torch.nn.Module) -> None: + if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate features for masked images. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor): Input masks. + + Returns: + torch.Tensor: Features with cls_tokens. + """ + B = x.shape[0] + x = self.patch_embed(x)[0] + + # masking: length -> length * mask_ratio + B, L, _ = x.shape + mask_tokens = self.mask_token.expand(B, L, -1) + mask = mask.flatten(1).unsqueeze(-1) + x = x * (1 - mask.int()) + mask_tokens * mask + + # append cls token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.drop_after_pos(x) + + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + return x diff --git a/mmpretrain/models/selfsup/milan.py b/mmpretrain/models/selfsup/milan.py new file mode 100644 index 00000000000..585356b7908 --- /dev/null +++ b/mmpretrain/models/selfsup/milan.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.model import BaseModule +from mmengine.runner.checkpoint import _load_checkpoint + +from mmpretrain.registry import MODELS +from ..utils import build_clip_model +from .mae import MAEViT + + +@MODELS.register_module() +class CLIPGenerator(BaseModule): + """Get the features and attention from the last layer of CLIP. + + This module is used to generate target features in masked image modeling. + + Args: + tokenizer_path (str): The path of the checkpoint of CLIP. + """ + + def __init__(self, tokenizer_path: str) -> None: + super().__init__() + self.tokenizer = build_clip_model( + _load_checkpoint(tokenizer_path), False) + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the features and attention from the last layer of CLIP. + + Args: + x (torch.Tensor): The input image, which is of shape (N, 3, H, W). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + The features and attention from the last layer of CLIP, + which are of shape (N, L, C) and (N, L, L), respectively. + """ + # use the visual branch of CLIP to get the features + clip_features = self.tokenizer.encode_image(x) + return clip_features + + +@MODELS.register_module() +class MILANViT(MAEViT): + """Vision Transformer for MILAN pre-training. + + Implementation of the encoder for `MILAN: Masked Image Pretraining on + Language Assisted Representation `_. + + This module inherits from MAEViT and only overrides the forward function + and replace random masking with attention masking. + """ + + def attention_masking( + self, x: torch.Tensor, mask_ratio: float, importance: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate attention mask for MILAN. + + This is what is different from MAEViT, which uses random masking. + Attention masking generates attention mask for MILAN, according to + importance. The higher the importance, the more likely the patch is + kept. + + Args: + x (torch.Tensor): Input images, which is of shape B x L x C. + mask_ratio (float): The ratio of patches to be masked. + importance (torch.Tensor): Importance of each patch, which is of + shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: + - ``x_masked``: masked image + - ``ids_restore``: the ids to restore original image + - ``ids_keep``: ids of the kept patches + - ``ids_dump``: ids of the removed patches + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = importance.to(x.device) # large is keep, small is remove + + # sort noise for each sample + ids_shuffle = torch.multinomial(noise, L, replacement=False) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_dump = ids_shuffle[:, len_keep:] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, ids_restore, ids_keep, ids_dump + + def forward( + self, x: torch.Tensor, importance: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + This function generates mask and masks some patches randomly and get + the hidden features for visible patches. The mask is generated by + importance. The higher the importance, the more likely the patch is + kept. The importance is calculated by CLIP. The higher the CLIP score, + the more likely the patch is kept. The CLIP score is calculated by + by cross attention between the class token and all other tokens from + the last layer. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + importance (torch.Tensor): Importance of each patch, which is of + shape B x L. + + Returns: + Tuple[torch.Tensor, ...]: + masked image, the ids to restore original image, ids of the + kept patches, ids of the removed patches. + - x (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ids_restore (torch.Tensor): ids to restore original image. + - ids_keep (torch.Tensor): ids of the kept patches. + - ids_dump (torch.Tensor): ids of the removed patches. + """ + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, ids_restore, ids_keep, ids_dump = self.attention_masking( + x, self.mask_ratio, importance) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for _, layer in enumerate(self.layers): + x = layer(x) + # Use final norm + x = self.norm1(x) + + return x, ids_restore, ids_keep, ids_dump diff --git a/mmpretrain/models/selfsup/mixmim.py b/mmpretrain/models/selfsup/mixmim.py new file mode 100644 index 00000000000..943fcbd20e0 --- /dev/null +++ b/mmpretrain/models/selfsup/mixmim.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from mmpretrain.models.backbones import MixMIMTransformer +from mmpretrain.registry import MODELS +from ..utils import build_2d_sincos_position_embedding + + +@MODELS.register_module() +class MixMIMPretrainTransformer(MixMIMTransformer): + """MixMIM backbone for MixMIM pre-training. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + attn_drop_rate (float): Attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost + range_mask_ratio (float): The range of mask ratio. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'base', + mlp_ratio: float = 4, + img_size: int = 224, + patch_size: int = 4, + in_channels: int = 3, + window_size: List = [14, 14, 14, 7], + qkv_bias: bool = True, + patch_cfg: dict = dict(), + norm_cfg: dict = dict(type='LN'), + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + attn_drop_rate: float = 0.0, + use_checkpoint: bool = False, + range_mask_ratio: float = 0.0, + init_cfg: Optional[dict] = None) -> None: + + super().__init__( + arch=arch, + mlp_ratio=mlp_ratio, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + window_size=window_size, + qkv_bias=qkv_bias, + patch_cfg=patch_cfg, + norm_cfg=norm_cfg, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + attn_drop_rate=attn_drop_rate, + use_checkpoint=use_checkpoint, + init_cfg=init_cfg) + + self.range_mask_ratio = range_mask_ratio + + def init_weights(self): + """Initialize position embedding, patch embedding.""" + super(MixMIMTransformer, self).init_weights() + + pos_embed = build_2d_sincos_position_embedding( + int(self.num_patches**.5), + self.absolute_pos_embed.shape[-1], + cls_token=False) + self.absolute_pos_embed.data.copy_(pos_embed.float()) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5): + """Generate the mask for MixMIM Pretraining. + + Args: + x (torch.Tensor): Image with data augmentation applied, which is + of shape B x L x C. + mask_ratio (float): The mask ratio of total patches. + Defaults to 0.5. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - mask_s1 (torch.Tensor): mask with stride of + self.encoder_stride // 8. + - mask_s2 (torch.Tensor): mask with stride of + self.encoder_stride // 4. + - mask_s3 (torch.Tensor): mask with stride of + self.encoder_stride // 2. + - mask (torch.Tensor): mask with stride of + self.encoder_stride. + """ + + B, C, H, W = x.shape + out_H = H // self.encoder_stride + out_W = W // self.encoder_stride + s3_H, s3_W = out_H * 2, out_W * 2 + s2_H, s2_W = out_H * 4, out_W * 4 + s1_H, s1_W = out_H * 8, out_W * 8 + + seq_l = out_H * out_W + # use a shared mask for a batch images + mask = torch.zeros([1, 1, seq_l], device=x.device) + + mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio) + noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1] + # ascend: small is keep, large is removed + mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)] + mask.scatter_(2, mask_idx, 1) + mask = mask.reshape(1, 1, out_H, out_W) + mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest') + mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest') + mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest') + + mask = mask.reshape(1, out_H * out_W, 1).contiguous() + mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous() + mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous() + mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous() + + return mask_s1, mask_s2, mask_s3, mask + + def forward(self, x: torch.Tensor, mask_ratio=0.5): + """Generate features for masked images. + + This function generates mask and masks some patches randomly and get + the hidden features for visible patches. + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - x (torch.Tensor): hidden features, which is of shape + B x L x C. + - mask_s4 (torch.Tensor): the mask tensor for the last layer. + """ + + mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(x, mask_ratio) + + x, _ = self.patch_embed(x) + + x = x * (1. - mask_s1) + x.flip(0) * mask_s1 + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for idx, layer in enumerate(self.layers): + if idx == 0: + x = layer(x, attn_mask=mask_s1) + elif idx == 1: + x = layer(x, attn_mask=mask_s2) + elif idx == 2: + x = layer(x, attn_mask=mask_s3) + elif idx == 3: + x = layer(x, attn_mask=mask_s4) + + x = self.norm(x) + + return x, mask_s4 diff --git a/mmpretrain/models/selfsup/mocov3.py b/mmpretrain/models/selfsup/mocov3.py new file mode 100644 index 00000000000..839a5a5be5e --- /dev/null +++ b/mmpretrain/models/selfsup/mocov3.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from functools import reduce +from operator import mul +from typing import List, Optional, Union + +import torch.nn as nn +from mmcv.cnn.bricks.transformer import PatchEmbed +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpretrain.models.backbones import VisionTransformer +from mmpretrain.models.utils import (build_2d_sincos_position_embedding, + to_2tuple) +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class MoCoV3ViT(VisionTransformer): + """Vision Transformer for MoCoV3 pre-training. + + A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for + Image Recognition at Scale `_. + + Part of the code is modified from: + ``_. + + Args: + stop_grad_conv1 (bool): whether to stop the gradient of + convolution layer in `PatchEmbed`. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + stop_grad_conv1: bool = False, + frozen_stages: int = -1, + norm_eval: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = None, + **kwargs) -> None: + + # add MoCoV3 ViT-small arch + self.arch_zoo.update( + dict.fromkeys( + ['mocov3-s', 'mocov3-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 1536, + })) + + super().__init__(init_cfg=init_cfg, **kwargs) + self.patch_size = kwargs['patch_size'] + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + self.init_cfg = init_cfg + + if isinstance(self.patch_embed, PatchEmbed): + if stop_grad_conv1: + self.patch_embed.projection.weight.requires_grad = False + self.patch_embed.projection.bias.requires_grad = False + + self._freeze_stages() + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding, qkv layers and cls + token.""" + super().init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + # Use fixed 2D sin-cos position embedding + pos_emb = build_2d_sincos_position_embedding( + patches_resolution=self.patch_resolution, + embed_dims=self.embed_dims, + cls_token=True) + self.pos_embed.data.copy_(pos_emb) + self.pos_embed.requires_grad = False + + # xavier_uniform initialization for PatchEmbed + if isinstance(self.patch_embed, PatchEmbed): + val = math.sqrt( + 6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) + + self.embed_dims)) + nn.init.uniform_(self.patch_embed.projection.weight, -val, val) + nn.init.zeros_(self.patch_embed.projection.bias) + + # initialization for linear layers + for name, m in self.named_modules(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6. / + float(m.weight.shape[0] // 3 + m.weight.shape[1])) + nn.init.uniform_(m.weight, -val, val) + else: + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + nn.init.normal_(self.cls_token, std=1e-6) + + def _freeze_stages(self) -> None: + """Freeze patch_embed layer, some parameters and stages.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + self.cls_token.requires_grad = False + self.pos_embed.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if i == (self.num_layers) and self.final_norm: + for param in getattr(self, 'norm1').parameters(): + param.requires_grad = False + + def train(self, mode: bool = True) -> None: + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmpretrain/models/selfsup/simmim.py b/mmpretrain/models/selfsup/simmim.py new file mode 100644 index 00000000000..cc7b096e558 --- /dev/null +++ b/mmpretrain/models/selfsup/simmim.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.models import SwinTransformer +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SimMIMSwinTransformer(SwinTransformer): + """Swin Transformer for SimMIM pre-training. + + Args: + Args: + arch (str | dict): Swin Transformer architecture + Defaults to 'T'. + img_size (int | tuple): The size of input image. + Defaults to 224. + in_channels (int): The num of input channels. + Defaults to 3. + drop_rate (float): Dropout rate after embedding. + Defaults to 0. + drop_path_rate (float): Stochastic depth rate. + Defaults to 0.1. + out_indices (tuple): Layers to be outputted. Defaults to (3, ). + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer at end + of backone. Defaults to dict(type='LN') + stage_cfgs (Sequence | dict): Extra config dict for each + stage. Defaults to empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to empty dict. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'T', + img_size: Union[Tuple[int, int], int] = 224, + in_channels: int = 3, + drop_rate: float = 0., + drop_path_rate: float = 0.1, + out_indices: tuple = (3, ), + use_abs_pos_embed: bool = False, + with_cp: bool = False, + frozen_stages: bool = -1, + norm_eval: bool = False, + norm_cfg: dict = dict(type='LN'), + stage_cfgs: Union[Sequence, dict] = dict(), + patch_cfg: dict = dict(), + pad_small_map: bool = False, + init_cfg: Optional[dict] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + in_channels=in_channels, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + out_indices=out_indices, + use_abs_pos_embed=use_abs_pos_embed, + with_cp=with_cp, + frozen_stages=frozen_stages, + norm_eval=norm_eval, + norm_cfg=norm_cfg, + stage_cfgs=stage_cfgs, + patch_cfg=patch_cfg, + pad_small_map=pad_small_map, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize weights.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + trunc_normal_(self.mask_token, mean=0, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + """Initialize weights.""" + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, + mask: torch.Tensor) -> Sequence[torch.Tensor]: + """Generate features for masked images. + + This function generates mask images and get the hidden features for + them. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor): Masks used to construct masked images. + + Returns: + tuple: A tuple containing features from multi-stages. + """ + x, hw_shape = self.patch_embed(x) + + assert mask is not None + B, L, _ = x.shape + + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/mmpretrain/models/target_generators/__init__.py b/mmpretrain/models/target_generators/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py index 61cd664852f..b5391a72a5f 100644 --- a/mmpretrain/models/utils/__init__.py +++ b/mmpretrain/models/utils/__init__.py @@ -5,6 +5,7 @@ ShiftWindowMSA, WindowMSA, WindowMSAV2) from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix from .channel_shuffle import channel_shuffle +from .clip_generator_helper import build_clip_model from .data_preprocessor import ClsDataPreprocessor from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed, resize_relative_position_bias_table) @@ -17,6 +18,7 @@ PositionEncodingFourier, build_2d_sincos_position_embedding) from .se_layer import SELayer +from .vector_quantizer import NormEMAVectorQuantizer __all__ = [ 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', @@ -28,5 +30,6 @@ 'LayerScale', 'WindowMSA', 'WindowMSAV2', 'ChannelMultiheadAttention', 'PositionEncodingFourier', 'LeAttention', 'GRN', 'LayerNorm2d', 'build_norm_layer', 'CrossMultiheadAttention', - 'build_2d_sincos_position_embedding', 'PromptMultiheadAttention' + 'build_2d_sincos_position_embedding', 'PromptMultiheadAttention', + 'NormEMAVectorQuantizer', 'build_clip_model' ] diff --git a/mmpretrain/models/utils/attention.py b/mmpretrain/models/utils/attention.py index cdfabffb70f..5b1e4b1af54 100644 --- a/mmpretrain/models/utils/attention.py +++ b/mmpretrain/models/utils/attention.py @@ -567,7 +567,7 @@ class BEiTAttention(BaseModule): Args: embed_dims (int): Number of input channels. num_heads (int): Number of attention heads. - window_size (tuple[int]): The height and width of the window. + window_size (tuple[int, int]): The height and width of the window. use_rel_pos_bias (bool): Whether to use unique relative position bias, if False, use shared relative position bias defined in backbone. bias (str): The option to add leanable bias for q, k, v. If bias is @@ -606,6 +606,10 @@ def __init__(self, self._init_qv_bias() qkv_bias = False + if window_size is None: + assert not use_rel_pos_bias + else: + assert isinstance(window_size, tuple) self.window_size = window_size self.use_rel_pos_bias = use_rel_pos_bias self._init_rel_pos_embedding() diff --git a/mmpretrain/models/utils/clip_generator_helper.py b/mmpretrain/models/utils/clip_generator_helper.py new file mode 100644 index 00000000000..80cc74a79eb --- /dev/null +++ b/mmpretrain/models/utils/clip_generator_helper.py @@ -0,0 +1,391 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/zejiangh/MILAN +from collections import OrderedDict +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.logging import MMLogger +from torch import nn + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + """A faster version of GELU.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + """Residual Attention Block (RAB). + + This module implements the same function as the MultiheadAttention in + MMClassification, but with a different interface, which is mainly used + in CLIP. + + Args: + d_model (int): The feature dimension. + n_head (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + Defaults to None. + """ + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: Optional[torch.Tensor] = None, + return_attention: bool = False) -> None: + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + self.return_attention = return_attention + + def attention(self, x: torch.Tensor) -> torch.Tensor: + """Attention function.""" + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + if self.return_attention: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask) + else: + return self.attn( + x, + x, + x, + need_weights=self.return_attention, + attn_mask=self.attn_mask)[0] + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Forward function.""" + if self.return_attention: + x_, attention = self.attention(self.ln_1(x)) + x = x + x_ + x = x + self.mlp(self.ln_2(x)) + return x, attention + else: + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + """Transformer. + + Both visual and text branches use this transformer. + + Args: + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + attn_mask (torch.Tensor, optional): The attention mask. + """ + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: Optional[torch.Tensor] = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for _ in range(layers - 1): + self.resblocks.append( + ResidualAttentionBlock(width, heads, attn_mask)) + self.resblocks.append( + ResidualAttentionBlock( + width, heads, attn_mask, return_attention=True)) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward function.""" + z = [] + for idx, blk in enumerate(self.resblocks): + if idx < self.layers - 1: + x = blk(x) + z.append(x.permute(1, 0, 2)) + else: + x, attention = blk(x) + z.append(x.permute(1, 0, 2)) + return x, attention, z + + +class VisionTransformer(nn.Module): + """Vision Transformer for CLIP. + + Args: + input_resolution (int): The image size. + patch_size (int): The patch size. + width (int): The feature dimension. + layers (int): The number of layers. + heads (int): The number of attention heads. + out_dim (int): The output dimension. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__(self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + finetune=False, + average_targets: int = 1) -> None: + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.finetune = finetune + if finetune is False: + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.average_targets = average_targets + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function.""" + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, attention, z = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x, attention + + +class CLIP(nn.Module): + """CLIP. + + Args: + embed_dim (int): The embedding dimension. + image_resolution (int): The image size. + vision_layers (int): The number of layers in the vision transformer. + vision_width (int): The feature dimension in the vision transformer. + vision_patch_size (int): The patch size in the vision transformer. + context_length (int): The context length. + vocab_size (int): The vocabulary size. + transformer_width (int): The feature dimension in the text transformer. + transformer_heads (int): The number of attention heads in the + text transformer. + transformer_layers (int): The number of layers in the text transformer. + fineturn (bool): Whether to fineturn the model. + average_target (bool): Whether to average the target. + """ + + def __init__( + self, + embed_dim: int, + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + finetune: bool = False, + average_targets: int = 1, + ) -> None: + super().__init__() + + self.context_length = context_length + + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + finetune=finetune, + average_targets=average_targets, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self) -> None: + """Initialize the parameters. + + The pretrained weight will override the initialized parameters by this + function. + """ + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self) -> torch.Tensor: + """Build the attention mask.""" + # lazily create causal attention mask, with full attention between the + # vision tokens pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self) -> torch.dtype: + """Get the dtype.""" + return self.visual.conv1.weight.dtype + + def encode_image(self, + image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode the image. + + Get the feature and attention mask from the last layer of the visual + branch of CLIP. + + Args: + image (torch.Tensor): The image tensor with shape NCHW. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask. + """ + return self.visual(image.type(self.dtype)) + + +def build_clip_model(state_dict: dict, + finetune: bool = False, + average_targets: int = 1) -> nn.Module: + """Build the CLIP model. + + Args: + state_dict (dict): The pretrained state dict. + finetune (bool): Whether to fineturn the model. + average_targets (bool): Whether to average the target. + + Returns: + nn.Module: The CLIP model. + """ + vit = 'visual.proj' in state_dict + + if vit: + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + finetune, + average_targets, + ) + + for key in ['input_resolution', 'context_length', 'vocab_size']: + if key in state_dict: + del state_dict[key] + + msg = model.load_state_dict(state_dict, strict=False) + MMLogger.get_current_instance().info(f'Load CLIP model: {msg}') + return model.eval() diff --git a/mmpretrain/models/utils/data_preprocessor.py b/mmpretrain/models/utils/data_preprocessor.py index bd1e6e7c140..9aea8aaad03 100644 --- a/mmpretrain/models/utils/data_preprocessor.py +++ b/mmpretrain/models/utils/data_preprocessor.py @@ -81,7 +81,7 @@ def __init__(self, else: self._enable_normalize = False - if batch_augments is not None: + if batch_augments: self.batch_augments = RandomBatchAugment(**batch_augments) if not self.to_onehot: from mmengine.logging import MMLogger diff --git a/mmpretrain/models/utils/vector_quantizer.py b/mmpretrain/models/utils/vector_quantizer.py new file mode 100644 index 00000000000..7c2ea89339e --- /dev/null +++ b/mmpretrain/models/utils/vector_quantizer.py @@ -0,0 +1,232 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2022 Microsoft +# Modified from +# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mmengine.dist import all_reduce + + +def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average with norm data.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1)) + + +def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor: + """Sample vectors according to the given number.""" + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num, ), device=device) + + return samples[indices] + + +def kmeans(samples: torch.Tensor, + num_clusters: int, + num_iters: int = 10, + use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """Run k-means algorithm.""" + dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = F.normalize(new_means, p=2, dim=-1) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + """The codebook of embedding vectors. + + Args: + num_tokens (int): Number of embedding vectors in the codebook. + codebook_dim (int) : The dimension of embedding vectors in the + codebook. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_tokens: int, + codebook_dim: int, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + if codebook_init_path is None: + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = F.normalize(weight, p=2, dim=-1) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f'load init codebook weight from {codebook_init_path}') + codebook_ckpt_weight = torch.load( + codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data: torch.Tensor) -> None: + """Initialize embedding vectors of codebook.""" + if self.initted: + return + print('Performing K-means init for codebook') + embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id: torch.Tensor) -> torch.Tensor: + """Get embedding vectors.""" + return F.embedding(embed_id, self.weight) + + +class NormEMAVectorQuantizer(nn.Module): + """Normed EMA vector quantizer module. + + Args: + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + beta (float): The mutiplier for VectorQuantizer embedding loss. + Defaults to 1. + decay (float): The decay parameter of EMA. Defaults to 0.99. + statistic_code_usage (bool): Whether to use cluster_size to record + statistic. Defaults to True. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_embed: int, + embed_dims: int, + beta: float, + decay: float = 0.99, + statistic_code_usage: bool = True, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None) -> None: + super().__init__() + self.codebook_dim = embed_dims + self.num_tokens = num_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA( + num_tokens=self.num_tokens, + codebook_dim=self.codebook_dim, + kmeans_init=kmeans_init, + codebook_init_path=codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(num_embed)) + + def reset_cluster_size(self, device): + + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + """Forward function.""" + # reshape z -> (batch, height, width, channel) + z = rearrange(z, 'b c h w -> b h w c') + z = F.normalize(z, p=2, dim=-1) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + # 'n d -> d n' + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + all_reduce(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # update cluster size with EMA + bins = encodings.sum(0) + all_reduce(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + all_reduce(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = F.normalize(embed_normalized, p=2, dim=-1) + embed_normalized = torch.where(zero_mask[..., None], + self.embedding.weight, + embed_normalized) + + # Update embedding vectors with EMA + norm_ema_inplace(self.embedding.weight, embed_normalized, + self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, encoding_indices diff --git a/tests/test_models/test_backbones/test_beit.py b/tests/test_models/test_backbones/test_beit.py index b10adb9e10b..5ed7287f3f9 100644 --- a/tests/test_models/test_backbones/test_beit.py +++ b/tests/test_models/test_backbones/test_beit.py @@ -4,7 +4,7 @@ import torch -from mmpretrain.models.backbones import BEiT +from mmpretrain.models.backbones import BEiTViT class TestBEiT(TestCase): @@ -18,7 +18,7 @@ def test_structure(self): with self.assertRaisesRegex(AssertionError, 'not in default archs'): cfg = deepcopy(self.cfg) cfg['arch'] = 'unknown' - BEiT(**cfg) + BEiTViT(**cfg) # Test invalid custom arch with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): @@ -28,7 +28,7 @@ def test_structure(self): 'num_heads': 16, 'feedforward_channels': 4096 } - BEiT(**cfg) + BEiTViT(**cfg) # Test custom arch cfg = deepcopy(self.cfg) @@ -38,7 +38,7 @@ def test_structure(self): 'num_heads': 16, 'feedforward_channels': 1024 } - model = BEiT(**cfg) + model = BEiTViT(**cfg) self.assertEqual(model.embed_dims, 128) self.assertEqual(model.num_layers, 24) self.assertIsNone(model.pos_embed) @@ -51,21 +51,21 @@ def test_structure(self): cfg = deepcopy(self.cfg) cfg['out_indices'] = {1: 1} with self.assertRaisesRegex(AssertionError, "get "): - BEiT(**cfg) + BEiTViT(**cfg) cfg['out_indices'] = [0, 13] with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): - BEiT(**cfg) + BEiTViT(**cfg) # Test pos_embed cfg = deepcopy(self.cfg) cfg['use_abs_pos_emb'] = True - model = BEiT(**cfg) + model = BEiTViT(**cfg) self.assertEqual(model.pos_embed.shape, (1, 197, 768)) # Test model structure cfg = deepcopy(self.cfg) cfg['drop_path_rate'] = 0.1 - model = BEiT(**cfg) + model = BEiTViT(**cfg) self.assertEqual(len(model.layers), 12) dpr_inc = 0.1 / (12 - 1) dpr = 0 @@ -85,7 +85,7 @@ def test_forward(self): # test with output_cls_token cfg = deepcopy(self.cfg) cfg['output_cls_token'] = True - model = BEiT(**cfg) + model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) @@ -95,7 +95,7 @@ def test_forward(self): # test without output_cls_token cfg = deepcopy(self.cfg) - model = BEiT(**cfg) + model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) @@ -105,7 +105,7 @@ def test_forward(self): # test without average cfg = deepcopy(self.cfg) cfg['avg_token'] = False - model = BEiT(**cfg) + model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) @@ -115,7 +115,7 @@ def test_forward(self): # Test forward with multi out indices cfg = deepcopy(self.cfg) cfg['out_indices'] = [-3, -2, -1] - model = BEiT(**cfg) + model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) diff --git a/tests/test_models/test_selfsup/test_beit.py b/tests/test_models/test_selfsup/test_beit.py new file mode 100644 index 00000000000..235e26f19e2 --- /dev/null +++ b/tests/test_models/test_selfsup/test_beit.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmpretrain.models import BEiTPretrainViT + +backbone = dict( + arch='base', + patch_size=16, + drop_path_rate=0.1, + final_norm=True, + layer_scale_init_value=0.1, +) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_beit_pretrain_vit(): + beit_backbone = BEiTPretrainViT(**backbone) + beit_backbone.init_weights() + + fake_inputs = torch.randn((2, 3, 224, 224)) + fake_mask = torch.zeros((2, 196)) + fake_mask[:, 75:150] = 1 + fake_outputs = beit_backbone(fake_inputs, fake_mask) + + assert list(fake_outputs[0].shape) == [2, 197, 768] diff --git a/tests/test_models/test_selfsup/test_cae.py b/tests/test_models/test_selfsup/test_cae.py new file mode 100644 index 00000000000..3979009adeb --- /dev/null +++ b/tests/test_models/test_selfsup/test_cae.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmpretrain.models import CAEViT + +backbone = dict(arch='b', patch_size=16, layer_scale_init_value=0.1) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_cae_vit(): + cae_backbone = CAEViT(**backbone) + cae_backbone.init_weights() + fake_inputs = torch.randn((2, 3, 224, 224)) + fake_mask = torch.zeros((2, 196)).bool() + fake_mask[:, 75:150] = 1 + fake_outputs = cae_backbone(fake_inputs, fake_mask) + + assert list(fake_outputs.shape) == [2, 122, 768] diff --git a/tests/test_models/test_selfsup/test_mae.py b/tests/test_models/test_selfsup/test_mae.py new file mode 100644 index 00000000000..457006eca6e --- /dev/null +++ b/tests/test_models/test_selfsup/test_mae.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmpretrain.models import MAEViT + +backbone = dict(arch='b', patch_size=16, mask_ratio=0.75) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_mae_vit(): + mae_backbone = MAEViT(**backbone) + mae_backbone.init_weights() + fake_inputs = torch.randn((2, 3, 224, 224)) + fake_outputs = mae_backbone(fake_inputs)[0] + + assert list(fake_outputs.shape) == [2, 50, 768] diff --git a/tests/test_models/test_selfsup/test_maskfeat.py b/tests/test_models/test_selfsup/test_maskfeat.py new file mode 100644 index 00000000000..f49ba515e3d --- /dev/null +++ b/tests/test_models/test_selfsup/test_maskfeat.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmpretrain.models import MaskFeatViT + +backbone = dict(arch='b', patch_size=16) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_maskfeat_vit(): + maskfeat_backbone = MaskFeatViT(**backbone) + maskfeat_backbone.init_weights() + fake_inputs = torch.randn((2, 3, 224, 224)) + fake_mask = torch.randn((2, 14, 14)) + fake_outputs = maskfeat_backbone(fake_inputs, fake_mask) + + assert list(fake_outputs.shape) == [2, 197, 768] diff --git a/tests/test_models/test_selfsup/test_mocov3.py b/tests/test_models/test_selfsup/test_mocov3.py new file mode 100644 index 00000000000..8b1ecf90e95 --- /dev/null +++ b/tests/test_models/test_selfsup/test_mocov3.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest + +from mmpretrain.models import MoCoV3ViT + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_vision_transformer(): + vit = MoCoV3ViT( + arch='mocov3-small', patch_size=16, frozen_stages=12, norm_eval=True) + vit.init_weights() + vit.train() + + for p in vit.parameters(): + assert p.requires_grad is False diff --git a/tests/test_models/test_selfsup/test_simmim.py b/tests/test_models/test_selfsup/test_simmim.py new file mode 100644 index 00000000000..5ef062914a7 --- /dev/null +++ b/tests/test_models/test_selfsup/test_simmim.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmpretrain.models import SimMIMSwinTransformer + +backbone = dict( + arch='B', img_size=192, stage_cfgs=dict(block_cfgs=dict(window_size=6))) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_cae_vit(): + simmim_backbone = SimMIMSwinTransformer(**backbone) + simmim_backbone.init_weights() + fake_inputs = torch.randn((2, 3, 192, 192)) + fake_mask = torch.rand((2, 48, 48)) + fake_outputs = simmim_backbone(fake_inputs, fake_mask)[0] + + assert list(fake_outputs.shape) == [2, 1024, 6, 6] diff --git a/tests/test_models/test_selfsup/test_target_generators.py b/tests/test_models/test_selfsup/test_target_generators.py new file mode 100644 index 00000000000..dd71a87309a --- /dev/null +++ b/tests/test_models/test_selfsup/test_target_generators.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform +from unittest import TestCase + +import pytest +import torch + +from mmpretrain.models import VQKD, Encoder, HOGGenerator + + +class TestDALLE(TestCase): + + @pytest.mark.skipif( + platform.system() == 'Windows', reason='Windows mem limit') + def test_dalle(self): + model = Encoder() + fake_inputs = torch.rand((2, 3, 112, 112)) + fake_outputs = model(fake_inputs) + + assert list(fake_outputs.shape) == [2, 8192, 14, 14] + + +class TestHOGGenerator(TestCase): + + def test_hog_generator(self): + hog_generator = HOGGenerator() + + fake_input = torch.randn((2, 3, 224, 224)) + fake_output = hog_generator(fake_input) + assert list(fake_output.shape) == [2, 196, 108] + + fake_hog_out = hog_generator.out[0].unsqueeze(0) + fake_hog_img = hog_generator.generate_hog_image(fake_hog_out) + assert fake_hog_img.shape == (224, 224) + + with pytest.raises(AssertionError): + fake_hog_img = hog_generator.generate_hog_image( + hog_generator.out[0]) + + +class TestVQKD(TestCase): + + ENCODER_CFG = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + with_cls_token=True, + avg_token=False, + frozen_stages=-1, + output_cls_token=False, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + + @pytest.mark.skipif( + platform.system() == 'Windows', reason='Windows mem limit') + def test_vqkd(self): + model = VQKD(encoder_config=self.ENCODER_CFG) + fake_inputs = torch.rand((2, 3, 224, 224)) + fake_outputs = model(fake_inputs) + + assert list(fake_outputs.shape) == [2, 196]