This project aims to re-implement DeiT and DeiT-III using Jax/Flax and running on TPUs. Given that the original repository is written in PyTorch, this project provides an alternative codebase for training a variant of ViT on TPUs.
We have trained ViTs using both DeiT and DeiT-III recipes. All experiments were done on a v4-64
pod slice, and you can see the training details in the wandb logs.
Name | Data | Resolution | Epochs | Time | Reimpl. | Original | Config | Wandb | Model |
---|---|---|---|---|---|---|---|---|---|
T/16 | in1k | 224 | 300 | 2h 40m | 73.1% | 72.2% | config | log | ckpt |
S/16 | in1k | 224 | 300 | 2h 43m | 79.68% | 79.8% | config | log | ckpt |
B/16 | in1k | 224 | 300 | 4h 40m | 81.46% | 81.8% | config | log | ckpt |
Name | Data | Resolution | Epochs | Time | Reimpl. | Original | Config | Wandb | Model |
---|---|---|---|---|---|---|---|---|---|
S/16 | in1k | 224 | 400 | 2h 38m | 80.7% | 80.4% | config | log | ckpt |
S/16 | in1k | 224 | 800 | 5h 19m | 81.44% | 81.4% | config | log | ckpt |
B/16 | in1k | 192 → 224 | 400 | 4h 42m | 83.6% | 83.5% | pt / ft | pt / ft | pt / ft |
B/16 | in1k | 192 → 224 | 800 | 9h 28m | 83.91% | 83.8% | pt / ft | pt / ft | pt / ft |
L/16 | in1k | 192 → 224 | 400 | 14h 10m | 84.62% | 84.5% | pt / ft | pt / ft | pt / ft |
L/16 | in1k | 192 → 224 | 800 | - | - | 84.9% | pt / ft | - | - |
H/14 | in1k | 154 → 224 | 400 | 19h 10m | 85.12% | 85.1% | pt / ft | pt / ft | pt / ft |
H/14 | in1k | 154 → 224 | 800 | - | - | 85.2% | pt / ft | - | - |
Name | Data | Resolution | Epochs | Time | Reimpl. | Original | Config | Wandb | Model |
---|---|---|---|---|---|---|---|---|---|
S/16 | in21k | 224 | 90 | 7h 30m | 83.04% | 82.6% | pt / ft | pt / ft | pt / ft |
S/16 | in21k | 224 | 240 | 20h 6m | 83.39% | 83.1% | pt / ft | pt / ft | pt / ft |
B/16 | in21k | 224 | 90 | 12h 12m | 85.35% | 85.2% | pt / ft | pt / ft | pt / ft |
B/16 | in21k | 224 | 240 | 33h 9m | 85.68% | 85.7% | pt / ft | pt / ft | pt / ft |
L/16 | in21k | 224 | 90 | 37h 13m | 86.83% | 86.8% | pt / ft | pt / ft | pt / ft |
L/16 | in21k | 224 | 240 | - | - | 87% | pt / ft | - | - |
H/14 | in21k | 126 → 224 | 90 | 35h 51m | 86.78% | 87.2% | pt / ft | pt / ft | pt / ft |
H/14 | in21k | 126 → 224 | 240 | - | - | - | pt / ft | - | - |
To begin, create a TPU instance for training ViTs. We have tested on both v3-8
and v4-64
. We recommend using the v4-64
pod slice. If you do not have any TPU quota, visit this link and apply for the TRC program.
$ gcloud compute tpus tpu-vm create tpu-name \
--zone=us-central2-b \
--accelerator-type=v4-64 \
--version=tpu-ubuntu2204-base
Once the TPU instance is created, clone this repository and install the required dependencies. All dependencies and installation steps are sepcified in the scripts/setup.sh file. Note that you should use the gcloud
command to execute the same command on all nodes simultaneously. The v4-64
pod slice contains 8 computing nodes, each with 4 v4 chips.
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="git clone https://github.com/affjljoo3581/deit3-jax"
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="bash deit3-jax/scripts/setup.sh"
Additionally, log in to your wandb account using the command below. Replace $WANDB_API_KEY
with your own API key.
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="source ~/miniconda3/bin/activate base; wandb login $WANDB_API_KEY"
deit3-jax
utilizes webdataset to load training samples from various sources, such as huggingface hub and GCS. Timm provides webdataset versions of ImageNet-1k and ImageNet-21k on the huggingface hub. We recommend copying the resources to your GCS bucket for faster download speeds. To download both datasets to your bucket, use the following command:
$ export HF_TOKEN=...
$ export GCS_DATASET_DIR=gs://...
$ bash scripts/prepare-imagenet1k-dataset.sh
$ bash scripts/prepare-imagenet21k-dataset.sh
For example, you can list the tarfiles in your bucket like this:
$ gsutil ls gs://affjljoo3581-tpu-v4-storage/datasets/imagenet-1k-wds/
gs://affjljoo3581-tpu-v4-storage/datasets/imagenet-1k-wds/imagenet1k-train-0000.tar
gs://affjljoo3581-tpu-v4-storage/datasets/imagenet-1k-wds/imagenet1k-train-0001.tar
gs://affjljoo3581-tpu-v4-storage/datasets/imagenet-1k-wds/imagenet1k-train-0002.tar
gs://affjljoo3581-tpu-v4-storage/datasets/imagenet-1k-wds/imagenet1k-train-0003.tar
gs://affjljoo3581-tpu-v4-storage/datasets/imagenet-1k-wds/imagenet1k-train-0004.tar
...
However, GCS is not the only way to use webdataset. Instead of prefetching into your own bucket, it is also possible to directly stream from the huggingface hub while training.
$ export TRAIN_SHARDS=https://huggingface.co/datasets/timm/imagenet-1k-wds/resolve/main/imagenet1k-train-{0000..1023}.tar
$ export VALID_SHARDS=https://huggingface.co/datasets/timm/imagenet-1k-wds/resolve/main/imagenet1k-validation-{00..63}.tar
$ python3 src/main.py \
--train-dataset-shards "pipe:curl -s -L $TRAIN_SHARDS -H 'Authorization:Bearer $HF_TOKEN'" \
--valid-dataset-shards "pipe:curl -s -L $VALID_SHARDS -H 'Authorization:Bearer $HF_TOKEN'" \
...
Since intermittent decreases in download performance may occur when streaming from the huggingface hub, we recommend using the GCS bucket for stable download speed and consistent training.
You can now train your ViTs using the command below. Replace $CONFIG_FILE
with the path to the configuration file you want to use. Instead, you can customize your own training recipes by adjusting the hyperparameters. The various training presets are available in the config folder.
$ export GCS_MODEL_DIR=gs://...
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="source ~/miniconda3/bin/activate base; cd deit3-jax; screen -dmL bash $CONFIG_FILE"
The training results will be saved to $GCS_MODEL_DIR
. You can specify a local directory path instead of a GCS path to save models locally.
To use the pretrained checkpoints, you can convert .msgpack
to timm-compatible .pth
files.
$ python scripts/convert_flax_to_pytorch.py deit3-s16-224-in1k-400ep-best.msgpack
$ ls
deit3-s16-224-in1k-400ep-best.msgpack deit3-s16-224-in1k-400ep-best.pth
After converting .msgpack
to .pth
, you can load it with timm:
>>> import torch
>>> import timm
>>> model = timm.create_model("vit_small_patch16_224", init_values=1e-4)
>>> model.load_state_dict(torch.load("deit3-s16-224-in1k-400ep-best.pth"))
<All keys matched successfully>
--random-crop
: Type of random cropping. Choosenone
for nothing,rrc
for RandomResizedCrop, andsrc
for SimpleResizedCrop proposed in DeiT-III.--color-jitter
: Factor for color jitter augmentation.--auto-augment
: Name of auto-augment policy used in Timm (e.g.rand-m9-mstd0.5-inc1
).--random-erasing
: Probability of random erasing augmentation.--augment-repeats
: Number of augmentation repetitions.--test-crop-ratio
: Center crop ratio for test preprocessing.--mixup
: Factor (alpha) for Mixup augmentation. Disable by setting to 0.--cutmix
: Factor (alpha) for CutMix augmentation. Disable by setting to 0.--criterion
: Type of classification loss. Choosece
for softmax cross entropy andbce
for sigmoid cross entropy.--label-smoothing
: Factor for label smoothing.
--layers
: Number of layers.--dim
: Number of hidden features.--heads
: Number of attention heads.--labels
: Number of classification labels.--layerscale
: Flag to enable LayerScale.--patch-size
: Patch size in ViT embedding layer.--image-size
: Input image size.--posemb
: Type of positional embeddings in ViT. Chooselearnable
for learnable parameters andsincos2d
for sinusoidal encoding.--pooling
: Type of pooling strategy. Choosecls
for using[CLS]
token andgap
for global average pooling.--dropout
: Dropout rate.--droppath
: DropPath rate.--grad-ckpt
: Flag to enable gradient checkpointing for reducing memory footprint.
--optimizer
: Type of optimizer. Chooseadamw
for AdamW andlamb
for LAMB.--learning-rate
: Peak learning rate.--weight-decay
: Decoupled weight decay rate.--adam-b1
: Adam beta1.--adam-b2
: Adam beta2.--adam-eps
: Adam epsilon.--lr-decay
: Layerwise learning rate decay rate.--clip-grad
: Maximum gradient norm.--grad-accum
: Number of gradient accumulation steps.--warmup-steps
: Number of learning rate warmup steps.--training-steps
: Number of total training steps.--log-interval
: Number of logging intervals.--eval-interval
: Number of evaluation intervals.
--init-seed
: Random seed for weight initialization.--mixup-seed
: Random seed for Mixup and CutMix augmentations.--dropout-seed
: Random seed for Dropout regularization.--shuffle-seed
: Random seed for dataset shuffling.--pretrained-ckpt
: Pretrained model path to load from.--label-mapping
: Label mapping file to reuse the pretrained classification head for transfer learning.
@misc{park2024deit3jax,
author = {Jungwoo Park},
title = {deit3-jax},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/affjljoo3581/deit3-jax}}
}
This repository is released under the Apache 2.0 license as found in the LICENSE file.
Thanks to the TPU Research Cloud program for providing resources. All models are trained on the TPU v4-64
pod slice.