Skip to content

Commit

Permalink
[Feature] Add InternImage Classification project (open-mmlab#1569)
Browse files Browse the repository at this point in the history
* [Feature] add internimage project

* [Feature] add internimage project

* update license

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* update license

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* [Feature] add internimage project

* update internimage configs

* support internimage project

* support internimage project

* support internimage project

* internimage
  • Loading branch information
zzc98 authored Jun 13, 2023
1 parent 8e9e880 commit 3eaf719
Show file tree
Hide file tree
Showing 26 changed files with 3,532 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mmpretrain/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def __init__(self,
self.neck = neck
self.head = head

# If the model needs to load pretrain weights from a third party,
# the key can be modified with this hook
if hasattr(self.backbone, '_checkpoint_filter'):
self._register_load_state_dict_pre_hook(
self.backbone._checkpoint_filter)

def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
Expand Down
121 changes: 121 additions & 0 deletions projects/internimage_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# InternImage Classification

## Description

This is the implementation of [InternImage](https://arxiv.org/abs/2211.05778) for image classification.

## Usage

### Setup Environment

Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) documentation of MMPretrain to finish installation.

Please install DCNv3. Run the command below following the [ InternImage official installation instructions](https://github.com/OpenGVLab/InternImage/blob/master/classification/README.md).

```shell
cd ops_dcnv3
sh ./make.sh
```

### Training and Test Commands

At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/internimage_classification/` root directory, please run command below to add it.

```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```

#### Training

##### On Local Single GPU

```bash
# train with mim
mim train mmpretrain ${CONFIG} --work-dir ${WORK_DIR}

# a specific command example
mim train mmpretrain configs/internimage-tiny_8xb128_in1k-224.py \
--work-dir work_dirs/internimage-tiny_8xb128_in1k-224/
```

##### On Multiple GPUs

```bash
# train with mim
mim train mmpretrain ${CONFIG} \
--work-dir ${WORK_DIR} \
--launcher pytorch --gpus 8
```

##### On Multiple GPUs with Slurm

```bash
# train with mim
mim train mmpretrain ${CONFIG} \
--work-dir ${WORK_DIR} \
--launcher slurm --gpus 16 --gpus-per-node 8 \
--partition ${PARTITION}
```

#### Test

Please download the pretrain weight provided by [OpenGVLab](https://github.com/OpenGVLab/) from [here](https://huggingface.co/OpenGVLab/InternImage/tree/main)

##### On Local Single GPU

```bash
# test with mim
mim test mmpretrain ${CONFIG} -C ${CHECKPOINT}

# a specific command example
mim test mmpretrain configs/internimage-tiny_8xb128_in1k-224.py -C /PATH/TO/internimage_t_1k_224.pth
```

##### On Multiple GPUs

```bash
# test with mim
# a specific command examples, 8 GPUs here
mim test mmpretrain configs/internimage_t_1k_224.py \
-C /PATH/TO/internimage_t_1k_224.pth \
--launcher pytorch --gpus 8
```

##### On Multiple GPUs with Slurm

```bash
# test with mim
mim test mmpretrain ${CONFIG} \
-C ${CHECKPOINT}
--work-dir ${WORK_DIR} \
--launcher slurm --gpus 8 --gpus-per-node 8 \
--partition ${PARTITION} \
$PY_ARGS
```

Note: `PY_ARGS` is other optional args.

## Results on ImageNet1K

The accuracy of different models on ImageNet1K,

| name | resolution | acc@1 | acc@5 | config | weight |
| :------------: | :--------: | :-----: | :-----: | :-------------------------------------------------------: | :-----------------------------------------------------------------------------------------------: |
| InternImage-T | 224 | 83.4700 | 96.5340 | [config](./configs/internimage-tiny_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_t_1k_224.pth) |
| InternImage-S | 224 | 84.1640 | 96.9320 | [config](./configs/internimage-small_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_s_1k_224.pth) |
| InternImage-B | 224 | 84.8660 | 97.1820 | [config](./configs/internimage-base_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_b_1k_224.pth) |
| InternImage-L | 384 | 87.7060 | 98.3820 | [config](./configs/internimage-large_8xb128_in1k-384.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_l_22kto1k_384.pth) |
| InternImage-XL | 384 | 88.0460 | 98.5620 | [config](./configs/internimage-xlagre_8xb128_in1k-384.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_xl_22kto1k_384.pth) |
| InternImage-H | 640 | 89.5500 | 98.8500 | [config](./configs/internimage-huge_8xb128_in1k-640.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_h_22kto1k_640.pth) |
| InternImage-G | 512 | 90.0580 | 98.9700 | [config](./configs/internimage-giant_8xb128_in1k-512.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_g_22kto1k_512.pth) |

## Citation

```bibtex
@article{wang2022internimage,
title={InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions},
author={Wang, Wenhai and Dai, Jifeng and Chen, Zhe and Huang, Zhenhang and Li, Zhiqi and Zhu, Xizhou and Hu, Xiaowei and Lu, Tong and Lu, Lewei and Li, Hongsheng and others},
journal={arXiv preprint arXiv:2211.05778},
year={2022}
}
```
113 changes: 113 additions & 0 deletions projects/internimage_classification/configs/_base_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
_base_ = 'mmpretrain::_base_/default_runtime.py'

# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=224,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]

train_dataloader = dict(
batch_size=128,
num_workers=8,
dataset=dict(
type=dataset_type,
data_root='../../data/imagenet',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
batch_size=128,
num_workers=8,
dataset=dict(
type=dataset_type,
data_root='../../data/imagenet',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

test_dataloader = val_dataloader
test_evaluator = val_evaluator

# model setting
custom_imports = dict(imports='models')

model = dict(
type='ImageClassifier',
backbone=dict(
type='InternImage',
stem_channels=64,
drop_path_rate=0.1,
stage_blocks=[4, 4, 18, 4],
groups=[4, 8, 16, 32]),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5)))

# optimizer
optim_wrapper = dict(
optimizer=dict(type='AdamW', lr=1.25e-04, eps=1e-8, betas=(0.9, 0.999)),
weight_decay=0.05)

# learning policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
by_epoch=True,
begin=0,
end=20,
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type='CosineAnnealingLR',
T_max=280,
by_epoch=True,
begin=20,
end=300,
eta_min=1.25e-06)
]

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=128 * 8)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = './_base_.py'

model = dict(
backbone=dict(
stem_channels=112,
drop_path_rate=0.5,
stage_blocks=[4, 4, 21, 4],
groups=[7, 14, 28, 56],
layer_scale=1e-5,
post_norm=True),
head=dict(in_channels=1344))

optim_wrapper = dict(optimizer=dict(lr=0.0005))
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
_base_ = './_base_.py'

model = dict(
backbone=dict(
stem_channels=512,
drop_path_rate=0.4,
stage_blocks=[2, 2, 48, 4],
groups=[16, 32, 64, 128],
dw_kernel_size=5,
level2_post_norm=True,
level2_post_norm_block_ids=[5, 11, 17, 23, 29, 35, 41, 47],
center_feature_scale=True,
use_clip_projector=True,
),
neck=None,
head=dict(in_channels=768))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=512,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=512,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=512),
dict(type='PackInputs'),
]

train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

optim_wrapper = dict(optimizer=dict(lr=5e-6))
param_scheduler = [
dict(
type='LinearLR',
by_epoch=True,
begin=0,
end=2,
convert_to_iter_based=True),
dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20)
]
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
_base_ = './_base_.py'

model = dict(
backbone=dict(
stem_channels=320,
drop_path_rate=0.1,
stage_blocks=[6, 6, 32, 6],
groups=[10, 20, 40, 80],
dw_kernel_size=5,
res_post_norm=True,
level2_post_norm=True,
level2_post_norm_block_ids=[5, 11, 17, 23, 29],
center_feature_scale=True,
use_clip_projector=True,
),
neck=None,
head=dict(in_channels=768))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=640,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=640,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=640),
dict(type='PackInputs')
]

train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

optim_wrapper = dict(optimizer=dict(lr=5e-6))
param_scheduler = [
dict(
type='LinearLR',
by_epoch=True,
begin=0,
end=2,
convert_to_iter_based=True),
dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20)
]
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
Loading

0 comments on commit 3eaf719

Please sign in to comment.