diff --git a/mindcv/models/layers/__init__.py b/mindcv/models/layers/__init__.py index 2810dbca1..e12ae441b 100644 --- a/mindcv/models/layers/__init__.py +++ b/mindcv/models/layers/__init__.py @@ -1,9 +1,21 @@ """layers init""" -from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite +from . import ( + activation, + conv_norm_act, + drop_path, + format, + identity, + patch_dropout, + pooling, + selective_kernel, + squeeze_excite, +) from .activation import * from .conv_norm_act import * from .drop_path import * +from .format import * from .identity import * +from .patch_dropout import * from .pooling import * from .selective_kernel import * from .squeeze_excite import * diff --git a/mindcv/models/layers/format.py b/mindcv/models/layers/format.py new file mode 100644 index 000000000..058a74517 --- /dev/null +++ b/mindcv/models/layers/format.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Union + +import mindspore + + +class Format(str, Enum): + NCHW = 'NCHW' + NHWC = 'NHWC' + NCL = 'NCL' + NLC = 'NLC' + + +FormatT = Union[str, Format] + + +def nchw_to(x: mindspore.Tensor, fmt: Format): + if fmt == Format.NHWC: + x = x.permute(0, 2, 3, 1) + elif fmt == Format.NLC: + x = x.flatten(start_dim=2).transpose((0, 2, 1)) + elif fmt == Format.NCL: + x = x.flatten(start_dim=2) + return x + + +def nhwc_to(x: mindspore.Tensor, fmt: Format): + if fmt == Format.NCHW: + x = x.permute(0, 3, 1, 2) + elif fmt == Format.NLC: + x = x.flatten(start_dim=1, end_dim=2) + elif fmt == Format.NCL: + x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1)) + return x diff --git a/mindcv/models/layers/patch_dropout.py b/mindcv/models/layers/patch_dropout.py new file mode 100644 index 000000000..ad854dbfc --- /dev/null +++ b/mindcv/models/layers/patch_dropout.py @@ -0,0 +1,54 @@ +import numpy as np + +import mindspore as ms +from mindspore import nn, ops + + +class PatchDropout(nn.Cell): + """ + https://arxiv.org/abs/2212.00794 + """ + def __init__( + self, + prob: float = 0.5, + num_prefix_tokens: int = 1, + ordered: bool = False, + return_indices: bool = False, + ): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) + self.ordered = ordered + self.return_indices = return_indices + self.sort = ops.Sort() + + def forward(self, x): + if not self.training or self.prob == 0.: + if self.return_indices: + return x, None + return x + + if self.num_prefix_tokens: + prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] + else: + prefix_tokens = None + + B = x.shape[0] + L = x.shape[1] + num_keep = max(1, int(L * (1. - self.prob))) + _, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32)) + keep_indices = indices[:, :num_keep] + if self.ordered: + # NOTE does not need to maintain patch order in typical transformer use, + # but possibly useful for debug / visualization + keep_indices, _ = self.sort(keep_indices) + keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2])) + x = ops.gather_elements(x, dim=1, index=keep_indices) + + if prefix_tokens is not None: + x = ops.concat((prefix_tokens, x), axis=1) + + if self.return_indices: + return x, keep_indices + return x diff --git a/mindcv/models/layers/patch_embed.py b/mindcv/models/layers/patch_embed.py index d2ca684f1..661e07890 100644 --- a/mindcv/models/layers/patch_embed.py +++ b/mindcv/models/layers/patch_embed.py @@ -4,6 +4,7 @@ from mindspore import Tensor, nn, ops +from .format import Format, nchw_to from .helpers import to_2tuple @@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell): embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Cell, optional): Normalization layer. Default: None """ + output_fmt: Format def __init__( self, - image_size: int = 224, + image_size: Optional[int] = 224, patch_size: int = 4, in_chans: int = 3, embed_dim: int = 96, norm_layer: Optional[nn.Cell] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, ) -> None: super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]] - self.image_size = image_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans + self.patch_size = to_2tuple(patch_size) + if image_size is not None: + self.image_size = to_2tuple(image_size) + self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)]) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + else: + self.image_size = None + self.patches_resolution = None + self.num_patches = None + + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + self.flatten = flatten + self.output_fmt = Format.NCHW + + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad self.embed_dim = embed_dim self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - pad_mode='pad', has_bias=True, weight_init="TruncatedNormal") + pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal") if norm_layer is not None: if isinstance(embed_dim, int): @@ -50,11 +67,29 @@ def __init__( def construct(self, x: Tensor) -> Tensor: """docstring""" - B = x.shape[0] - # FIXME look at relaxing size constraints - x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C - x = ops.Transpose()(x, (0, 2, 1)) + B, C, H, W = x.shape + if self.image_size is not None: + if self.strict_img_size: + if (H, W) != (self.image_size[0], self.image_size[1]): + raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]}," + f"{self.image_size[1]}).") + elif not self.dynamic_img_pad: + if H % self.patch_size[0] != 0: + raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).") + if W % self.patch_size[1] != 0: + raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).") + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = ops.pad(x, (0, pad_w, 0, pad_h)) + # FIXME look at relaxing size constraints + x = self.proj(x) + if self.flatten: + x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C + x = ops.Transpose()(x, (0, 2, 1)) + elif self.output_fmt != "NCHW": + x = nchw_to(x, self.output_fmt) if self.norm is not None: x = self.norm(x) return x diff --git a/mindcv/models/layers/pos_embed.py b/mindcv/models/layers/pos_embed.py index c570c5c1b..ba4548580 100644 --- a/mindcv/models/layers/pos_embed.py +++ b/mindcv/models/layers/pos_embed.py @@ -1,11 +1,53 @@ """positional embedding""" -from typing import Tuple +import math +from typing import List, Optional, Tuple import numpy as np import mindspore as ms from mindspore import Parameter, Tensor, nn, ops +from .compatibility import Interpolate + + +def resample_abs_pos_embed( + posemb, + new_size: List[int], + old_size: Optional[List[int]] = None, + num_prefix_tokens: int = 1, + interpolation: str = 'nearest', +): + # sort out sizes, assume square if old size not provided + num_pos_tokens = posemb.shape[1] + num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens + + if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: + return posemb + + if old_size is None: + hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) + old_size = hw, hw + + if num_prefix_tokens: + posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] + else: + posemb_prefix, posemb = None, posemb + + # do the interpolation + embed_dim = posemb.shape[-1] + orig_dtype = posemb.dtype + posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + interpolate = Interpolate(mode=interpolation, align_corners=True) + posemb = interpolate(posemb, size=new_size) + posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) + posemb = posemb.astype(orig_dtype) + + # add back extra (class, etc) prefix tokens + if posemb_prefix is not None: + posemb = ops.concatcat((posemb_prefix, posemb), axis=1) + + return posemb + class RelativePositionBiasWithCLS(nn.Cell): def __init__( diff --git a/mindcv/models/mae.py b/mindcv/models/mae.py index f50346958..4a5cf887e 100644 --- a/mindcv/models/mae.py +++ b/mindcv/models/mae.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional +from typing import Callable, Optional import numpy as np @@ -8,8 +8,10 @@ from mindspore.common.initializer import Normal, initializer from .helpers import load_pretrained +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed from .registry import register_model -from .vit_encoder import Block, VisionTransformerEncoder +from .vit import Block, VisionTransformer __all__ = [ "mae_b_16_224_pretrain", @@ -91,12 +93,12 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): return emb -class MAEForPretrain(VisionTransformerEncoder): +class MAEForPretrain(nn.Cell): def __init__( self, - img_size: int = 224, + image_size: int = 224, patch_size: int = 16, - in_chans: int = 3, + in_channels: int = 3, embed_dim: int = 1024, depth: int = 24, num_heads: int = 16, @@ -104,28 +106,33 @@ def __init__( decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = None, act_layer: nn.Cell = nn.GELU, norm_layer: nn.Cell = nn.LayerNorm, + mlp_layer: Callable = Mlp, norm_pix_loss: bool = True, mask_ratio: float = 0.75, - **kwargs + **kwargs, ): - super(MAEForPretrain, self).__init__( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - init_values=None, - act_layer=act_layer, - norm_layer=norm_layer, - use_abs_pos_emb=True, - use_rel_pos_bias=False, - use_shared_rel_pos_bias=False, - **kwargs - ) + super(MAEForPretrain, self).__init__() + self.patch_embed = PatchEmbed(image_size=image_size, patch_size=patch_size, + in_chans=in_channels, embed_dim=embed_dim) + self.num_patches = self.patch_embed.num_patches + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + Block( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, + ) for i in range(depth) + ]) + self.cls_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, embed_dim))) self.unmask_len = int(np.floor(self.num_patches * (1 - mask_ratio))) @@ -148,12 +155,14 @@ def __init__( self.decoder_blocks = nn.CellList([ Block( - dim=decoder_embed_dim, num_heads=decoder_num_heads, qkv_bias=True, - mlp_ratio=mlp_ratio, init_values=None, act_layer=act_layer, norm_layer=norm_layer, - ) for _ in range(decoder_depth) + dim=decoder_embed_dim, num_heads=decoder_num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, + ) for i in range(decoder_depth) ]) self.decoder_norm = norm_layer((decoder_embed_dim,)) - self.decoder_pred = nn.Dense(decoder_embed_dim, patch_size ** 2 * in_chans) + self.decoder_pred = nn.Dense(decoder_embed_dim, patch_size ** 2 * in_channels) self.sort = ops.Sort() @@ -178,7 +187,6 @@ def _init_weights(self): cell.beta.set_data( initializer('zeros', cell.beta.shape, cell.beta.dtype) ) - if name == "patch_embed.proj": cell.weight.set_data( initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) @@ -288,145 +296,86 @@ def construct(self, imgs, mask): loss = self.forward_loss(imgs, pred, mask) return loss + def get_num_layers(self): + return len(self.blocks) -class MAEForFinetune(VisionTransformerEncoder): - def __init__( - self, - img_size: int = 224, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - num_classes: int = 1000, - use_mean_pooling: bool = True, - **kwargs - ): - super(MAEForFinetune, self).__init__( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - pos_drop_rate=pos_drop_rate, - proj_drop_rate=proj_drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - init_values=None, - act_layer=act_layer, - norm_layer=norm_layer, - use_abs_pos_emb=True, - use_rel_pos_bias=False, - use_shared_rel_pos_bias=False, - **kwargs - ) - self.use_mean_pooling = use_mean_pooling - if self.use_mean_pooling: - self.fc_norm = norm_layer((embed_dim,)) - else: - self.norm = norm_layer((embed_dim,)) - self.head = nn.Dense(embed_dim, num_classes, weight_init='TruncatedNormal') - - self._init_weights() - self._fix_init_weights() - - def construct(self, x): - x = self.forward_features(x) - if self.use_mean_pooling: - x = x[:, 1:].mean(axis=1) - x = self.fc_norm(x) - else: - x = self.norm(x) - x = x[:, 0] - x = self.head(x) - return x + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} @register_model -def mae_b_16_224_pretrain(pretrained=False, **kwargs): +def mae_b_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_b_16_224_pretrain"] model = MAEForPretrain( - patch_size=16, embed_dim=768, depth=12, num_heads=12, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - act_layer=partial(nn.GELU, approximate=False), + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs ) if pretrained: - pass + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_l_16_224_pretrain(pretrained=False, **kwargs): +def mae_l_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_l_16_224_pretrain"] model = MAEForPretrain( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - act_layer=partial(nn.GELU, approximate=False), + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs ) if pretrained: - pass + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_h_16_224_pretrain(pretrained=False, **kwargs): +def mae_h_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_h_16_224_pretrain"] model = MAEForPretrain( - patch_size=16, embed_dim=1280, depth=32, num_heads=16, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - act_layer=partial(nn.GELU, approximate=False), + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs ) if pretrained: - pass + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_b_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): +def mae_b_16_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): default_cfg = default_cfgs["mae_b_16_224_finetune"] - model = MAEForFinetune( - patch_size=16, in_chans=in_chans, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs ) if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_l_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): +def mae_l_16_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): default_cfg = default_cfgs["mae_l_16_224_finetune"] - model = MAEForFinetune( - patch_size=16, in_chans=in_chans, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs ) if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_h_14_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): +def mae_h_14_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): default_cfg = default_cfgs["mae_h_14_224_finetune"] - model = MAEForFinetune( - patch_size=14, in_chans=in_chans, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + model = VisionTransformer( + image_size=224, patch_size=14, in_channels=in_channels, embed_dim=1280, depth=32, num_heads=16, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs ) if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index b26712f91..c792b5708 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -1,23 +1,23 @@ """ViT""" -import math -from typing import List, Optional, Union +from typing import Callable, Optional import numpy as np +import mindspore as ms from mindspore import Parameter, Tensor, nn, ops -from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.common.initializer import HeUniform, TruncatedNormal, initializer -from .helpers import ConfigDict, load_pretrained +from .helpers import load_pretrained from .layers.compatibility import Dropout from .layers.drop_path import DropPath from .layers.mlp import Mlp +from .layers.patch_dropout import PatchDropout from .layers.patch_embed import PatchEmbed -from .layers.pos_embed import RelativePositionBiasWithCLS +from .layers.pos_embed import resample_abs_pos_embed from .registry import register_model __all__ = [ - "VisionTransformerEncoder", - "ViT", + "VisionTransformer", "vit_b_16_224", "vit_b_16_384", "vit_l_16_224", # with pretrained weights @@ -65,10 +65,9 @@ class Attention(nn.Cell): dim (int): The dimension of input features. num_heads (int): The number of attention heads. Default: 8. qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. - qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. + qk_norm (bool): Specifies whether to do normalization to q and k. attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0. - attn_head_dim (int): The user-defined dimension of attention head features. Default: None. Returns: Tensor, output tensor. @@ -81,27 +80,23 @@ def __init__( dim: int, num_heads: int = 8, qkv_bias: bool = True, - qk_scale: Optional[float] = None, + qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, - attn_head_dim: Optional[int] = None, + norm_layer: nn.Cell = nn.LayerNorm, ): super(Attention, self).__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * num_heads - - if qk_scale: - self.scale = Tensor(qk_scale) - else: - self.scale = Tensor(head_dim ** -0.5) + self.head_dim = dim // num_heads + self.scale = Tensor(self.head_dim ** -0.5) - self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) + self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + self.q_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() + self.k_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() self.attn_drop = Dropout(attn_drop) - self.proj = nn.Dense(all_head_dim, dim) + self.proj = nn.Dense(dim, dim) self.proj_drop = Dropout(proj_drop) self.mul = ops.Mul() @@ -110,22 +105,20 @@ def __init__( self.unstack = ops.Unstack(axis=0) self.attn_matmul_v = ops.BatchMatMul() self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - self.softmax = nn.Softmax(axis=-1) - def construct(self, x, rel_pos_bias=None): + def construct(self, x): b, n, c = x.shape qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) + qkv = self.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim)) qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) q, k, v = self.unstack(qkv) + q, k = self.q_norm(q), self.k_norm(k) attn = self.q_matmul_k(q, k) attn = self.mul(attn, self.scale) - if rel_pos_bias is not None: - attn = attn + rel_pos_bias - - attn = self.softmax(attn) + attn = attn.astype(ms.float32) + attn = ops.softmax(attn, axis=-1) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) @@ -164,7 +157,7 @@ def construct(self, x): return self.gamma * x -class TransformerBlock(nn.Cell): +class Block(nn.Cell): """ Transformer block implementation. @@ -172,10 +165,8 @@ class TransformerBlock(nn.Cell): dim (int): The dimension of embedding. num_heads (int): The number of attention heads. qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. - qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0. - attn_head_dim (int): The user-defined dimension of attention head features. Default: None. mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer. drop_path (float): The drop rate for drop path. Default: 0.0. act_layer (nn.Cell): Activation function which will be stacked on top of the @@ -193,41 +184,48 @@ def __init__( self, dim: int, num_heads: int = 8, + mlp_ratio: float = 4., qkv_bias: bool = False, - qk_scale: Optional[float] = None, - attn_drop: float = 0., + qk_norm: bool = False, proj_drop: float = 0., - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - drop_path: float = 0., + attn_drop: float = 0., init_values: Optional[float] = None, + drop_path: float = 0., act_layer: nn.Cell = nn.GELU, norm_layer: nn.Cell = nn.LayerNorm, + mlp_layer: Callable = Mlp, ): - super(TransformerBlock, self).__init__() + super(Block, self).__init__() self.norm1 = norm_layer((dim,)) self.attn = Attention( - dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, ) self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer((dim,)) - self.mlp = Mlp( - in_features=dim, hidden_features=int(dim * mlp_ratio), - act_layer=act_layer, drop=proj_drop + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop ) self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def construct(self, x, rel_pos_bias=None): - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) + def construct(self, x): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x -class VisionTransformerEncoder(nn.Cell): +class VisionTransformer(nn.Cell): ''' ViT encoder, which returns the feature encoded by transformer encoder. ''' @@ -236,62 +234,93 @@ def __init__( image_size: int = 224, patch_size: int = 16, in_channels: int = 3, + global_pool: str = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, - attn_head_dim: Optional[int] = None, mlp_ratio: float = 4., qkv_bias: bool = True, - qk_scale: Optional[float] = None, + qk_norm: bool = False, + drop_rate: float = 0., pos_drop_rate: float = 0., + patch_drop_rate: float = 0., proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - init_values: Optional[float] = 0.1, + weight_init: bool = True, + init_values: Optional[float] = None, + no_embed_class: bool = False, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, act_layer: nn.Cell = nn.GELU, + embed_layer: Callable = PatchEmbed, norm_layer: nn.Cell = nn.LayerNorm, - use_rel_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = True, - **kwargs + mlp_layer: Callable = Mlp, + class_token: bool = True, + block_fn: Callable = Block, + num_classes: int = 1000, ): - super(VisionTransformerEncoder, self).__init__() - self.embed_dim = embed_dim - self.patch_embed = PatchEmbed(image_size=image_size, patch_size=patch_size, - in_chans=in_channels, embed_dim=embed_dim) - self.num_patches = self.patch_embed.num_patches - - self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) + super(VisionTransformer, self).__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + + self.global_pool = global_pool + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + self.dynamic_img_size = dynamic_img_size + self.dynamic_img_pad = dynamic_img_pad + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + elif dynamic_img_pad: + embed_args.update(dict(output_fmt='NHWC')) + + self.patch_embed = embed_layer( + image_size=image_size, + patch_size=patch_size, + in_chans=in_channels, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches - self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), - (1, self.num_patches + 1, embed_dim))) if not use_rel_pos_emb else None + self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), (1, embed_len, embed_dim))) self.pos_drop = Dropout(pos_drop_rate) - - if use_shared_rel_pos_bias: - self.rel_pos_bias = RelativePositionBiasWithCLS( - window_size=self.patch_embed.patches_resolution, - num_heads=num_heads, - ) - elif use_rel_pos_bias: - self.rel_pos_bias = nn.CellList([ - RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, - num_heads=num_heads) for _ in range(depth) - ]) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) else: - self.rel_pos_bias = None + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] self.blocks = nn.CellList([ - TransformerBlock( - dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, + block_fn( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, - act_layer=act_layer, norm_layer=norm_layer + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, ) for i in range(depth) ]) - self._init_weights() - self._fix_init_weights() + self.norm = norm_layer((embed_dim,)) if not use_fc_norm else nn.Identity() + self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else nn.Identity() + self.head_drop = Dropout(drop_rate) + self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init: + self._init_weights() def get_num_layers(self): return len(self.blocks) @@ -318,474 +347,155 @@ def _init_weights(self): ) elif isinstance(cell, nn.Conv2d): cell.weight.set_data( - initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + initializer(HeUniform(), cell.weight.shape, cell.weight.dtype) ) if cell.bias is not None: cell.bias.set_data( - initializer('zeros', cell.bias.shape, cell.bias.dtype) + initializer("zeros", cell.bias.shape, cell.bias.dtype) ) - def _fix_init_weights(self): - for i, block in enumerate(self.blocks): - block.attn.proj.weight.set_data( - ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) - ) - block.mlp.fc2.weight.set_data( - ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) + def _pos_embed(self, x): + if self.dynamic_img_size or self.dynamic_img_pad: + # bhwc format + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) + x = ops.reshape(x, (B, -1, C)) + else: + pos_embed = self.pos_embed + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if self.cls_token is not None: + cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) + x = x + pos_embed + + return self.pos_drop(x) def forward_features(self, x): x = self.patch_embed(x) - bsz = x.shape[0] - - cls_tokens = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) - cls_tokens = cls_tokens.astype(x.dtype) - x = ops.concat((cls_tokens, x), axis=1) - - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.pos_drop(x) - - if isinstance(self.rel_pos_bias, nn.CellList): - for i, blk in enumerate(self.blocks): - rel_pos_bias = self.rel_pos_bias[i]() - x = blk(x, rel_pos_bias) - else: - rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None - for blk in self.blocks: - x = blk(x, rel_pos_bias) - + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) return x - def construct(self, x): - x = self.forward_features(x) + def forward_head(self, x): + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(axis=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + x = self.head_drop(x) + x = self.head(x) return x - -class DenseHead(nn.Cell): - """ - LinearClsHead architecture. - - Args: - input_channel (int): The number of input channel. - num_classes (int): Number of classes. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (Union[str, Cell, Primitive]): activate function applied to the output. Eg. `ReLU`. Default: None. - keep_prob (float): Dropout keeping rate, between [0, 1]. E.g. rate=0.9, means dropping out 10% of input. - Default: 1.0. - - Returns: - Tensor, output tensor. - """ - - def __init__( - self, - input_channel: int, - num_classes: int, - has_bias: bool = True, - activation: Optional[Union[str, nn.Cell]] = None, - keep_prob: float = 1.0, - ) -> None: - super().__init__() - - self.dropout = Dropout(p=1.0-keep_prob) - self.classifier = nn.Dense(input_channel, num_classes, has_bias=has_bias, activation=activation) - def construct(self, x): - if self.training: - x = self.dropout(x) - x = self.classifier(x) + x = self.forward_features(x) + x = self.forward_head(x) return x -class MultilayerDenseHead(nn.Cell): - """ - MultilayerDenseHead architecture. - - Args: - input_channel (int): The number of input channel. - num_classes (int): Number of classes. - mid_channel (list): Number of channels in the hidden fc layers. - keep_prob (list): Dropout keeping rate, between [0, 1]. E.g. rate=0.9, means dropping out 10% of - input. - activation (list): activate function applied to the output. Eg. `ReLU`. - - Returns: - Tensor, output tensor. - """ - - def __init__( - self, - input_channel: int, - num_classes: int, - mid_channel: List[int], - keep_prob: List[float], - activation: List[Optional[Union[str, nn.Cell]]], - ) -> None: - super().__init__() - mid_channel.append(num_classes) - assert len(mid_channel) == len(activation) == len(keep_prob), "The length of the list should be the same." - - length = len(activation) - head = [] - - for i in range(length): - linear = DenseHead( - input_channel, - mid_channel[i], - activation=activation[i], - keep_prob=keep_prob[i], - ) - head.append(linear) - input_channel = mid_channel[i] - - self.classifier = nn.SequentialCell(head) - - def construct(self, x): - x = self.classifier(x) - - return x - +@register_model +def vit_b_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_16_224"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -class ViT(VisionTransformerEncoder): - def __init__( - self, - image_size: int = 224, - patch_size: int = 16, - in_channels: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - init_values: Optional[float] = 0.1, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - use_rel_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = True, - use_cls: bool = True, - representation_size: Optional[int] = None, - num_classes: int = 1000, - **kwargs - ): - super(ViT, self).__init__( - image_size=image_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - pos_drop_rate=pos_drop_rate, - proj_drop_rate=proj_drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - init_values=init_values, - act_layer=act_layer, - norm_layer=norm_layer, - use_rel_pos_emb=use_rel_pos_emb, - use_rel_pos_bias=use_rel_pos_bias, - use_shared_rel_pos_bias=use_shared_rel_pos_bias, - **kwargs - ) - self.use_cls = use_cls - self.norm = norm_layer((embed_dim,)) - - if representation_size: - self.head = MultilayerDenseHead( - input_channel=embed_dim, - num_classes=num_classes, - mid_channel=[representation_size], - activation=["tanh", None], - keep_prob=[1.0, 1.0], - ) - else: - self.head = DenseHead(input_channel=embed_dim, num_classes=num_classes) + return model - def construct(self, x): - x = self.forward_features(x) - x = self.norm(x) - if self.use_cls: - x = x[:, 0] - else: - x = x[:, 1:].mean(axis=1) +@register_model +def vit_b_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_16_384"] + model = VisionTransformer( + image_size=384, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - x = self.head(x) - return x + return model -def vit( - image_size: int = 224, - patch_size: int = 16, - in_channels: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - init_values: Optional[float] = None, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - use_rel_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = False, - use_cls: bool = True, - representation_size: Optional[int] = None, - num_classes: int = 1000, - pretrained: bool = False, - url_cfg: dict = None, -) -> ViT: - - """Vision Transformer architecture.""" - - model = ViT( - image_size=image_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - pos_drop_rate=pos_drop_rate, - proj_drop_rate=proj_drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - init_values=init_values, - act_layer=act_layer, - norm_layer=norm_layer, - use_rel_pos_emb=use_rel_pos_emb, - use_rel_pos_bias=use_rel_pos_bias, - use_shared_rel_pos_bias=use_shared_rel_pos_bias, - use_cls=use_cls, - representation_size=representation_size, - num_classes=num_classes +@register_model +def vit_l_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_16_224"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs ) - if pretrained: - # Download the pre-trained checkpoint file from url, and load ckpt file. - load_pretrained(model, url_cfg, num_classes=num_classes, in_channels=in_channels) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def vit_b_16_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_16_224"] - - return vit(**config) - +def vit_l_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_16_384"] + model = VisionTransformer( + image_size=384, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_b_16_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 384 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_16_384"] - - return vit(**config) + return model @register_model -def vit_l_16_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 1024 - config.depth = 24 - config.num_heads = 16 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 1024 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_l_16_224"] - - return vit(**config) - +def vit_b_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_32_224"] + model = VisionTransformer( + image_size=224, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_l_16_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 384 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 1024 - config.depth = 24 - config.num_heads = 16 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 1024 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_l_16_384"] - - return vit(**config) + return model @register_model -def vit_b_32_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 32 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_32_224"] - - return vit(**config) - +def vit_b_32_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_32_384"] + model = VisionTransformer( + image_size=384, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_b_32_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 384 - config.patch_size = 32 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_32_384"] - - return vit(**config) + return model @register_model -def vit_l_32_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 32 - config.in_channels = in_channels - config.embed_dim = 1024 - config.depth = 24 - config.num_heads = 16 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 1024 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_l_32_224"] - - return vit(**config) +def vit_l_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_32_224"] + model = VisionTransformer( + image_size=224, patch_size=32, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model diff --git a/mindcv/models/vit_encoder.py b/mindcv/models/vit_encoder.py deleted file mode 100644 index 010065273..000000000 --- a/mindcv/models/vit_encoder.py +++ /dev/null @@ -1,296 +0,0 @@ -import math -from typing import Optional, Tuple - -import numpy as np - -import mindspore as ms -from mindspore import Parameter, Tensor, nn, ops -from mindspore.common.initializer import TruncatedNormal, initializer - -from .layers.drop_path import DropPath -from .layers.mlp import Mlp -from .layers.patch_embed import PatchEmbed - - -class RelativePositionBiasWithCLS(nn.Cell): - def __init__( - self, - window_size: Tuple[int], - num_heads: int - ): - super(RelativePositionBiasWithCLS, self).__init__() - self.window_size = window_size - self.num_tokens = window_size[0] * window_size[1] - - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 - # 3: cls to token, token to cls, cls to cls - self.relative_position_bias_table = Parameter( - Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16) - ) - coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) - coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) - coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] - - relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] - relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] - relative_coords[:, :, 0] += window_size[0] - 1 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[0] - 1 - - relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), - dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] - relative_position_index[1:, 1:] = relative_coords.sum(-1) - relative_position_index[0, 0:] = num_relative_distance - 3 - relative_position_index[0:, 0] = num_relative_distance - 2 - relative_position_index[0, 0] = num_relative_distance - 1 - relative_position_index = Tensor(relative_position_index.reshape(-1)) - - self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) - self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) - - def construct(self): - out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) - out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) - out = ops.transpose(out, (2, 0, 1)) - out = ops.expand_dims(out, 0) - return out - - -class Attention(nn.Cell): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_scale: Optional[float] = None, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - attn_head_dim: Optional[int] = None, - ): - super(Attention, self).__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * num_heads - - if qk_scale: - self.scale = Tensor(qk_scale) - else: - self.scale = Tensor(head_dim ** -0.5) - - self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) - - self.attn_drop = nn.Dropout(1 - attn_drop) - self.proj = nn.Dense(all_head_dim, dim) - self.proj_drop = nn.Dropout(1 - proj_drop) - - self.mul = ops.Mul() - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() - self.unstack = ops.Unstack(axis=0) - self.attn_matmul_v = ops.BatchMatMul() - self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - - def construct(self, x, rel_pos_bias=None): - b, n, c = x.shape - qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) - qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) - q, k, v = self.unstack(qkv) - - attn = self.q_matmul_k(q, k) - attn = self.mul(attn, self.scale) - - if rel_pos_bias is not None: - attn = attn + rel_pos_bias - - attn = attn.astype(ms.float32) - attn = ops.softmax(attn, axis=-1) - attn = self.attn_drop(attn) - - out = self.attn_matmul_v(attn, v) - out = self.transpose(out, (0, 2, 1, 3)) - out = self.reshape(out, (b, n, c)) - out = self.proj(out) - out = self.proj_drop(out) - - return out - - -class LayerScale(nn.Cell): - def __init__( - self, - dim: int, - init_values: float = 1e-5 - ): - super(LayerScale, self).__init__() - self.gamma = Parameter(initializer(init_values, dim)) - - def construct(self, x): - return self.gamma * x - - -class Block(nn.Cell): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_scale: Optional[float] = None, - attn_drop: float = 0., - proj_drop: float = 0., - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - drop_path: float = 0., - init_values: Optional[float] = None, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - ): - super(Block, self).__init__() - self.norm1 = norm_layer((dim,)) - self.attn = Attention( - dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, - ) - self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer((dim,)) - self.mlp = Mlp( - in_features=dim, hidden_features=int(dim * mlp_ratio), - act_layer=act_layer, drop=proj_drop - ) - self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def construct(self, x, rel_pos_bias=None): - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -class VisionTransformerEncoder(nn.Cell): - def __init__( - self, - img_size: int = 224, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - init_values: Optional[float] = 0.1, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - use_abs_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = True, - **kwargs - ): - super(VisionTransformerEncoder, self).__init__() - self.embed_dim = embed_dim - self.patch_embed = PatchEmbed(image_size=img_size, patch_size=patch_size, - in_chans=in_chans, embed_dim=embed_dim) - self.num_patches = self.patch_embed.num_patches - - self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) - - self.pos_embed = Parameter( - initializer(TruncatedNormal(0.02), (1, self.num_patches + 1, embed_dim))) if use_abs_pos_emb else None - self.pos_drop = nn.Dropout(1 - pos_drop_rate) - - if use_shared_rel_pos_bias: - self.rel_pos_bias = RelativePositionBiasWithCLS( - window_size=self.patch_embed.patches_resolution, num_heads=num_heads) - elif use_rel_pos_bias: - self.rel_pos_bias = nn.CellList([ - RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, - num_heads=num_heads) for _ in range(depth) - ]) - else: - self.rel_pos_bias = None - - dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] - self.blocks = nn.CellList([ - Block( - dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, - act_layer=act_layer, norm_layer=norm_layer - ) for i in range(depth) - ]) - - def get_num_layers(self): - return len(self.blocks) - - def no_weight_decay(self): - return {'pos_embed', 'cls_token'} - - def _init_weights(self): - for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.weight.set_data( - initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) - ) - if cell.bias is not None: - cell.bias.set_data( - initializer('zeros', cell.bias.shape, cell.bias.dtype) - ) - elif isinstance(cell, nn.LayerNorm): - cell.gamma.set_data( - initializer('ones', cell.gamma.shape, cell.gamma.dtype) - ) - cell.beta.set_data( - initializer('zeros', cell.beta.shape, cell.beta.dtype) - ) - elif isinstance(cell, nn.Conv2d): - cell.weight.set_data( - initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) - ) - if cell.bias is not None: - cell.bias.set_data( - initializer('zeros', cell.bias.shape, cell.bias.dtype) - ) - - def _fix_init_weights(self): - for i, block in enumerate(self.blocks): - block.attn.proj.weight.set_data( - ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) - ) - block.mlp.fc2.weight.set_data( - ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) - ) - - def forward_features(self, x): - x = self.patch_embed(x) - bsz = x.shape[0] - - cls_token = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) - cls_token = cls_token.astype(x.dtype) - x = ops.concat((cls_token, x), axis=1) - - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.pos_drop(x) - - if isinstance(self.rel_pos_bias, nn.CellList): - for i, blk in enumerate(self.blocks): - rel_pos_bias = self.rel_pos_bias[i]() - x = blk(x, rel_pos_bias) - else: - rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None - for blk in self.blocks: - x = blk(x, rel_pos_bias) - - return x - - def construct(self, x): - return self.forward_features(x)