diff --git a/configs/vit/README.md b/configs/vit/README.md index e5d743fbb..f424e57e4 100644 --- a/configs/vit/README.md +++ b/configs/vit/README.md @@ -36,9 +36,9 @@ Our reproduced model performance on ImageNet-1K is reported as follows. | Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | |--------------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------| -| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt) | -| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt) | -| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt) | +| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt) | +| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt) | +| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt) | diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py index 6f28ba6f1..f3395796c 100644 --- a/mindcv/models/__init__.py +++ b/mindcv/models/__init__.py @@ -18,6 +18,7 @@ inceptionv3, inceptionv4, layers, + mae, mixnet, mlpmixer, mnasnet, @@ -74,6 +75,7 @@ from .inceptionv3 import * from .inceptionv4 import * from .layers import * +from .mae import * from .mixnet import * from .mlpmixer import * from .mnasnet import * @@ -132,6 +134,7 @@ __all__.extend(["InceptionV3", "inception_v3"]) __all__.extend(["InceptionV4", "inception_v4"]) __all__.extend(layers.__all__) +__all__.extend(mae.__all__) __all__.extend(mixnet.__all__) __all__.extend(mlpmixer.__all__) __all__.extend(mnasnet.__all__) diff --git a/mindcv/models/layers/__init__.py b/mindcv/models/layers/__init__.py index 2810dbca1..c3e4de210 100644 --- a/mindcv/models/layers/__init__.py +++ b/mindcv/models/layers/__init__.py @@ -1,10 +1,24 @@ """layers init""" -from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite +from . import ( + activation, + conv_norm_act, + drop_path, + format, + identity, + patch_dropout, + pooling, + pos_embed, + selective_kernel, + squeeze_excite, +) from .activation import * from .conv_norm_act import * from .drop_path import * +from .format import * from .identity import * +from .patch_dropout import * from .pooling import * +from .pos_embed import * from .selective_kernel import * from .squeeze_excite import * diff --git a/mindcv/models/layers/format.py b/mindcv/models/layers/format.py new file mode 100644 index 000000000..058a74517 --- /dev/null +++ b/mindcv/models/layers/format.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Union + +import mindspore + + +class Format(str, Enum): + NCHW = 'NCHW' + NHWC = 'NHWC' + NCL = 'NCL' + NLC = 'NLC' + + +FormatT = Union[str, Format] + + +def nchw_to(x: mindspore.Tensor, fmt: Format): + if fmt == Format.NHWC: + x = x.permute(0, 2, 3, 1) + elif fmt == Format.NLC: + x = x.flatten(start_dim=2).transpose((0, 2, 1)) + elif fmt == Format.NCL: + x = x.flatten(start_dim=2) + return x + + +def nhwc_to(x: mindspore.Tensor, fmt: Format): + if fmt == Format.NCHW: + x = x.permute(0, 3, 1, 2) + elif fmt == Format.NLC: + x = x.flatten(start_dim=1, end_dim=2) + elif fmt == Format.NCL: + x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1)) + return x diff --git a/mindcv/models/layers/patch_dropout.py b/mindcv/models/layers/patch_dropout.py new file mode 100644 index 000000000..ad854dbfc --- /dev/null +++ b/mindcv/models/layers/patch_dropout.py @@ -0,0 +1,54 @@ +import numpy as np + +import mindspore as ms +from mindspore import nn, ops + + +class PatchDropout(nn.Cell): + """ + https://arxiv.org/abs/2212.00794 + """ + def __init__( + self, + prob: float = 0.5, + num_prefix_tokens: int = 1, + ordered: bool = False, + return_indices: bool = False, + ): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) + self.ordered = ordered + self.return_indices = return_indices + self.sort = ops.Sort() + + def forward(self, x): + if not self.training or self.prob == 0.: + if self.return_indices: + return x, None + return x + + if self.num_prefix_tokens: + prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] + else: + prefix_tokens = None + + B = x.shape[0] + L = x.shape[1] + num_keep = max(1, int(L * (1. - self.prob))) + _, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32)) + keep_indices = indices[:, :num_keep] + if self.ordered: + # NOTE does not need to maintain patch order in typical transformer use, + # but possibly useful for debug / visualization + keep_indices, _ = self.sort(keep_indices) + keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2])) + x = ops.gather_elements(x, dim=1, index=keep_indices) + + if prefix_tokens is not None: + x = ops.concat((prefix_tokens, x), axis=1) + + if self.return_indices: + return x, keep_indices + return x diff --git a/mindcv/models/layers/patch_embed.py b/mindcv/models/layers/patch_embed.py index d2ca684f1..661e07890 100644 --- a/mindcv/models/layers/patch_embed.py +++ b/mindcv/models/layers/patch_embed.py @@ -4,6 +4,7 @@ from mindspore import Tensor, nn, ops +from .format import Format, nchw_to from .helpers import to_2tuple @@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell): embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Cell, optional): Normalization layer. Default: None """ + output_fmt: Format def __init__( self, - image_size: int = 224, + image_size: Optional[int] = 224, patch_size: int = 4, in_chans: int = 3, embed_dim: int = 96, norm_layer: Optional[nn.Cell] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, ) -> None: super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]] - self.image_size = image_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans + self.patch_size = to_2tuple(patch_size) + if image_size is not None: + self.image_size = to_2tuple(image_size) + self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)]) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + else: + self.image_size = None + self.patches_resolution = None + self.num_patches = None + + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + self.flatten = flatten + self.output_fmt = Format.NCHW + + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad self.embed_dim = embed_dim self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - pad_mode='pad', has_bias=True, weight_init="TruncatedNormal") + pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal") if norm_layer is not None: if isinstance(embed_dim, int): @@ -50,11 +67,29 @@ def __init__( def construct(self, x: Tensor) -> Tensor: """docstring""" - B = x.shape[0] - # FIXME look at relaxing size constraints - x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C - x = ops.Transpose()(x, (0, 2, 1)) + B, C, H, W = x.shape + if self.image_size is not None: + if self.strict_img_size: + if (H, W) != (self.image_size[0], self.image_size[1]): + raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]}," + f"{self.image_size[1]}).") + elif not self.dynamic_img_pad: + if H % self.patch_size[0] != 0: + raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).") + if W % self.patch_size[1] != 0: + raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).") + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = ops.pad(x, (0, pad_w, 0, pad_h)) + # FIXME look at relaxing size constraints + x = self.proj(x) + if self.flatten: + x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C + x = ops.Transpose()(x, (0, 2, 1)) + elif self.output_fmt != "NCHW": + x = nchw_to(x, self.output_fmt) if self.norm is not None: x = self.norm(x) return x diff --git a/mindcv/models/layers/pos_embed.py b/mindcv/models/layers/pos_embed.py new file mode 100644 index 000000000..ba4548580 --- /dev/null +++ b/mindcv/models/layers/pos_embed.py @@ -0,0 +1,93 @@ +"""positional embedding""" +import math +from typing import List, Optional, Tuple + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops + +from .compatibility import Interpolate + + +def resample_abs_pos_embed( + posemb, + new_size: List[int], + old_size: Optional[List[int]] = None, + num_prefix_tokens: int = 1, + interpolation: str = 'nearest', +): + # sort out sizes, assume square if old size not provided + num_pos_tokens = posemb.shape[1] + num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens + + if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: + return posemb + + if old_size is None: + hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) + old_size = hw, hw + + if num_prefix_tokens: + posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] + else: + posemb_prefix, posemb = None, posemb + + # do the interpolation + embed_dim = posemb.shape[-1] + orig_dtype = posemb.dtype + posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + interpolate = Interpolate(mode=interpolation, align_corners=True) + posemb = interpolate(posemb, size=new_size) + posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) + posemb = posemb.astype(orig_dtype) + + # add back extra (class, etc) prefix tokens + if posemb_prefix is not None: + posemb = ops.concatcat((posemb_prefix, posemb), axis=1) + + return posemb + + +class RelativePositionBiasWithCLS(nn.Cell): + def __init__( + self, + window_size: Tuple[int], + num_heads: int + ): + super(RelativePositionBiasWithCLS, self).__init__() + self.window_size = window_size + self.num_tokens = window_size[0] * window_size[1] + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + # 3: cls to token, token to cls, cls to cls + self.relative_position_bias_table = Parameter( + Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16) + ) + coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1) + coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1) + coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww] + + relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww] + relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2] + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[0] - 1 + + relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1), + dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1] + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index = Tensor(relative_position_index.reshape(-1)) + + self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) + self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) + + def construct(self): + out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) + out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) + out = ops.transpose(out, (2, 0, 1)) + out = ops.expand_dims(out, 0) + return out diff --git a/mindcv/models/mae.py b/mindcv/models/mae.py new file mode 100644 index 000000000..4a5cf887e --- /dev/null +++ b/mindcv/models/mae.py @@ -0,0 +1,381 @@ +from functools import partial +from typing import Callable, Optional + +import numpy as np + +import mindspore as ms +from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import Normal, initializer + +from .helpers import load_pretrained +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed +from .registry import register_model +from .vit import Block, VisionTransformer + +__all__ = [ + "mae_b_16_224_pretrain", + "mae_l_16_224_pretrain", + "mae_h_16_224_pretrain", + "mae_b_16_224_finetune", + "mae_l_16_224_finetune", + "mae_h_14_224_finetune" +] + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "first_conv": "patch_embed.proj", + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "mae_b_16_224_finetune": _cfg( + url="https://download.mindspore.cn/toolkits/mindcv/mae/mae_b_16_224_finetune-cc05b899.ckpt" + ), + "mae_l_16_224_finetune": _cfg(url=""), + "mae_h_14_224_finetune": _cfg(url=""), +} + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class MAEForPretrain(nn.Cell): + def __init__( + self, + image_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dim: int = 1024, + depth: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4., + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + init_values: Optional[float] = None, + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + mlp_layer: Callable = Mlp, + norm_pix_loss: bool = True, + mask_ratio: float = 0.75, + **kwargs, + ): + super(MAEForPretrain, self).__init__() + self.patch_embed = PatchEmbed(image_size=image_size, patch_size=patch_size, + in_chans=in_channels, embed_dim=embed_dim) + self.num_patches = self.patch_embed.num_patches + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + Block( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, + ) for i in range(depth) + ]) + + self.cls_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, embed_dim))) + + self.unmask_len = int(np.floor(self.num_patches * (1 - mask_ratio))) + + encoder_pos_emb = Tensor(get_2d_sincos_pos_embed( + embed_dim, int(self.num_patches ** 0.5), cls_token=True), ms.float32 + ) + encoder_pos_emb = ops.expand_dims(encoder_pos_emb, axis=0) + self.pos_embed = Parameter(encoder_pos_emb, requires_grad=False) + self.norm = norm_layer((embed_dim,)) + + self.decoder_embed = nn.Dense(embed_dim, decoder_embed_dim) + self.mask_token = Parameter(initializer(Normal(sigma=0.02), (1, 1, decoder_embed_dim))) + + decoder_pos_emb = Tensor(get_2d_sincos_pos_embed( + decoder_embed_dim, int(self.num_patches ** 0.5), cls_token=True), ms.float32 + ) + decoder_pos_emb = ops.expand_dims(decoder_pos_emb, axis=0) + self.decoder_pos_embed = Parameter(decoder_pos_emb, requires_grad=False) + + self.decoder_blocks = nn.CellList([ + Block( + dim=decoder_embed_dim, num_heads=decoder_num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, + ) for i in range(decoder_depth) + ]) + self.decoder_norm = norm_layer((decoder_embed_dim,)) + self.decoder_pred = nn.Dense(decoder_embed_dim, patch_size ** 2 * in_channels) + + self.sort = ops.Sort() + + self.norm_pix_loss = norm_pix_loss + self._init_weights() + + def _init_weights(self): + for name, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) + ) + if name == "patch_embed.proj": + cell.weight.set_data( + initializer("xavier_uniform", cell.weight.shape, cell.weight.dtype) + ) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size ** 2 * 3) + """ + N, _, H, W = imgs.shape + p = self.patch_embed.patch_size[0] + assert H == W and H % p == 0 + h = w = H // p + + x = ops.reshape(imgs, (N, 3, h, p, w, p)) + x = ops.transpose(x, (0, 2, 4, 3, 5, 1)) + x = ops.reshape(x, (N, h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size ** 2 * 3) + imgs: (N, 3, H, W) + """ + N, L, _ = x.shape + p = self.patch_embed.patch_size[0] + h = w = int(L ** 0.5) + assert h * w == L + + imgs = ops.reshape(x, (N, h, w, p, p, 3)) + imgs = ops.transpose(imgs, (0, 5, 1, 3, 2, 4)) + imgs = ops.reshape(imgs, (N, 3, h * p, w * p)) + return imgs + + def apply_masking(self, x, mask): + D = x.shape[2] + _, ids_shuffle = self.sort(mask.astype(ms.float32)) + _, ids_restore = self.sort(ids_shuffle.astype(ms.float32)) + + ids_keep = ids_shuffle[:, :self.unmask_len] + ids_keep = ops.broadcast_to(ops.expand_dims(ids_keep, axis=-1), (-1, -1, D)) + x_unmasked = ops.gather_elements(x, dim=1, index=ids_keep) + + return x_unmasked, ids_restore + + def forward_features(self, x, mask): + x = self.patch_embed(x) + bsz = x.shape[0] + + x = x + self.pos_embed[:, 1:, :] + x, ids_restore = self.apply_masking(x, mask) + + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_token = ops.broadcast_to(cls_token, (bsz, -1, -1)) + cls_token = cls_token.astype(x.dtype) + x = ops.concat((cls_token, x), axis=1) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x, ids_restore + + def forward_decoder(self, x, ids_restore): + x = self.decoder_embed(x) + bsz, L, D = x.shape + + mask_len = self.num_patches + 1 - L + mask_tokens = ops.broadcast_to(self.mask_token, (bsz, mask_len, -1)) + mask_tokens = mask_tokens.astype(x.dtype) + + x_ = ops.concat((x[:, 1:, :], mask_tokens), axis=1) + ids_restore = ops.broadcast_to(ops.expand_dims(ids_restore, axis=-1), (-1, -1, D)) + x_ = ops.gather_elements(x_, dim=1, index=ids_restore) + x = ops.concat((x[:, :1, :], x_), axis=1) + + x = x + self.decoder_pos_embed + + for blk in self.decoder_blocks: + x = blk(x) + + x = self.decoder_norm(x) + x = self.decoder_pred(x) + + return x[:, 1:, :] + + def forward_loss(self, imgs, pred, mask): + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(axis=-1, keep_dims=True) + std = target.std(axis=-1, keepdims=True) + target = (target - mean) / std + + loss = (pred - target) ** 2 + loss = loss.mean(axis=-1) + + mask = mask.astype(loss.dtype) + loss = (loss * mask).sum() / mask.sum() + return loss + + def construct(self, imgs, mask): + bsz = imgs.shape[0] + mask = ops.reshape(mask, (bsz, -1)) + features, ids_restore = self.forward_features(imgs, mask) + pred = self.forward_decoder(features, ids_restore) + loss = self.forward_loss(imgs, pred, mask) + return loss + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + +@register_model +def mae_b_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_b_16_224_pretrain"] + model = MAEForPretrain( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + return model + + +@register_model +def mae_l_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_l_16_224_pretrain"] + model = MAEForPretrain( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + return model + + +@register_model +def mae_h_16_224_pretrain(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_h_16_224_pretrain"] + model = MAEForPretrain( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, act_layer=partial(nn.GELU, approximate=False), + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + return model + + +@register_model +def mae_b_16_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_b_16_224_finetune"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + return model + + +@register_model +def mae_l_16_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_l_16_224_finetune"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + return model + + +@register_model +def mae_h_14_224_finetune(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["mae_h_14_224_finetune"] + model = VisionTransformer( + image_size=224, patch_size=14, in_channels=in_channels, embed_dim=1280, depth=32, num_heads=16, + global_pool='avg', norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) + return model diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index ac2c4c4c7..5a679df72 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -1,28 +1,30 @@ """ViT""" -from typing import List, Optional, Union +from typing import Callable, Optional import numpy as np import mindspore as ms -from mindspore import Tensor, nn -from mindspore import ops -from mindspore import ops as P -from mindspore.common.initializer import Normal, initializer -from mindspore.common.parameter import Parameter +from mindspore import Parameter, Tensor, nn, ops +from mindspore.common.initializer import HeUniform, TruncatedNormal, initializer -from .helpers import ConfigDict, load_pretrained +from .helpers import load_pretrained from .layers.compatibility import Dropout +from .layers.drop_path import DropPath +from .layers.mlp import Mlp +from .layers.patch_dropout import PatchDropout +from .layers.patch_embed import PatchEmbed +from .layers.pos_embed import resample_abs_pos_embed from .registry import register_model __all__ = [ - "ViT", + "VisionTransformer", "vit_b_16_224", "vit_b_16_384", - "vit_l_16_224", # train + "vit_l_16_224", # with pretrained weights "vit_l_16_384", - "vit_b_32_224", # train + "vit_b_32_224", # with pretrained weights "vit_b_32_384", - "vit_l_32_224", # train + "vit_l_32_224", # with pretrained weights ] @@ -32,7 +34,7 @@ def _cfg(url="", **kwargs): "num_classes": 1000, "input_size": (3, 224, 224), "first_conv": "patch_embed.proj", - "classifier": "classifier", + "classifier": "head.classifier", **kwargs, } @@ -42,62 +44,19 @@ def _cfg(url="", **kwargs): "vit_b_16_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt"), + "vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt"), "vit_l_16_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt"), + "vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt"), "vit_b_32_384": _cfg( url="", input_size=(3, 384, 384) ), - "vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt"), + "vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt"), } -class PatchEmbedding(nn.Cell): - """ - Path embedding layer for ViT. First rearrange b c (h p) (w p) -> b (h w) (p p c). - - Args: - image_size (int): Input image size. Default: 224. - patch_size (int): Patch size of image. Default: 16. - embed_dim (int): The dimension of embedding. Default: 768. - input_channels (int): The number of input channel. Default: 3. - - Returns: - Tensor, output tensor. - - Examples: - >>> ops = PathEmbedding(224, 16, 768, 3) - """ - - MIN_NUM_PATCHES = 4 - - def __init__( - self, - image_size: int = 224, - patch_size: int = 16, - embed_dim: int = 768, - input_channels: int = 3, - ): - super().__init__() - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = (image_size // patch_size) ** 2 - self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True) - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() - - def construct(self, x): - """Path Embedding construct.""" - x = self.conv(x) - b, c, h, w = x.shape - x = self.reshape(x, (b, c, h * w)) - x = self.transpose(x, (0, 2, 1)) - - return x - - +# TODO: Flash Attention class Attention(nn.Cell): """ Attention layer implementation, Rearrange Input -> B x N x hidden size. @@ -105,8 +64,10 @@ class Attention(nn.Cell): Args: dim (int): The dimension of input features. num_heads (int): The number of attention heads. Default: 8. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention. Default: 1.0. + qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. + qk_norm (bool): Specifies whether to do normalization to q and k. + attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. + proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0. Returns: Tensor, output tensor. @@ -114,23 +75,29 @@ class Attention(nn.Cell): Examples: >>> ops = Attention(768, 12) """ - def __init__( self, dim: int, num_heads: int = 8, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Cell = nn.LayerNorm, ): - super().__init__() + super(Attention, self).__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = Tensor(head_dim**-0.5) + self.head_dim = dim // num_heads + self.scale = Tensor(self.head_dim ** -0.5) - self.qkv = nn.Dense(dim, dim * 3) - self.attn_drop = Dropout(p=1.0-attention_keep_prob) - self.out = nn.Dense(dim, dim) - self.out_drop = Dropout(p=1.0-keep_prob) + self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + self.q_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() + self.k_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() + + self.attn_drop = Dropout(attn_drop) + self.proj = nn.Dense(dim, dim) + self.proj_drop = Dropout(proj_drop) self.mul = ops.Mul() self.reshape = ops.Reshape() @@ -138,142 +105,74 @@ def __init__( self.unstack = ops.Unstack(axis=0) self.attn_matmul_v = ops.BatchMatMul() self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - self.softmax = nn.Softmax(axis=-1) def construct(self, x): - """Attention construct.""" b, n, c = x.shape qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) + qkv = self.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim)) qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) q, k, v = self.unstack(qkv) + q, k = self.q_norm(q), self.k_norm(k) attn = self.q_matmul_k(q, k) attn = self.mul(attn, self.scale) - attn = self.softmax(attn) + + attn = attn.astype(ms.float32) + attn = ops.softmax(attn, axis=-1) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) out = self.transpose(out, (0, 2, 1, 3)) out = self.reshape(out, (b, n, c)) - out = self.out(out) - out = self.out_drop(out) + out = self.proj(out) + out = self.proj_drop(out) return out -class FeedForward(nn.Cell): +class LayerScale(nn.Cell): """ - Feed Forward layer implementation. + Layer scale, help ViT improve the training dynamic, allowing for the training + of deeper high-capacity image transformers that benefit from depth Args: - in_features (int): The dimension of input features. - hidden_features (int): The dimension of hidden features. Default: None. - out_features (int): The dimension of output features. Default: None - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. + dim (int): The output dimension of attnetion layer or mlp layer. + init_values (float): The scale factor. Default: 1e-5. Returns: Tensor, output tensor. Examples: - >>> ops = FeedForward(768, 3072) + >>> ops = LayerScale(768, 0.01) """ - def __init__( self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - activation: nn.Cell = nn.GELU, - keep_prob: float = 1.0, + dim: int, + init_values: float = 1e-5 ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.dense1 = nn.Dense(in_features, hidden_features) - self.activation = activation() - self.dense2 = nn.Dense(hidden_features, out_features) - self.dropout = Dropout(p=1.0-keep_prob) - - def construct(self, x): - """Feed Forward construct.""" - x = self.dense1(x) - x = self.activation(x) - x = self.dropout(x) - x = self.dense2(x) - x = self.dropout(x) - - return x - - -class ResidualCell(nn.Cell): - """ - Cell which implements Residual function: - - $$output = x + f(x)$$ - - Args: - cell (Cell): Cell needed to add residual block. - - Returns: - Tensor, output tensor. - - Examples: - >>> ops = ResidualCell(nn.Dense(3,4)) - """ - - def __init__(self, cell): - super().__init__() - self.cell = cell - - def construct(self, x): - """ResidualCell construct.""" - return self.cell(x) + x - - -class DropPath(nn.Cell): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ - - def __init__(self, keep_prob=None, seed=0): - super().__init__() - self.keep_prob = 1 - keep_prob - seed = min(seed, 0) - self.rand = P.UniformReal(seed=seed) - self.shape = P.Shape() - self.floor = P.Floor() + super(LayerScale, self).__init__() + self.gamma = Parameter(initializer(init_values, dim)) def construct(self, x): - if self.training: - x_shape = self.shape(x) - random_tensor = self.rand((x_shape[0], 1, 1)) - random_tensor = random_tensor + self.keep_prob - random_tensor = self.floor(random_tensor) - x = x / self.keep_prob - x = x * random_tensor - - return x + return self.gamma * x -class TransformerEncoder(nn.Cell): +class Block(nn.Cell): """ - TransformerEncoder implementation. + Transformer block implementation. Args: dim (int): The dimension of embedding. - num_layers (int): The depth of transformer. num_heads (int): The number of attention heads. - mlp_dim (int): The dimension of MLP hidden layer. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention. Default: 1.0. - drop_path_keep_prob (float): The keep rate for drop path. Default: 1.0. - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution - layer. Default: nn.LayerNorm. + qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True. + attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0. + proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0. + mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer. + drop_path (float): The drop rate for drop path. Default: 0.0. + act_layer (nn.Cell): Activation function which will be stacked on top of the + normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. + norm_layer (nn.Cell): Norm layer that will be stacked on top of the convolution + layer. Default: nn.LayerNorm. Returns: Tensor, output tensor. @@ -281,644 +180,322 @@ class TransformerEncoder(nn.Cell): Examples: >>> ops = TransformerEncoder(768, 12, 12, 3072) """ - def __init__( self, dim: int, - num_layers: int, - num_heads: int, - mlp_dim: int, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, - drop_path_keep_prob: float = 1.0, - activation: nn.Cell = nn.GELU, - norm: nn.Cell = nn.LayerNorm, + num_heads: int = 8, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Cell = nn.GELU, + norm_layer: nn.Cell = nn.LayerNorm, + mlp_layer: Callable = Mlp, ): - super().__init__() - drop_path_rate = 1 - drop_path_keep_prob - dpr = [i.item() for i in np.linspace(0, drop_path_rate, num_layers)] - attn_seeds = [np.random.randint(1024) for _ in range(num_layers)] - mlp_seeds = [np.random.randint(1024) for _ in range(num_layers)] - - layers = [] - for i in range(num_layers): - normalization1 = norm((dim,)) - normalization2 = norm((dim,)) - attention = Attention(dim=dim, - num_heads=num_heads, - keep_prob=keep_prob, - attention_keep_prob=attention_keep_prob) - - feedforward = FeedForward(in_features=dim, - hidden_features=mlp_dim, - activation=activation, - keep_prob=keep_prob) - - if drop_path_rate > 0: - layers.append( - nn.SequentialCell([ - ResidualCell(nn.SequentialCell([normalization1, - attention, - DropPath(dpr[i], attn_seeds[i])])), - ResidualCell(nn.SequentialCell([normalization2, - feedforward, - DropPath(dpr[i], mlp_seeds[i])]))])) - else: - layers.append( - nn.SequentialCell([ - ResidualCell(nn.SequentialCell([normalization1, - attention])), - ResidualCell(nn.SequentialCell([normalization2, - feedforward])) - ]) - ) - self.layers = nn.SequentialCell(layers) - - def construct(self, x): - """Transformer construct.""" - return self.layers(x) - - -class DenseHead(nn.Cell): - """ - LinearClsHead architecture. - - Args: - input_channel (int): The number of input channel. - num_classes (int): Number of classes. - has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. - activation (Union[str, Cell, Primitive]): activate function applied to the output. Eg. `ReLU`. Default: None. - keep_prob (float): Dropout keeping rate, between [0, 1]. E.g. rate=0.9, means dropping out 10% of input. - Default: 1.0. - - Returns: - Tensor, output tensor. - """ - - def __init__( - self, - input_channel: int, - num_classes: int, - has_bias: bool = True, - activation: Optional[Union[str, nn.Cell]] = None, - keep_prob: float = 1.0, - ) -> None: - super().__init__() - - self.dropout = Dropout(p=1.0-keep_prob) - self.classifier = nn.Dense(input_channel, num_classes, has_bias=has_bias, activation=activation) + super(Block, self).__init__() + self.norm1 = norm_layer((dim,)) + self.attn = Attention( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer((dim,)) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop + ) + self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def construct(self, x): - if self.training: - x = self.dropout(x) - x = self.classifier(x) + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x -class MultilayerDenseHead(nn.Cell): - """ - MultilayerDenseHead architecture. - - Args: - input_channel (int): The number of input channel. - num_classes (int): Number of classes. - mid_channel (list): Number of channels in the hidden fc layers. - keep_prob (list): Dropout keeping rate, between [0, 1]. E.g. rate=0.9, means dropping out 10% of - input. - activation (list): activate function applied to the output. Eg. `ReLU`. - - Returns: - Tensor, output tensor. - """ - +class VisionTransformer(nn.Cell): + ''' + ViT encoder, which returns the feature encoded by transformer encoder. + ''' def __init__( self, - input_channel: int, - num_classes: int, - mid_channel: List[int], - keep_prob: List[float], - activation: List[Optional[Union[str, nn.Cell]]], - ) -> None: - super().__init__() - mid_channel.append(num_classes) - assert len(mid_channel) == len(activation) == len(keep_prob), "The length of the list should be the same." - - length = len(activation) - head = [] - - for i in range(length): - linear = DenseHead( - input_channel, - mid_channel[i], - activation=activation[i], - keep_prob=keep_prob[i], + image_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + global_pool: str = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: bool = True, + init_values: Optional[float] = None, + no_embed_class: bool = False, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + act_layer: nn.Cell = nn.GELU, + embed_layer: Callable = PatchEmbed, + norm_layer: nn.Cell = nn.LayerNorm, + mlp_layer: Callable = Mlp, + class_token: bool = True, + block_fn: Callable = Block, + num_classes: int = 1000, + ): + super(VisionTransformer, self).__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + + self.global_pool = global_pool + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + self.dynamic_img_size = dynamic_img_size + self.dynamic_img_pad = dynamic_img_pad + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + elif dynamic_img_pad: + embed_args.update(dict(output_fmt='NHWC')) + + self.patch_embed = embed_layer( + image_size=image_size, + patch_size=patch_size, + in_chans=in_channels, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), (1, embed_len, embed_dim))) + self.pos_drop = Dropout(pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, ) - head.append(linear) - input_channel = mid_channel[i] - - self.classifier = nn.SequentialCell(head) - - def construct(self, x): - x = self.classifier(x) - - return x - - -class BaseClassifier(nn.Cell): - """ - generate classifier to combine the backbone and head - """ - - def __init__(self, backbone, neck=None, head=None): - super().__init__() - self.backbone = backbone - if neck: - self.neck = neck - self.with_neck = True else: - self.with_neck = False - if head: - self.head = head - self.with_head = True + self.patch_drop = nn.Identity() + + self.norm_pre = norm_layer((embed_dim,)) if pre_norm else nn.Identity() + dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] + self.blocks = nn.CellList([ + block_fn( + dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, + mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values, + act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer, + ) for i in range(depth) + ]) + + self.norm = norm_layer((embed_dim,)) if not use_fc_norm else nn.Identity() + self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else nn.Identity() + self.head_drop = Dropout(drop_rate) + self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init: + self._init_weights() + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _init_weights(self): + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data( + initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) + ) + elif isinstance(cell, nn.LayerNorm): + cell.gamma.set_data( + initializer('ones', cell.gamma.shape, cell.gamma.dtype) + ) + cell.beta.set_data( + initializer('zeros', cell.beta.shape, cell.beta.dtype) + ) + elif isinstance(cell, nn.Conv2d): + cell.weight.set_data( + initializer(HeUniform(), cell.weight.shape, cell.weight.dtype) + ) + if cell.bias is not None: + cell.bias.set_data( + initializer("zeros", cell.bias.shape, cell.bias.dtype) + ) + + def _pos_embed(self, x): + if self.dynamic_img_size or self.dynamic_img_pad: + # bhwc format + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = ops.reshape(x, (B, -1, C)) else: - self.with_head = False - - def forward_features(self, x: Tensor) -> Tensor: - x = self.backbone(x) + pos_embed = self.pos_embed + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if self.cls_token is not None: + cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = cls_tokens.astype(x.dtype) + x = ops.concat((cls_tokens, x), axis=1) + x = x + pos_embed + + return self.pos_drop(x) + + def forward_features(self, x): + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) return x - def forward_head(self, x: Tensor) -> Tensor: + def forward_head(self, x): + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(axis=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + x = self.head_drop(x) x = self.head(x) return x def construct(self, x): x = self.forward_features(x) - if self.with_neck: - x = self.neck(x) - if self.with_head: - x = self.forward_head(x) + x = self.forward_head(x) return x -def init(init_type, shape, dtype, name, requires_grad): - initial = initializer(init_type, shape, dtype).init_data() - return Parameter(initial, name=name, requires_grad=requires_grad) - - -class ViT(nn.Cell): - """ - Vision Transformer architecture implementation. - - Args: - image_size (int): Input image size. Default: 224. - input_channels (int): The number of input channel. Default: 3. - patch_size (int): Patch size of image. Default: 16. - embed_dim (int): The dimension of embedding. Default: 768. - num_layers (int): The depth of transformer. Default: 12. - num_heads (int): The number of attention heads. Default: 12. - mlp_dim (int): The dimension of MLP hidden layer. Default: 3072. - keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0. - attention_keep_prob (float): The keep rate for attention layer. Default: 1.0. - drop_path_keep_prob (float): The keep rate for drop path. Default: 1.0. - activation (nn.Cell): Activation function which will be stacked on top of the - normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU. - norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution - layer. Default: nn.LayerNorm. - pool (str): The method of pooling. Default: 'cls'. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor of shape :math:`(N, 768)` - - Raises: - ValueError: If `split` is not 'train', 'test' or 'infer'. - - Supported Platforms: - ``GPU`` - - Examples: - >>> net = ViT() - >>> x = ms.Tensor(np.ones([1, 3, 224, 224]), ms.float32) - >>> output = net(x) - >>> print(output.shape) - (1, 768) - - About ViT: - - Vision Transformer (ViT) shows that a pure transformer applied directly to sequences of image - patches can perform very well on image classification tasks. When pre-trained on large amounts - of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, - CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art - convolutional networks while requiring substantially fewer computational resources to train. - - Citation: - - .. code-block:: - - @article{2020An, - title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}, - author={Dosovitskiy, A. and Beyer, L. and Kolesnikov, A. and Weissenborn, D. and Houlsby, N.}, - year={2020}, - } - """ - - def __init__( - self, - image_size: int = 224, - input_channels: int = 3, - patch_size: int = 16, - embed_dim: int = 768, - num_layers: int = 12, - num_heads: int = 12, - mlp_dim: int = 3072, - keep_prob: float = 1.0, - attention_keep_prob: float = 1.0, - drop_path_keep_prob: float = 1.0, - activation: nn.Cell = nn.GELU, - norm: Optional[nn.Cell] = nn.LayerNorm, - pool: str = "cls", - ) -> None: - super().__init__() - - self.patch_embedding = PatchEmbedding(image_size=image_size, - patch_size=patch_size, - embed_dim=embed_dim, - input_channels=input_channels) - num_patches = self.patch_embedding.num_patches - - if pool == "cls": - self.cls_token = init(init_type=Normal(sigma=1.0), - shape=(1, 1, embed_dim), - dtype=ms.float32, - name="cls", - requires_grad=True) - self.pos_embedding = init(init_type=Normal(sigma=1.0), - shape=(1, num_patches + 1, embed_dim), - dtype=ms.float32, - name="pos_embedding", - requires_grad=True) - self.concat = ops.Concat(axis=1) - else: - self.pos_embedding = init(init_type=Normal(sigma=1.0), - shape=(1, num_patches, embed_dim), - dtype=ms.float32, - name="pos_embedding", - requires_grad=True) - self.mean = ops.ReduceMean(keep_dims=False) - - self.pool = pool - self.pos_dropout = Dropout(p=1.0-keep_prob) - self.norm = norm((embed_dim,)) - self.tile = ops.Tile() - self.transformer = TransformerEncoder( - dim=embed_dim, - num_layers=num_layers, - num_heads=num_heads, - mlp_dim=mlp_dim, - keep_prob=keep_prob, - attention_keep_prob=attention_keep_prob, - drop_path_keep_prob=drop_path_keep_prob, - activation=activation, - norm=norm, - ) - - def construct(self, x): - """ViT construct.""" - x = self.patch_embedding(x) - - if self.pool == "cls": - cls_tokens = self.tile(self.cls_token, (x.shape[0], 1, 1)) - x = self.concat((cls_tokens, x)) - x += self.pos_embedding - else: - x += self.pos_embedding - x = self.pos_dropout(x) - x = self.transformer(x) - x = self.norm(x) +@register_model +def vit_b_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_16_224"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - if self.pool == "cls": - x = x[:, 0] - else: - x = self.mean(x, (1, )) # (1,) or (1,2) - return x + return model -def vit( - image_size: int, - input_channels: int, - patch_size: int, - embed_dim: int, - num_layers: int, - num_heads: int, - num_classes: int, - mlp_dim: int, - dropout: float = 0.0, - attention_dropout: float = 0.0, - drop_path_rate: float = 0.0, - activation: nn.Cell = nn.GELU, - norm: nn.Cell = nn.LayerNorm, - pool: str = "cls", - representation_size: Optional[int] = None, - pretrained: bool = False, - url_cfg: dict = None, -) -> ViT: - """Vision Transformer architecture.""" - backbone = ViT( - image_size=image_size, - input_channels=input_channels, - patch_size=patch_size, - embed_dim=embed_dim, - num_layers=num_layers, - num_heads=num_heads, - mlp_dim=mlp_dim, - keep_prob=1.0 - dropout, - attention_keep_prob=1.0 - attention_dropout, - drop_path_keep_prob=1.0 - drop_path_rate, - activation=activation, - norm=norm, - pool=pool, +@register_model +def vit_b_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_16_384"] + model = VisionTransformer( + image_size=384, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs ) - if representation_size: - head = MultilayerDenseHead( - input_channel=embed_dim, - num_classes=num_classes, - mid_channel=[representation_size], - activation=["tanh", None], - keep_prob=[1.0, 1.0], - ) - else: - head = DenseHead(input_channel=embed_dim, num_classes=num_classes) - - model = BaseClassifier(backbone=backbone, head=head) - if pretrained: - # Download the pre-trained checkpoint file from url, and load ckpt file. - load_pretrained(model, url_cfg, num_classes=num_classes, in_channels=input_channels) + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model -def vit_b_16_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - image_size: int = 224, - has_logits: bool = False, - drop_rate: float = 0.0, - # attention-dropout: float = 0.0, - drop_path_rate: float = 0.0, -) -> ViT: - """ - Constructs a vit_b_16 architecture from - `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. - - Args: - pretrained (bool): Whether to download and load the pre-trained model. Default: False. - num_classes (int): The number of classification. Default: 1000. - in_channels (int): The number of input channels. Default: 3. - image_size (int): The input image size. Default: 224 for ImageNet. - has_logits (bool): Whether has logits or not. Default: False. - drop_rate (float): The drop out rate. Default: 0.0.s - drop_path_rate (float): The stochastic depth rate. Default: 0.0. - - Returns: - ViT network, MindSpore.nn.Cell +def vit_l_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_16_224"] + model = VisionTransformer( + image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + return model - Examples: - >>> net = vit_b_16_224() - >>> x = ms.Tensor(np.ones([1, 3, 224, 224]), ms.float32) - >>> output = net(x) - >>> print(output.shape) - (1, 1000) - Outputs: - Tensor of shape :math:`(N, CLASSES_{out})` +@register_model +def vit_l_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_16_384"] + model = VisionTransformer( + image_size=384, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - Supported Platforms: - ``GPU`` - """ - config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes - config.patch_size = 16 - config.embed_dim = 768 - config.mlp_dim = 3072 - config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout - config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" - config.representation_size = 768 if has_logits else None - - config.url_cfg = default_cfgs["vit_b_16_224"] - - return vit(**config) + return model @register_model -def vit_b_16_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - image_size: int = 384, - has_logits: bool = False, - drop_rate: float = 0.0, - # attention-dropout: float = 0.0, - drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes - config.patch_size = 16 - config.embed_dim = 768 - config.mlp_dim = 3072 - config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout - config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" - config.representation_size = 768 if has_logits else None - - config.url_cfg = default_cfgs["vit_b_16_384"] - - return vit(**config) - +def vit_b_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_32_224"] + model = VisionTransformer( + image_size=224, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_l_16_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - image_size: int = 224, - has_logits: bool = False, - drop_rate: float = 0.0, - # attention-dropout: float = 0.0, - drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - - config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes - config.patch_size = 16 - config.embed_dim = 1024 - config.mlp_dim = 4096 - config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout - config.drop_path_rate = drop_path_rate - config.input_channels = in_channels - config.pool = "cls" - config.pretrained = pretrained - config.representation_size = 1024 if has_logits else None - - config.url_cfg = default_cfgs["vit_l_16_224"] - - return vit(**config) + return model @register_model -def vit_l_16_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - image_size: int = 384, - has_logits: bool = False, - drop_rate: float = 0.0, - # attention-dropout: float = 0.0, - drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - - config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes - config.patch_size = 16 - config.embed_dim = 1024 - config.mlp_dim = 4096 - config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout - config.drop_path_rate = drop_path_rate - config.input_channels = in_channels - config.pool = "cls" - config.pretrained = pretrained - config.representation_size = 1024 if has_logits else None - - config.url_cfg = default_cfgs["vit_l_16_384"] - - return vit(**config) - +def vit_b_32_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_b_32_384"] + model = VisionTransformer( + image_size=384, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_b_32_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - image_size: int = 224, - has_logits: bool = False, - drop_rate: float = 0.0, - # attention-dropout: float = 0.0, - drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes - config.patch_size = 32 - config.embed_dim = 768 - config.mlp_dim = 3072 - config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout - config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" - config.representation_size = 768 if has_logits else None - - config.url_cfg = default_cfgs["vit_b_32_224"] - - return vit(**config) + return model @register_model -def vit_b_32_384( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - image_size: int = 384, - has_logits: bool = False, - drop_rate: float = 0.0, - # attention_dropout: float = 0.0, - drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes - config.patch_size = 32 - config.embed_dim = 768 - config.mlp_dim = 3072 - config.num_heads = 12 - config.num_layers = 12 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention_dropout - config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" - config.representation_size = 768 if has_logits else None - - config.url_cfg = default_cfgs["vit_b_32_384"] - - return vit(**config) - +def vit_l_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs): + default_cfg = default_cfgs["vit_l_32_224"] + model = VisionTransformer( + image_size=224, patch_size=32, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16, + num_classes=num_classes, **kwargs + ) + if pretrained: + load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) -@register_model -def vit_l_32_224( - pretrained: bool = False, - num_classes: int = 1000, - in_channels: int = 3, - image_size: int = 224, - has_logits: bool = False, - drop_rate: float = 0.0, - # attention-dropout: float = 0.0, - drop_path_rate: float = 0.0, -) -> ViT: - """construct and return a ViT network""" - config = ConfigDict() - config.image_size = image_size - config.num_classes = num_classes - config.patch_size = 32 - config.embed_dim = 1024 - config.mlp_dim = 4096 - config.num_heads = 16 - config.num_layers = 24 - config.dropout = drop_rate - config.attention_dropout = drop_rate # attention-dropout - config.drop_path_rate = drop_path_rate - config.pretrained = pretrained - config.input_channels = in_channels - config.pool = "cls" - config.representation_size = 1024 if has_logits else None - - config.url_cfg = default_cfgs["vit_l_32_224"] - - return vit(**config) + return model