Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository here.
The official Jax repository is here.
A tensorflow2 translation also exists here, created by research scientist Junho Kim! 🙏
Flax translation by Enrico Shippole!
$ pip install vit-pytorch
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
image_size
: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and heightpatch_size
: int.
Size of patches.image_size
must be divisible bypatch_size
.
The number of patches is:n = (image_size // patch_size) ** 2
andn
must be greater than 16.num_classes
: int.
Number of classes to classify.dim
: int.
Last dimension of output tensor after linear transformationnn.Linear(..., dim)
.depth
: int.
Number of Transformer blocks.heads
: int.
Number of heads in Multi-head Attention layer.mlp_dim
: int.
Dimension of the MLP (FeedForward) layer.channels
: int, default3
.
Number of image's channels.dropout
: float between[0, 1]
, default0.
.
Dropout rate.emb_dropout
: float between[0, 1]
, default0
.
Embedding dropout rate.pool
: string, eithercls
token pooling ormean
pooling
An update from some of the same authors of the original paper proposes simplifications to ViT
that allows it to train faster and better.
Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head
You can use it by importing the SimpleViT
as shown below
import torch
from vit_pytorch import SimpleViT
v = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
You will need to pass in two additional hyperparameters: (1) the number of frames frames
and (2) patch size along the frame dimension frame_patch_size
For starters, 3D ViT
import torch
from vit_pytorch.vit_3d import ViT
v = ViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
3D Simple ViT
import torch
from vit_pytorch.simple_vit_3d import SimpleViT
v = SimpleViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
3D version of CCT
import torch
from vit_pytorch.cct_3d import CCT
cct = CCT(
img_size = 224,
num_frames = 8,
embedding_dim = 384,
n_conv_layers = 2,
frame_kernel_size = 3,
kernel_size = 7,
stride = 2,
padding = 3,
pooling_kernel_size = 3,
pooling_stride = 2,
pooling_padding = 1,
num_layers = 14,
num_heads = 6,
mlp_ratio = 3.,
num_classes = 1000,
positional_embedding = 'learnable'
)
video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
pred = cct(video)
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
import torch
from vit_pytorch.vit import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# import Recorder and wrap the ViT
from vit_pytorch.recorder import Recorder
v = Recorder(v)
# forward pass now returns predictions and the attention maps
img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)
# there is one extra patch due to the CLS token
attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)
to cleanup the class and the hooks once you have collected enough data
v = v.eject() # wrapper is discarded and original ViT instance is returned