diff --git a/projects/starnet_backbone/README.md b/projects/starnet_backbone/README.md new file mode 100644 index 0000000000..e622a7cd0f --- /dev/null +++ b/projects/starnet_backbone/README.md @@ -0,0 +1,118 @@ +# Example Project + +This is an example README for community `projects/`. You can write your README in your own project. Here are +some recommended parts of a README for others to understand and use your project, you can copy or modify them +according to your project. + +## Usage + +### Setup Environment + +Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) to install +MMPreTrain. + +At first, add the current folder to `PYTHONPATH`, so that Python can find your code. Run command in the current directory to add it. + +> Please run it every time after you opened a new shell. + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Data Preparation + +Prepare the ImageNet-2012 dataset according to the [instruction](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#imagenet). + +### Training commands + +**To train with single GPU:** + +```bash +mim train mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py +``` + +**To train with multiple GPUs:** + +```bash +mim train mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --launcher pytorch --gpus 8 +``` + +**To train with multiple GPUs by slurm:** + +```bash +mim train mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --launcher slurm \ + --gpus 16 --gpus-per-node 8 --partition $PARTITION +``` + +### Testing commands + +**To test with single GPU:** + +```bash +mim test mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --checkpoint $CHECKPOINT +``` + +**To test with multiple GPUs:** + +```bash +mim test mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8 +``` + +**To test with multiple GPUs by slurm:** + +```bash +mim test mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --checkpoint $CHECKPOINT --launcher slurm \ + --gpus 16 --gpus-per-node 8 --partition $PARTITION +``` + +## Citation + + + +```BibTeX +@misc{2023mmpretrain, + title={OpenMMLab's Pre-training Toolbox and Benchmark}, + author={MMPreTrain Contributors}, + howpublished = {\url{https://github.com/open-mmlab/mmpretrain}}, + year={2023} +} +``` + +## Checklist + +Here is a checklist of this project's progress. And you can ignore this part if you don't plan to contribute +to MMPreTrain projects. + +- [ ] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [ ] Finish the code + + + + - [ ] Basic docstrings & proper citation + + + + - [ ] Converted checkpoint and results (Only for reproduction) + + + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training results + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Unit tests + + + + - [ ] Code style + + + + - [ ] `metafile.yml` and `README.md` + + diff --git a/projects/starnet_backbone/configs/models/starnet/starnet_s1.py b/projects/starnet_backbone/configs/models/starnet/starnet_s1.py new file mode 100644 index 0000000000..68826a714e --- /dev/null +++ b/projects/starnet_backbone/configs/models/starnet/starnet_s1.py @@ -0,0 +1,11 @@ +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='StarNet', arch='s1'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=9, + in_channels=192, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + )) diff --git a/projects/starnet_backbone/configs/starnet/startnet_s1_32xb32_in1k.py b/projects/starnet_backbone/configs/starnet/startnet_s1_32xb32_in1k.py new file mode 100644 index 0000000000..17406da71b --- /dev/null +++ b/projects/starnet_backbone/configs/starnet/startnet_s1_32xb32_in1k.py @@ -0,0 +1,7 @@ +_base_ = [ + '../_base_/models/starnet/starnet_s1.py', + '../_base_/datasets/imagenet_bs32.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] +train_dataloader = dict(batch_size=24) +val_dataloader = dict(batch_size=24) diff --git a/projects/starnet_backbone/mmpretrain/models/backbones/starnet.py b/projects/starnet_backbone/mmpretrain/models/backbones/starnet.py new file mode 100644 index 0000000000..cd65a22ebd --- /dev/null +++ b/projects/starnet_backbone/mmpretrain/models/backbones/starnet.py @@ -0,0 +1,315 @@ +from typing import Optional, Sequence + +import torch +import torch.nn as nn +from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer, + build_norm_layer) +from mmengine.model import BaseModule, Sequential + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.registry import MODELS + + +class Block(BaseModule): + """StarNet Block. + + Args: + in_channels (int): The number of input channels. + mlp_ratio (float): The expansion ratio in both pointwise convolution. + Defaults to 3. + drop_path (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='ReLU6')``. + conv_cfg (dict): Config dict for convolution layer. + Defaults to ``dict(type='Conv2d')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + in_channels, + mlp_ratio: float = 3., + drop_path: float = 0., + conv_cfg: Optional[dict] = dict(type='Conv2d'), + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='ReLU6'), + init_cfg: Optional[dict] = None, + ) -> None: + super().__init__(init_cfg=init_cfg) + self.dwconv = ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=7, + stride=1, + padding=(7 - 1) // 2, + groups=in_channels, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + ) + self.fc1 = ConvModule( + in_channels=in_channels, + out_channels=mlp_ratio * in_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=None, + ) + + self.fc2 = ConvModule( + in_channels=in_channels, + out_channels=mlp_ratio * in_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=None, + ) + self.g = ConvModule( + in_channels=mlp_ratio * in_channels, + out_channels=in_channels, + kernel_size=1, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + ) + self.dwconv2 = ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=7, + stride=1, + padding=(7 - 1) // 2, + groups=in_channels, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=None, + ) + self.act = build_activation_layer(act_cfg) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + x = self.dwconv(x) + x1, x2 = self.fc1(x), self.fc2(x) + x = self.act(x1) * x2 + x = self.dwconv2(self.g(x)) + x = identity + self.drop_path(x) + return x + + +@MODELS.register_module() +class StarNet(BaseBackbone): + """StarNet. + + A PyTorch implementation of StarNet introduced by: + `Rewrite the Stars `_ + + Modified from the `official repo + `. + + Args: + arch (str | dict): The model's architecture. + it should include the following two keys: + + - layers (list[int]): Number of blocks at each stage. + - embed_dims (list[int]): The number of channels at each stage. + + Defaults to 's1'. + + in_channels (int): Number of input image channels. Default: 3. + out_channels (int): Output channels of the stem layer. Default: 32. + mlp_ratio (float): The expansion ratio in pointwise convolution. + Defaults to 4. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='BN')``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='ReLU6')``. + conv_cfg (dict): Config dict for convolution layer. + Defaults to ``dict(type='Conv2d')``. + init_cfg (dict, optional): Initialization config dict + """ + + arch_settings = { + 's1': { + 'layers': [2, 2, 8, 3], + 'embed_dims': [24, 48, 96, 192], + }, + 's2': { + 'layers': [1, 2, 6, 2], + 'embed_dims': [32, 64, 128, 256], + }, + 's3': { + 'layers': [2, 2, 8, 4], + 'embed_dims': [32, 64, 128, 256], + }, + 's4': { + 'layers': [3, 3, 12, 5], + 'embed_dims': [32, 64, 128, 256], + }, + 's050': { + 'layers': [1, 1, 3, 1], + 'embed_dims': [16, 32, 64, 128], + }, + 's100': { + 'layers': [1, 2, 4, 1], + 'embed_dims': [20, 40, 80, 160], + }, + 's150': { + 'layers': [1, 2, 4, 2], + 'embed_dims': [24, 48, 96, 192], + } + } + + def __init__( + self, + arch='s1', + in_channels: int = 3, + out_channels: int = 32, + out_indices=-1, + frozen_stages=0, + mlp_ratio: float = 4., + drop_path_rate: float = 0., + conv_cfg: Optional[dict] = dict(type='Conv2d'), + norm_cfg: Optional[dict] = dict(type='BN'), + act_cfg: Optional[dict] = dict(type='ReLU6'), + init_cfg=[ + dict(type='Kaiming', layer=['Conv2d']), + dict(type='Constant', val=1, layer=['_BatchNorm']) + ] + ) -> None: + super().__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'layers' in arch and 'embed_dims' in arch, \ + f'The arch dict must have "layers" and "embed_dims", ' \ + f'but got {list(arch.keys())}.' + + self.layers = arch['layers'] + self.embed_dims = arch['embed_dims'] + depth = len(self.layers) + self.num_stages = len(self.layers) + self.mlp_ratio = mlp_ratio + self.drop_path_rate = drop_path_rate + self.in_channels = in_channels + self.out_channels = out_channels + + self.stem = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.layers)) + ] + + self.stages = [] + cur = 0 + for i in range(depth): + stage = self._make_stage( + planes=self.out_channels, + num_blocks=self.layers[i], + cur=cur, + dpr=dpr, + stages_num=i) + self.out_channels = self.embed_dims[i] + cur += self.layers[i] + stage_name = f'stage{i}' + self.add_module(stage_name, stage) + self.stages.append(stage_name) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + out_indices = list(out_indices) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_stages + index + assert 0 <= out_indices[i] <= self.num_stages, \ + f'Invalid out_indices {index}.' + self.out_indices = out_indices + + if self.out_indices: + for i_layer in self.out_indices: + layer = build_norm_layer(norm_cfg, self.embed_dims[i_layer])[1] + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + self.frozen_stages = frozen_stages + self._freeze_stages() + + def _make_stage(self, planes, num_blocks, cur, dpr, stages_num): + down_sampler = ConvModule( + in_channels=planes, + out_channels=self.embed_dims[stages_num], + kernel_size=3, + stride=2, + padding=1, + bias=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + ) + + blocks = [ + Block( + in_channels=self.embed_dims[stages_num], + mlp_ratio=self.mlp_ratio, + drop_path=dpr[cur + i], + ) for i in range(num_blocks) + ] + + return Sequential(down_sampler, *blocks) + + def forward(self, x): + x = self.stem(x) + + outs = [] + for i, stage_name in enumerate(self.stages): + stage = getattr(self, stage_name) + x = stage(x) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x) + outs.append(x_out) + + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + stage_layer = getattr(self, f'stage{i}') + stage_layer.eval() + + for param in stage_layer.parameters(): + param.requires_grad = False + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(StarNet, self).train(mode) + self._freeze_stages()