Skip to content

Commit

Permalink
extend vit and mae
Browse files Browse the repository at this point in the history
  • Loading branch information
sageyou committed Sep 15, 2023
1 parent e006a24 commit 8d1e233
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 945 deletions.
14 changes: 13 additions & 1 deletion mindcv/models/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
"""layers init"""
from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite
from . import (
activation,
conv_norm_act,
drop_path,
format,
identity,
patch_dropout,
pooling,
selective_kernel,
squeeze_excite,
)
from .activation import *
from .conv_norm_act import *
from .drop_path import *
from .format import *
from .identity import *
from .patch_dropout import *
from .pooling import *
from .selective_kernel import *
from .squeeze_excite import *
Expand Down
34 changes: 34 additions & 0 deletions mindcv/models/layers/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum
from typing import Union

import mindspore


class Format(str, Enum):
NCHW = 'NCHW'
NHWC = 'NHWC'
NCL = 'NCL'
NLC = 'NLC'


FormatT = Union[str, Format]


def nchw_to(x: mindspore.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
elif fmt == Format.NLC:
x = x.flatten(start_dim=2).transpose((0, 2, 1))
elif fmt == Format.NCL:
x = x.flatten(start_dim=2)
return x


def nhwc_to(x: mindspore.Tensor, fmt: Format):
if fmt == Format.NCHW:
x = x.permute(0, 3, 1, 2)
elif fmt == Format.NLC:
x = x.flatten(start_dim=1, end_dim=2)
elif fmt == Format.NCL:
x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1))
return x
54 changes: 54 additions & 0 deletions mindcv/models/layers/patch_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np

import mindspore as ms
from mindspore import nn, ops


class PatchDropout(nn.Cell):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered
self.return_indices = return_indices
self.sort = ops.Sort()

def forward(self, x):
if not self.training or self.prob == 0.:
if self.return_indices:
return x, None
return x

if self.num_prefix_tokens:
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
else:
prefix_tokens = None

B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - self.prob)))
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
keep_indices = indices[:, :num_keep]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices, _ = self.sort(keep_indices)
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2]))
x = ops.gather_elements(x, dim=1, index=keep_indices)

if prefix_tokens is not None:
x = ops.concat((prefix_tokens, x), axis=1)

if self.return_indices:
return x, keep_indices
return x
65 changes: 50 additions & 15 deletions mindcv/models/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mindspore import Tensor, nn, ops

from .format import Format, nchw_to
from .helpers import to_2tuple


Expand All @@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell):
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Cell, optional): Normalization layer. Default: None
"""
output_fmt: Format

def __init__(
self,
image_size: int = 224,
image_size: Optional[int] = 224,
patch_size: int = 4,
in_chans: int = 3,
embed_dim: int = 96,
norm_layer: Optional[nn.Cell] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
) -> None:
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
self.image_size = image_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]

self.in_chans = in_chans
self.patch_size = to_2tuple(patch_size)
if image_size is not None:
self.image_size = to_2tuple(image_size)
self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
else:
self.image_size = None
self.patches_resolution = None
self.num_patches = None

if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
self.flatten = flatten
self.output_fmt = Format.NCHW

self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.embed_dim = embed_dim

self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size,
pad_mode='pad', has_bias=True, weight_init="TruncatedNormal")
pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal")

if norm_layer is not None:
if isinstance(embed_dim, int):
Expand All @@ -50,11 +67,29 @@ def __init__(

def construct(self, x: Tensor) -> Tensor:
"""docstring"""
B = x.shape[0]
# FIXME look at relaxing size constraints
x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.Transpose()(x, (0, 2, 1))
B, C, H, W = x.shape
if self.image_size is not None:
if self.strict_img_size:
if (H, W) != (self.image_size[0], self.image_size[1]):
raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]},"
f"{self.image_size[1]}).")
elif not self.dynamic_img_pad:
if H % self.patch_size[0] != 0:
raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).")
if W % self.patch_size[1] != 0:
raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).")
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = ops.pad(x, (0, pad_w, 0, pad_h))

# FIXME look at relaxing size constraints
x = self.proj(x)
if self.flatten:
x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C
x = ops.Transpose()(x, (0, 2, 1))
elif self.output_fmt != "NCHW":
x = nchw_to(x, self.output_fmt)
if self.norm is not None:
x = self.norm(x)
return x
44 changes: 43 additions & 1 deletion mindcv/models/layers/pos_embed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,53 @@
"""positional embedding"""
from typing import Tuple
import math
from typing import List, Optional, Tuple

import numpy as np

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops

from .compatibility import Interpolate


def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'nearest',
):
# sort out sizes, assume square if old size not provided
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens

if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb

if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw

if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb

# do the interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
interpolate = Interpolate(mode=interpolation, align_corners=True)
posemb = interpolate(posemb, size=new_size)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
posemb = posemb.astype(orig_dtype)

# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
posemb = ops.concatcat((posemb_prefix, posemb), axis=1)

return posemb


class RelativePositionBiasWithCLS(nn.Cell):
def __init__(
Expand Down
Loading

0 comments on commit 8d1e233

Please sign in to comment.