From 34d835d622b0c18a4873023ca03def07b2d7cbcd Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Thu, 31 Aug 2023 23:47:04 +0800 Subject: [PATCH 1/5] Refractor ViT to support relative positional embedding and layer scale; Checkpoint updated --- configs/vit/README.md | 6 +- mindcv/models/layers/attention.py | 87 +++ mindcv/models/layers/pos_embed.py | 49 ++ mindcv/models/vit.py | 939 ++++++++++++------------------ 4 files changed, 502 insertions(+), 579 deletions(-) create mode 100644 mindcv/models/layers/attention.py create mode 100644 mindcv/models/layers/pos_embed.py diff --git a/configs/vit/README.md b/configs/vit/README.md index e5d743fbb..f424e57e4 100644 --- a/configs/vit/README.md +++ b/configs/vit/README.md @@ -36,9 +36,9 @@ Our reproduced model performance on ImageNet-1K is reported as follows. | Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | |--------------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------| -| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt) | -| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt) | -| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt) | +| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt) | +| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt) | +| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt) | diff --git a/mindcv/models/layers/attention.py b/mindcv/models/layers/attention.py new file mode 100644 index 000000000..5ac033e6b --- /dev/null +++ b/mindcv/models/layers/attention.py @@ -0,0 +1,87 @@ +"""attention layers""" +#TODO: add Flash Attention + +from typing import Optional, Tuple, Union, List +from mindspore import nn, ops, Tensor, Parameter +from .compatibility import Dropout + +class Attention(nn.Cell): + """ + Attention layer implementation, Rearrange Input -> B x N x hidden size. + + Args: + dim (int): The dimension of input features. + num_heads (int): The number of attention heads. Default: 8. + qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. + qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. + attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. + proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0. + attn_head_dim (int): The user-defined dimension of attention head features. Default: None. + + Returns: + Tensor, output tensor. + + Examples: + >>> ops = Attention(768, 12) + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_head_dim: Optional[int] = None, + ): + super(Attention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * num_heads + + if qk_scale: + self.scale = Tensor(qk_scale) + else: + self.scale = Tensor(head_dim ** -0.5) + + self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) + + self.attn_drop = Dropout(attn_drop) + self.proj = nn.Dense(all_head_dim, dim) + self.proj_drop = Dropout(proj_drop) + + self.mul = ops.Mul() + self.reshape = ops.Reshape() + self.transpose = ops.Transpose() + self.unstack = ops.Unstack(axis=0) + self.attn_matmul_v = ops.BatchMatMul() + self.q_matmul_k = ops.BatchMatMul(transpose_b=True) + self.softmax = nn.Softmax(axis=-1) + + def construct(self, x, rel_pos_bias=None): + b, n, c = x.shape + qkv = self.qkv(x) + qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) + qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) + q, k, v = self.unstack(qkv) + + attn = self.q_matmul_k(q, k) + attn = self.mul(attn, self.scale) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + out = self.attn_matmul_v(attn, v) + out = self.transpose(out, (0, 2, 1, 3)) + out = self.reshape(out, (b, n, c)) + out = self.proj(out) + out = self.proj_drop(out) + + return out + + diff --git a/mindcv/models/layers/pos_embed.py b/mindcv/models/layers/pos_embed.py new file mode 100644 index 000000000..66052947f --- /dev/null +++ b/mindcv/models/layers/pos_embed.py @@ -0,0 +1,49 @@ +"""positional embedding""" +from mindspore import nn, ops, Tensor, Parameter +from typing import Optional, Tuple, Union, List + + +class RelativePositionBiasWithCLS(nn.Cell): + def __init__( + self, + window_size: Tuple[int], + num_heads: int + ): + super(RelativePositionBiasWithCLS, self).__init__() + self.window_size = window_size + self.num_tokens = window_size[0] * window_size[1] + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # 3: cls to token, token to cls, cls to cls + self.relative_position_bias_table = Parameter( + Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16) + ) + coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) + coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) + coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] + + relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] + relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[0] - 1 + + relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), + dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index = Tensor(relative_position_index.reshape(-1)) + + self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) + self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) + + def construct(self): + out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) + out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) + out = ops.transpose(out, (2, 0, 1)) + out = ops.expand_dims(out, 0) + return out + + diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index ac2c4c4c7..7fac9e2c7 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -1,28 +1,31 @@ """ViT""" +import math from typing import List, Optional, Union import numpy as np -import mindspore as ms -from mindspore import Tensor, nn -from mindspore import ops -from mindspore import ops as P -from mindspore.common.initializer import Normal, initializer -from mindspore.common.parameter import Parameter +from mindspore import Parameter, nn, ops +from mindspore.common.initializer import TruncatedNormal, initializer from .helpers import ConfigDict, load_pretrained +from .layers.attention import Attention from .layers.compatibility import Dropout +from .layers.drop_path import DropPath +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed +from .layers.pos_embed import RelativePositionBiasWithCLS from .registry import register_model __all__ = [ + "VisionTransformerEncoder", "ViT", "vit_b_16_224", "vit_b_16_384", - "vit_l_16_224", # train + "vit_l_16_224", # with pretrained weights "vit_l_16_384", - "vit_b_32_224", # train + "vit_b_32_224", # with pretrained weights "vit_b_32_384", - "vit_l_32_224", # train + "vit_l_32_224", # with pretrained weights ] @@ -32,7 +35,7 @@ def _cfg(url="", **kwargs): "num_classes": 1000, "input_size": (3, 224, 224), "first_conv": "patch_embed.proj", - "classifier": "classifier", + "classifier": "head.classifier", **kwargs, } @@ -42,301 +45,241 @@ def _cfg(url="", **kwargs): "vit_b_16_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt"), + "vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt"), "vit_l_16_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt"), + "vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt"), "vit_b_32_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt"), + "vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt"), } -class PatchEmbedding(nn.Cell): +class LayerScale(nn.Cell): """ - Path embedding layer for ViT. First rearrange b c (h p) (w p) -> b (h w) (p p c). + Layer scale, help ViT improve the training dynamic, allowing for the training + of deeper high-capacity image transformers that benefit from depth Args: - image_size (int): Input image size. Default: 224. - patch_size (int): Patch size of image. Default: 16. - embed_dim (int): The dimension of embedding. Default: 768. - input_channels (int): The number of input channel. Default: 3. + dim (int): The output dimension of attnetion layer or mlp layer. + init_values (float): The scale factor. Default: 1e-5. Returns: Tensor, output tensor. Examples: - >>> ops = PathEmbedding(224, 16, 768, 3) + >>> ops = LayerScale(768, 0.01) """ - - MIN_NUM_PATCHES = 4 - def __init__( self, - image_size: int = 224, - patch_size: int = 16, - embed_dim: int = 768, - input_channels: int = 3, + dim: int, + init_values: float = 1e-5 ): - super().__init__() - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = (image_size // patch_size) ** 2 - self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True) - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() + super(LayerScale, self).__init__() + self.gamma = Parameter(initializer(init_values, dim)) def construct(self, x): - """Path Embedding construct.""" - x = self.conv(x) - b, c, h, w = x.shape - x = self.reshape(x, (b, c, h * w)) - x = self.transpose(x, (0, 2, 1)) - - return x + return self.gamma * x -class Attention(nn.Cell): +class TransformerBlock(nn.Cell): """ - Attention layer implementation, Rearrange Input -> B x N x hidden size. + Transformer block implementation. Args: - dim (int): The dimension of input features. - num_heads (int): The number of attention heads. Default: 8. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention. Default: 1.0. + dim (int): The dimension of embedding. + num_heads (int): The number of attention heads. + qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. + qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. + attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. + proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0. + attn_head_dim (int): The user-defined dimension of attention head features. Default: None. + mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer. + drop_path (float): The drop rate for drop path. Default: 0.0. + act_layer (nn.Cell): Activation function which will be stacked on top of the + normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. + norm_layer (nn.Cell): Norm layer that will be stacked on top of the convolution + layer. Default: nn.LayerNorm. Returns: Tensor, output tensor. Examples: - >>> ops = Attention(768, 12) + >>> ops = TransformerEncoder(768, 12, 12, 3072) """ - def __init__( self, dim: int, num_heads: int = 8, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0., + proj_drop: float = 0., + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + drop_path: float = 0., + init_values: Optional[float] = None, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = Tensor(head_dim**-0.5) - - self.qkv = nn.Dense(dim, dim * 3) - self.attn_drop = Dropout(p=1.0-attention_keep_prob) - self.out = nn.Dense(dim, dim) - self.out_drop = Dropout(p=1.0-keep_prob) - - self.mul = ops.Mul() - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() - self.unstack = ops.Unstack(axis=0) - self.attn_matmul_v = ops.BatchMatMul() - self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - self.softmax = nn.Softmax(axis=-1) - - def construct(self, x): - """Attention construct.""" - b, n, c = x.shape - qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) - qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) - q, k, v = self.unstack(qkv) - - attn = self.q_matmul_k(q, k) - attn = self.mul(attn, self.scale) - attn = self.softmax(attn) - attn = self.attn_drop(attn) - - out = self.attn_matmul_v(attn, v) - out = self.transpose(out, (0, 2, 1, 3)) - out = self.reshape(out, (b, n, c)) - out = self.out(out) - out = self.out_drop(out) - - return out - - -class FeedForward(nn.Cell): - """ - Feed Forward layer implementation. + super(TransformerBlock, self).__init__() + self.norm1 = norm_layer((dim,)) + self.attn = Attention( + dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, + ) + self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - Args: - in_features (int): The dimension of input features. - hidden_features (int): The dimension of hidden features. Default: None. - out_features (int): The dimension of output features. Default: None - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. + self.norm2 = norm_layer((dim,)) + self.mlp = Mlp( + in_features=dim, hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, drop=proj_drop + ) + self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - Returns: - Tensor, output tensor. + def construct(self, x, rel_pos_bias=None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x - Examples: - >>> ops = FeedForward(768, 3072) - """ +class VisionTransformerEncoder(nn.Cell): + ''' + ViT encoder, which returns the feature encoded by transformer encoder. + ''' def __init__( self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - activation: nn.Cell = nn.GELU, - keep_prob: float = 1.0, + image_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = 0.1, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_rel_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + **kwargs ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.dense1 = nn.Dense(in_features, hidden_features) - self.activation = activation() - self.dense2 = nn.Dense(hidden_features, out_features) - self.dropout = Dropout(p=1.0-keep_prob) - - def construct(self, x): - """Feed Forward construct.""" - x = self.dense1(x) - x = self.activation(x) - x = self.dropout(x) - x = self.dense2(x) - x = self.dropout(x) - - return x - - -class ResidualCell(nn.Cell): - """ - Cell which implements Residual function: - - $$output = x + f(x)$$ - - Args: - cell (Cell): Cell needed to add residual block. - - Returns: - Tensor, output tensor. - - Examples: - >>> ops = ResidualCell(nn.Dense(3,4)) - """ - - def __init__(self, cell): - super().__init__() - self.cell = cell - - def construct(self, x): - """ResidualCell construct.""" - return self.cell(x) + x + super(VisionTransformerEncoder, self).__init__() + self.embed_dim = embed_dim + self.patch_embed = PatchEmbed(image_size=image_size, patch_size=patch_size, + in_chans=in_channels, embed_dim=embed_dim) + self.num_patches = self.patch_embed.num_patches + + self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) + + self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), + (1, self.num_patches + 1, embed_dim))) if not use_rel_pos_emb else None + self.pos_drop = Dropout(pos_drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBiasWithCLS( + window_size=self.patch_embed.patches_resolution, + num_heads=num_heads, + ) + elif use_rel_pos_bias: + self.rel_pos_bias = nn.CellList([ + RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, + num_heads=num_heads) for _ in range(depth) + ]) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + TransformerBlock( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer + ) for i in range(depth) + ]) + + self._init_weights() + self._fix_init_weights() + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _init_weights(self): + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) + ) + elif isinstance(cell, nn.Conv2d): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + + def _fix_init_weights(self): + for i, block in enumerate(self.blocks): + block.attn.proj.weight.set_data( + ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) + ) + block.mlp.fc2.weight.set_data( + ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) + ) + def forward_features(self, x): + x = self.patch_embed(x) + bsz = x.shape[0] -class DropPath(nn.Cell): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ + cls_tokens = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) - def __init__(self, keep_prob=None, seed=0): - super().__init__() - self.keep_prob = 1 - keep_prob - seed = min(seed, 0) - self.rand = P.UniformReal(seed=seed) - self.shape = P.Shape() - self.floor = P.Floor() + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) - def construct(self, x): - if self.training: - x_shape = self.shape(x) - random_tensor = self.rand((x_shape[0], 1, 1)) - random_tensor = random_tensor + self.keep_prob - random_tensor = self.floor(random_tensor) - x = x / self.keep_prob - x = x * random_tensor + if isinstance(self.rel_pos_bias, nn.CellList): + for i, blk in enumerate(self.blocks): + rel_pos_bias = self.rel_pos_bias[i]() + x = blk(x, rel_pos_bias) + else: + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) return x - -class TransformerEncoder(nn.Cell): - """ - TransformerEncoder implementation. - - Args: - dim (int): The dimension of embedding. - num_layers (int): The depth of transformer. - num_heads (int): The number of attention heads. - mlp_dim (int): The dimension of MLP hidden layer. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention. Default: 1.0. - drop_path_keep_prob (float): The keep rate for drop path. Default: 1.0. - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution - layer. Default: nn.LayerNorm. - - Returns: - Tensor, output tensor. - - Examples: - >>> ops = TransformerEncoder(768, 12, 12, 3072) - """ - - def __init__( - self, - dim: int, - num_layers: int, - num_heads: int, - mlp_dim: int, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, - drop_path_keep_prob: float = 1.0, - activation: nn.Cell = nn.GELU, - norm: nn.Cell = nn.LayerNorm, - ): - super().__init__() - drop_path_rate = 1 - drop_path_keep_prob - dpr = [i.item() for i in np.linspace(0, drop_path_rate, num_layers)] - attn_seeds = [np.random.randint(1024) for _ in range(num_layers)] - mlp_seeds = [np.random.randint(1024) for _ in range(num_layers)] - - layers = [] - for i in range(num_layers): - normalization1 = norm((dim,)) - normalization2 = norm((dim,)) - attention = Attention(dim=dim, - num_heads=num_heads, - keep_prob=keep_prob, - attention_keep_prob=attention_keep_prob) - - feedforward = FeedForward(in_features=dim, - hidden_features=mlp_dim, - activation=activation, - keep_prob=keep_prob) - - if drop_path_rate > 0: - layers.append( - nn.SequentialCell([ - ResidualCell(nn.SequentialCell([normalization1, - attention, - DropPath(dpr[i], attn_seeds[i])])), - ResidualCell(nn.SequentialCell([normalization2, - feedforward, - DropPath(dpr[i], mlp_seeds[i])]))])) - else: - layers.append( - nn.SequentialCell([ - ResidualCell(nn.SequentialCell([normalization1, - attention])), - ResidualCell(nn.SequentialCell([normalization2, - feedforward])) - ]) - ) - self.layers = nn.SequentialCell(layers) - def construct(self, x): - """Transformer construct.""" - return self.layers(x) + x = self.forward_features(x) + return x class DenseHead(nn.Cell): @@ -424,238 +367,143 @@ def construct(self, x): return x -class BaseClassifier(nn.Cell): - """ - generate classifier to combine the backbone and head - """ - - def __init__(self, backbone, neck=None, head=None): - super().__init__() - self.backbone = backbone - if neck: - self.neck = neck - self.with_neck = True - else: - self.with_neck = False - if head: - self.head = head - self.with_head = True - else: - self.with_head = False - - def forward_features(self, x: Tensor) -> Tensor: - x = self.backbone(x) - return x - - def forward_head(self, x: Tensor) -> Tensor: - x = self.head(x) - return x - - def construct(self, x): - x = self.forward_features(x) - if self.with_neck: - x = self.neck(x) - if self.with_head: - x = self.forward_head(x) - return x - - -def init(init_type, shape, dtype, name, requires_grad): - initial = initializer(init_type, shape, dtype).init_data() - return Parameter(initial, name=name, requires_grad=requires_grad) - - -class ViT(nn.Cell): - """ - Vision Transformer architecture implementation. - - Args: - image_size (int): Input image size. Default: 224. - input_channels (int): The number of input channel. Default: 3. - patch_size (int): Patch size of image. Default: 16. - embed_dim (int): The dimension of embedding. Default: 768. - num_layers (int): The depth of transformer. Default: 12. - num_heads (int): The number of attention heads. Default: 12. - mlp_dim (int): The dimension of MLP hidden layer. Default: 3072. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention layer. Default: 1.0. - drop_path_keep_prob (float): The keep rate for drop path. Default: 1.0. - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution - layer. Default: nn.LayerNorm. - pool (str): The method of pooling. Default: 'cls'. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor of shape :math:`(N, 768)` - - Raises: - ValueError: If `split` is not 'train', 'test' or 'infer'. - - Supported Platforms: - ``GPU`` - - Examples: - >>> net = ViT() - >>> x = ms.Tensor(np.ones([1, 3, 224, 224]), ms.float32) - >>> output = net(x) - >>> print(output.shape) - (1, 768) - - About ViT: - - Vision Transformer (ViT) shows that a pure transformer applied directly to sequences of image - patches can perform very well on image classification tasks. When pre-trained on large amounts - of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, - CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art - convolutional networks while requiring substantially fewer computational resources to train. - - Citation: - - .. code-block:: - - @article{2020An, - title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}, - author={Dosovitskiy, A. and Beyer, L. and Kolesnikov, A. and Weissenborn, D. and Houlsby, N.}, - year={2020}, - } - """ - +class ViT(VisionTransformerEncoder): def __init__( self, image_size: int = 224, - input_channels: int = 3, patch_size: int = 16, + in_channels: int = 3, embed_dim: int = 768, - num_layers: int = 12, + depth: int = 12, num_heads: int = 12, - mlp_dim: int = 3072, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, - drop_path_keep_prob: float = 1.0, - activation: nn.Cell = nn.GELU, - norm: Optional[nn.Cell] = nn.LayerNorm, - pool: str = "cls", - ) -> None: - super().__init__() - - self.patch_embedding = PatchEmbedding(image_size=image_size, - patch_size=patch_size, - embed_dim=embed_dim, - input_channels=input_channels) - num_patches = self.patch_embedding.num_patches - - if pool == "cls": - self.cls_token = init(init_type=Normal(sigma=1.0), - shape=(1, 1, embed_dim), - dtype=ms.float32, - name="cls", - requires_grad=True) - self.pos_embedding = init(init_type=Normal(sigma=1.0), - shape=(1, num_patches + 1, embed_dim), - dtype=ms.float32, - name="pos_embedding", - requires_grad=True) - self.concat = ops.Concat(axis=1) - else: - self.pos_embedding = init(init_type=Normal(sigma=1.0), - shape=(1, num_patches, embed_dim), - dtype=ms.float32, - name="pos_embedding", - requires_grad=True) - self.mean = ops.ReduceMean(keep_dims=False) - - self.pool = pool - self.pos_dropout = Dropout(p=1.0-keep_prob) - self.norm = norm((embed_dim,)) - self.tile = ops.Tile() - self.transformer = TransformerEncoder( - dim=embed_dim, - num_layers=num_layers, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = 0.1, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_rel_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + use_cls: bool = True, + representation_size: Optional[int] = None, + num_classes: int = 1000, + **kwargs + ): + super(ViT, self).__init__( + image_size=image_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, num_heads=num_heads, - mlp_dim=mlp_dim, - keep_prob=keep_prob, - attention_keep_prob=attention_keep_prob, - drop_path_keep_prob=drop_path_keep_prob, - activation=activation, - norm=norm, + attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + pos_drop_rate=pos_drop_rate, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + init_values=init_values, + act_layer=act_layer, + norm_layer=norm_layer, + use_rel_pos_emb=use_rel_pos_emb, + use_rel_pos_bias=use_rel_pos_bias, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + **kwargs ) + self.use_cls = use_cls + self.norm = norm_layer((embed_dim,)) + + if representation_size: + self.head = MultilayerDenseHead( + input_channel=embed_dim, + num_classes=num_classes, + mid_channel=[representation_size], + activation=["tanh", None], + keep_prob=[1.0, 1.0], + ) + else: + self.head = DenseHead(input_channel=embed_dim, num_classes=num_classes) def construct(self, x): - """ViT construct.""" - x = self.patch_embedding(x) - - if self.pool == "cls": - cls_tokens = self.tile(self.cls_token, (x.shape[0], 1, 1)) - x = self.concat((cls_tokens, x)) - x += self.pos_embedding - else: - x += self.pos_embedding - x = self.pos_dropout(x) - x = self.transformer(x) + x = self.forward_features(x) x = self.norm(x) - if self.pool == "cls": + if self.use_cls: x = x[:, 0] else: - x = self.mean(x, (1, )) # (1,) or (1,2) + x = x[:, 1:].mean(axis=1) + + x = self.head(x) return x def vit( - image_size: int, - input_channels: int, - patch_size: int, - embed_dim: int, - num_layers: int, - num_heads: int, - num_classes: int, - mlp_dim: int, - dropout: float = 0.0, - attention_dropout: float = 0.0, - drop_path_rate: float = 0.0, - activation: nn.Cell = nn.GELU, - norm: nn.Cell = nn.LayerNorm, - pool: str = "cls", + image_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = None, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_rel_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = False, + use_cls: bool = True, representation_size: Optional[int] = None, + num_classes: int = 1000, pretrained: bool = False, url_cfg: dict = None, ) -> ViT: + """Vision Transformer architecture.""" - backbone = ViT( + + model = ViT( image_size=image_size, - input_channels=input_channels, patch_size=patch_size, + in_channels=in_channels, embed_dim=embed_dim, - num_layers=num_layers, + depth=depth, num_heads=num_heads, - mlp_dim=mlp_dim, - keep_prob=1.0 - dropout, - attention_keep_prob=1.0 - attention_dropout, - drop_path_keep_prob=1.0 - drop_path_rate, - activation=activation, - norm=norm, - pool=pool, + attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + pos_drop_rate=pos_drop_rate, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + init_values=init_values, + act_layer=act_layer, + norm_layer=norm_layer, + use_rel_pos_emb=use_rel_pos_emb, + use_rel_pos_bias=use_rel_pos_bias, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + use_cls=use_cls, + representation_size=representation_size, + num_classes=num_classes ) - if representation_size: - head = MultilayerDenseHead( - input_channel=embed_dim, - num_classes=num_classes, - mid_channel=[representation_size], - activation=["tanh", None], - keep_prob=[1.0, 1.0], - ) - else: - head = DenseHead(input_channel=embed_dim, num_classes=num_classes) - - model = BaseClassifier(backbone=backbone, head=head) if pretrained: # Download the pre-trained checkpoint file from url, and load ckpt file. - load_pretrained(model, url_cfg, num_classes=num_classes, in_channels=input_channels) + load_pretrained(model, url_cfg, num_classes=num_classes, in_channels=in_channels) return model @@ -665,60 +513,25 @@ def vit_b_16_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """ - Constructs a vit_b_16 architecture from - `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. - - Args: - pretrained (bool): Whether to download and load the pre-trained model. Default: False. - num_classes (int): The number of classification. Default: 1000. - in_channels (int): The number of input channels. Default: 3. - image_size (int): The input image size. Default: 224 for ImageNet. - has_logits (bool): Whether has logits or not. Default: False. - drop_rate (float): The drop out rate. Default: 0.0.s - drop_path_rate (float): The stochastic depth rate. Default: 0.0. - - Returns: - ViT network, MindSpore.nn.Cell - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Examples: - >>> net = vit_b_16_224() - >>> x = ms.Tensor(np.ones([1, 3, 224, 224]), ms.float32) - >>> output = net(x) - >>> print(output.shape) - (1, 1000) - - Outputs: - Tensor of shape :math:`(N, CLASSES_{out})` - - Supported Platforms: - ``GPU`` - """ +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_16_224"] return vit(**config) @@ -729,29 +542,25 @@ def vit_b_16_384( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 384, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 384 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_16_384"] return vit(**config) @@ -762,30 +571,25 @@ def vit_l_16_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 1024 - config.mlp_dim = 4096 + config.depth = 24 config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.input_channels = in_channels - config.pool = "cls" - config.pretrained = pretrained config.representation_size = 1024 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_l_16_224"] return vit(**config) @@ -796,30 +600,25 @@ def vit_l_16_384( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 384, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 384 config.patch_size = 16 + config.in_channels = in_channels config.embed_dim = 1024 - config.mlp_dim = 4096 + config.depth = 24 config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.input_channels = in_channels - config.pool = "cls" - config.pretrained = pretrained config.representation_size = 1024 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_l_16_384"] return vit(**config) @@ -830,29 +629,25 @@ def vit_b_32_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 32 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_32_224"] return vit(**config) @@ -863,29 +658,25 @@ def vit_b_32_384( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 384, has_logits: bool = False, drop_rate: float = 0.0, - # attention_dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 384 config.patch_size = 32 + config.in_channels = in_channels config.embed_dim = 768 - config.mlp_dim = 3072 + config.depth = 12 config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention_dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 768 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_b_32_384"] return vit(**config) @@ -896,29 +687,25 @@ def vit_l_32_224( pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, - image_size: int = 224, has_logits: bool = False, drop_rate: float = 0.0, - # attention-dropout: float = 0.0, drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" +): config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes + config.image_size = 224 config.patch_size = 32 + config.in_channels = in_channels config.embed_dim = 1024 - config.mlp_dim = 4096 + config.depth = 24 config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout + config.pos_drop_rate = drop_rate + config.proj_drop_rate = drop_rate + config.attn_drop_rate = drop_rate config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" config.representation_size = 1024 if has_logits else None + config.num_classes = num_classes + config.pretrained = pretrained config.url_cfg = default_cfgs["vit_l_32_224"] return vit(**config) From e3170e2c88235953cf5629591a4c0857ee538e63 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Thu, 31 Aug 2023 23:55:06 +0800 Subject: [PATCH 2/5] fix format --- mindcv/models/layers/attention.py | 16 +++++++++------- mindcv/models/layers/pos_embed.py | 20 +++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/mindcv/models/layers/attention.py b/mindcv/models/layers/attention.py index 5ac033e6b..e62f363bc 100644 --- a/mindcv/models/layers/attention.py +++ b/mindcv/models/layers/attention.py @@ -1,10 +1,14 @@ -"""attention layers""" -#TODO: add Flash Attention +"""attention layers +TODO: add Flash Attention +""" + +from typing import Optional + +from mindspore import Tensor, nn, ops -from typing import Optional, Tuple, Union, List -from mindspore import nn, ops, Tensor, Parameter from .compatibility import Dropout + class Attention(nn.Cell): """ Attention layer implementation, Rearrange Input -> B x N x hidden size. @@ -71,7 +75,7 @@ def construct(self, x, rel_pos_bias=None): attn = self.mul(attn, self.scale) if rel_pos_bias is not None: - attn = attn + rel_pos_bias + attn = attn + rel_pos_bias attn = self.softmax(attn) attn = self.attn_drop(attn) @@ -83,5 +87,3 @@ def construct(self, x, rel_pos_bias=None): out = self.proj_drop(out) return out - - diff --git a/mindcv/models/layers/pos_embed.py b/mindcv/models/layers/pos_embed.py index 66052947f..c570c5c1b 100644 --- a/mindcv/models/layers/pos_embed.py +++ b/mindcv/models/layers/pos_embed.py @@ -1,6 +1,10 @@ """positional embedding""" -from mindspore import nn, ops, Tensor, Parameter -from typing import Optional, Tuple, Union, List +from typing import Tuple + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops class RelativePositionBiasWithCLS(nn.Cell): @@ -20,21 +24,21 @@ def __init__( ) coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) - coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] + coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] - relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] - relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] + relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] + relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] relative_coords[:, :, 0] += window_size[0] - 1 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[0] - 1 relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), - dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] + dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] relative_position_index[1:, 1:] = relative_coords.sum(-1) relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0, 0] = num_relative_distance - 1 - relative_position_index = Tensor(relative_position_index.reshape(-1)) + relative_position_index = Tensor(relative_position_index.reshape(-1)) self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) @@ -45,5 +49,3 @@ def construct(self): out = ops.transpose(out, (2, 0, 1)) out = ops.expand_dims(out, 0) return out - - From 696f09b4fda31518c36e75c2652f994f41433511 Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Fri, 1 Sep 2023 17:38:29 +0800 Subject: [PATCH 3/5] undo attention --- mindcv/models/layers/attention.py | 89 ------------------------------- mindcv/models/vit.py | 84 ++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 91 deletions(-) delete mode 100644 mindcv/models/layers/attention.py diff --git a/mindcv/models/layers/attention.py b/mindcv/models/layers/attention.py deleted file mode 100644 index e62f363bc..000000000 --- a/mindcv/models/layers/attention.py +++ /dev/null @@ -1,89 +0,0 @@ -"""attention layers -TODO: add Flash Attention -""" - -from typing import Optional - -from mindspore import Tensor, nn, ops - -from .compatibility import Dropout - - -class Attention(nn.Cell): - """ - Attention layer implementation, Rearrange Input -> B x N x hidden size. - - Args: - dim (int): The dimension of input features. - num_heads (int): The number of attention heads. Default: 8. - qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. - qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. - attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. - proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0. - attn_head_dim (int): The user-defined dimension of attention head features. Default: None. - - Returns: - Tensor, output tensor. - - Examples: - >>> ops = Attention(768, 12) - """ - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - attn_head_dim: Optional[int] = None, - ): - super(Attention, self).__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * num_heads - - if qk_scale: - self.scale = Tensor(qk_scale) - else: - self.scale = Tensor(head_dim ** -0.5) - - self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) - - self.attn_drop = Dropout(attn_drop) - self.proj = nn.Dense(all_head_dim, dim) - self.proj_drop = Dropout(proj_drop) - - self.mul = ops.Mul() - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() - self.unstack = ops.Unstack(axis=0) - self.attn_matmul_v = ops.BatchMatMul() - self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - self.softmax = nn.Softmax(axis=-1) - - def construct(self, x, rel_pos_bias=None): - b, n, c = x.shape - qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) - qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) - q, k, v = self.unstack(qkv) - - attn = self.q_matmul_k(q, k) - attn = self.mul(attn, self.scale) - - if rel_pos_bias is not None: - attn = attn + rel_pos_bias - - attn = self.softmax(attn) - attn = self.attn_drop(attn) - - out = self.attn_matmul_v(attn, v) - out = self.transpose(out, (0, 2, 1, 3)) - out = self.reshape(out, (b, n, c)) - out = self.proj(out) - out = self.proj_drop(out) - - return out diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index 7fac9e2c7..b26712f91 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -4,11 +4,10 @@ import numpy as np -from mindspore import Parameter, nn, ops +from mindspore import Parameter, Tensor, nn, ops from mindspore.common.initializer import TruncatedNormal, initializer from .helpers import ConfigDict, load_pretrained -from .layers.attention import Attention from .layers.compatibility import Dropout from .layers.drop_path import DropPath from .layers.mlp import Mlp @@ -57,6 +56,87 @@ def _cfg(url="", **kwargs): } +# TODO: Flash Attention +class Attention(nn.Cell): + """ + Attention layer implementation, Rearrange Input -> B x N x hidden size. + + Args: + dim (int): The dimension of input features. + num_heads (int): The number of attention heads. Default: 8. + qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. + qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. + attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. + proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0. + attn_head_dim (int): The user-defined dimension of attention head features. Default: None. + + Returns: + Tensor, output tensor. + + Examples: + >>> ops = Attention(768, 12) + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_head_dim: Optional[int] = None, + ): + super(Attention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * num_heads + + if qk_scale: + self.scale = Tensor(qk_scale) + else: + self.scale = Tensor(head_dim ** -0.5) + + self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) + + self.attn_drop = Dropout(attn_drop) + self.proj = nn.Dense(all_head_dim, dim) + self.proj_drop = Dropout(proj_drop) + + self.mul = ops.Mul() + self.reshape = ops.Reshape() + self.transpose = ops.Transpose() + self.unstack = ops.Unstack(axis=0) + self.attn_matmul_v = ops.BatchMatMul() + self.q_matmul_k = ops.BatchMatMul(transpose_b=True) + self.softmax = nn.Softmax(axis=-1) + + def construct(self, x, rel_pos_bias=None): + b, n, c = x.shape + qkv = self.qkv(x) + qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) + qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) + q, k, v = self.unstack(qkv) + + attn = self.q_matmul_k(q, k) + attn = self.mul(attn, self.scale) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + out = self.attn_matmul_v(attn, v) + out = self.transpose(out, (0, 2, 1, 3)) + out = self.reshape(out, (b, n, c)) + out = self.proj(out) + out = self.proj_drop(out) + + return out + + class LayerScale(nn.Cell): """ Layer scale, help ViT improve the training dynamic, allowing for the training From e006a249ec0a4c465ac596a2dddbb1c62fa5b132 Mon Sep 17 00:00:00 2001 From: hanhuiyu1996 Date: Mon, 7 Aug 2023 14:15:18 +0800 Subject: [PATCH 4/5] add model mae and fintune checkpoint file --- mindcv/models/__init__.py | 3 + mindcv/models/mae.py | 432 +++++++++++++++++++++++++++++++++++ mindcv/models/vit_encoder.py | 296 ++++++++++++++++++++++++ 3 files changed, 731 insertions(+) create mode 100644 mindcv/models/mae.py create mode 100644 mindcv/models/vit_encoder.py diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py index 6f28ba6f1..f3395796c 100644 --- a/mindcv/models/__init__.py +++ b/mindcv/models/__init__.py @@ -18,6 +18,7 @@ inceptionv3, inceptionv4, layers, + mae, mixnet, mlpmixer, mnasnet, @@ -74,6 +75,7 @@ from .inceptionv3 import * from .inceptionv4 import * from .layers import * +from .mae import * from .mixnet import * from .mlpmixer import * from .mnasnet import * @@ -132,6 +134,7 @@ __all__.extend(["InceptionV3", "inception_v3"]) __all__.extend(["InceptionV4", "inception_v4"]) __all__.extend(layers.__all__) +__all__.extend(mae.__all__) __all__.extend(mixnet.__all__) __all__.extend(mlpmixer.__all__) __all__.extend(mnasnet.__all__) diff --git a/mindcv/models/mae.py b/mindcv/models/mae.py new file mode 100644 index 000000000..f50346958 --- /dev/null +++ b/mindcv/models/mae.py @@ -0,0 +1,432 @@ +from functools import partial +from typing import Optional + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import Normal, initializer + +from .helpers import load_pretrained +from .registry import register_model +from .vit_encoder import Block, VisionTransformerEncoder + +__all__ = [ + "mae_b_16_224_pretrain", + "mae_l_16_224_pretrain", + "mae_h_16_224_pretrain", + "mae_b_16_224_finetune", + "mae_l_16_224_finetune", + "mae_h_14_224_finetune" +] + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "first_conv": "patch_embed.proj", + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "mae_b_16_224_finetune": _cfg( + url="https://download.mindspore.cn/toolkits/mindcv/mae/mae_b_16_224_finetune-cc05b899.ckpt" + ), + "mae_l_16_224_finetune": _cfg(url=""), + "mae_h_14_224_finetune": _cfg(url=""), +} + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class MAEForPretrain(VisionTransformerEncoder): + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1024, + depth: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4., + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + norm_pix_loss: bool = True, + mask_ratio: float = 0.75, + **kwargs + ): + super(MAEForPretrain, self).__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=None, + act_layer=act_layer, + norm_layer=norm_layer, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + **kwargs + ) + self.cls_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, embed_dim))) + + self.unmask_len = int(np.floor(self.num_patches * (1 - mask_ratio))) + + encoder_pos_emb = Tensor(get_2d_sincos_pos_embed( + embed_dim, int(self.num_patches ** 0.5), cls_token=True), ms.float32 + ) + encoder_pos_emb = ops.expand_dims(encoder_pos_emb, axis=0) + self.pos_embed = Parameter(encoder_pos_emb, requires_grad=False) + self.norm = norm_layer((embed_dim,)) + + self.decoder_embed = nn.Dense(embed_dim, decoder_embed_dim) + self.mask_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, decoder_embed_dim))) + + decoder_pos_emb = Tensor(get_2d_sincos_pos_embed( + decoder_embed_dim, int(self.num_patches ** 0.5), cls_token=True), ms.float32 + ) + decoder_pos_emb = ops.expand_dims(decoder_pos_emb, axis=0) + self.decoder_pos_embed = Parameter(decoder_pos_emb, requires_grad=False) + + self.decoder_blocks = nn.CellList([ + Block( + dim=decoder_embed_dim, num_heads=decoder_num_heads, qkv_bias=True, + mlp_ratio=mlp_ratio, init_values=None, act_layer=act_layer, norm_layer=norm_layer, + ) for _ in range(decoder_depth) + ]) + self.decoder_norm = norm_layer((decoder_embed_dim,)) + self.decoder_pred = nn.Dense(decoder_embed_dim, patch_size ** 2 * in_chans) + + self.sort = ops.Sort() + + self.norm_pix_loss = norm_pix_loss + self._init_weights() + + def _init_weights(self): + for name, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) + ) + + if name == "patch_embed.proj": + cell.weight.set_data( + initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) + ) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size ** 2 * 3) + """ + N, _, H, W = imgs.shape + p = self.patch_embed.patch_size[0] + assert H == W and H % p == 0 + h = w = H // p + + x = ops.reshape(imgs, (N, 3, h, p, w, p)) + x = ops.transpose(x, (0, 2, 4, 3, 5, 1)) + x = ops.reshape(x, (N, h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size ** 2 * 3) + imgs: (N, 3, H, W) + """ + N, L, _ = x.shape + p = self.patch_embed.patch_size[0] + h = w = int(L ** 0.5) + assert h * w == L + + imgs = ops.reshape(x, (N, h, w, p, p, 3)) + imgs = ops.transpose(imgs, (0, 5, 1, 3, 2, 4)) + imgs = ops.reshape(imgs, (N, 3, h * p, w * p)) + return imgs + + def apply_masking(self, x, mask): + D = x.shape[2] + _, ids_shuffle = self.sort(mask.astype(ms.float32)) + _, ids_restore = self.sort(ids_shuffle.astype(ms.float32)) + + ids_keep = ids_shuffle[:, :self.unmask_len] + ids_keep = ops.broadcast_to(ops.expand_dims(ids_keep, axis=-1), (-1, -1, D)) + x_unmasked = ops.gather_elements(x, dim=1, index=ids_keep) + + return x_unmasked, ids_restore + + def forward_features(self, x, mask): + x = self.patch_embed(x) + bsz = x.shape[0] + + x = x + self.pos_embed[:, 1:, :] + x, ids_restore = self.apply_masking(x, mask) + + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_token = ops.broadcast_to(cls_token, (bsz, -1, -1)) + cls_token = cls_token.astype(x.dtype) + x = ops.concat((cls_token, x), axis=1) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x, ids_restore + + def forward_decoder(self, x, ids_restore): + x = self.decoder_embed(x) + bsz, L, D = x.shape + + mask_len = self.num_patches + 1 - L + mask_tokens = ops.broadcast_to(self.mask_token, (bsz, mask_len, -1)) + mask_tokens = mask_tokens.astype(x.dtype) + + x_ = ops.concat((x[:, 1:, :], mask_tokens), axis=1) + ids_restore = ops.broadcast_to(ops.expand_dims(ids_restore, axis=-1), (-1, -1, D)) + x_ = ops.gather_elements(x_, dim=1, index=ids_restore) + x = ops.concat((x[:, :1, :], x_), axis=1) + + x = x + self.decoder_pos_embed + + for blk in self.decoder_blocks: + x = blk(x) + + x = self.decoder_norm(x) + x = self.decoder_pred(x) + + return x[:, 1:, :] + + def forward_loss(self, imgs, pred, mask): + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(axis=-1, keep_dims=True) + std = target.std(axis=-1, keepdims=True) + target = (target - mean) / std + + loss = (pred - target) ** 2 + loss = loss.mean(axis=-1) + + mask = mask.astype(loss.dtype) + loss = (loss * mask).sum() / mask.sum() + return loss + + def construct(self, imgs, mask): + bsz = imgs.shape[0] + mask = ops.reshape(mask, (bsz, -1)) + features, ids_restore = self.forward_features(imgs, mask) + pred = self.forward_decoder(features, ids_restore) + loss = self.forward_loss(imgs, pred, mask) + return loss + + +class MAEForFinetune(VisionTransformerEncoder): + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + num_classes: int = 1000, + use_mean_pooling: bool = True, + **kwargs + ): + super(MAEForFinetune, self).__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + pos_drop_rate=pos_drop_rate, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + init_values=None, + act_layer=act_layer, + norm_layer=norm_layer, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + **kwargs + ) + self.use_mean_pooling = use_mean_pooling + if self.use_mean_pooling: + self.fc_norm = norm_layer((embed_dim,)) + else: + self.norm = norm_layer((embed_dim,)) + self.head = nn.Dense(embed_dim, num_classes, weight_init='TruncatedNormal') + + self._init_weights() + self._fix_init_weights() + + def construct(self, x): + x = self.forward_features(x) + if self.use_mean_pooling: + x = x[:, 1:].mean(axis=1) + x = self.fc_norm(x) + else: + x = self.norm(x) + x = x[:, 0] + x = self.head(x) + return x + + +@register_model +def mae_b_16_224_pretrain(pretrained=False, **kwargs): + model = MAEForPretrain( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + pass + return model + + +@register_model +def mae_l_16_224_pretrain(pretrained=False, **kwargs): + model = MAEForPretrain( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + pass + return model + + +@register_model +def mae_h_16_224_pretrain(pretrained=False, **kwargs): + model = MAEForPretrain( + patch_size=16, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + pass + return model + + +@register_model +def mae_b_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): + default_cfg = default_cfgs["mae_b_16_224_finetune"] + model = MAEForFinetune( + patch_size=16, in_chans=in_chans, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + return model + + +@register_model +def mae_l_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): + default_cfg = default_cfgs["mae_l_16_224_finetune"] + model = MAEForFinetune( + patch_size=16, in_chans=in_chans, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + return model + + +@register_model +def mae_h_14_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): + default_cfg = default_cfgs["mae_h_14_224_finetune"] + model = MAEForFinetune( + patch_size=14, in_chans=in_chans, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + return model diff --git a/mindcv/models/vit_encoder.py b/mindcv/models/vit_encoder.py new file mode 100644 index 000000000..010065273 --- /dev/null +++ b/mindcv/models/vit_encoder.py @@ -0,0 +1,296 @@ +import math +from typing import Optional, Tuple + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import TruncatedNormal, initializer + +from .layers.drop_path import DropPath +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed + + +class RelativePositionBiasWithCLS(nn.Cell): + def __init__( + self, + window_size: Tuple[int], + num_heads: int + ): + super(RelativePositionBiasWithCLS, self).__init__() + self.window_size = window_size + self.num_tokens = window_size[0] * window_size[1] + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # 3: cls to token, token to cls, cls to cls + self.relative_position_bias_table = Parameter( + Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16) + ) + coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) + coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) + coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] + + relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] + relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[0] - 1 + + relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), + dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index = Tensor(relative_position_index.reshape(-1)) + + self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) + self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) + + def construct(self): + out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) + out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) + out = ops.transpose(out, (2, 0, 1)) + out = ops.expand_dims(out, 0) + return out + + +class Attention(nn.Cell): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_head_dim: Optional[int] = None, + ): + super(Attention, self).__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * num_heads + + if qk_scale: + self.scale = Tensor(qk_scale) + else: + self.scale = Tensor(head_dim ** -0.5) + + self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) + + self.attn_drop = nn.Dropout(1 - attn_drop) + self.proj = nn.Dense(all_head_dim, dim) + self.proj_drop = nn.Dropout(1 - proj_drop) + + self.mul = ops.Mul() + self.reshape = ops.Reshape() + self.transpose = ops.Transpose() + self.unstack = ops.Unstack(axis=0) + self.attn_matmul_v = ops.BatchMatMul() + self.q_matmul_k = ops.BatchMatMul(transpose_b=True) + + def construct(self, x, rel_pos_bias=None): + b, n, c = x.shape + qkv = self.qkv(x) + qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) + qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) + q, k, v = self.unstack(qkv) + + attn = self.q_matmul_k(q, k) + attn = self.mul(attn, self.scale) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.astype(ms.float32) + attn = ops.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + out = self.attn_matmul_v(attn, v) + out = self.transpose(out, (0, 2, 1, 3)) + out = self.reshape(out, (b, n, c)) + out = self.proj(out) + out = self.proj_drop(out) + + return out + + +class LayerScale(nn.Cell): + def __init__( + self, + dim: int, + init_values: float = 1e-5 + ): + super(LayerScale, self).__init__() + self.gamma = Parameter(initializer(init_values, dim)) + + def construct(self, x): + return self.gamma * x + + +class Block(nn.Cell): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0., + proj_drop: float = 0., + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + drop_path: float = 0., + init_values: Optional[float] = None, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + ): + super(Block, self).__init__() + self.norm1 = norm_layer((dim,)) + self.attn = Attention( + dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, + ) + self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer((dim,)) + self.mlp = Mlp( + in_features=dim, hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, drop=proj_drop + ) + self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def construct(self, x, rel_pos_bias=None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformerEncoder(nn.Cell): + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + attn_head_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = 0.1, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + use_abs_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + **kwargs + ): + super(VisionTransformerEncoder, self).__init__() + self.embed_dim = embed_dim + self.patch_embed = PatchEmbed(image_size=img_size, patch_size=patch_size, + in_chans=in_chans, embed_dim=embed_dim) + self.num_patches = self.patch_embed.num_patches + + self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) + + self.pos_embed = Parameter( + initializer(TruncatedNormal(0.02), (1, self.num_patches + 1, embed_dim))) if use_abs_pos_emb else None + self.pos_drop = nn.Dropout(1 - pos_drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBiasWithCLS( + window_size=self.patch_embed.patches_resolution, num_heads=num_heads) + elif use_rel_pos_bias: + self.rel_pos_bias = nn.CellList([ + RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, + num_heads=num_heads) for _ in range(depth) + ]) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + Block( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer + ) for i in range(depth) + ]) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _init_weights(self): + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) + ) + elif isinstance(cell, nn.Conv2d): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + + def _fix_init_weights(self): + for i, block in enumerate(self.blocks): + block.attn.proj.weight.set_data( + ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) + ) + block.mlp.fc2.weight.set_data( + ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) + ) + + def forward_features(self, x): + x = self.patch_embed(x) + bsz = x.shape[0] + + cls_token = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) + cls_token = cls_token.astype(x.dtype) + x = ops.concat((cls_token, x), axis=1) + + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + if isinstance(self.rel_pos_bias, nn.CellList): + for i, blk in enumerate(self.blocks): + rel_pos_bias = self.rel_pos_bias[i]() + x = blk(x, rel_pos_bias) + else: + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + + return x + + def construct(self, x): + return self.forward_features(x) From eef1dfbe313e91507ce5c2ae0975b8a8c3b268e6 Mon Sep 17 00:00:00 2001 From: hanhuiyu1996 Date: Fri, 15 Sep 2023 10:40:07 +0800 Subject: [PATCH 5/5] extend vit and mae --- mindcv/models/layers/__init__.py | 16 +- mindcv/models/layers/format.py | 34 ++ mindcv/models/layers/patch_dropout.py | 54 ++ mindcv/models/layers/patch_embed.py | 65 ++- mindcv/models/layers/pos_embed.py | 44 +- mindcv/models/mae.py | 195 +++---- mindcv/models/vit.py | 728 ++++++++------------------ mindcv/models/vit_encoder.py | 296 ----------- 8 files changed, 487 insertions(+), 945 deletions(-) create mode 100644 mindcv/models/layers/format.py create mode 100644 mindcv/models/layers/patch_dropout.py delete mode 100644 mindcv/models/vit_encoder.py diff --git a/mindcv/models/layers/__init__.py b/mindcv/models/layers/__init__.py index 2810dbca1..c3e4de210 100644 --- a/mindcv/models/layers/__init__.py +++ b/mindcv/models/layers/__init__.py @@ -1,10 +1,24 @@ """layers init""" -from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite +from . import ( + activation, + conv_norm_act, + drop_path, + format, + identity, + patch_dropout, + pooling, + pos_embed, + selective_kernel, + squeeze_excite, +) from .activation import * from .conv_norm_act import * from .drop_path import * +from .format import * from .identity import * +from .patch_dropout import * from .pooling import * +from .pos_embed import * from .selective_kernel import * from .squeeze_excite import * diff --git a/mindcv/models/layers/format.py b/mindcv/models/layers/format.py new file mode 100644 index 000000000..058a74517 --- /dev/null +++ b/mindcv/models/layers/format.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Union + +import mindspore + + +class Format(str, Enum): + NCHW = 'NCHW' + NHWC = 'NHWC' + NCL = 'NCL' + NLC = 'NLC' + + +FormatT = Union[str, Format] + + +def nchw_to(x: mindspore.Tensor, fmt: Format): + if fmt == Format.NHWC: + x = x.permute(0, 2, 3, 1) + elif fmt == Format.NLC: + x = x.flatten(start_dim=2).transpose((0, 2, 1)) + elif fmt == Format.NCL: + x = x.flatten(start_dim=2) + return x + + +def nhwc_to(x: mindspore.Tensor, fmt: Format): + if fmt == Format.NCHW: + x = x.permute(0, 3, 1, 2) + elif fmt == Format.NLC: + x = x.flatten(start_dim=1, end_dim=2) + elif fmt == Format.NCL: + x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1)) + return x diff --git a/mindcv/models/layers/patch_dropout.py b/mindcv/models/layers/patch_dropout.py new file mode 100644 index 000000000..ad854dbfc --- /dev/null +++ b/mindcv/models/layers/patch_dropout.py @@ -0,0 +1,54 @@ +import numpy as np + +import mindspore as ms +from mindspore import nn, ops + + +class PatchDropout(nn.Cell): + """ + https://arxiv.org/abs/2212.00794 + """ + def __init__( + self, + prob: float = 0.5, + num_prefix_tokens: int = 1, + ordered: bool = False, + return_indices: bool = False, + ): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) + self.ordered = ordered + self.return_indices = return_indices + self.sort = ops.Sort() + + def forward(self, x): + if not self.training or self.prob == 0.: + if self.return_indices: + return x, None + return x + + if self.num_prefix_tokens: + prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] + else: + prefix_tokens = None + + B = x.shape[0] + L = x.shape[1] + num_keep = max(1, int(L * (1. - self.prob))) + _, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32)) + keep_indices = indices[:, :num_keep] + if self.ordered: + # NOTE does not need to maintain patch order in typical transformer use, + # but possibly useful for debug / visualization + keep_indices, _ = self.sort(keep_indices) + keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2])) + x = ops.gather_elements(x, dim=1, index=keep_indices) + + if prefix_tokens is not None: + x = ops.concat((prefix_tokens, x), axis=1) + + if self.return_indices: + return x, keep_indices + return x diff --git a/mindcv/models/layers/patch_embed.py b/mindcv/models/layers/patch_embed.py index d2ca684f1..661e07890 100644 --- a/mindcv/models/layers/patch_embed.py +++ b/mindcv/models/layers/patch_embed.py @@ -4,6 +4,7 @@ from mindspore import Tensor, nn, ops +from .format import Format, nchw_to from .helpers import to_2tuple @@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell): embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Cell, optional): Normalization layer. Default: None """ + output_fmt: Format def __init__( self, - image_size: int = 224, + image_size: Optional[int] = 224, patch_size: int = 4, in_chans: int = 3, embed_dim: int = 96, norm_layer: Optional[nn.Cell] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, ) -> None: super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]] - self.image_size = image_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans + self.patch_size = to_2tuple(patch_size) + if image_size is not None: + self.image_size = to_2tuple(image_size) + self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)]) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + else: + self.image_size = None + self.patches_resolution = None + self.num_patches = None + + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + self.flatten = flatten + self.output_fmt = Format.NCHW + + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad self.embed_dim = embed_dim self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - pad_mode='pad', has_bias=True, weight_init="TruncatedNormal") + pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal") if norm_layer is not None: if isinstance(embed_dim, int): @@ -50,11 +67,29 @@ def __init__( def construct(self, x: Tensor) -> Tensor: """docstring""" - B = x.shape[0] - # FIXME look at relaxing size constraints - x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C - x = ops.Transpose()(x, (0, 2, 1)) + B, C, H, W = x.shape + if self.image_size is not None: + if self.strict_img_size: + if (H, W) != (self.image_size[0], self.image_size[1]): + raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]}," + f"{self.image_size[1]}).") + elif not self.dynamic_img_pad: + if H % self.patch_size[0] != 0: + raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).") + if W % self.patch_size[1] != 0: + raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).") + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = ops.pad(x, (0, pad_w, 0, pad_h)) + # FIXME look at relaxing size constraints + x = self.proj(x) + if self.flatten: + x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C + x = ops.Transpose()(x, (0, 2, 1)) + elif self.output_fmt != "NCHW": + x = nchw_to(x, self.output_fmt) if self.norm is not None: x = self.norm(x) return x diff --git a/mindcv/models/layers/pos_embed.py b/mindcv/models/layers/pos_embed.py index c570c5c1b..ba4548580 100644 --- a/mindcv/models/layers/pos_embed.py +++ b/mindcv/models/layers/pos_embed.py @@ -1,11 +1,53 @@ """positional embedding""" -from typing import Tuple +import math +from typing import List, Optional, Tuple import numpy as np import mindspore as ms from mindspore import Parameter, Tensor, nn, ops +from .compatibility import Interpolate + + +def resample_abs_pos_embed( + posemb, + new_size: List[int], + old_size: Optional[List[int]] = None, + num_prefix_tokens: int = 1, + interpolation: str = 'nearest', +): + # sort out sizes, assume square if old size not provided + num_pos_tokens = posemb.shape[1] + num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens + + if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: + return posemb + + if old_size is None: + hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) + old_size = hw, hw + + if num_prefix_tokens: + posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] + else: + posemb_prefix, posemb = None, posemb + + # do the interpolation + embed_dim = posemb.shape[-1] + orig_dtype = posemb.dtype + posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + interpolate = Interpolate(mode=interpolation, align_corners=True) + posemb = interpolate(posemb, size=new_size) + posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) + posemb = posemb.astype(orig_dtype) + + # add back extra (class, etc) prefix tokens + if posemb_prefix is not None: + posemb = ops.concatcat((posemb_prefix, posemb), axis=1) + + return posemb + class RelativePositionBiasWithCLS(nn.Cell): def __init__( diff --git a/mindcv/models/mae.py b/mindcv/models/mae.py index f50346958..4a5cf887e 100644 --- a/mindcv/models/mae.py +++ b/mindcv/models/mae.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional +from typing import Callable, Optional import numpy as np @@ -8,8 +8,10 @@ from mindspore.common.initializer import Normal, initializer from .helpers import load_pretrained +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed from .registry import register_model -from .vit_encoder import Block, VisionTransformerEncoder +from .vit import Block, VisionTransformer __all__ = [ "mae_b_16_224_pretrain", @@ -91,12 +93,12 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): return emb -class MAEForPretrain(VisionTransformerEncoder): +class MAEForPretrain(nn.Cell): def __init__( self, - img_size: int = 224, + image_size: int = 224, patch_size: int = 16, - in_chans: int = 3, + in_channels: int = 3, embed_dim: int = 1024, depth: int = 24, num_heads: int = 16, @@ -104,28 +106,33 @@ def __init__( decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = None, act_layer: nn.Cell = nn.GELU, norm_layer: nn.Cell = nn.LayerNorm, + mlp_layer: Callable = Mlp, norm_pix_loss: bool = True, mask_ratio: float = 0.75, - **kwargs + **kwargs, ): - super(MAEForPretrain, self).__init__( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - init_values=None, - act_layer=act_layer, - norm_layer=norm_layer, - use_abs_pos_emb=True, - use_rel_pos_bias=False, - use_shared_rel_pos_bias=False, - **kwargs - ) + super(MAEForPretrain, self).__init__() + self.patch_embed = PatchEmbed(image_size=image_size, patch_size=patch_size, + in_chans=in_channels, embed_dim=embed_dim) + self.num_patches = self.patch_embed.num_patches + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + Block( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, + ) for i in range(depth) + ]) + self.cls_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, embed_dim))) self.unmask_len = int(np.floor(self.num_patches * (1 - mask_ratio))) @@ -148,12 +155,14 @@ def __init__( self.decoder_blocks = nn.CellList([ Block( - dim=decoder_embed_dim, num_heads=decoder_num_heads, qkv_bias=True, - mlp_ratio=mlp_ratio, init_values=None, act_layer=act_layer, norm_layer=norm_layer, - ) for _ in range(decoder_depth) + dim=decoder_embed_dim, num_heads=decoder_num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, + ) for i in range(decoder_depth) ]) self.decoder_norm = norm_layer((decoder_embed_dim,)) - self.decoder_pred = nn.Dense(decoder_embed_dim, patch_size ** 2 * in_chans) + self.decoder_pred = nn.Dense(decoder_embed_dim, patch_size ** 2 * in_channels) self.sort = ops.Sort() @@ -178,7 +187,6 @@ def _init_weights(self): cell.beta.set_data( initializer('zeros', cell.beta.shape, cell.beta.dtype) ) - if name == "patch_embed.proj": cell.weight.set_data( initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) @@ -288,145 +296,86 @@ def construct(self, imgs, mask): loss = self.forward_loss(imgs, pred, mask) return loss + def get_num_layers(self): + return len(self.blocks) -class MAEForFinetune(VisionTransformerEncoder): - def __init__( - self, - img_size: int = 224, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - num_classes: int = 1000, - use_mean_pooling: bool = True, - **kwargs - ): - super(MAEForFinetune, self).__init__( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - pos_drop_rate=pos_drop_rate, - proj_drop_rate=proj_drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - init_values=None, - act_layer=act_layer, - norm_layer=norm_layer, - use_abs_pos_emb=True, - use_rel_pos_bias=False, - use_shared_rel_pos_bias=False, - **kwargs - ) - self.use_mean_pooling = use_mean_pooling - if self.use_mean_pooling: - self.fc_norm = norm_layer((embed_dim,)) - else: - self.norm = norm_layer((embed_dim,)) - self.head = nn.Dense(embed_dim, num_classes, weight_init='TruncatedNormal') - - self._init_weights() - self._fix_init_weights() - - def construct(self, x): - x = self.forward_features(x) - if self.use_mean_pooling: - x = x[:, 1:].mean(axis=1) - x = self.fc_norm(x) - else: - x = self.norm(x) - x = x[:, 0] - x = self.head(x) - return x + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} @register_model -def mae_b_16_224_pretrain(pretrained=False, **kwargs): +def mae_b_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_b_16_224_pretrain"] model = MAEForPretrain( - patch_size=16, embed_dim=768, depth=12, num_heads=12, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - act_layer=partial(nn.GELU, approximate=False), + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs ) if pretrained: - pass + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_l_16_224_pretrain(pretrained=False, **kwargs): +def mae_l_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_l_16_224_pretrain"] model = MAEForPretrain( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - act_layer=partial(nn.GELU, approximate=False), + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs ) if pretrained: - pass + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_h_16_224_pretrain(pretrained=False, **kwargs): +def mae_h_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_h_16_224_pretrain"] model = MAEForPretrain( - patch_size=16, embed_dim=1280, depth=32, num_heads=16, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - act_layer=partial(nn.GELU, approximate=False), + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs ) if pretrained: - pass + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_b_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): +def mae_b_16_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): default_cfg = default_cfgs["mae_b_16_224_finetune"] - model = MAEForFinetune( - patch_size=16, in_chans=in_chans, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs ) if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_l_16_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): +def mae_l_16_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): default_cfg = default_cfgs["mae_l_16_224_finetune"] - model = MAEForFinetune( - patch_size=16, in_chans=in_chans, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs ) if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def mae_h_14_224_finetune(pretrained=True, in_chans=3, num_classes=1000, **kwargs): +def mae_h_14_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): default_cfg = default_cfgs["mae_h_14_224_finetune"] - model = MAEForFinetune( - patch_size=14, in_chans=in_chans, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), num_classes=num_classes, **kwargs + model = VisionTransformer( + image_size=224, patch_size=14, in_channels=in_channels, embed_dim=1280, depth=32, num_heads=16, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs ) if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_chans) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index b26712f91..5a679df72 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -1,23 +1,23 @@ """ViT""" -import math -from typing import List, Optional, Union +from typing import Callable, Optional import numpy as np +import mindspore as ms from mindspore import Parameter, Tensor, nn, ops -from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.common.initializer import HeUniform, TruncatedNormal, initializer -from .helpers import ConfigDict, load_pretrained +from .helpers import load_pretrained from .layers.compatibility import Dropout from .layers.drop_path import DropPath from .layers.mlp import Mlp +from .layers.patch_dropout import PatchDropout from .layers.patch_embed import PatchEmbed -from .layers.pos_embed import RelativePositionBiasWithCLS +from .layers.pos_embed import resample_abs_pos_embed from .registry import register_model __all__ = [ - "VisionTransformerEncoder", - "ViT", + "VisionTransformer", "vit_b_16_224", "vit_b_16_384", "vit_l_16_224", # with pretrained weights @@ -65,10 +65,9 @@ class Attention(nn.Cell): dim (int): The dimension of input features. num_heads (int): The number of attention heads. Default: 8. qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. - qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. + qk_norm (bool): Specifies whether to do normalization to q and k. attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0. - attn_head_dim (int): The user-defined dimension of attention head features. Default: None. Returns: Tensor, output tensor. @@ -81,27 +80,23 @@ def __init__( dim: int, num_heads: int = 8, qkv_bias: bool = True, - qk_scale: Optional[float] = None, + qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, - attn_head_dim: Optional[int] = None, + norm_layer: nn.Cell = nn.LayerNorm, ): super(Attention, self).__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * num_heads - - if qk_scale: - self.scale = Tensor(qk_scale) - else: - self.scale = Tensor(head_dim ** -0.5) + self.head_dim = dim // num_heads + self.scale = Tensor(self.head_dim ** -0.5) - self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) + self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + self.q_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() + self.k_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() self.attn_drop = Dropout(attn_drop) - self.proj = nn.Dense(all_head_dim, dim) + self.proj = nn.Dense(dim, dim) self.proj_drop = Dropout(proj_drop) self.mul = ops.Mul() @@ -110,22 +105,20 @@ def __init__( self.unstack = ops.Unstack(axis=0) self.attn_matmul_v = ops.BatchMatMul() self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - self.softmax = nn.Softmax(axis=-1) - def construct(self, x, rel_pos_bias=None): + def construct(self, x): b, n, c = x.shape qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) + qkv = self.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim)) qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) q, k, v = self.unstack(qkv) + q, k = self.q_norm(q), self.k_norm(k) attn = self.q_matmul_k(q, k) attn = self.mul(attn, self.scale) - if rel_pos_bias is not None: - attn = attn + rel_pos_bias - - attn = self.softmax(attn) + attn = attn.astype(ms.float32) + attn = ops.softmax(attn, axis=-1) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) @@ -164,7 +157,7 @@ def construct(self, x): return self.gamma * x -class TransformerBlock(nn.Cell): +class Block(nn.Cell): """ Transformer block implementation. @@ -172,10 +165,8 @@ class TransformerBlock(nn.Cell): dim (int): The dimension of embedding. num_heads (int): The number of attention heads. qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. - qk_scale: (float): The user-defined factor to scale the product of q and k. Default: None. attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0. - attn_head_dim (int): The user-defined dimension of attention head features. Default: None. mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer. drop_path (float): The drop rate for drop path. Default: 0.0. act_layer (nn.Cell): Activation function which will be stacked on top of the @@ -193,41 +184,48 @@ def __init__( self, dim: int, num_heads: int = 8, + mlp_ratio: float = 4., qkv_bias: bool = False, - qk_scale: Optional[float] = None, - attn_drop: float = 0., + qk_norm: bool = False, proj_drop: float = 0., - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - drop_path: float = 0., + attn_drop: float = 0., init_values: Optional[float] = None, + drop_path: float = 0., act_layer: nn.Cell = nn.GELU, norm_layer: nn.Cell = nn.LayerNorm, + mlp_layer: Callable = Mlp, ): - super(TransformerBlock, self).__init__() + super(Block, self).__init__() self.norm1 = norm_layer((dim,)) self.attn = Attention( - dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, ) self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer((dim,)) - self.mlp = Mlp( - in_features=dim, hidden_features=int(dim * mlp_ratio), - act_layer=act_layer, drop=proj_drop + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop ) self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def construct(self, x, rel_pos_bias=None): - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) + def construct(self, x): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x -class VisionTransformerEncoder(nn.Cell): +class VisionTransformer(nn.Cell): ''' ViT encoder, which returns the feature encoded by transformer encoder. ''' @@ -236,62 +234,93 @@ def __init__( image_size: int = 224, patch_size: int = 16, in_channels: int = 3, + global_pool: str = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, - attn_head_dim: Optional[int] = None, mlp_ratio: float = 4., qkv_bias: bool = True, - qk_scale: Optional[float] = None, + qk_norm: bool = False, + drop_rate: float = 0., pos_drop_rate: float = 0., + patch_drop_rate: float = 0., proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - init_values: Optional[float] = 0.1, + weight_init: bool = True, + init_values: Optional[float] = None, + no_embed_class: bool = False, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, act_layer: nn.Cell = nn.GELU, + embed_layer: Callable = PatchEmbed, norm_layer: nn.Cell = nn.LayerNorm, - use_rel_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = True, - **kwargs + mlp_layer: Callable = Mlp, + class_token: bool = True, + block_fn: Callable = Block, + num_classes: int = 1000, ): - super(VisionTransformerEncoder, self).__init__() - self.embed_dim = embed_dim - self.patch_embed = PatchEmbed(image_size=image_size, patch_size=patch_size, - in_chans=in_channels, embed_dim=embed_dim) - self.num_patches = self.patch_embed.num_patches - - self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) + super(VisionTransformer, self).__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + + self.global_pool = global_pool + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + self.dynamic_img_size = dynamic_img_size + self.dynamic_img_pad = dynamic_img_pad + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + elif dynamic_img_pad: + embed_args.update(dict(output_fmt='NHWC')) + + self.patch_embed = embed_layer( + image_size=image_size, + patch_size=patch_size, + in_chans=in_channels, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches - self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), - (1, self.num_patches + 1, embed_dim))) if not use_rel_pos_emb else None + self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), (1, embed_len, embed_dim))) self.pos_drop = Dropout(pos_drop_rate) - - if use_shared_rel_pos_bias: - self.rel_pos_bias = RelativePositionBiasWithCLS( - window_size=self.patch_embed.patches_resolution, - num_heads=num_heads, - ) - elif use_rel_pos_bias: - self.rel_pos_bias = nn.CellList([ - RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, - num_heads=num_heads) for _ in range(depth) - ]) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) else: - self.rel_pos_bias = None + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer((embed_dim,)) if pre_norm else nn.Identity() dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] self.blocks = nn.CellList([ - TransformerBlock( - dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, + block_fn( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, - act_layer=act_layer, norm_layer=norm_layer + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, ) for i in range(depth) ]) - self._init_weights() - self._fix_init_weights() + self.norm = norm_layer((embed_dim,)) if not use_fc_norm else nn.Identity() + self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else nn.Identity() + self.head_drop = Dropout(drop_rate) + self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init: + self._init_weights() def get_num_layers(self): return len(self.blocks) @@ -318,474 +347,155 @@ def _init_weights(self): ) elif isinstance(cell, nn.Conv2d): cell.weight.set_data( - initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + initializer(HeUniform(), cell.weight.shape, cell.weight.dtype) ) if cell.bias is not None: cell.bias.set_data( - initializer('zeros', cell.bias.shape, cell.bias.dtype) + initializer("zeros", cell.bias.shape, cell.bias.dtype) ) - def _fix_init_weights(self): - for i, block in enumerate(self.blocks): - block.attn.proj.weight.set_data( - ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) - ) - block.mlp.fc2.weight.set_data( - ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) + def _pos_embed(self, x): + if self.dynamic_img_size or self.dynamic_img_pad: + # bhwc format + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) + x = ops.reshape(x, (B, -1, C)) + else: + pos_embed = self.pos_embed + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if self.cls_token is not None: + cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) + x = x + pos_embed + + return self.pos_drop(x) def forward_features(self, x): x = self.patch_embed(x) - bsz = x.shape[0] - - cls_tokens = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) - cls_tokens = cls_tokens.astype(x.dtype) - x = ops.concat((cls_tokens, x), axis=1) - - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.pos_drop(x) - - if isinstance(self.rel_pos_bias, nn.CellList): - for i, blk in enumerate(self.blocks): - rel_pos_bias = self.rel_pos_bias[i]() - x = blk(x, rel_pos_bias) - else: - rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None - for blk in self.blocks: - x = blk(x, rel_pos_bias) - + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) return x - def construct(self, x): - x = self.forward_features(x) + def forward_head(self, x): + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(axis=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + x = self.head_drop(x) + x = self.head(x) return x - -class DenseHead(nn.Cell): - """ - LinearClsHead architecture. - - Args: - input_channel (int): The number of input channel. - num_classes (int): Number of classes. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (Union[str, Cell, Primitive]): activate function applied to the output. Eg. `ReLU`. Default: None. - keep_prob (float): Dropout keeping rate, between [0, 1]. E.g. rate=0.9, means dropping out 10% of input. - Default: 1.0. - - Returns: - Tensor, output tensor. - """ - - def __init__( - self, - input_channel: int, - num_classes: int, - has_bias: bool = True, - activation: Optional[Union[str, nn.Cell]] = None, - keep_prob: float = 1.0, - ) -> None: - super().__init__() - - self.dropout = Dropout(p=1.0-keep_prob) - self.classifier = nn.Dense(input_channel, num_classes, has_bias=has_bias, activation=activation) - def construct(self, x): - if self.training: - x = self.dropout(x) - x = self.classifier(x) + x = self.forward_features(x) + x = self.forward_head(x) return x -class MultilayerDenseHead(nn.Cell): - """ - MultilayerDenseHead architecture. - - Args: - input_channel (int): The number of input channel. - num_classes (int): Number of classes. - mid_channel (list): Number of channels in the hidden fc layers. - keep_prob (list): Dropout keeping rate, between [0, 1]. E.g. rate=0.9, means dropping out 10% of - input. - activation (list): activate function applied to the output. Eg. `ReLU`. - - Returns: - Tensor, output tensor. - """ - - def __init__( - self, - input_channel: int, - num_classes: int, - mid_channel: List[int], - keep_prob: List[float], - activation: List[Optional[Union[str, nn.Cell]]], - ) -> None: - super().__init__() - mid_channel.append(num_classes) - assert len(mid_channel) == len(activation) == len(keep_prob), "The length of the list should be the same." - - length = len(activation) - head = [] - - for i in range(length): - linear = DenseHead( - input_channel, - mid_channel[i], - activation=activation[i], - keep_prob=keep_prob[i], - ) - head.append(linear) - input_channel = mid_channel[i] - - self.classifier = nn.SequentialCell(head) - - def construct(self, x): - x = self.classifier(x) - - return x - +@register_model +def vit_b_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_16_224"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -class ViT(VisionTransformerEncoder): - def __init__( - self, - image_size: int = 224, - patch_size: int = 16, - in_channels: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - init_values: Optional[float] = 0.1, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - use_rel_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = True, - use_cls: bool = True, - representation_size: Optional[int] = None, - num_classes: int = 1000, - **kwargs - ): - super(ViT, self).__init__( - image_size=image_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - pos_drop_rate=pos_drop_rate, - proj_drop_rate=proj_drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - init_values=init_values, - act_layer=act_layer, - norm_layer=norm_layer, - use_rel_pos_emb=use_rel_pos_emb, - use_rel_pos_bias=use_rel_pos_bias, - use_shared_rel_pos_bias=use_shared_rel_pos_bias, - **kwargs - ) - self.use_cls = use_cls - self.norm = norm_layer((embed_dim,)) - - if representation_size: - self.head = MultilayerDenseHead( - input_channel=embed_dim, - num_classes=num_classes, - mid_channel=[representation_size], - activation=["tanh", None], - keep_prob=[1.0, 1.0], - ) - else: - self.head = DenseHead(input_channel=embed_dim, num_classes=num_classes) + return model - def construct(self, x): - x = self.forward_features(x) - x = self.norm(x) - if self.use_cls: - x = x[:, 0] - else: - x = x[:, 1:].mean(axis=1) +@register_model +def vit_b_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_16_384"] + model = VisionTransformer( + image_size=384, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - x = self.head(x) - return x + return model -def vit( - image_size: int = 224, - patch_size: int = 16, - in_channels: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - init_values: Optional[float] = None, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - use_rel_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = False, - use_cls: bool = True, - representation_size: Optional[int] = None, - num_classes: int = 1000, - pretrained: bool = False, - url_cfg: dict = None, -) -> ViT: - - """Vision Transformer architecture.""" - - model = ViT( - image_size=image_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - pos_drop_rate=pos_drop_rate, - proj_drop_rate=proj_drop_rate, - attn_drop_rate=attn_drop_rate, - drop_path_rate=drop_path_rate, - init_values=init_values, - act_layer=act_layer, - norm_layer=norm_layer, - use_rel_pos_emb=use_rel_pos_emb, - use_rel_pos_bias=use_rel_pos_bias, - use_shared_rel_pos_bias=use_shared_rel_pos_bias, - use_cls=use_cls, - representation_size=representation_size, - num_classes=num_classes +@register_model +def vit_l_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_16_224"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs ) - if pretrained: - # Download the pre-trained checkpoint file from url, and load ckpt file. - load_pretrained(model, url_cfg, num_classes=num_classes, in_channels=in_channels) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def vit_b_16_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_16_224"] - - return vit(**config) - +def vit_l_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_16_384"] + model = VisionTransformer( + image_size=384, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_b_16_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 384 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_16_384"] - - return vit(**config) + return model @register_model -def vit_l_16_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 1024 - config.depth = 24 - config.num_heads = 16 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 1024 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_l_16_224"] - - return vit(**config) - +def vit_b_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_32_224"] + model = VisionTransformer( + image_size=224, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_l_16_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 384 - config.patch_size = 16 - config.in_channels = in_channels - config.embed_dim = 1024 - config.depth = 24 - config.num_heads = 16 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 1024 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_l_16_384"] - - return vit(**config) + return model @register_model -def vit_b_32_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 32 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_32_224"] - - return vit(**config) - +def vit_b_32_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_32_384"] + model = VisionTransformer( + image_size=384, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_b_32_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 384 - config.patch_size = 32 - config.in_channels = in_channels - config.embed_dim = 768 - config.depth = 12 - config.num_heads = 12 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 768 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_b_32_384"] - - return vit(**config) + return model @register_model -def vit_l_32_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - has_logits: bool = False, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, -): - config = ConfigDict() - config.image_size = 224 - config.patch_size = 32 - config.in_channels = in_channels - config.embed_dim = 1024 - config.depth = 24 - config.num_heads = 16 - config.pos_drop_rate = drop_rate - config.proj_drop_rate = drop_rate - config.attn_drop_rate = drop_rate - config.drop_path_rate = drop_path_rate - config.representation_size = 1024 if has_logits else None - config.num_classes = num_classes - - config.pretrained = pretrained - config.url_cfg = default_cfgs["vit_l_32_224"] - - return vit(**config) +def vit_l_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_32_224"] + model = VisionTransformer( + image_size=224, patch_size=32, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + + return model diff --git a/mindcv/models/vit_encoder.py b/mindcv/models/vit_encoder.py deleted file mode 100644 index 010065273..000000000 --- a/mindcv/models/vit_encoder.py +++ /dev/null @@ -1,296 +0,0 @@ -import math -from typing import Optional, Tuple - -import numpy as np - -import mindspore as ms -from mindspore import Parameter, Tensor, nn, ops -from mindspore.common.initializer import TruncatedNormal, initializer - -from .layers.drop_path import DropPath -from .layers.mlp import Mlp -from .layers.patch_embed import PatchEmbed - - -class RelativePositionBiasWithCLS(nn.Cell): - def __init__( - self, - window_size: Tuple[int], - num_heads: int - ): - super(RelativePositionBiasWithCLS, self).__init__() - self.window_size = window_size - self.num_tokens = window_size[0] * window_size[1] - - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 - # 3: cls to token, token to cls, cls to cls - self.relative_position_bias_table = Parameter( - Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16) - ) - coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) - coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) - coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] - - relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] - relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] - relative_coords[:, :, 0] += window_size[0] - 1 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[0] - 1 - - relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), - dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] - relative_position_index[1:, 1:] = relative_coords.sum(-1) - relative_position_index[0, 0:] = num_relative_distance - 3 - relative_position_index[0:, 0] = num_relative_distance - 2 - relative_position_index[0, 0] = num_relative_distance - 1 - relative_position_index = Tensor(relative_position_index.reshape(-1)) - - self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) - self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) - - def construct(self): - out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) - out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) - out = ops.transpose(out, (2, 0, 1)) - out = ops.expand_dims(out, 0) - return out - - -class Attention(nn.Cell): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_scale: Optional[float] = None, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - attn_head_dim: Optional[int] = None, - ): - super(Attention, self).__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * num_heads - - if qk_scale: - self.scale = Tensor(qk_scale) - else: - self.scale = Tensor(head_dim ** -0.5) - - self.qkv = nn.Dense(dim, all_head_dim * 3, has_bias=qkv_bias) - - self.attn_drop = nn.Dropout(1 - attn_drop) - self.proj = nn.Dense(all_head_dim, dim) - self.proj_drop = nn.Dropout(1 - proj_drop) - - self.mul = ops.Mul() - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() - self.unstack = ops.Unstack(axis=0) - self.attn_matmul_v = ops.BatchMatMul() - self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - - def construct(self, x, rel_pos_bias=None): - b, n, c = x.shape - qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) - qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) - q, k, v = self.unstack(qkv) - - attn = self.q_matmul_k(q, k) - attn = self.mul(attn, self.scale) - - if rel_pos_bias is not None: - attn = attn + rel_pos_bias - - attn = attn.astype(ms.float32) - attn = ops.softmax(attn, axis=-1) - attn = self.attn_drop(attn) - - out = self.attn_matmul_v(attn, v) - out = self.transpose(out, (0, 2, 1, 3)) - out = self.reshape(out, (b, n, c)) - out = self.proj(out) - out = self.proj_drop(out) - - return out - - -class LayerScale(nn.Cell): - def __init__( - self, - dim: int, - init_values: float = 1e-5 - ): - super(LayerScale, self).__init__() - self.gamma = Parameter(initializer(init_values, dim)) - - def construct(self, x): - return self.gamma * x - - -class Block(nn.Cell): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_scale: Optional[float] = None, - attn_drop: float = 0., - proj_drop: float = 0., - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - drop_path: float = 0., - init_values: Optional[float] = None, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - ): - super(Block, self).__init__() - self.norm1 = norm_layer((dim,)) - self.attn = Attention( - dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=proj_drop, attn_head_dim=attn_head_dim, - ) - self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer((dim,)) - self.mlp = Mlp( - in_features=dim, hidden_features=int(dim * mlp_ratio), - act_layer=act_layer, drop=proj_drop - ) - self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def construct(self, x, rel_pos_bias=None): - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rel_pos_bias))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -class VisionTransformerEncoder(nn.Cell): - def __init__( - self, - img_size: int = 224, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - depth: int = 12, - num_heads: int = 12, - attn_head_dim: Optional[int] = None, - mlp_ratio: float = 4., - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - pos_drop_rate: float = 0., - proj_drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0., - init_values: Optional[float] = 0.1, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, - use_abs_pos_emb: bool = False, - use_rel_pos_bias: bool = False, - use_shared_rel_pos_bias: bool = True, - **kwargs - ): - super(VisionTransformerEncoder, self).__init__() - self.embed_dim = embed_dim - self.patch_embed = PatchEmbed(image_size=img_size, patch_size=patch_size, - in_chans=in_chans, embed_dim=embed_dim) - self.num_patches = self.patch_embed.num_patches - - self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) - - self.pos_embed = Parameter( - initializer(TruncatedNormal(0.02), (1, self.num_patches + 1, embed_dim))) if use_abs_pos_emb else None - self.pos_drop = nn.Dropout(1 - pos_drop_rate) - - if use_shared_rel_pos_bias: - self.rel_pos_bias = RelativePositionBiasWithCLS( - window_size=self.patch_embed.patches_resolution, num_heads=num_heads) - elif use_rel_pos_bias: - self.rel_pos_bias = nn.CellList([ - RelativePositionBiasWithCLS(window_size=self.patch_embed.patches_resolution, - num_heads=num_heads) for _ in range(depth) - ]) - else: - self.rel_pos_bias = None - - dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] - self.blocks = nn.CellList([ - Block( - dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, attn_head_dim=attn_head_dim, - mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, - act_layer=act_layer, norm_layer=norm_layer - ) for i in range(depth) - ]) - - def get_num_layers(self): - return len(self.blocks) - - def no_weight_decay(self): - return {'pos_embed', 'cls_token'} - - def _init_weights(self): - for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Dense): - cell.weight.set_data( - initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) - ) - if cell.bias is not None: - cell.bias.set_data( - initializer('zeros', cell.bias.shape, cell.bias.dtype) - ) - elif isinstance(cell, nn.LayerNorm): - cell.gamma.set_data( - initializer('ones', cell.gamma.shape, cell.gamma.dtype) - ) - cell.beta.set_data( - initializer('zeros', cell.beta.shape, cell.beta.dtype) - ) - elif isinstance(cell, nn.Conv2d): - cell.weight.set_data( - initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) - ) - if cell.bias is not None: - cell.bias.set_data( - initializer('zeros', cell.bias.shape, cell.bias.dtype) - ) - - def _fix_init_weights(self): - for i, block in enumerate(self.blocks): - block.attn.proj.weight.set_data( - ops.div(block.attn.proj.weight, math.sqrt(2.0 * (i + 1))) - ) - block.mlp.fc2.weight.set_data( - ops.div(block.mlp.fc2.weight, math.sqrt(2.0 * (i + 1))) - ) - - def forward_features(self, x): - x = self.patch_embed(x) - bsz = x.shape[0] - - cls_token = ops.broadcast_to(self.cls_token, (bsz, -1, -1)) - cls_token = cls_token.astype(x.dtype) - x = ops.concat((cls_token, x), axis=1) - - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.pos_drop(x) - - if isinstance(self.rel_pos_bias, nn.CellList): - for i, blk in enumerate(self.blocks): - rel_pos_bias = self.rel_pos_bias[i]() - x = blk(x, rel_pos_bias) - else: - rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None - for blk in self.blocks: - x = blk(x, rel_pos_bias) - - return x - - def construct(self, x): - return self.forward_features(x)