Skip to content

Commit

Permalink
[Refactor] BEiT refactor (open-mmlab#1705)
Browse files Browse the repository at this point in the history
* [Refactor] BEiT refactor

* [Fix] Fix arch zoo

* [Fix] Fix arch zoo

* [Fix] Fix freeze stages

* [Fix] Fix freeze ln2

* [Fix] Fix freezing vit ln2
  • Loading branch information
fanqiNO1 authored Jul 11, 2023
1 parent 78d0ddc commit 5c43d3e
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 8 deletions.
183 changes: 179 additions & 4 deletions mmpretrain/models/backbones/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.registry import MODELS
from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple)
from .vision_transformer import TransformerEncoderLayer, VisionTransformer
from .base_backbone import BaseBackbone
from .vision_transformer import TransformerEncoderLayer


class RelativePositionBias(BaseModule):
Expand Down Expand Up @@ -212,7 +214,7 @@ def forward(self, x: torch.Tensor,


@MODELS.register_module()
class BEiTViT(VisionTransformer):
class BEiTViT(BaseBackbone):
"""Backbone for BEiT.
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
Expand Down Expand Up @@ -282,6 +284,62 @@ class tokens with shape (B, L, C).
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 768,
'num_layers': 8,
'num_heads': 8,
'feedforward_channels': 768 * 3,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['eva-g', 'eva-giant'],
{
# The implementation in EVA
# <https://arxiv.org/abs/2211.07636>
'embed_dims': 1408,
'num_layers': 40,
'num_heads': 16,
'feedforward_channels': 6144
}),
**dict.fromkeys(
['deit-t', 'deit-tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 384 * 4
}),
**dict.fromkeys(
['deit-b', 'deit-base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
}
num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}

def __init__(self,
arch='base',
Expand All @@ -300,12 +358,12 @@ def __init__(self,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False,
layer_scale_init_value=0.1,
interpolate_mode='bicubic',
layer_scale_init_value=0.1,
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg)
super(BEiTViT, self).__init__(init_cfg)

if isinstance(arch, str):
arch = arch.lower()
Expand Down Expand Up @@ -345,6 +403,7 @@ def __init__(self,
self.out_type = out_type

# Set cls token
self.with_cls_token = with_cls_token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
self.num_extra_tokens = 1
Expand Down Expand Up @@ -426,6 +485,87 @@ def __init__(self,
if self.frozen_stages > 0:
self._freeze_stages()

@property
def norm1(self):
return self.ln1

@property
def norm2(self):
return self.ln2

def init_weights(self):
super(BEiTViT, self).init_weights()

if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)

def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return

ckpt_pos_embed_shape = state_dict[name].shape
if (not self.with_cls_token
and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1):
# Remove cls token from state dict if it's not used.
state_dict[name] = state_dict[name][:, 1:]
ckpt_pos_embed_shape = state_dict[name].shape

if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')

ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size

state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)

@staticmethod
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)

def _freeze_stages(self):
# freeze position embedding
if self.pos_embed is not None:
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze cls_token
if self.with_cls_token:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers):
if self.final_norm:
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False

if self.out_type == 'avg_featmap':
self.ln2.eval()
for param in self.ln2.parameters():
param.requires_grad = False

def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
Expand Down Expand Up @@ -520,3 +660,38 @@ def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
index_buffer = ckpt_key.replace('bias_table', 'index')
if index_buffer in state_dict:
del state_dict[index_buffer]

def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = self.num_layers + 2

if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers

param_name = param_name[len(prefix):]

if param_name in ('cls_token', 'pos_embed'):
layer_depth = 0
elif param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('layers'):
layer_id = int(param_name.split('.')[1])
layer_depth = layer_id + 1
else:
layer_depth = num_layers - 1

return layer_depth, num_layers
14 changes: 10 additions & 4 deletions mmpretrain/models/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,16 @@ def _freeze_stages(self):
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers) and self.final_norm:
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
if self.frozen_stages == len(self.layers):
if self.final_norm:
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False

if self.out_type == 'avg_featmap':
self.ln2.eval()
for param in self.ln2.parameters():
param.requires_grad = False

def forward(self, x):
B = x.shape[0]
Expand Down

0 comments on commit 5c43d3e

Please sign in to comment.