Skip to content

Commit

Permalink
[Feature] Support Flickr30k Retrieval dataset (open-mmlab#1625)
Browse files Browse the repository at this point in the history
* format

* remove abs path

* init add flickr30k caption

* remove abs dir

* update blip readme

* add convert sscripts

* minor

* minor
  • Loading branch information
InvincibleWyq authored Jun 19, 2023
1 parent a1cfe88 commit 6d7fe91
Show file tree
Hide file tree
Showing 9 changed files with 612 additions and 2 deletions.
92 changes: 92 additions & 0 deletions configs/_base_/datasets/flickr30k_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# data settings

data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='CleanCaption', keys='gt_caption'),
dict(
type='PackInputs',
algorithm_keys=['gt_caption'],
meta_keys=['image_id'],
),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]

train_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

# refer tools/dataset_converters/convert_flickr30k_ann.py
val_evaluator = dict(
type='COCOCaption',
ann_file='data/flickr30k_val_gt.json',
)

# # If you want standard test, please manually configure the test dataset
test_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='test',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

# refer tools/dataset_converters/convert_flickr30k_ann.py
test_evaluator = dict(
type='COCOCaption',
ann_file='data/flickr30k_test_gt.json',
)
112 changes: 112 additions & 0 deletions configs/_base_/datasets/flickr30k_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

rand_increasing_policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.0)),
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='horizontal'),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='vertical'),
]

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
crop_ratio_range=(0.5, 1.0),
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies=rand_increasing_policies,
num_policies=2,
magnitude_level=5),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'is_matched'],
meta_keys=['image_id']),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
meta_keys=['image_id']),
]

train_dataloader = dict(
batch_size=32,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=64,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=test_pipeline,
test_mode=True, # This is required for evaluation
),
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
persistent_workers=True,
)

val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))

# If you want standard test, please manually configure the test dataset
test_dataloader = dict(
batch_size=64,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='test',
pipeline=test_pipeline,
test_mode=True, # This is required for evaluation
),
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
persistent_workers=True,
)
test_evaluator = val_evaluator
18 changes: 18 additions & 0 deletions configs/blip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
| :----------------------------- | :--------: | :---: | :----: | :-----------------------------------: | :--------------------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_caption`\* | 223.97 | 14.69 | 109.12 | [config](./blip-base_8xb32_nocaps.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |

### Image Caption on Flickr30k

| Model | Params (M) | SPICE | CIDER | Config | Download |
| :----------------------------- | :--------: | :---: | :---: | :----------------------------------------------: | :----------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_caption`\* | 223.97 | 15.58 | 68.89 | [config](./blip-base_8xb32_caption_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-caption_20230419-a5b71af3.pth) |

### Visual Grounding on RefCOCO

| Model | Params (M) | Accuracy (testA) | Accuracy (testB) | Config | Download |
Expand Down Expand Up @@ -88,6 +94,18 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
| :------------------------------- | :--------: | :------: | :------: | :--------------------------------------: | :----------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_retrieval`\* | 447.49 | 64.82 | 86.28 | [config](./blip-base_8xb32_retrieval.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |

### Image-To-Text Retrieval on Flickr30k

| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_retrieval`\* | 447.49 | 95.10# | 99.60# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |

### Text-To-Image Retrieval on Flickr30k

| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_retrieval`\* | 447.49 | 85.26# | 96.58# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |

### NLVR on NLVR2

| Model | Params (M) | Top-1 (%) | Config | Download |
Expand Down
59 changes: 59 additions & 0 deletions configs/blip/blip-base_8xb32_caption_flickr30k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
_base_ = [
'../_base_/datasets/flickr30k_caption.py',
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='BlipCaption',
vision_encoder=dict(
type='VisionTransformer',
arch='b',
img_size=384,
patch_size=16,
out_type='raw',
),
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
decoder_head=dict(
type='SeqGenerationHead',
decoder=dict(
type='XBertLMHeadDecoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
),
prompt='a picture of ',
max_txt_len=20,
)

# schedule settings
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))

param_scheduler = [
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=0,
end=10,
)
]

train_cfg = dict(max_epochs=10)
val_cfg = dict()
test_cfg = dict()
83 changes: 83 additions & 0 deletions configs/blip/blip-base_8xb32_retrieval_flickr30k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
_base_ = [
'../_base_/datasets/flickr30k_retrieval.py',
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='BlipRetrieval',
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
vision_backbone=dict(
type='VisionTransformer',
arch='b',
img_size=384,
patch_size=16,
out_type='raw',
),
text_backbone=dict(
type='XBertEncoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
vision_neck=dict(
type='Linear',
in_features=768,
out_features=256,
),
text_neck=dict(
type='Linear',
in_features=768,
out_features=256,
),
head=dict(
type='ITCHead',
embed_dim=256,
),
multimodal_head=dict(
type='ITMHead',
hidden_size=768,
with_pooler=False,
),
topk=256,
max_txt_len=35,
)

# optimizer
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.04)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)

# learning rate scheduler
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)]

# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=6)
val_cfg = dict(type='RetrievalValLoop')
test_cfg = dict(type='RetrievalTestLoop')

randomness = dict(seed=42)

default_hooks = dict(logger=dict(interval=1))

custom_hooks = [
dict(
type='WarmupParamHook',
param_name='alpha',
module_name='head',
warmup_epochs=2)
]
Loading

0 comments on commit 6d7fe91

Please sign in to comment.