forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] Add self-supervised backbones and target generators. (open…
…-mmlab#1379) * add heads * add losses * fix * remove mim head * add modified backbones and target generators * add unittest * refactor caevit * add window_size check * fix lint * apply new DataSample * fix ut error * update ut * fix ut * fix lint * Update base modules. --------- Co-authored-by: mzr1996 <[email protected]>
- Loading branch information
1 parent
63d9f27
commit e453a45
Showing
30 changed files
with
2,776 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .base import BaseSelfSupervisor | ||
from .beit import VQKD, BEiTPretrainViT | ||
from .cae import CAEViT, Encoder | ||
from .mae import MAEViT | ||
from .maskfeat import HOGGenerator, MaskFeatViT | ||
from .milan import CLIPGenerator, MILANViT | ||
from .mixmim import MixMIMPretrainTransformer | ||
from .mocov3 import MoCoV3ViT | ||
from .simmim import SimMIMSwinTransformer | ||
|
||
__all__ = [ | ||
'BaseSelfSupervisor', | ||
'BEiTPretrainViT', | ||
'VQKD', | ||
'CAEViT', | ||
'Encoder', | ||
'MAEViT', | ||
'HOGGenerator', | ||
'MaskFeatViT', | ||
'CLIPGenerator', | ||
'MILANViT', | ||
'MixMIMPretrainTransformer', | ||
'MoCoV3ViT', | ||
'SimMIMSwinTransformer', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import ABCMeta, abstractmethod | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from mmengine.model import BaseModel | ||
from torch import nn | ||
|
||
from mmpretrain.registry import MODELS | ||
from mmpretrain.structures import DataSample | ||
|
||
|
||
class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta): | ||
"""BaseModel for Self-Supervised Learning. | ||
All self-supervised algorithms should inherit this module. | ||
Args: | ||
backbone (dict): The backbone module. See | ||
:mod:`mmpretrain.models.backbones`. | ||
neck (dict, optional): The neck module to process features from | ||
backbone. See :mod:`mmpretrain.models.necks`. Defaults to None. | ||
head (dict, optional): The head module to do prediction and calculate | ||
loss from processed features. See :mod:`mmpretrain.models.heads`. | ||
Notice that if the head is not set, almost all methods cannot be | ||
used except :meth:`extract_feat`. Defaults to None. | ||
target_generator: (dict, optional): The target_generator module to | ||
generate targets for self-supervised learning optimization, such as | ||
HOG, extracted features from other modules(DALL-E, CLIP), etc. | ||
pretrained (str, optional): The pretrained checkpoint path, support | ||
local path and remote path. Defaults to None. | ||
data_preprocessor (Union[dict, nn.Module], optional): The config for | ||
preprocessing input data. If None or no specified type, it will use | ||
"SelfSupDataPreprocessor" as type. | ||
See :class:`SelfSupDataPreprocessor` for more details. | ||
Defaults to None. | ||
init_cfg (dict, optional): the config to control the initialization. | ||
Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
backbone: dict, | ||
neck: Optional[dict] = None, | ||
head: Optional[dict] = None, | ||
target_generator: Optional[dict] = None, | ||
pretrained: Optional[str] = None, | ||
data_preprocessor: Optional[Union[dict, nn.Module]] = None, | ||
init_cfg: Optional[dict] = None): | ||
if pretrained is not None: | ||
init_cfg = dict(type='Pretrained', checkpoint=pretrained) | ||
|
||
data_preprocessor = data_preprocessor or {} | ||
if isinstance(data_preprocessor, dict): | ||
data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor') | ||
data_preprocessor = MODELS.build(data_preprocessor) | ||
elif not isinstance(data_preprocessor, nn.Module): | ||
raise TypeError('data_preprocessor should be a `dict` or ' | ||
f'`nn.Module` instance, but got ' | ||
f'{type(data_preprocessor)}') | ||
|
||
super().__init__( | ||
init_cfg=init_cfg, data_preprocessor=data_preprocessor) | ||
|
||
if not isinstance(backbone, nn.Module): | ||
backbone = MODELS.build(backbone) | ||
if neck is not None and not isinstance(neck, nn.Module): | ||
neck = MODELS.build(neck) | ||
if head is not None and not isinstance(head, nn.Module): | ||
head = MODELS.build(head) | ||
if target_generator is not None and not isinstance( | ||
target_generator, nn.Module): | ||
target_generator = MODELS.build(target_generator) | ||
|
||
self.backbone = backbone | ||
self.neck = neck | ||
self.head = head | ||
self.target_generator = target_generator | ||
|
||
@property | ||
def with_neck(self) -> bool: | ||
"""Check if the model has a neck module.""" | ||
return hasattr(self, 'neck') and self.neck is not None | ||
|
||
@property | ||
def with_head(self) -> bool: | ||
"""Check if the model has a head module.""" | ||
return hasattr(self, 'head') and self.head is not None | ||
|
||
@property | ||
def with_target_generator(self) -> bool: | ||
"""Check if the model has a target_generator module.""" | ||
return hasattr( | ||
self, 'target_generator') and self.target_generator is not None | ||
|
||
def forward(self, | ||
inputs: torch.Tensor, | ||
data_samples: Optional[List[DataSample]] = None, | ||
mode: str = 'tensor'): | ||
"""The unified entry for a forward process in both training and test. | ||
The method should accept three modes: "tensor", "predict" and "loss": | ||
- "tensor": Forward the backbone network and return the feature | ||
tensor(s) tensor without any post-processing, same as a common | ||
PyTorch Module. | ||
- "loss": Forward and return a dict of losses according to the given | ||
inputs and data samples. | ||
Args: | ||
inputs (torch.Tensor): The input tensor with shape | ||
(N, C, ...) in general. | ||
data_samples (List[DataSample], optional): The other data of | ||
every samples. It's required for some algorithms | ||
if ``mode="loss"``. Defaults to None. | ||
mode (str): Return what kind of value. Defaults to 'tensor'. | ||
Returns: | ||
The return type depends on ``mode``. | ||
- If ``mode="tensor"``, return a tensor or a tuple of tensor. | ||
- If ``mode="loss"``, return a dict of tensor. | ||
""" | ||
if mode == 'tensor': | ||
feats = self.extract_feat(inputs) | ||
return feats | ||
elif mode == 'loss': | ||
return self.loss(inputs, data_samples) | ||
else: | ||
raise RuntimeError(f'Invalid mode "{mode}".') | ||
|
||
@abstractmethod | ||
def extract_feat(self, inputs: torch.Tensor): | ||
"""Extract features from the input tensor with shape (N, C, ...). | ||
The sub-classes are recommended to implement this method to extract | ||
features from backbone and neck. | ||
Args: | ||
inputs (Tensor): A batch of inputs. The shape of it should be | ||
``(num_samples, num_channels, *img_shape)``. | ||
Returns: | ||
tuple | Tensor: The output feature tensor(s). | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def loss(self, inputs: torch.Tensor, | ||
data_samples: List[DataSample]) -> dict: | ||
"""Calculate losses from a batch of inputs and data samples. | ||
This is a abstract method, and subclass should overwrite this methods | ||
if needed. | ||
Args: | ||
inputs (torch.Tensor): The input tensor with shape | ||
(N, C, ...) in general. | ||
data_samples (List[DataSample]): The annotation data of | ||
every samples. | ||
Returns: | ||
dict[str, Tensor]: A dictionary of loss components. | ||
""" | ||
raise NotImplementedError |
Oops, something went wrong.