From da950e6d2c3c0d7ff38e228528bd0ca6e0565dec Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 30 Mar 2021 22:15:19 -0700 Subject: [PATCH] add working PiT --- setup.py | 2 +- vit_pytorch/pit.py | 180 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 vit_pytorch/pit.py diff --git a/setup.py b/setup.py index 6052d4c..5607589 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.11.1', + version = '0.12.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/pit.py b/vit_pytorch/pit.py new file mode 100644 index 0000000..0758dfb --- /dev/null +++ b/vit_pytorch/pit.py @@ -0,0 +1,180 @@ +from math import sqrt + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +# helpers + +def cast_tuple(val, num): + return val if isinstance(val, tuple) else (val,) * num + +def conv_output_size(image_size, kernel_size, stride, padding = 0): + return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) + +# classes + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +# depthwise convolution, for pooling + +class DepthWiseConv2d(nn.Module): + def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), + nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) + ) + def forward(self, x): + return self.net(x) + +# pooling layer + +class Pool(nn.Module): + def __init__(self, dim): + super().__init__() + self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1) + self.cls_ff = nn.Linear(dim, dim * 2) + + def forward(self, x): + cls_token, tokens = x[:, :1], x[:, 1:] + + cls_token = self.cls_ff(cls_token) + + tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1]))) + tokens = self.downsample(tokens) + tokens = rearrange(tokens, 'b c h w -> b (h w) c') + + return torch.cat((cls_token, tokens), dim = 1) + +# main class + +class PiT(nn.Module): + def __init__( + self, + *, + image_size, + patch_size, + num_classes, + dim, + depth, + heads, + mlp_dim, + dim_head = 64, + dropout = 0., + emb_dropout = 0. + ): + super().__init__() + assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' + assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing' + heads = cast_tuple(heads, len(depth)) + + patch_dim = 3 * patch_size ** 2 + + self.to_patch_embedding = nn.Sequential( + nn.Unfold(kernel_size = patch_size, stride = patch_size // 2), + Rearrange('b c n -> b n c'), + nn.Linear(patch_dim, dim) + ) + + output_size = conv_output_size(image_size, patch_size, patch_size // 2) + num_patches = output_size ** 2 + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + layers = [] + + for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)): + not_last = ind < (len(depth) - 1) + + layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout)) + + if not_last: + layers.append(Pool(dim)) + dim *= 2 + + self.layers = nn.Sequential( + *layers, + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img): + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding + x = self.dropout(x) + + return self.layers(x)