From 70ff2abbf7dbf5232cb2f9b961d34b6101311699 Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Tue, 20 Jun 2023 17:37:08 +0800 Subject: [PATCH] [Refactor] Refactor _prepare_pos_embed in ViT (#1656) * deal with cls_token * Update implement --------- Co-authored-by: mzr1996 --- mmpretrain/models/backbones/vision_transformer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py index cd0a70d377f..82e401c3c87 100644 --- a/mmpretrain/models/backbones/vision_transformer.py +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -305,6 +305,7 @@ def __init__(self, self.out_type = out_type # Set cls token + self.with_cls_token = with_cls_token if with_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) elif out_type != 'cls_token': @@ -404,6 +405,11 @@ def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.patch_embed.init_out_size + if (not self.with_cls_token and ckpt_pos_embed_shape[1] + == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape,