Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend vit and add mae model and finetune checkpoint file #707

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ Our reproduced model performance on ImageNet-1K is reported as follows.

| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
|--------------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------|
| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt) |
| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt) |
| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt) |
| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt) |
| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt) |
| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt) |

</div>

Expand Down
3 changes: 3 additions & 0 deletions mindcv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
inceptionv3,
inceptionv4,
layers,
mae,
mixnet,
mlpmixer,
mnasnet,
Expand Down Expand Up @@ -74,6 +75,7 @@
from .inceptionv3 import *
from .inceptionv4 import *
from .layers import *
from .mae import *
from .mixnet import *
from .mlpmixer import *
from .mnasnet import *
Expand Down Expand Up @@ -132,6 +134,7 @@
__all__.extend(["InceptionV3", "inception_v3"])
__all__.extend(["InceptionV4", "inception_v4"])
__all__.extend(layers.__all__)
__all__.extend(mae.__all__)
__all__.extend(mixnet.__all__)
__all__.extend(mlpmixer.__all__)
__all__.extend(mnasnet.__all__)
Expand Down
16 changes: 15 additions & 1 deletion mindcv/models/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
"""layers init"""
from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite
from . import (
activation,
conv_norm_act,
drop_path,
format,
identity,
patch_dropout,
pooling,
pos_embed,
selective_kernel,
squeeze_excite,
)
from .activation import *
from .conv_norm_act import *
from .drop_path import *
from .format import *
from .identity import *
from .patch_dropout import *
from .pooling import *
from .pos_embed import *
from .selective_kernel import *
from .squeeze_excite import *

Expand Down
34 changes: 34 additions & 0 deletions mindcv/models/layers/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum
from typing import Union

import mindspore


class Format(str, Enum):
NCHW = 'NCHW'
NHWC = 'NHWC'
NCL = 'NCL'
NLC = 'NLC'


FormatT = Union[str, Format]


def nchw_to(x: mindspore.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
elif fmt == Format.NLC:
x = x.flatten(start_dim=2).transpose((0, 2, 1))
elif fmt == Format.NCL:
x = x.flatten(start_dim=2)
return x


def nhwc_to(x: mindspore.Tensor, fmt: Format):
if fmt == Format.NCHW:
x = x.permute(0, 3, 1, 2)
elif fmt == Format.NLC:
x = x.flatten(start_dim=1, end_dim=2)
elif fmt == Format.NCL:
x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1))
return x
54 changes: 54 additions & 0 deletions mindcv/models/layers/patch_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np

import mindspore as ms
from mindspore import nn, ops


class PatchDropout(nn.Cell):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered
self.return_indices = return_indices
self.sort = ops.Sort()

def forward(self, x):
if not self.training or self.prob == 0.:
if self.return_indices:
return x, None
return x

if self.num_prefix_tokens:
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
else:
prefix_tokens = None

B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - self.prob)))
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
keep_indices = indices[:, :num_keep]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices, _ = self.sort(keep_indices)
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2]))
x = ops.gather_elements(x, dim=1, index=keep_indices)

if prefix_tokens is not None:
x = ops.concat((prefix_tokens, x), axis=1)

if self.return_indices:
return x, keep_indices
return x
65 changes: 50 additions & 15 deletions mindcv/models/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mindspore import Tensor, nn, ops

from .format import Format, nchw_to
from .helpers import to_2tuple


Expand All @@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell):
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Cell, optional): Normalization layer. Default: None
"""
output_fmt: Format

def __init__(
self,
image_size: int = 224,
image_size: Optional[int] = 224,
patch_size: int = 4,
in_chans: int = 3,
embed_dim: int = 96,
norm_layer: Optional[nn.Cell] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
) -> None:
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
self.image_size = image_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]

self.in_chans = in_chans
self.patch_size = to_2tuple(patch_size)
if image_size is not None:
self.image_size = to_2tuple(image_size)
self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
else:
self.image_size = None
self.patches_resolution = None
self.num_patches = None

if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
self.flatten = flatten
self.output_fmt = Format.NCHW

self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.embed_dim = embed_dim

self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size,
pad_mode='pad', has_bias=True, weight_init="TruncatedNormal")
pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal")

if norm_layer is not None:
if isinstance(embed_dim, int):
Expand All @@ -50,11 +67,29 @@ def __init__(

def construct(self, x: Tensor) -> Tensor:
"""docstring"""
B = x.shape[0]
# FIXME look at relaxing size constraints
x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.Transpose()(x, (0, 2, 1))
B, C, H, W = x.shape
if self.image_size is not None:
if self.strict_img_size:
if (H, W) != (self.image_size[0], self.image_size[1]):
raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]},"
f"{self.image_size[1]}).")
elif not self.dynamic_img_pad:
if H % self.patch_size[0] != 0:
raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).")
if W % self.patch_size[1] != 0:
raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).")
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = ops.pad(x, (0, pad_w, 0, pad_h))

# FIXME look at relaxing size constraints
x = self.proj(x)
if self.flatten:
x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.Transpose()(x, (0, 2, 1))
elif self.output_fmt != "NCHW":
x = nchw_to(x, self.output_fmt)
if self.norm is not None:
x = self.norm(x)
return x
93 changes: 93 additions & 0 deletions mindcv/models/layers/pos_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""positional embedding"""
import math
from typing import List, Optional, Tuple

import numpy as np

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops

from .compatibility import Interpolate


def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'nearest',
):
# sort out sizes, assume square if old size not provided
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens

if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb

if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw

if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb

# do the interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
interpolate = Interpolate(mode=interpolation, align_corners=True)
posemb = interpolate(posemb, size=new_size)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
posemb = posemb.astype(orig_dtype)

# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
posemb = ops.concatcat((posemb_prefix, posemb), axis=1)

return posemb


class RelativePositionBiasWithCLS(nn.Cell):
def __init__(
self,
window_size: Tuple[int],
num_heads: int
):
super(RelativePositionBiasWithCLS, self).__init__()
self.window_size = window_size
self.num_tokens = window_size[0] * window_size[1]

num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# 3: cls to token, token to cls, cls to cls
self.relative_position_bias_table = Parameter(
Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16)
)
coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1)
coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1)
coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww]

relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww]
relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2]
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[0] - 1

relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1),
dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1]
relative_position_index[1:, 1:] = relative_coords.sum(-1)
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
relative_position_index = Tensor(relative_position_index.reshape(-1))

self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16)
self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False)

def construct(self):
out = ops.matmul(self.relative_position_index, self.relative_position_bias_table)
out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1))
out = ops.transpose(out, (2, 0, 1))
out = ops.expand_dims(out, 0)
return out
Loading