forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support SparK. (open-mmlab#1531)
* add spark configs * fix configs * remove repeat aug * add module codes * support lr layer decay of resnet * update * fix lint * add metafile and readme * fix lint * add models and logs * refactor codes * fix lint * update model rst * update name * add docstring * add ut * fix lint --------- Co-authored-by: Ma Zerun <[email protected]>
- Loading branch information
1 parent
bfd49b0
commit a1cfe88
Showing
27 changed files
with
1,964 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# SparK | ||
|
||
> [Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling](https://arxiv.org/abs/2301.03580) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
We identify and overcome two key obstacles in extending the success of BERT-style pre-training, or the masked image modeling, to convolutional networks (convnets): (i) convolution operation cannot handle irregular, random-masked input images; (ii) the single-scale nature of BERT pre-training is inconsistent with convnet's hierarchical structure. For (i), we treat unmasked pixels as sparse voxels of 3D point clouds and use sparse convolution to encode. This is the first use of sparse convolution for 2D masked modeling. For (ii), we develop a hierarchical decoder to reconstruct images from multi-scale encoded features. Our method called Sparse masKed modeling (SparK) is general: it can be used directly on any convolutional model without backbone modifications. We validate it on both classical (ResNet) and modern (ConvNeXt) models: on three downstream tasks, it surpasses both state-of-the-art contrastive learning and transformer-based masked modeling by similarly large margins (around +1.0%). Improvements on object detection and instance segmentation are more substantial (up to +3.5%), verifying the strong transferability of features learned. We also find its favorable scaling behavior by observing more gains on larger models. All this evidence reveals a promising future of generative pre-training on convnets. Codes and models are released at https://github.com/keyu-tian/SparK. | ||
|
||
<div align=center> | ||
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/b93e8d6f-ec1e-4f27-b986-da470fabe7df" width="80%"/> | ||
</div> | ||
|
||
## How to use it? | ||
|
||
<!-- [TABS-BEGIN] --> | ||
|
||
**Predict image** | ||
|
||
```python | ||
from mmpretrain import inference_model | ||
|
||
predict = inference_model('resnet50_spark-pre_300e_in1k', 'demo/bird.JPEG') | ||
print(predict['pred_class']) | ||
print(predict['pred_score']) | ||
``` | ||
|
||
**Use the model** | ||
|
||
```python | ||
import torch | ||
from mmpretrain import get_model | ||
|
||
model = get_model('spark_sparse-resnet50_800e_in1k', pretrained=True) | ||
inputs = torch.rand(1, 3, 224, 224) | ||
out = model(inputs) | ||
print(type(out)) | ||
# To extract features. | ||
feats = model.extract_feat(inputs) | ||
print(type(feats)) | ||
``` | ||
|
||
**Train/Test Command** | ||
|
||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). | ||
|
||
Train: | ||
|
||
```shell | ||
python tools/train.py configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py | ||
``` | ||
|
||
Test: | ||
|
||
```shell | ||
python tools/test.py configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth | ||
``` | ||
|
||
<!-- [TABS-END] --> | ||
|
||
## Models and results | ||
|
||
### Pretrained models | ||
|
||
| Model | Params (M) | Flops (G) | Config | Download | | ||
| :--------------------------------------- | :--------: | :-------: | :-------------------------------------------------------------------: | :----------------------------------------------------------------------: | | ||
| `spark_sparse-resnet50_800e_in1k` | 37.97 | 4.10 | [config](spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.json) | | ||
| `spark_sparse-convnextv2-tiny_800e_in1k` | 39.73 | 4.47 | [config](spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.json) | | ||
|
||
### Image Classification on ImageNet-1k | ||
|
||
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download | | ||
| :------------------------------------ | :----------------------------------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------------: | :-----------------------------------------: | | ||
| `resnet50_spark-pre_300e_in1k` | [SPARK](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth) | 23.52 | 1.31 | 80.10 | 94.90 | [config](benchmarks/resnet50_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.json) | | ||
| `convnextv2-tiny_spark-pre_300e_in1k` | [SPARK](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth) | 28.64 | 4.47 | 82.80 | 96.30 | [config](benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.json) | | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@Article{tian2023designing, | ||
author = {Keyu Tian and Yi Jiang and Qishuai Diao and Chen Lin and Liwei Wang and Zehuan Yuan}, | ||
title = {Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling}, | ||
journal = {arXiv:2301.03580}, | ||
year = {2023}, | ||
} | ||
``` |
122 changes: 122 additions & 0 deletions
122
configs/spark/benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
_base_ = [ | ||
'../../_base_/datasets/imagenet_bs64_swin_224.py', | ||
'../../_base_/default_runtime.py', | ||
] | ||
|
||
data_preprocessor = dict( | ||
num_classes=1000, | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
bgr_mean = data_preprocessor['mean'][::-1] | ||
bgr_std = data_preprocessor['std'][::-1] | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
scale=224, | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='NumpyToPIL', to_rgb=True), | ||
dict( | ||
type='torchvision/TrivialAugmentWide', | ||
num_magnitude_bins=31, | ||
interpolation='bicubic', | ||
fill=None), | ||
dict(type='PILToNumpy', to_bgr=True), | ||
dict( | ||
type='RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=bgr_mean, | ||
fill_std=bgr_std), | ||
dict(type='PackInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
dataset=dict(pipeline=train_pipeline), | ||
sampler=dict(type='RepeatAugSampler', shuffle=True), | ||
) | ||
|
||
# Model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='ConvNeXt', | ||
arch='tiny', | ||
drop_path_rate=0.1, | ||
layer_scale_init_value=0., | ||
use_grn=True, | ||
), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=768, | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02, bias=0.), | ||
), | ||
train_cfg=dict(augments=[ | ||
dict(type='Mixup', alpha=0.8), | ||
dict(type='CutMix', alpha=1.0), | ||
]), | ||
) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type='EMAHook', | ||
momentum=1e-4, | ||
evaluate_on_origin=True, | ||
priority='ABOVE_NORMAL') | ||
] | ||
|
||
# schedule settings | ||
# optimizer | ||
optim_wrapper = dict( | ||
optimizer=dict( | ||
type='AdamW', lr=3.2e-3, betas=(0.9, 0.999), weight_decay=0.05), | ||
constructor='LearningRateDecayOptimWrapperConstructor', | ||
paramwise_cfg=dict( | ||
layer_decay_rate=0.7, | ||
norm_decay_mult=0.0, | ||
bias_decay_mult=0.0, | ||
flat_decay_mult=0.0)) | ||
|
||
# learning policy | ||
param_scheduler = [ | ||
# warm up learning rate scheduler | ||
dict( | ||
type='LinearLR', | ||
start_factor=0.0001, | ||
by_epoch=True, | ||
begin=0, | ||
end=20, | ||
convert_to_iter_based=True), | ||
# main learning rate scheduler | ||
dict( | ||
type='CosineAnnealingLR', | ||
T_max=280, | ||
eta_min=1.0e-5, | ||
by_epoch=True, | ||
begin=20, | ||
end=300) | ||
] | ||
train_cfg = dict(by_epoch=True, max_epochs=300) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
default_hooks = dict( | ||
# only keeps the latest 2 checkpoints | ||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
# based on the actual training batch size. | ||
auto_scale_lr = dict(base_batch_size=2048) |
107 changes: 107 additions & 0 deletions
107
configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
_base_ = [ | ||
'../../_base_/models/resnet50.py', | ||
'../../_base_/datasets/imagenet_bs256_rsb_a12.py', | ||
'../../_base_/default_runtime.py' | ||
] | ||
# modification is based on ResNets RSB settings | ||
data_preprocessor = dict( | ||
num_classes=1000, | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
bgr_mean = data_preprocessor['mean'][::-1] | ||
bgr_std = data_preprocessor['std'][::-1] | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
scale=224, | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='NumpyToPIL', to_rgb=True), | ||
dict( | ||
type='torchvision/TrivialAugmentWide', | ||
num_magnitude_bins=31, | ||
interpolation='bicubic', | ||
fill=None), | ||
dict(type='PILToNumpy', to_bgr=True), | ||
dict( | ||
type='RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=bgr_mean, | ||
fill_std=bgr_std), | ||
dict(type='PackInputs'), | ||
] | ||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict( | ||
norm_cfg=dict(type='SyncBN', requires_grad=True), | ||
drop_path_rate=0.05, | ||
), | ||
head=dict( | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, use_sigmoid=True)), | ||
train_cfg=dict(augments=[ | ||
dict(type='Mixup', alpha=0.1), | ||
dict(type='CutMix', alpha=1.0) | ||
])) | ||
|
||
# schedule settings | ||
# optimizer | ||
optim_wrapper = dict( | ||
optimizer=dict( | ||
type='Lamb', | ||
lr=0.016, | ||
weight_decay=0.02, | ||
), | ||
constructor='LearningRateDecayOptimWrapperConstructor', | ||
paramwise_cfg=dict( | ||
layer_decay_rate=0.7, | ||
norm_decay_mult=0.0, | ||
bias_decay_mult=0.0, | ||
flat_decay_mult=0.0)) | ||
|
||
# learning policy | ||
param_scheduler = [ | ||
# warm up learning rate scheduler | ||
dict( | ||
type='LinearLR', | ||
start_factor=0.0001, | ||
by_epoch=True, | ||
begin=0, | ||
end=5, | ||
# update by iter | ||
convert_to_iter_based=True), | ||
# main learning rate scheduler | ||
dict( | ||
type='CosineAnnealingLR', | ||
T_max=295, | ||
eta_min=1.0e-6, | ||
by_epoch=True, | ||
begin=5, | ||
end=300) | ||
] | ||
train_cfg = dict(by_epoch=True, max_epochs=300) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
default_hooks = dict( | ||
# only keeps the latest 2 checkpoints | ||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) | ||
# randomness | ||
randomness = dict(seed=0, diff_rank_seed=True) | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
# based on the actual training batch size. | ||
auto_scale_lr = dict(base_batch_size=2048) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
Collections: | ||
- Name: SparK | ||
Metadata: | ||
Architecture: | ||
- Dense Connections | ||
- GELU | ||
- Layer Normalization | ||
- Multi-Head Attention | ||
- Scaled Dot-Product Attention | ||
Paper: | ||
Title: 'Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling' | ||
URL: https://arxiv.org/abs/2301.03580 | ||
README: configs/spark/README.md | ||
Code: | ||
URL: null | ||
Version: null | ||
|
||
Models: | ||
- Name: spark_sparse-resnet50_800e_in1k | ||
Metadata: | ||
FLOPs: 4100000000 | ||
Parameters: 37971000 | ||
Training Data: | ||
- ImageNet-1k | ||
In Collection: SparK | ||
Results: null | ||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth | ||
Config: configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py | ||
Downstream: | ||
- resnet50_spark-pre_300e_in1k | ||
- Name: resnet50_spark-pre_300e_in1k | ||
Metadata: | ||
FLOPs: 1310000000 | ||
Parameters: 23520000 | ||
Training Data: | ||
- ImageNet-1k | ||
In Collection: SparK | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 80.1 | ||
Top 5 Accuracy: 94.9 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth | ||
Config: configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py | ||
|
||
- Name: spark_sparse-convnextv2-tiny_800e_in1k | ||
Metadata: | ||
FLOPs: 4470000000 | ||
Parameters: 39732000 | ||
Training Data: | ||
- ImageNet-1k | ||
In Collection: SparK | ||
Results: null | ||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth | ||
Config: configs/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py | ||
Downstream: | ||
- convnextv2-tiny_spark-pre_300e_in1k | ||
- Name: convnextv2-tiny_spark-pre_300e_in1k | ||
Metadata: | ||
FLOPs: 4469631744 | ||
Parameters: 28635496 | ||
Training Data: | ||
- ImageNet-1k | ||
In Collection: SparK | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 82.8 | ||
Top 5 Accuracy: 96.3 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.pth | ||
Config: configs/spark/benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py |
Oops, something went wrong.