From 034919d0326da74e601a14bb7e3b6074e0d14fb9 Mon Sep 17 00:00:00 2001 From: zzc98 <40905160+zzc98@users.noreply.github.com> Date: Sat, 6 May 2023 19:28:31 +0800 Subject: [PATCH] [Feature] add eva02 backbone (#1450) * [CI] Add test mim CI. (#879) * [CI] Add test mim CI. (#879) * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update ci * rebase * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update readme and configs * update readme and configs * refactore eva02 * [CI] Add test mim CI. (#879) * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update ci * rebase * feat: add eva02 backbone * feat: add eva02 backbone * feat: add eva02 backbone * update * update readme and configs * refactore eva02 * update readme and metafile * update readme and metafile * update readme and metafile * update * rename eva02 * rename eva02 * fix uts * rename configs --------- Co-authored-by: Ma Zerun Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> --- .../_base_/datasets/imagenet_bs16_eva_448.py | 62 ++++ configs/eva02/README.md | 109 ++++++ configs/eva02/eva02-base-p14_headless.py | 21 ++ configs/eva02/eva02-base-p14_in1k.py | 32 ++ configs/eva02/eva02-large-p14_headless.py | 21 ++ configs/eva02/eva02-large-p14_in1k.py | 32 ++ configs/eva02/eva02-small-p14_headless.py | 20 + configs/eva02/eva02-small-p14_in1k.py | 31 ++ configs/eva02/eva02-tiny-p14_headless.py | 20 + configs/eva02/eva02-tiny-p14_in1k.py | 31 ++ configs/eva02/metafile.yml | 199 ++++++++++ docs/en/api/models.rst | 1 + mmpretrain/models/backbones/__init__.py | 2 + mmpretrain/models/backbones/vit_eva02.py | 350 ++++++++++++++++++ mmpretrain/models/utils/__init__.py | 3 +- mmpretrain/models/utils/position_encoding.py | 75 ++++ model-index.yml | 1 + .../test_models/test_backbones/test_eva02.py | 143 +++++++ .../test_utils/test_position_encoding.py | 13 +- tools/model_converters/eva02_to_mmpretrain.py | 153 ++++++++ 20 files changed, 1317 insertions(+), 2 deletions(-) create mode 100644 configs/_base_/datasets/imagenet_bs16_eva_448.py create mode 100644 configs/eva02/README.md create mode 100644 configs/eva02/eva02-base-p14_headless.py create mode 100644 configs/eva02/eva02-base-p14_in1k.py create mode 100644 configs/eva02/eva02-large-p14_headless.py create mode 100644 configs/eva02/eva02-large-p14_in1k.py create mode 100644 configs/eva02/eva02-small-p14_headless.py create mode 100644 configs/eva02/eva02-small-p14_in1k.py create mode 100644 configs/eva02/eva02-tiny-p14_headless.py create mode 100644 configs/eva02/eva02-tiny-p14_in1k.py create mode 100644 configs/eva02/metafile.yml create mode 100644 mmpretrain/models/backbones/vit_eva02.py create mode 100644 tests/test_models/test_backbones/test_eva02.py create mode 100644 tools/model_converters/eva02_to_mmpretrain.py diff --git a/configs/_base_/datasets/imagenet_bs16_eva_448.py b/configs/_base_/datasets/imagenet_bs16_eva_448.py new file mode 100644 index 00000000000..b90bba14eef --- /dev/null +++ b/configs/_base_/datasets/imagenet_bs16_eva_448.py @@ -0,0 +1,62 @@ +# dataset settings +dataset_type = 'ImageNet' +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=448, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=448, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=448), + dict(type='PackInputs'), +] + +train_dataloader = dict( + batch_size=16, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/train.txt', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), +) + +val_dataloader = dict( + batch_size=8, + num_workers=5, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +val_evaluator = dict(type='Accuracy', topk=(1, 5)) + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader +test_evaluator = val_evaluator diff --git a/configs/eva02/README.md b/configs/eva02/README.md new file mode 100644 index 00000000000..bf0cea780fd --- /dev/null +++ b/configs/eva02/README.md @@ -0,0 +1,109 @@ +# EVA-02 + +> [EVA-02: A Visual Representation for Neon Genesis](https://arxiv.org/abs/2303.11331) + + + +## Abstract + +We launch EVA-02, a next-generation Transformer-based visual representation pre-trained to reconstruct strong and robust language-aligned vision features via masked image modeling. With an updated plain Transformer architecture as well as extensive pre-training from an open & accessible giant CLIP vision encoder, EVA-02 demonstrates superior performance compared to prior state-of-the-art approaches across various representative vision tasks, while utilizing significantly fewer parameters and compute budgets. Notably, using exclusively publicly accessible training data, EVA-02 with only 304M parameters achieves a phenomenal 90.0 fine-tuning top-1 accuracy on ImageNet-1K val set. Additionally, our EVA-02-CLIP can reach up to 80.4 zero-shot top-1 on ImageNet-1K, outperforming the previous largest & best open-sourced CLIP with only ~1/6 parameters and ~1/6 image-text training data. We offer four EVA-02 variants in various model sizes, ranging from 6M to 304M parameters, all with impressive performance. To facilitate open accessand open research, we release the complete suite of EVA-02 to the community. + +
+TrV builds upon the original plain ViT architecture and includes several enhancements: SwinGLU FFN, sub-LN, 2D RoPE, and JAX weight initialization. To keep the parameter & FLOPs consistent with the baseline, the FFN hidden dim of SwiGLU is 2/3× of the typical MLP counterpart. +
+ +## How to use it? + + + +**Predict image** + +```python +from mmpretrain import inference_model + +predict = inference_model('vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px', 'demo/bird.JPEG') +print(predict['pred_class']) +print(predict['pred_score']) +``` + +**Use the model** + +```python +import torch +from mmpretrain import get_model + +model = get_model('vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px', pretrained=True) +inputs = torch.rand(1, 3, 336, 336) +out = model(inputs) +print(type(out)) +# To extract features. +feats = model.extract_feat(inputs) +print(type(feats)) +``` + +**Train/Test Command** + +Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). + +Train: + +```shell +python tools/train.py configs/eva02/eva02-tiny-p14_in1k.py +``` + +Test: + +```shell +python tools/test.py configs/eva02/eva02-tiny-p14_in1k.py /path/to/eva02-tiny-p14_in1k.pth +``` + + + +## Models and results + +### Pretrained models + +| Model | Params (M) | Flops (G) | Config | Download | +| :-------------------------------- | :--------: | :-------: | :-----------------------------------: | :-----------------------------------------------------------------------------------------------------------: | +| `vit-tiny-p14_eva02-pre_in21k`\* | 5.50 | 1.70 | [config](eva02-tiny-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_pre_in21k_20230505-d703e7b1.pth) | +| `vit-small-p14_eva02-pre_in21k`\* | 21.62 | 6.14 | [config](eva02-small-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_pre_in21k_20230505-3175f463.pth) | +| `vit-base-p14_eva02-pre_in21k`\* | 85.77 | 23.22 | [config](eva02-base-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_pre_in21k_20230505-2f2d4d3c.pth) | +| `vit-large-p14_eva02-pre_in21k`\* | 303.29 | 81.15 | [config](eva02-large-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_in21k_20230505-9072de5d.pth) | +| `vit-large-p14_eva02-pre_m38m`\* | 303.29 | 81.15 | [config](eva02-large-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_m38m_20230505-b8a1a261.pth) | + +- The input size / patch size of MIM pre-trained EVA-02 is `224x224` / `14x14`. + +*Models with * are converted from the [official repo](https://github.com/baaivision/EVA).* + +### Image Classification on ImageNet-1k + +#### (*w/o* IN-21K intermediate fine-tuning) + +| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :---------------------------------------------------- | :----------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------: | :-------------------------------------------------------: | +| `vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px`\* | EVA02 ImageNet-21k | 5.76 | 4.68 | 80.69 | 95.54 | [config](./eva02-tiny-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_in21k-pre_3rdparty_in1k-336px_20230505-a4e8708a.pth) | +| `vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px`\* | EVA02 ImageNet-21k | 22.13 | 15.48 | 85.78 | 97.60 | [config](./eva02-small-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_in21k-pre_3rdparty_in1k-336px_20230505-9c5b0e85.pth) | +| `vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k | 87.13 | 107.11 | 88.29 | 98.53 | [config](./eva02-base-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_3rdparty_in1k-448px_20230505-8ad211c5.pth) | + +*Models with * are converted from the [official repo](https://github.com/baaivision/EVA/tree/master/EVA-02). The config files of these models are only for inference. We haven't reprodcue the training results.* + +#### (*w* IN-21K intermediate fine-tuning) + +| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :---------------------------------------------------- | :----------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------: | :-------------------------------------------------------: | +| `vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k | 87.13 | 107.11 | 88.47 | 98.62 | [config](./eva02-base-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-5cd4d87f.pth) | +| `vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k | 305.08 | 362.33 | 89.65 | 98.95 | [config](./eva02-large-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-926d1599.pth) | +| `vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 Merged-38M | 305.10 | 362.33 | 89.83 | 99.00 | [config](./eva02-large-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_m38m-pre_in21k-medft_3rdparty_in1k-448px_20230505-150dc5ed.pth) | + +*Models with * are converted from the [official repo](https://github.com/baaivision/EVA/tree/master/EVA-02). The config files of these models are only for inference. We haven't reprodcue the training results.* + +## Citation + +```bibtex +@article{EVA-02, + title={EVA-02: A Visual Representation for Neon Genesis}, + author={Yuxin Fang and Quan Sun and Xinggang Wang and Tiejun Huang and Xinlong Wang and Yue Cao}, + journal={arXiv preprint arXiv:2303.11331}, + year={2023} +} +``` diff --git a/configs/eva02/eva02-base-p14_headless.py b/configs/eva02/eva02-base-p14_headless.py new file mode 100644 index 00000000000..27aa8f8a502 --- /dev/null +++ b/configs/eva02/eva02-base-p14_headless.py @@ -0,0 +1,21 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='b', + img_size=224, + patch_size=14, + sub_ln=True, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=None, +) + +data_preprocessor = dict( + # RGB format normalization parameters + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + # convert image from BGR to RGB + to_rgb=True, +) diff --git a/configs/eva02/eva02-base-p14_in1k.py b/configs/eva02/eva02-base-p14_in1k.py new file mode 100644 index 00000000000..c8400d38542 --- /dev/null +++ b/configs/eva02/eva02-base-p14_in1k.py @@ -0,0 +1,32 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs16_eva_448.py', + '../_base_/schedules/imagenet_bs2048_AdamW.py', + '../_base_/default_runtime.py' +] + +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='b', + img_size=448, + patch_size=14, + sub_ln=True, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) diff --git a/configs/eva02/eva02-large-p14_headless.py b/configs/eva02/eva02-large-p14_headless.py new file mode 100644 index 00000000000..e101ac977c8 --- /dev/null +++ b/configs/eva02/eva02-large-p14_headless.py @@ -0,0 +1,21 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='l', + img_size=224, + patch_size=14, + sub_ln=True, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=None, +) + +data_preprocessor = dict( + # RGB format normalization parameters + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + # convert image from BGR to RGB + to_rgb=True, +) diff --git a/configs/eva02/eva02-large-p14_in1k.py b/configs/eva02/eva02-large-p14_in1k.py new file mode 100644 index 00000000000..91a42776daf --- /dev/null +++ b/configs/eva02/eva02-large-p14_in1k.py @@ -0,0 +1,32 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs16_eva_448.py', + '../_base_/schedules/imagenet_bs2048_AdamW.py', + '../_base_/default_runtime.py' +] + +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='l', + img_size=448, + patch_size=14, + sub_ln=True, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) diff --git a/configs/eva02/eva02-small-p14_headless.py b/configs/eva02/eva02-small-p14_headless.py new file mode 100644 index 00000000000..a969819308e --- /dev/null +++ b/configs/eva02/eva02-small-p14_headless.py @@ -0,0 +1,20 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='s', + img_size=224, + patch_size=14, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=None, +) + +data_preprocessor = dict( + # RGB format normalization parameters + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + # convert image from BGR to RGB + to_rgb=True, +) diff --git a/configs/eva02/eva02-small-p14_in1k.py b/configs/eva02/eva02-small-p14_in1k.py new file mode 100644 index 00000000000..4a16d92456e --- /dev/null +++ b/configs/eva02/eva02-small-p14_in1k.py @@ -0,0 +1,31 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs16_eva_336.py', + '../_base_/schedules/imagenet_bs2048_AdamW.py', + '../_base_/default_runtime.py' +] + +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='s', + img_size=336, + patch_size=14, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=384, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) diff --git a/configs/eva02/eva02-tiny-p14_headless.py b/configs/eva02/eva02-tiny-p14_headless.py new file mode 100644 index 00000000000..783d0ea2ebf --- /dev/null +++ b/configs/eva02/eva02-tiny-p14_headless.py @@ -0,0 +1,20 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='t', + img_size=224, + patch_size=14, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=None, +) + +data_preprocessor = dict( + # RGB format normalization parameters + mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255], + std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255], + # convert image from BGR to RGB + to_rgb=True, +) diff --git a/configs/eva02/eva02-tiny-p14_in1k.py b/configs/eva02/eva02-tiny-p14_in1k.py new file mode 100644 index 00000000000..84e68d7edd9 --- /dev/null +++ b/configs/eva02/eva02-tiny-p14_in1k.py @@ -0,0 +1,31 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs16_eva_336.py', + '../_base_/schedules/imagenet_bs2048_AdamW.py', + '../_base_/default_runtime.py' +] + +model = dict( + type='ImageClassifier', + backbone=dict( + type='ViTEVA02', + arch='t', + img_size=336, + patch_size=14, + final_norm=False, + out_type='avg_featmap'), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=192, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + ), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=.02), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.), + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) diff --git a/configs/eva02/metafile.yml b/configs/eva02/metafile.yml new file mode 100644 index 00000000000..80acf904fb4 --- /dev/null +++ b/configs/eva02/metafile.yml @@ -0,0 +1,199 @@ +Collections: + - Name: EVA02 + Metadata: + Architecture: + - Rotary Position Embedding + - Sub Layer Normalization + - SwiGLU + Paper: + Title: 'EVA-02: A Visual Representation for Neon Genesis' + URL: https://arxiv.org/abs/2303.11331 + README: configs/eva02/README.md + +Models: + - Name: vit-tiny-p14_eva02-pre_in21k + Metadata: + FLOPs: 1703439360 + Parameters: 5504064 + Training Data: + - ImageNet-21k + In Collection: EVA02 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_pre_in21k_20230505-d703e7b1.pth + Config: configs/eva02/eva02-tiny-p14_headless.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_Ti_pt_in21k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + Downstream: + - vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px + - Name: vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px + Metadata: + FLOPs: 4675416000 + Parameters: 5758888 + Training Data: + - ImageNet-21k + - ImageNet-1k + In Collection: EVA02 + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 80.69 + Top 5 Accuracy: 95.54 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_in21k-pre_3rdparty_in1k-336px_20230505-a4e8708a.pth + Config: configs/eva02/eva02-tiny-p14_in1k.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + - Name: vit-small-p14_eva02-pre_in21k + Metadata: + FLOPs: 6135404544 + Parameters: 21624960 + Training Data: + - ImageNet-21k + In Collection: EVA02 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_pre_in21k_20230505-3175f463.pth + Config: configs/eva02/eva02-small-p14_headless.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_S_pt_in21k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + Downstream: + - vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px + - Name: vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px + Metadata: + FLOPs: 15476744064 + Parameters: 22133608 + Training Data: + - ImageNet-21k + - ImageNet-1k + In Collection: EVA02 + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 85.78 + Top 5 Accuracy: 97.60 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_in21k-pre_3rdparty_in1k-336px_20230505-9c5b0e85.pth + Config: configs/eva02/eva02-small-p14_in1k.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + - Name: vit-base-p14_eva02-pre_in21k + Metadata: + FLOPs: 23216492544 + Parameters: 85766400 + Training Data: + - ImageNet-21k + In Collection: EVA02 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_pre_in21k_20230505-2f2d4d3c.pth + Config: configs/eva02/eva02-base-p14_headless.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_B_pt_in21k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + Downstream: + - vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px + - vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px + - Name: vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px + Metadata: + FLOPs: 107105984256 + Parameters: 87126760 + Training Data: + - ImageNet-21k + - ImageNet-1k + In Collection: EVA02 + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 88.29 + Top 5 Accuracy: 98.53 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_3rdparty_in1k-448px_20230505-8ad211c5.pth + Config: configs/eva02/eva02-base-p14_in1k.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + - Name: vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px + Metadata: + FLOPs: 107105984256 + Parameters: 87126760 + Training Data: + - ImageNet-21k + - ImageNet-1k + In Collection: EVA02 + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 88.47 + Top 5 Accuracy: 98.62 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-5cd4d87f.pth + Config: configs/eva02/eva02-base-p14_in1k.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + - Name: vit-large-p14_eva02-pre_in21k + Metadata: + FLOPs: 81146703792 + Parameters: 303291328 + Training Data: + - ImageNet-21k + In Collection: EVA02 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_in21k_20230505-9072de5d.pth + Config: configs/eva02/eva02-large-p14_headless.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_L_pt_in21k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + Downstream: + - vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px + - Name: vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px + Metadata: + FLOPs: 362333836208 + Parameters: 305104808 + Training Data: + - ImageNet-21k + - ImageNet-1k + In Collection: EVA02 + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 89.65 + Top 5 Accuracy: 98.95 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-926d1599.pth + Config: configs/eva02/eva02-large-p14_in1k.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + - Name: vit-large-p14_eva02-pre_m38m + Metadata: + FLOPs: 81146703792 + Parameters: 303291328 + Training Data: + - Merged-38M + In Collection: EVA02 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_m38m_20230505-b8a1a261.pth + Config: configs/eva02/eva02-large-p14_headless.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_L_pt_m38m_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 + Downstream: + - vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px + - Name: vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px + Metadata: + FLOPs: 362333836208 + Parameters: 305104808 + Training Data: + - Merged-38M + - ImageNet-21k + - ImageNet-1k + In Collection: EVA02 + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 89.83 + Top 5 Accuracy: 99.00 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_m38m-pre_in21k-medft_3rdparty_in1k-448px_20230505-150dc5ed.pth + Config: configs/eva02/eva02-large-p14_in1k.py + Converted From: + Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt + Code: https://github.com/baaivision/EVA/tree/master/EVA-02 diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index 4275786270a..7b6d607a450 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -189,6 +189,7 @@ Backbones VisionTransformer ViTSAM XCiT + ViTEVA02 .. module:: mmpretrain.models.necks diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py index ab77dd65590..d9830f12c0b 100644 --- a/mmpretrain/models/backbones/__init__.py +++ b/mmpretrain/models/backbones/__init__.py @@ -52,6 +52,7 @@ from .vgg import VGG from .vig import PyramidVig, Vig from .vision_transformer import VisionTransformer +from .vit_eva02 import ViTEVA02 from .vit_sam import ViTSAM from .xcit import XCiT @@ -118,4 +119,5 @@ 'PyramidVig', 'XCiT', 'ViTSAM', + 'ViTEVA02', ] diff --git a/mmpretrain/models/backbones/vit_eva02.py b/mmpretrain/models/backbones/vit_eva02.py new file mode 100644 index 00000000000..20ec4b247bb --- /dev/null +++ b/mmpretrain/models/backbones/vit_eva02.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, ModuleList + +from mmpretrain.registry import MODELS +from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer, + resize_pos_embed) +from .vision_transformer import VisionTransformer + + +class AttentionWithRoPE(BaseModule): + """Multi-head Attention Module with 2D sincos position embedding (RoPE). + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): If True, add a learnable bias to q and v. Note + that we follows the official implementation where ``k_bias`` + is 0. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool) If True, add a learnable bias to output projection. + Defaults to True. + rope (:obj:`torch.nn.Module`, optional): If it is an object of the + ``RotaryEmbedding``, the rotation of the token position will be + performed before the softmax. Defaults to None. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + qkv_bias=True, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + init_cfg=None): + super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + self.scale = qk_scale or self.head_dims**-0.5 + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.with_cls_token = with_cls_token + + self.rope = rope + + def forward(self, x, patch_resolution): + B, N, _ = x.shape + + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(dim=0) + + if self.rope: + if self.with_cls_token: + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t, patch_resolution) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] if self.with_cls_token else k + ro_k_t = self.rope(k_t, patch_resolution) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + else: + q = self.rope(q, patch_resolution) + k = self.rope(k, patch_resolution) + + q = q * self.scale + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1).type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class EVA02EndcoderLayer(BaseModule): + """Implements one encoder EVA02EndcoderLayer in EVA02. + + Args: + embed_dims (int): The feature dimension + num_heads (int): Parallel attention heads + feedforward_channels (int): The hidden dimension of FFNs. + sub_ln (bool): Whether to add the sub layer normalization + in the attention module. Defaults to False. + attn_drop (float): Dropout rate of the dropout layer after the + attention calculation of query and key. Defaults to 0. + proj_drop (float): Dropout rate of the dropout layer after the + output projection. Defaults to 0. + qkv_bias (bool): enable bias for qkv if True. Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + proj_bias (bool): enable bias for projection in the attention module + if True. Defaults to True. + rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object + in the attention module. Defaults to None. + drop_rate (float): Dropout rate in the mlp module. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + sub_ln=False, + attn_drop=0., + proj_drop=0., + qkv_bias=False, + qk_scale=None, + proj_bias=True, + rope=None, + with_cls_token=True, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + init_cfg=None): + super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims) + + self.attn = AttentionWithRoPE( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + proj_bias=proj_bias, + rope=rope, + with_cls_token=with_cls_token) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate)) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims) + + if drop_rate > 0: + dropout_layer = dict(type='Dropout', drop_prob=drop_rate) + else: + dropout_layer = None + + if sub_ln: + ffn_norm = norm_cfg + else: + ffn_norm = None + + self.mlp = SwiGLUFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + dropout_layer=dropout_layer, + norm_cfg=ffn_norm, + add_identity=False, + ) + + def forward(self, x, patch_resolution): + inputs = x + x = self.norm1(x) + x = self.attn(x, patch_resolution) + x = self.drop_path(x) + x = inputs + x + + inputs = x + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = inputs + x + + return x + + +@MODELS.register_module() +class ViTEVA02(VisionTransformer): + """EVA02 Vision Transformer. + + A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis + `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'tiny', 'small', 'base', 'large'. If use dict, + it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **mlp_ratio** (float): The ratio of the mlp module. + + Defaults to 'tiny'. + + sub_ln (bool): Whether to add the sub layer normalization in swiglu. + Defaults to False. + drop_rate (float): Probability of an element to be zeroed in the + mlp module. Defaults to 0. + attn_drop_rate (float): Probability of an element to be zeroed after + the softmax in the attention. Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed after + projection in the attention. Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + **kwargs(dict, optional): Other args for Vision Transformer. + """ + arch_zoo = { + **dict.fromkeys( + ['t', 'ti', 'tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': int(192 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': int(384 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': int(768 * 4 * 2 / 3) + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': int(1024 * 4 * 2 / 3) + }) + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} + + def __init__(self, + arch='tiny', + sub_ln=False, + drop_rate=0., + attn_drop_rate=0., + proj_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + norm_cfg=dict(type='LN'), + with_cls_token=True, + layer_cfgs=dict(), + **kwargs): + # set essential args for Vision Transformer + kwargs.update( + arch=arch, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + with_cls_token=with_cls_token) + super(ViTEVA02, self).__init__(**kwargs) + + self.num_heads = self.arch_settings['num_heads'] + + # Set RoPE + head_dim = self.embed_dims // self.num_heads + self.rope = RotaryEmbeddingFast( + embed_dims=head_dim, patch_resolution=self.patch_resolution) + + # stochastic depth decay rule + dpr = np.linspace(0, drop_path_rate, self.num_layers) + self.layers = ModuleList() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self. + arch_settings['feedforward_channels'], + sub_ln=sub_ln, + norm_cfg=norm_cfg, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + rope=self.rope, + with_cls_token=with_cls_token, + drop_path_rate=dpr[i]) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(EVA02EndcoderLayer(**_layer_cfg)) + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, patch_resolution) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py index 904de6b7967..b7df9e415eb 100644 --- a/mmpretrain/models/utils/__init__.py +++ b/mmpretrain/models/utils/__init__.py @@ -18,7 +18,7 @@ from .make_divisible import make_divisible from .norm import GRN, LayerNorm2d, build_norm_layer from .position_encoding import (ConditionalPositionEncoding, - PositionEncodingFourier, + PositionEncodingFourier, RotaryEmbeddingFast, build_2d_sincos_position_embedding) from .res_layer_extra_norm import ResLayerExtraNorm from .se_layer import SELayer @@ -72,4 +72,5 @@ 'ResLayerExtraNorm', 'SwiGLUFFN', 'SwiGLUFFNFused', + 'RotaryEmbeddingFast', ] diff --git a/mmpretrain/models/utils/position_encoding.py b/mmpretrain/models/utils/position_encoding.py index a200c06629a..07a3c486a25 100644 --- a/mmpretrain/models/utils/position_encoding.py +++ b/mmpretrain/models/utils/position_encoding.py @@ -8,6 +8,8 @@ from mmengine.model import BaseModule from mmengine.utils import digit_version +from ..utils import to_2tuple + # After pytorch v1.10.0, use torch.meshgrid without indexing # will raise extra warning. For more details, # refers to https://github.com/pytorch/pytorch/issues/50276 @@ -170,3 +172,76 @@ def build_2d_sincos_position_embedding( pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) return pos_emb + + +class RotaryEmbeddingFast(BaseModule): + """Implements 2D rotary embedding (RoPE) for image tokens. Position + encoding is implemented with sin and cos functions, + + .. math:: + Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\ + Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}} + Args: + embed_dims (int): The feature dimension for each head. + patch_resolution (int | tuple): The resolution of the + image, in format (H, W). + theta (float): The hyperparameter for position coding. + Defaults to 10000. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + patch_resolution, + theta=10000., + init_cfg=None): + super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg) + + self.half_dim = embed_dims // 2 + self.patch_resolution = to_2tuple(patch_resolution) + self.theta = theta + + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos) + self.register_buffer('freqs_sin', freqs_sin) + + def compute_position_embedding(self): + frequency = self.theta**( + torch.arange(0, self.half_dim, 2).float() / self.half_dim) + frequency = 1. / frequency + + h, w = self.patch_resolution + th = torch.arange(h) / h * self.half_dim + tw = torch.arange(w) / w * self.half_dim + + position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2) + position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2) + + height = position_h[:, None, :].expand(h, w, self.half_dim) + width = position_w[None, :, :].expand(h, w, self.half_dim) + position = torch.cat((height, width), dim=-1) + + freqs_cos = position.cos().view(-1, position.shape[-1]) + freqs_sin = position.sin().view(-1, position.shape[-1]) + + return freqs_cos, freqs_sin + + def forward(self, x, patch_resolution): + # Check whether the patch resolution is the predefined size + patch_resolution = to_2tuple(patch_resolution) + if patch_resolution != self.patch_resolution: + self.patch_resolution = patch_resolution + freqs_cos, freqs_sin = self.compute_position_embedding() + self.register_buffer('freqs_cos', freqs_cos.to(x.device)) + self.register_buffer('freqs_sin', freqs_sin.to(x.device)) + + batch, num_heads, num_patches, dim = x.shape + + inputs = x + x = x.reshape(batch, num_heads, num_patches, -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + x = x.reshape(batch, num_heads, num_patches, dim) + + return inputs * self.freqs_cos + x * self.freqs_sin diff --git a/model-index.yml b/model-index.yml index 8df1d3d3f31..c960b360a27 100644 --- a/model-index.yml +++ b/model-index.yml @@ -69,4 +69,5 @@ Import: - configs/riformer/metafile.yml - configs/sam/metafile.yml - configs/glip/metafile.yml + - configs/eva02/metafile.yml - configs/dinov2/metafile.yml diff --git a/tests/test_models/test_backbones/test_eva02.py b/tests/test_models/test_backbones/test_eva02.py new file mode 100644 index 00000000000..0672754223c --- /dev/null +++ b/tests/test_models/test_backbones/test_eva02.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from unittest import TestCase + +import torch + +from mmpretrain.models.backbones import ViTEVA02 + + +class TestEVA02(TestCase): + + def setUp(self): + self.cfg = dict( + arch='t', + img_size=336, + patch_size=14, + drop_path_rate=0.1, + drop_rate=0.1, + attn_drop_rate=0.2, + proj_drop_rate=0.3, + ) + + def test_structure(self): + # Test invalid default arch + with self.assertRaisesRegex(AssertionError, 'not in default archs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = 'unknown' + ViTEVA02(**cfg) + + # Test invalid custom arch + with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': int(24 * 4 * 2 / 3) + } + ViTEVA02(**cfg) + + # Test custom arch + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'embed_dims': 128, + 'num_layers': 6, + 'num_heads': 16, + 'feedforward_channels': int(128 * 4 * 2 / 3) + } + model = ViTEVA02(**cfg) + self.assertEqual(model.embed_dims, 128) + self.assertEqual(model.num_layers, 6) + for layer in model.layers: + self.assertEqual(layer.attn.num_heads, 16) + + # Test out_indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = {1: 1} + with self.assertRaisesRegex(AssertionError, "get "): + ViTEVA02(**cfg) + cfg['out_indices'] = [0, 13] + with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): + ViTEVA02(**cfg) + + # Test model structure + cfg = deepcopy(self.cfg) + model = ViTEVA02(**cfg) + self.assertEqual(len(model.layers), 12) + self.assertEqual(model.cls_token.shape, (1, 1, 192)) + self.assertEqual(model.pos_embed.shape, (1, 577, 192)) + dpr_inc = 0.1 / (12 - 1) + dpr = 0 + for layer in model.layers: + self.assertEqual(layer.attn.embed_dims, 192) + self.assertEqual(layer.attn.num_heads, 3) + self.assertAlmostEqual(layer.drop_path.drop_prob, dpr) + self.assertAlmostEqual(layer.mlp.dropout_layer.p, 0.1) + self.assertAlmostEqual(layer.attn.attn_drop.p, 0.2) + self.assertAlmostEqual(layer.attn.proj_drop.p, 0.3) + dpr += dpr_inc + + # Test model structure: final_norm + cfg = deepcopy(self.cfg) + cfg['final_norm'] = True + model = ViTEVA02(**cfg) + self.assertNotEqual(model.norm1.__class__, torch.nn.Identity) + + def test_forward(self): + imgs = torch.randn(1, 3, 336, 336) + + # test with_cls_token=False + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['out_type'] = 'cls_token' + with self.assertRaisesRegex(ValueError, 'must be True'): + ViTEVA02(**cfg) + + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['out_type'] = 'raw' + model = ViTEVA02(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (1, 24 * 24, 192)) + + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['out_type'] = 'featmap' + model = ViTEVA02(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (1, 192, 24, 24)) + + cfg = deepcopy(self.cfg) + cfg['with_cls_token'] = False + cfg['out_type'] = 'avg_featmap' + model = ViTEVA02(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + patch_token = outs[-1] + self.assertEqual(patch_token.shape, (1, 192)) + + # test with output cls_token + cfg = deepcopy(self.cfg) + model = ViTEVA02(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + cls_token = outs[-1] + self.assertEqual(cls_token.shape, (1, 192)) + + # Test forward with multi out indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = [-3, -2, -1] + model = ViTEVA02(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 3) + for out in outs: + self.assertEqual(out.shape, (1, 192)) diff --git a/tests/test_models/test_utils/test_position_encoding.py b/tests/test_models/test_utils/test_position_encoding.py index 221a20df126..7d80023cba8 100644 --- a/tests/test_models/test_utils/test_position_encoding.py +++ b/tests/test_models/test_utils/test_position_encoding.py @@ -1,10 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmpretrain.models.utils import ConditionalPositionEncoding +from mmpretrain.models.utils import (ConditionalPositionEncoding, + RotaryEmbeddingFast) def test_conditional_position_encoding_module(): CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2) outs = CPE(torch.randn(1, 3136, 32), (56, 56)) assert outs.shape == torch.Size([1, 784, 32]) + + +def test_rotary_embedding_fast_module(): + RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=24) + outs = RoPE(torch.randn(1, 2, 24 * 24, 64), (24, 24)) + assert outs.shape == torch.Size([1, 2, 24 * 24, 64]) + + RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=(14, 20)) + outs = RoPE(torch.randn(1, 2, 14 * 20, 64), (14, 20)) + assert outs.shape == torch.Size([1, 2, 14 * 20, 64]) diff --git a/tools/model_converters/eva02_to_mmpretrain.py b/tools/model_converters/eva02_to_mmpretrain.py new file mode 100644 index 00000000000..e5a8682f0f0 --- /dev/null +++ b/tools/model_converters/eva02_to_mmpretrain.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_eva02(ckpt): + + new_ckpt = OrderedDict() + qkv_proj = {} + qkv_bias = {} + w12_weight = {} + w12_bias = {} + + banned = { + 'mask_token', + 'lm_head.weight', + 'lm_head.bias', + 'norm.weight', + 'norm.bias', + } + + for k, v in list(ckpt.items()): + + if k in banned: + continue + + if k.startswith('head'): + new_k = k.replace('head.', 'head.fc.') + new_ckpt[new_k] = v + else: + if k.startswith('patch_embed'): + new_k = k.replace('proj.', 'projection.') + + elif k.startswith('fc_norm') or k.startswith('norm'): + new_k = k.replace('norm.', 'ln2.') + new_k = k.replace('fc_norm.', 'ln2.') + + elif k.startswith('blocks'): + new_k = k.replace('blocks.', 'layers.') + + if 'mlp' in new_k: + if 'w1.' in new_k or 'w2.' in new_k: + # For base and large version, mlp is implemented with + # 2 linears, where w1 and w2 are required to integrate + # into w12. + s = new_k.split('.') # e.g. layers.0.mlp.w1.weight + idx = s[1] + if 'weight' in new_k: + # w1.weight or w2.weight + if idx not in w12_weight: + w12_weight[idx] = {} + w12_weight[idx][s[-2]] = v + else: + # w1.bias or w2.bias + if idx not in w12_bias: + w12_bias[idx] = {} + w12_bias[idx][s[-2]] = v + continue + + if 'ffn_ln' in new_k: + new_k = new_k.replace('ffn_ln.', 'norm.') + + elif 'attn' in new_k: + if 'q_proj.weight' in new_k or \ + 'k_proj.weight' in new_k or \ + 'v_proj.weight' in new_k: + # For base and large version, qkv projection is + # implemented with three linear layers, + s = new_k.split('.') + idx = s[1] + if idx not in qkv_proj: + qkv_proj[idx] = {} + qkv_proj[idx][s[-2]] = v + continue + + if 'q_bias' in new_k or 'v_bias' in new_k: + # k_bias is 0 + s = new_k.split('.') + idx = s[1] + if idx not in qkv_bias: + qkv_bias[idx] = {} + qkv_bias[idx][s[-1]] = v + continue + + else: + new_k = k + + new_k = 'backbone.' + new_k + new_ckpt[new_k] = v + + for idx in qkv_proj: + q_proj = qkv_proj[idx]['q_proj'] + k_proj = qkv_proj[idx]['k_proj'] + v_proj = qkv_proj[idx]['v_proj'] + weight = torch.cat((q_proj, k_proj, v_proj)) + new_k = f'backbone.layers.{idx}.attn.qkv.weight' + new_ckpt[new_k] = weight + + for idx in qkv_bias: + q_bias = qkv_bias[idx]['q_bias'] + k_bias = torch.zeros_like(q_bias) + v_bias = qkv_bias[idx]['v_bias'] + weight = torch.cat((q_bias, k_bias, v_bias)) + new_k = f'backbone.layers.{idx}.attn.qkv.bias' + new_ckpt[new_k] = weight + + for idx in w12_weight: + w1 = w12_weight[idx]['w1'] + w2 = w12_weight[idx]['w2'] + weight = torch.cat((w1, w2)) + new_k = f'backbone.layers.{idx}.mlp.w12.weight' + new_ckpt[new_k] = weight + + for idx in w12_bias: + w1 = w12_bias[idx]['w1'] + w2 = w12_bias[idx]['w2'] + weight = torch.cat((w1, w2)) + new_k = f'backbone.layers.{idx}.mlp.w12.bias' + new_ckpt[new_k] = weight + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained eva02 ' + 'models to mmpretrain style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'module' in checkpoint: + state_dict = checkpoint['module'] + else: + state_dict = checkpoint + + weight = convert_eva02(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main()