Implementation of Compact Transformers from Escaping the Big Data Paradigm with Compact Transformers
The official Pytorch implementation can be found here: https://github.com/SHI-Labs/Compact-Transformers
Compact Convolutional Transformer (CCT) is represented by three main changes on ViT:
- Convolutional Tokenizer, instead of the direct image patching of ViT
- Sequence Pooling instead of the Class Token
- Learnable Positional Embedding instead of Sinusodial Embedding
CCT naturally inherits other components of ViT, such as:
- Multi-Head Self Attention
- Feed Forward Network (MLP Block)
- Dropouts and Stochastic Depth
!pip install git+https://github.com/johnypark/CCT-keras
from CCT_keras import CCT
model = CCT(num_classes = 1000, input_shape = (224, 224, 3))
The default CCT() is set as CCT_14_7x2 in the paper, for which the authors used to train on ImageNet from scratch.
model = summary()
.
.
.
layer_normalization_26 (LayerN (None, 196, 384) 768 ['add_25[0][0]']
ormalization)
multi_head_self_attention_13 ( (None, None, 384) 591360 ['layer_normalization_26[0][0]']
MultiHeadSelfAttention)
drop_path_26 (DropPath) (None, None, 384) 0 ['multi_head_self_attention_13[0]
[0]']
add_26 (Add) (None, 196, 384) 0 ['add_25[0][0]',
'drop_path_26[0][0]']
layer_normalization_27 (LayerN (None, 196, 384) 768 ['add_26[0][0]']
ormalization)
feed_forward_network_13 (FeedF (None, 196, 384) 886272 ['layer_normalization_27[0][0]']
orwardNetwork)
drop_path_27 (DropPath) (None, 196, 384) 0 ['feed_forward_network_13[0][0]']
add_27 (Add) (None, 196, 384) 0 ['add_26[0][0]',
'drop_path_27[0][0]']
layer_normalization_28 (LayerN (None, 196, 384) 768 ['add_27[0][0]']
ormalization)
dense (Dense) (None, 196, 1) 385 ['layer_normalization_28[0][0]']
tf.linalg.matmul (TFOpLambda) (None, 1, 384) 0 ['dense[0][0]',
'layer_normalization_28[0][0]']
flatten (Flatten) (None, 384) 0 ['tf.linalg.matmul[0][0]']
dropout_1 (Dropout) (None, 384) 0 ['flatten[0][0]']
dense_1 (Dense) (None, 1000) 385000 ['dropout_1[0][0]']
==================================================================================================
Total params: 24,735,401
Trainable params: 24,735,401
Non-trainable params: 0
__________________________________________________________________________________________________
model_weights_dict = {(w.name): (idx, w.dtype, w.shape) for idx, w in enumerate(model.weights)}
names_dense = [name for name in model_weights_dict.keys() if 'dense' in name]
idx_dense = [model_weights_dict[name][0] for name in names_dense]
>>model_weights_dict
{'conv2d/kernel:0': (0, tf.float32, TensorShape([7, 7, 3, 192])),
'conv2d_1/kernel:0': (1, tf.float32, TensorShape([7, 7, 192, 384])),
'layer_normalization/gamma:0': (2, tf.float32, TensorShape([384])),
'layer_normalization/beta:0': (3, tf.float32, TensorShape([384])),
'multi_head_self_attention/dense_query/kernel:0': (4,
tf.float32,
TensorShape([384, 384])),
'multi_head_self_attention/dense_query/bias:0': (5,
tf.float32,
TensorShape([384])),
'multi_head_self_attention/dense_key/kernel:0': (6,
tf.float32,
TensorShape([384, 384])),
'multi_head_self_attention/dense_key/bias:0': (7,
tf.float32,
TensorShape([384])),
'multi_head_self_attention/dense_value/kernel:0': (8,
tf.float32,
TensorShape([384, 384])),
'multi_head_self_attention/dense_value/bias:0': (9,
tf.float32,
TensorShape([384])),
'multi_head_self_attention/dense_out/kernel:0': (10,
tf.float32,
TensorShape([384, 384])),
'multi_head_self_attention/dense_out/bias:0': (11,
tf.float32,
TensorShape([384])),
'layer_normalization_1/gamma:0': (12, tf.float32, TensorShape([384])),
'layer_normalization_1/beta:0': (13, tf.float32, TensorShape([384])),
'feed_forward_network/dense_hidden/kernel:0': (14,
tf.float32,
TensorShape([384, 1152])),
'feed_forward_network/dense_hidden/bias:0': (15,
tf.float32,
TensorShape([1152])),
'feed_forward_network/dense_out/kernel:0': (16,
tf.float32,
TensorShape([1152, 384])),
'feed_forward_network/dense_out/bias:0': (17, tf.float32, TensorShape([384])),
'layer_normalization_2/gamma:0': (18, tf.float32, TensorShape([384])),
'layer_normalization_2/beta:0': (19, tf.float32, TensorShape([384])),
Results and weights are adpoted directly from the official PyTorch implementation (https://github.com/SHI-Labs/Compact-Transformers). I plan to gradually port the PyTorch weights to Tensorflow and keep things posted here.
Type can be read in the format L/PxC
where L
is the number of transformer
layers, P
is the patch/convolution size, and C
(CCT only) is the number of
convolutional layers.
Model | Pretraining | Epochs | PE | Source | CIFAR-10 | CIFAR-100 |
CCT-7/3x1 | None | 300 | Learnable | Official Pytorch | 96.53% | 80.92% |
CCT-keras | TBD | TBD | ||||
1500 | Sinusoidal | Official Pytorch | 97.48% | 82.72% | ||
CCT-keras | TBD | TBD | ||||
5000 | Sinusoidal | Official Pytorch | 98.00% | 82.87% | ||
CCT-keras | TBD | TBD |
Model | Pre-training | PE | Image Size | Source | Accuracy |
CCT-7/7x2 | None | Sinusoidal | 224x224 | Official Pytorch | 97.19% |
CCT-keras | TBD | ||||
CCT-14/7x2 | ImageNet-1k | Learnable | 384x384 | Official Pytorch | 99.76% |
CCT-keras | TBD |
</tbody>
Model | Type | Resolution | Epochs | # Params | MACs | Source | Top-1 Accuracy |
ViT | 12/16 | 384 | 300 | 86.8M | 17.6G | Offical Pytorch | 77.91% |
CCT | 14/7x2 | 224 | 310 | 22.36M | 5.11G | Offical Pytorch | 80.67% |
CCT-keras | TBD | ||||||
14/7x2 | 384 | 310 + 30 | 22.51M | 15.02G | Offical Pytorch | 82.71% | |
CCT-keras | TBD |