diff --git a/README.md b/README.md
index af3bd2f..c077427 100644
--- a/README.md
+++ b/README.md
@@ -143,6 +143,36 @@ img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```
+## CaiT
+
+This paper also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.
+
+They also add Talking Heads, noting improvements
+
+You can use this scheme as follows
+
+```python
+import torch
+from vit_pytorch.cait import CaiT
+
+v = CaiT(
+ image_size = 256,
+ patch_size = 32,
+ num_classes = 1000,
+ dim = 1024,
+ depth = 12, # depth of transformer for patch to patch attention only
+ cls_depth = 2, # depth of cross attention of CLS tokens to patch
+ heads = 16,
+ mlp_dim = 2048,
+ dropout = 0.1,
+ emb_dropout = 0.1
+)
+
+img = torch.randn(1, 3, 256, 256)
+
+preds = v(img) # (1, 1000)
+```
+
## Token-to-Token ViT
@@ -164,7 +194,8 @@ v = T2TViT(
)
img = torch.randn(1, 3, 224, 224)
-v(img) # (1, 1000)
+
+preds = v(img) # (1, 1000)
```
## Cross ViT
@@ -177,7 +208,7 @@ v(img) # (1, 1000)
import torch
from vit_pytorch.cross_vit import CrossViT
-model = CrossViT(
+v = CrossViT(
image_size = 256,
num_classes = 1000,
depth = 4, # number of multi-scale encoding blocks
@@ -199,7 +230,7 @@ model = CrossViT(
img = torch.randn(1, 3, 256, 256)
-pred = model(img) # (1, 1000)
+pred = v(img) # (1, 1000)
```
## PiT
@@ -212,7 +243,7 @@ pred = model(img) # (1, 1000)
import torch
from vit_pytorch.pit import PiT
-p = PiT(
+v = PiT(
image_size = 224,
patch_size = 14,
dim = 256,
@@ -228,7 +259,7 @@ p = PiT(
img = torch.randn(1, 3, 224, 224)
-preds = p(img) # (1, 1000)
+preds = v(img) # (1, 1000)
```
## CvT
@@ -241,7 +272,7 @@ preds = p(img) # (1, 1000)
import torch
from vit_pytorch.cvt import CvT
-model = CvT(
+v = CvT(
num_classes = 1000,
s1_emb_dim = 64, # stage 1 - dimension
s1_emb_kernel = 7, # stage 1 - conv kernel
@@ -272,7 +303,7 @@ model = CvT(
img = torch.randn(1, 3, 224, 224)
-pred = model(img) # (1, 1000)
+pred = v(img) # (1, 1000)
```
## Masked Patch Prediction
diff --git a/images/cait.png b/images/cait.png
new file mode 100644
index 0000000..9914a55
Binary files /dev/null and b/images/cait.png differ
diff --git a/setup.py b/setup.py
index 5607589..3209ba4 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.12.0',
+ version = '0.14.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
diff --git a/vit_pytorch/cait.py b/vit_pytorch/cait.py
new file mode 100644
index 0000000..bfab514
--- /dev/null
+++ b/vit_pytorch/cait.py
@@ -0,0 +1,148 @@
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+# classes
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, fn, init_eps = 0.1):
+ super().__init__()
+ scale = torch.zeros(1, 1, dim).fill_(init_eps)
+ self.scale = nn.Parameter(scale)
+ self.fn = fn
+ def forward(self, x, **kwargs):
+ return self.fn(x, **kwargs) * self.scale
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout = 0.):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout)
+ )
+ def forward(self, x):
+ return self.net(x)
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
+
+ self.attend = nn.Softmax(dim = -1)
+
+ self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
+ self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context = None):
+ b, n, _, h = *x.shape, self.heads
+
+ context = x if not exists(context) else torch.cat((x, context), dim = 1)
+
+ qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+
+ dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
+ attn = self.attend(dots)
+ attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class Transformer(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
+ LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
+ ]))
+ def forward(self, x, context = None):
+ for attn, ff in self.layers:
+ x = attn(x, context = context) + x
+ x = ff(x) + x
+ return x
+
+class CaiT(nn.Module):
+ def __init__(
+ self,
+ *,
+ image_size,
+ patch_size,
+ num_classes,
+ dim,
+ depth,
+ cls_depth,
+ heads,
+ mlp_dim,
+ dim_head = 64,
+ dropout = 0.,
+ emb_dropout = 0.
+ ):
+ super().__init__()
+ assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
+ num_patches = (image_size // patch_size) ** 2
+ patch_dim = 3 * patch_size ** 2
+
+ self.to_patch_embedding = nn.Sequential(
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
+ nn.Linear(patch_dim, dim),
+ )
+
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
+ self.dropout = nn.Dropout(emb_dropout)
+
+ self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
+ self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout)
+
+ self.mlp_head = nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, num_classes)
+ )
+
+ def forward(self, img):
+ x = self.to_patch_embedding(img)
+ b, n, _ = x.shape
+
+ x += self.pos_embedding[:, :n]
+ x = self.dropout(x)
+
+ x = self.patch_transformer(x)
+
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
+ x = self.cls_transformer(cls_tokens, context = x)
+
+ return self.mlp_head(x[:, 0])