From 0d3d0269cf2342f3238bf3ed9955a095087a13cb Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:10:25 +0800 Subject: [PATCH 1/3] add vpt --- mmpretrain/models/backbones/vpt.py | 325 +++++++++++++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 mmpretrain/models/backbones/vpt.py diff --git a/mmpretrain/models/backbones/vpt.py b/mmpretrain/models/backbones/vpt.py new file mode 100644 index 00000000000..969fe7d8058 --- /dev/null +++ b/mmpretrain/models/backbones/vpt.py @@ -0,0 +1,325 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +import torch +import torch.nn as nn +from mmpretrain.registry import MODELS +from mmpretrain.models.backbones import VisionTransformer +from mmpretrain.models.backbones import ViTEVA02 +from mmpretrain.models.utils import build_norm_layer +import torch +import torch.nn as nn + +from mmpretrain.models.utils import resize_pos_embed + + +def init_prompt(prompt_init, prompt): + if prompt_init == 'uniform': + nn.init.uniform_(prompt, -0.08, 0.08) + elif prompt_init == 'zero': + nn.init.zeros_(prompt) + elif prompt_init == 'kaiming': + nn.init.kaiming_normal_(prompt) + elif prompt_init == 'token': + nn.init.zeros_(prompt) + else: + nn.init.normal_(prompt, std=0.02) + +@MODELS.register_module() +class PromptedViT(VisionTransformer): + '''Vision Transformer with Prompt. + + A PyTorch implement of : `Visual Prompt Tuning`_ + + Args: + prompt_length (int): the length of prompt parameters. Defaults to 1. + deep_prompt (bool): Whether to use deep prompt, Defaults to True. + prompt_init (str): The Initialisation method. Defaults to 'normal'. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + - ``"avg_all"``: The global averaged feature map & cls_tocken + & prompt tensor with shape (B, C). + - ``"avg_prompt"``: The global averaged prompt tensor with + shape (B, C). + - ``"avg_prompt_clstoken"``: The global averaged cls_tocken + & prompt tensor with shape (B, C). + + Defaults to ``"avg_all"``. + *args(list, optional): Other args for VisionTransformer. + **kwargs(dict, optional): Other args for VisionTransformer. + ''' + + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'} + def __init__(self, + prompt_length: int = 1, + deep_prompt: bool = True, + out_type: str ='avg_all', + prompt_init: str = 'normal', + norm_cfg: dict =dict(type='LN'), + *args, + **kwargs): + super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) + + self.prompt_layers = len(self.layers) if deep_prompt else 1 + prompt = torch.empty( + self.prompt_layers, prompt_length, self.embed_dims) + init_prompt(prompt_init, prompt) + self.prompt_initialized = False if prompt_init == 'token' else True + self.prompt = nn.Parameter(prompt, requires_grad=True) + + self.prompt_length = prompt_length + self.deep_prompt = deep_prompt + self.num_extra_tokens = self.num_extra_tokens + prompt_length + + if self.out_type in {'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}: + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages + self.frozen_stages = len(self.layers) + self._freeze_stages() + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + # reshape to [layers, batch, tokens, embed_dims] + prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1) + x = torch.cat( + [x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], + dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if self.deep_prompt and i != len(self.layers) - 1: + x = torch.cat( + [x[:, :1, :], prompt[i, :, :, :], x[:, self.prompt_length + 1:, :]], + dim=1) + + # final_norm should be False here + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(x[:, self.prompt_length+1:].mean(dim=1)) + if self.out_type == 'avg_all': + return self.ln2(x.mean(dim=1)) + if self.out_type == 'avg_prompt': + return self.ln2(x[:, 1:self.prompt_length+1].mean(dim=1)) + if self.out_type == 'avg_prompt_clstoken': + return self.ln2(x[:, :self.prompt_length+1].mean(dim=1)) + + +def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution): + B, N, _ = x.shape + H, W = patch_resolution + extra_token_num = N - H * W + + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.rope: + if extra_token_num > 0: + q_t = q[:, :, extra_token_num:, :] + ro_q_t = self.rope(q_t, patch_resolution) + q = torch.cat((q[:, :, :extra_token_num, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, extra_token_num:, :] + ro_k_t = self.rope(k_t, patch_resolution) + k = torch.cat((k[:, :, :extra_token_num , :], ro_k_t), -2).type_as(v) + else: + q = self.rope(q, patch_resolution) + k = self.rope(k, patch_resolution) + + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1).type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +@MODELS.register_module() +class PromptedViTEVA02(ViTEVA02): + '''EVA02 Vision Transformer with Prompt. + + A PyTorch implement of : `Visual Prompt Tuning`_ + + Args: + prompt_length (int): the length of prompt parameters. Defaults to 1. + deep_prompt (bool): Whether to use deep prompt, Defaults to True. + prompt_init (str): The Initialisation method. Defaults to 'normal'. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + - ``"avg_all"``: The global averaged feature map & cls_tocken + & prompt tensor with shape (B, C). + - ``"avg_prompt"``: The global averaged prompt tensor with + shape (B, C). + - ``"avg_prompt_clstoken"``: The global averaged cls_tocken + & prompt tensor with shape (B, C). + + Defaults to ``"avg_all"``. + *args(list, optional): Other args for ViTEVA02. + **kwargs(dict, optional): Other args for ViTEVA02. + ''' + + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'} + # 'avg_all' : avg of 'prompt' & 'cls_token' & 'featmap' + # 'avg_prompt' avg of 'prompt' + # 'avg_prompt_clstoken' avg of 'cls_token' and 'prompt' + def __init__(self, + prompt_length = 1, + deep_prompt = True, + out_type='avg_all', + prompt_init: str = 'normal', + norm_cfg=dict(type='LN'), + *args, + **kwargs): + super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) + + self.prompt_layers = len(self.layers) if deep_prompt else 1 + prompt = torch.empty( + self.prompt_layers, prompt_length, self.embed_dims) + if prompt_init == 'uniform': + nn.init.uniform_(prompt, -0.08, 0.08) + elif prompt_init == 'zero': + nn.init.zeros_(prompt) + elif prompt_init == 'kaiming': + nn.init.kaiming_normal_(prompt) + elif prompt_init == 'token': + nn.init.zeros_(prompt) + self.prompt_initialized = False + else: + nn.init.normal_(prompt, std=0.02) + self.prompt = nn.Parameter(prompt, requires_grad=True) + self.prompt_length = prompt_length + self.deep_prompt = deep_prompt + + if self.out_type in {'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}: + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) + + # freeze stages + self.frozen_stages = len(self.layers) + self._freeze_stages() + + @patch( + 'mmpretrain.models.backbones.vit_eva02.AttentionWithRoPE.forward', + new_AttentionWithRoPE_forward_fn) + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + # reshape to [layers, batch, tokens, embed_dims] + prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1) + x = torch.cat( + [x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], + dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, patch_resolution) + + if self.deep_prompt and i != len(self.layers) - 1: + x = torch.cat( + [x[:, :1, :], prompt[i, :, :, :], x[:, self.prompt_length + 1:, :]], + dim=1) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(x[:, self.prompt_length:].mean(dim=1)) + if self.out_type == 'avg_all': + return self.ln2(x.mean(dim=1)) + if self.out_type == 'avg_prompt': + return self.ln2(x[:, 1:self.prompt_length+1].mean(dim=1)) + if self.out_type == 'avg_prompt_clstoken': + return self.ln2(x[:, :self.prompt_length+1].mean(dim=1)) + From 857cf7ad74747801475c23208e9a6d9c4b8efab8 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:12:49 +0800 Subject: [PATCH 2/3] lint --- mmpretrain/models/backbones/vpt.py | 140 +++++++++++++++-------------- 1 file changed, 75 insertions(+), 65 deletions(-) diff --git a/mmpretrain/models/backbones/vpt.py b/mmpretrain/models/backbones/vpt.py index 969fe7d8058..abd794668b2 100644 --- a/mmpretrain/models/backbones/vpt.py +++ b/mmpretrain/models/backbones/vpt.py @@ -1,16 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import patch -import torch -import torch.nn as nn -from mmpretrain.registry import MODELS -from mmpretrain.models.backbones import VisionTransformer -from mmpretrain.models.backbones import ViTEVA02 -from mmpretrain.models.utils import build_norm_layer import torch import torch.nn as nn -from mmpretrain.models.utils import resize_pos_embed +from mmpretrain.models.backbones import VisionTransformer, ViTEVA02 +from mmpretrain.models.utils import build_norm_layer, resize_pos_embed +from mmpretrain.registry import MODELS def init_prompt(prompt_init, prompt): @@ -25,11 +21,13 @@ def init_prompt(prompt_init, prompt): else: nn.init.normal_(prompt, std=0.02) + @MODELS.register_module() class PromptedViT(VisionTransformer): - '''Vision Transformer with Prompt. - - A PyTorch implement of : `Visual Prompt Tuning`_ + """Vision Transformer with Prompt. + + A PyTorch implement of : `Visual Prompt Tuning + `_ Args: prompt_length (int): the length of prompt parameters. Defaults to 1. @@ -48,41 +46,47 @@ class tokens with shape (B, L, C). & prompt tensor with shape (B, C). - ``"avg_prompt"``: The global averaged prompt tensor with shape (B, C). - - ``"avg_prompt_clstoken"``: The global averaged cls_tocken + - ``"avg_prompt_clstoken"``: The global averaged cls_tocken & prompt tensor with shape (B, C). Defaults to ``"avg_all"``. *args(list, optional): Other args for VisionTransformer. **kwargs(dict, optional): Other args for VisionTransformer. - ''' + """ num_extra_tokens = 1 # class token - OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'} + OUT_TYPES = { + 'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', + 'avg_prompt_clstoken' + } + def __init__(self, prompt_length: int = 1, deep_prompt: bool = True, - out_type: str ='avg_all', + out_type: str = 'avg_all', prompt_init: str = 'normal', - norm_cfg: dict =dict(type='LN'), + norm_cfg: dict = dict(type='LN'), *args, **kwargs): - super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) + super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) self.prompt_layers = len(self.layers) if deep_prompt else 1 - prompt = torch.empty( - self.prompt_layers, prompt_length, self.embed_dims) + prompt = torch.empty(self.prompt_layers, prompt_length, + self.embed_dims) init_prompt(prompt_init, prompt) self.prompt_initialized = False if prompt_init == 'token' else True self.prompt = nn.Parameter(prompt, requires_grad=True) self.prompt_length = prompt_length self.deep_prompt = deep_prompt - self.num_extra_tokens = self.num_extra_tokens + prompt_length + self.num_extra_tokens = self.num_extra_tokens + prompt_length - if self.out_type in {'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}: + if self.out_type in { + 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken' + }: self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) - - # freeze stages + + # freeze stages self.frozen_stages = len(self.layers) self._freeze_stages() @@ -107,18 +111,18 @@ def forward(self, x): # reshape to [layers, batch, tokens, embed_dims] prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1) - x = torch.cat( - [x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], - dim=1) + x = torch.cat([x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], dim=1) outs = [] for i, layer in enumerate(self.layers): x = layer(x) - + if self.deep_prompt and i != len(self.layers) - 1: - x = torch.cat( - [x[:, :1, :], prompt[i, :, :, :], x[:, self.prompt_length + 1:, :]], - dim=1) + x = torch.cat([ + x[:, :1, :], prompt[i, :, :, :], + x[:, self.prompt_length + 1:, :] + ], + dim=1) # final_norm should be False here if i == len(self.layers) - 1 and self.final_norm: @@ -141,13 +145,13 @@ def _format_output(self, x, hw): # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) if self.out_type == 'avg_featmap': - return self.ln2(x[:, self.prompt_length+1:].mean(dim=1)) + return self.ln2(x[:, self.prompt_length + 1:].mean(dim=1)) if self.out_type == 'avg_all': - return self.ln2(x.mean(dim=1)) + return self.ln2(x.mean(dim=1)) if self.out_type == 'avg_prompt': - return self.ln2(x[:, 1:self.prompt_length+1].mean(dim=1)) + return self.ln2(x[:, 1:self.prompt_length + 1].mean(dim=1)) if self.out_type == 'avg_prompt_clstoken': - return self.ln2(x[:, :self.prompt_length+1].mean(dim=1)) + return self.ln2(x[:, :self.prompt_length + 1].mean(dim=1)) def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution): @@ -163,11 +167,13 @@ def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution): if extra_token_num > 0: q_t = q[:, :, extra_token_num:, :] ro_q_t = self.rope(q_t, patch_resolution) - q = torch.cat((q[:, :, :extra_token_num, :], ro_q_t), -2).type_as(v) + q = torch.cat((q[:, :, :extra_token_num, :], ro_q_t), + -2).type_as(v) k_t = k[:, :, extra_token_num:, :] ro_k_t = self.rope(k_t, patch_resolution) - k = torch.cat((k[:, :, :extra_token_num , :], ro_k_t), -2).type_as(v) + k = torch.cat((k[:, :, :extra_token_num, :], ro_k_t), + -2).type_as(v) else: q = self.rope(q, patch_resolution) k = self.rope(k, patch_resolution) @@ -188,9 +194,10 @@ def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution): @MODELS.register_module() class PromptedViTEVA02(ViTEVA02): - '''EVA02 Vision Transformer with Prompt. - - A PyTorch implement of : `Visual Prompt Tuning`_ + """EVA02 Vision Transformer with Prompt. + + A PyTorch implement of : `Visual Prompt Tuning + `_ Args: prompt_length (int): the length of prompt parameters. Defaults to 1. @@ -209,32 +216,36 @@ class tokens with shape (B, L, C). & prompt tensor with shape (B, C). - ``"avg_prompt"``: The global averaged prompt tensor with shape (B, C). - - ``"avg_prompt_clstoken"``: The global averaged cls_tocken + - ``"avg_prompt_clstoken"``: The global averaged cls_tocken & prompt tensor with shape (B, C). Defaults to ``"avg_all"``. *args(list, optional): Other args for ViTEVA02. **kwargs(dict, optional): Other args for ViTEVA02. - ''' + """ num_extra_tokens = 1 # class token - OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'} + OUT_TYPES = { + 'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', + 'avg_prompt_clstoken' + } + # 'avg_all' : avg of 'prompt' & 'cls_token' & 'featmap' # 'avg_prompt' avg of 'prompt' # 'avg_prompt_clstoken' avg of 'cls_token' and 'prompt' def __init__(self, - prompt_length = 1, - deep_prompt = True, + prompt_length=1, + deep_prompt=True, out_type='avg_all', prompt_init: str = 'normal', norm_cfg=dict(type='LN'), *args, **kwargs): - super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) + super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs) self.prompt_layers = len(self.layers) if deep_prompt else 1 - prompt = torch.empty( - self.prompt_layers, prompt_length, self.embed_dims) + prompt = torch.empty(self.prompt_layers, prompt_length, + self.embed_dims) if prompt_init == 'uniform': nn.init.uniform_(prompt, -0.08, 0.08) elif prompt_init == 'zero': @@ -250,16 +261,17 @@ def __init__(self, self.prompt_length = prompt_length self.deep_prompt = deep_prompt - if self.out_type in {'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}: + if self.out_type in { + 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken' + }: self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) - - # freeze stages + + # freeze stages self.frozen_stages = len(self.layers) self._freeze_stages() - - @patch( - 'mmpretrain.models.backbones.vit_eva02.AttentionWithRoPE.forward', - new_AttentionWithRoPE_forward_fn) + + @patch('mmpretrain.models.backbones.vit_eva02.AttentionWithRoPE.forward', + new_AttentionWithRoPE_forward_fn) def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) @@ -281,18 +293,18 @@ def forward(self, x): # reshape to [layers, batch, tokens, embed_dims] prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1) - x = torch.cat( - [x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], - dim=1) + x = torch.cat([x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], dim=1) outs = [] for i, layer in enumerate(self.layers): x = layer(x, patch_resolution) if self.deep_prompt and i != len(self.layers) - 1: - x = torch.cat( - [x[:, :1, :], prompt[i, :, :, :], x[:, self.prompt_length + 1:, :]], - dim=1) + x = torch.cat([ + x[:, :1, :], prompt[i, :, :, :], + x[:, self.prompt_length + 1:, :] + ], + dim=1) if i == len(self.layers) - 1 and self.final_norm: x = self.ln1(x) @@ -302,7 +314,6 @@ def forward(self, x): return tuple(outs) - def _format_output(self, x, hw): if self.out_type == 'raw': return x @@ -315,11 +326,10 @@ def _format_output(self, x, hw): # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) if self.out_type == 'avg_featmap': - return self.ln2(x[:, self.prompt_length:].mean(dim=1)) + return self.ln2(x[:, self.prompt_length:].mean(dim=1)) if self.out_type == 'avg_all': - return self.ln2(x.mean(dim=1)) + return self.ln2(x.mean(dim=1)) if self.out_type == 'avg_prompt': - return self.ln2(x[:, 1:self.prompt_length+1].mean(dim=1)) + return self.ln2(x[:, 1:self.prompt_length + 1].mean(dim=1)) if self.out_type == 'avg_prompt_clstoken': - return self.ln2(x[:, :self.prompt_length+1].mean(dim=1)) - + return self.ln2(x[:, :self.prompt_length + 1].mean(dim=1)) From 5ba8782aef510778127c5778f4ab6125c0a60bb7 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:23:05 +0800 Subject: [PATCH 3/3] add to __init__ and api doc --- docs/en/api/models.rst | 2 + mmpretrain/models/backbones/__init__.py | 80 +++++-------------------- 2 files changed, 16 insertions(+), 66 deletions(-) diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index 93e3e8416ad..5d27cae2739 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -220,6 +220,8 @@ Backbones ViTSAM XCiT ViTEVA02 + PromptedViT + PromptedViTEVA02 .. module:: mmpretrain.models.necks diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py index 60e37fb7b6e..de102e32124 100644 --- a/mmpretrain/models/backbones/__init__.py +++ b/mmpretrain/models/backbones/__init__.py @@ -57,73 +57,21 @@ from .vision_transformer import VisionTransformer from .vit_eva02 import ViTEVA02 from .vit_sam import ViTSAM +from .vpt import PromptedViT, PromptedViTEVA02 from .xcit import XCiT __all__ = [ - 'LeNet5', - 'AlexNet', - 'VGG', - 'RegNet', - 'ResNet', - 'ResNeXt', - 'ResNetV1d', - 'ResNeSt', - 'ResNet_CIFAR', - 'SEResNet', - 'SEResNeXt', - 'ShuffleNetV1', - 'ShuffleNetV2', - 'MobileNetV2', - 'MobileNetV3', - 'VisionTransformer', - 'SwinTransformer', - 'TNT', - 'TIMMBackbone', - 'T2T_ViT', - 'Res2Net', - 'RepVGG', - 'Conformer', - 'MlpMixer', - 'DistilledVisionTransformer', - 'PCPVT', - 'SVT', - 'EfficientNet', - 'EfficientNetV2', - 'ConvNeXt', - 'HRNet', - 'ResNetV1c', - 'ConvMixer', - 'EdgeNeXt', - 'CSPDarkNet', - 'CSPResNet', - 'CSPResNeXt', - 'CSPNet', - 'RepLKNet', - 'RepMLPNet', - 'PoolFormer', - 'RIFormer', - 'DenseNet', - 'VAN', - 'InceptionV3', - 'MobileOne', - 'EfficientFormer', - 'SwinTransformerV2', - 'MViT', - 'DeiT3', - 'HorNet', - 'MobileViT', - 'DaViT', - 'BEiTViT', - 'RevVisionTransformer', - 'MixMIMTransformer', - 'TinyViT', - 'LeViT', - 'Vig', - 'PyramidVig', - 'XCiT', - 'ViTSAM', - 'ViTEVA02', - 'HiViT', - 'SparseResNet', - 'SparseConvNeXt', + 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', + 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', + 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', + 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', + 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT', + 'EfficientNet', 'EfficientNetV2', 'ConvNeXt', 'HRNet', 'ResNetV1c', + 'ConvMixer', 'EdgeNeXt', 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', + 'RepLKNet', 'RepMLPNet', 'PoolFormer', 'RIFormer', 'DenseNet', 'VAN', + 'InceptionV3', 'MobileOne', 'EfficientFormer', 'SwinTransformerV2', 'MViT', + 'DeiT3', 'HorNet', 'MobileViT', 'DaViT', 'BEiTViT', 'RevVisionTransformer', + 'MixMIMTransformer', 'TinyViT', 'LeViT', 'Vig', 'PyramidVig', 'XCiT', + 'ViTSAM', 'ViTEVA02', 'HiViT', 'SparseResNet', 'SparseConvNeXt', + 'PromptedViT', 'PromptedViTEVA02' ]