Skip to content

Commit

Permalink
[Refactor] Refactor _prepare_pos_embed in ViT (open-mmlab#1656)
Browse files Browse the repository at this point in the history
* deal with cls_token

* Update implement

---------

Co-authored-by: mzr1996 <[email protected]>
  • Loading branch information
fangyixiao18 and mzr1996 authored Jun 20, 2023
1 parent d4a6dfa commit 70ff2ab
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mmpretrain/models/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(self,
self.out_type = out_type

# Set cls token
self.with_cls_token = with_cls_token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
elif out_type != 'cls_token':
Expand Down Expand Up @@ -404,6 +405,11 @@ def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size

if (not self.with_cls_token and ckpt_pos_embed_shape[1]
== self.pos_embed.shape[1] + 1):
# Remove cls token from state dict if it's not used.
state_dict[name] = state_dict[name][:, 1:]

state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
Expand Down

0 comments on commit 70ff2ab

Please sign in to comment.