Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace multi head attention in decoder #3

Open
Mareeta26 opened this issue Jul 22, 2022 · 9 comments
Open

Replace multi head attention in decoder #3

Mareeta26 opened this issue Jul 22, 2022 · 9 comments

Comments

@Mareeta26
Copy link

Hi,
May I know whether I can use sima instead of multi head attention in decoder, to reduce complexity?

Thanks!

@soroush-abbasi
Copy link
Collaborator

soroush-abbasi commented Jul 22, 2022

Hi!

You can replace SimA with any self-attention module. It may requires some parameters tuning for specific models. We tried on CvT, ViT and XCiT. It also works with DINO loss (self-supervised). I plan to try it with MAE in the future. When you said decoder, which model are you referring to (e.g, decoder of DETR or MAE)?

Thanks! Have a good day!

@Mareeta26
Copy link
Author

Mareeta26 commented Jul 22, 2022

@soroush-abbasi Thanks for the reply. I meant the decoder of ConvTransformer, which incorporates convolutions in transformer.

@Mareeta26
Copy link
Author

@soroush-abbasi Also, is it possible to share the code for ViT with SimA? Thank you in advance!

@soroush-abbasi
Copy link
Collaborator

soroush-abbasi commented Jul 23, 2022

I guess it should work with decoder of ConvTransformer. You can simply replace self-attention with SimA attention (SimA class in below). To run with ViT/DeiT architecture, please replace these classes in sima.py as below (removing LPI layer, removing class attention layer):

class SimA(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        k = F.normalize(k, p=1.0, dim=-2)
        q = F.normalize(q, p=1.0, dim=-2)
        if (N / (C//self.num_heads)) < 1:
            x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C)
        else:
            x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)


        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return {}




class SimABlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.,
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 num_tokens=196, eta=None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = SimA(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
            proj_drop=drop
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
                       drop=drop)


        self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
        self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)

    def forward(self, x, H, W):
        x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x



class SimAVisionTransformer(nn.Module):


    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768,
                 depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
                 cls_attn_layers=2, use_pos=True, patch_proj='linear', eta=None, tokens_norm=False):
        
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim,
                                          patch_size=patch_size)

        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList([
            SimABlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
                norm_layer=norm_layer, num_tokens=num_patches, eta=eta)
            for i in range(depth)])

        
        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
        self.use_pos = use_pos

        # Classifier head
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token', 'dist_token'}

    def forward_features(self, x):
        B, C, H, W = x.shape

        x, (Hp, Wp) = self.patch_embed(x)

        if self.use_pos:
            pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
            x = x + pos_encoding

        x = self.pos_drop(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        for blk in self.blocks:
            x = blk(x, Hp, Wp)

        x = self.norm(x)[:, 0]
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)

        if self.training:
            return x, x
        else:
            return x

@Mareeta26
Copy link
Author

Mareeta26 commented Jul 23, 2022

@soroush-abbasi Thank you! So, don't we need SimABlock class for ConvTransformer? What is the purpose of it? Can you please explain?
My Input to the self attention module is a 5D tensor.Eg : 8,19,128,16,16.
How shall I modify SimA class for such an input ?

@soroush-abbasi
Copy link
Collaborator

soroush-abbasi commented Jul 23, 2022

SimABlock is a regular transformer block which has both self-attention(SimA) and MLP layer. As long as you replace self-attention in your code with SimA you should be fine I guess. So you need to figure out which dimensions in your input is sequence (N) and which is Token dimensions (D) . Or if your features are after splitting to multi-head, you need to find the ordering of B (batch size), H (heads), D (dimension after splitting) and N (sequence length/ number of tokens). sometimes tokens are not flattens. For example, one can look at the image feature maps as a set of tokens with 2D shape. If you have 512x16x16 feature map, you can flatten the last two dimensions to get 512x256 tokens (D=512, N=256). I guess last two dimensions are feature maps of the image in your case, but I'm not sure. Unfortunately, I'm not familiar with ConTransformer.

Thanks!

@Mareeta26
Copy link
Author

Sure, thanks for the reply!!

@Mareeta26 Mareeta26 reopened this Jul 25, 2022
@Mareeta26
Copy link
Author

@soroush-abbasi Can we use SimA if it's a masked self-attention?

@soroush-abbasi
Copy link
Collaborator

Hi,

It's a little complicated. So we normalize tokens in channel dimension before doing QKV dot product. Since we normalize tokens in the channel dimension, each token have effect on other tokens. Therefore, if you want to mask tokens, you need to apply masking before L1-normalization. Please let me know if you have more questions.

Thanks! Have a great day!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants