forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support Flickr30k Retrieval dataset (open-mmlab#1625)
* format * remove abs path * init add flickr30k caption * remove abs dir * update blip readme * add convert sscripts * minor * minor
- Loading branch information
1 parent
a1cfe88
commit 6d7fe91
Showing
9 changed files
with
612 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# data settings | ||
|
||
data_preprocessor = dict( | ||
type='MultiModalDataPreprocessor', | ||
mean=[122.770938, 116.7460125, 104.09373615], | ||
std=[68.5005327, 66.6321579, 70.32316305], | ||
to_rgb=True, | ||
) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
scale=384, | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='CleanCaption', keys='gt_caption'), | ||
dict( | ||
type='PackInputs', | ||
algorithm_keys=['gt_caption'], | ||
meta_keys=['image_id'], | ||
), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
scale=(384, 384), | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict(type='PackInputs', meta_keys=['image_id']), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=32, | ||
num_workers=5, | ||
dataset=dict( | ||
type='Flickr30kCaption', | ||
data_root='data/flickr30k', | ||
ann_file='annotations/dataset_flickr30k.json', | ||
data_prefix='images', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
drop_last=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=16, | ||
num_workers=5, | ||
dataset=dict( | ||
type='Flickr30kCaption', | ||
data_root='data/flickr30k', | ||
ann_file='annotations/dataset_flickr30k.json', | ||
data_prefix='images', | ||
split='val', | ||
pipeline=test_pipeline, | ||
), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
|
||
# refer tools/dataset_converters/convert_flickr30k_ann.py | ||
val_evaluator = dict( | ||
type='COCOCaption', | ||
ann_file='data/flickr30k_val_gt.json', | ||
) | ||
|
||
# # If you want standard test, please manually configure the test dataset | ||
test_dataloader = dict( | ||
batch_size=16, | ||
num_workers=5, | ||
dataset=dict( | ||
type='Flickr30kCaption', | ||
data_root='data/flickr30k', | ||
ann_file='annotations/dataset_flickr30k.json', | ||
data_prefix='images', | ||
split='test', | ||
pipeline=test_pipeline, | ||
), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
|
||
# refer tools/dataset_converters/convert_flickr30k_ann.py | ||
test_evaluator = dict( | ||
type='COCOCaption', | ||
ann_file='data/flickr30k_test_gt.json', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# data settings | ||
data_preprocessor = dict( | ||
type='MultiModalDataPreprocessor', | ||
mean=[122.770938, 116.7460125, 104.09373615], | ||
std=[68.5005327, 66.6321579, 70.32316305], | ||
to_rgb=True, | ||
) | ||
|
||
rand_increasing_policies = [ | ||
dict(type='AutoContrast'), | ||
dict(type='Equalize'), | ||
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)), | ||
dict( | ||
type='Brightness', magnitude_key='magnitude', | ||
magnitude_range=(0, 0.0)), | ||
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)), | ||
dict( | ||
type='Shear', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.3), | ||
direction='horizontal'), | ||
dict( | ||
type='Shear', | ||
magnitude_key='magnitude', | ||
magnitude_range=(0, 0.3), | ||
direction='vertical'), | ||
] | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
scale=384, | ||
crop_ratio_range=(0.5, 1.0), | ||
interpolation='bicubic'), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict( | ||
type='RandAugment', | ||
policies=rand_increasing_policies, | ||
num_policies=2, | ||
magnitude_level=5), | ||
dict(type='CleanCaption', keys='text'), | ||
dict( | ||
type='PackInputs', | ||
algorithm_keys=['text', 'is_matched'], | ||
meta_keys=['image_id']), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
scale=(384, 384), | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict(type='CleanCaption', keys='text'), | ||
dict( | ||
type='PackInputs', | ||
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'], | ||
meta_keys=['image_id']), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=32, | ||
num_workers=16, | ||
dataset=dict( | ||
type='Flickr30kRetrieval', | ||
data_root='data/flickr30k', | ||
ann_file='annotations/dataset_flickr30k.json', | ||
data_prefix='images', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
drop_last=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=64, | ||
num_workers=16, | ||
dataset=dict( | ||
type='Flickr30kRetrieval', | ||
data_root='data/flickr30k', | ||
ann_file='annotations/dataset_flickr30k.json', | ||
data_prefix='images', | ||
split='val', | ||
pipeline=test_pipeline, | ||
test_mode=True, # This is required for evaluation | ||
), | ||
sampler=dict(type='SequentialSampler', subsample_type='sequential'), | ||
persistent_workers=True, | ||
) | ||
|
||
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = dict( | ||
batch_size=64, | ||
num_workers=16, | ||
dataset=dict( | ||
type='Flickr30kRetrieval', | ||
data_root='data/flickr30k', | ||
ann_file='annotations/dataset_flickr30k.json', | ||
data_prefix='images', | ||
split='test', | ||
pipeline=test_pipeline, | ||
test_mode=True, # This is required for evaluation | ||
), | ||
sampler=dict(type='SequentialSampler', subsample_type='sequential'), | ||
persistent_workers=True, | ||
) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
_base_ = [ | ||
'../_base_/datasets/flickr30k_caption.py', | ||
'../_base_/default_runtime.py', | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
type='BlipCaption', | ||
vision_encoder=dict( | ||
type='VisionTransformer', | ||
arch='b', | ||
img_size=384, | ||
patch_size=16, | ||
out_type='raw', | ||
), | ||
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'), | ||
decoder_head=dict( | ||
type='SeqGenerationHead', | ||
decoder=dict( | ||
type='XBertLMHeadDecoder', | ||
med_config=dict( | ||
architectures=['BertModel'], | ||
attention_probs_dropout_prob=0.1, | ||
hidden_act='gelu', | ||
hidden_dropout_prob=0.1, | ||
hidden_size=768, | ||
initializer_range=0.02, | ||
intermediate_size=3072, | ||
layer_norm_eps=1e-12, | ||
max_position_embeddings=512, | ||
model_type='bert', | ||
num_attention_heads=12, | ||
num_hidden_layers=12, | ||
pad_token_id=0, | ||
add_type_embeddings=False, | ||
vocab_size=30524, | ||
encoder_width=768, | ||
add_cross_attention=True), | ||
), | ||
), | ||
prompt='a picture of ', | ||
max_txt_len=20, | ||
) | ||
|
||
# schedule settings | ||
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05)) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='CosineAnnealingLR', | ||
by_epoch=True, | ||
begin=0, | ||
end=10, | ||
) | ||
] | ||
|
||
train_cfg = dict(max_epochs=10) | ||
val_cfg = dict() | ||
test_cfg = dict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
_base_ = [ | ||
'../_base_/datasets/flickr30k_retrieval.py', | ||
'../_base_/default_runtime.py', | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
type='BlipRetrieval', | ||
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'), | ||
vision_backbone=dict( | ||
type='VisionTransformer', | ||
arch='b', | ||
img_size=384, | ||
patch_size=16, | ||
out_type='raw', | ||
), | ||
text_backbone=dict( | ||
type='XBertEncoder', | ||
med_config=dict( | ||
architectures=['BertModel'], | ||
attention_probs_dropout_prob=0.1, | ||
hidden_act='gelu', | ||
hidden_dropout_prob=0.1, | ||
hidden_size=768, | ||
initializer_range=0.02, | ||
intermediate_size=3072, | ||
layer_norm_eps=1e-12, | ||
max_position_embeddings=512, | ||
model_type='bert', | ||
num_attention_heads=12, | ||
num_hidden_layers=12, | ||
pad_token_id=0, | ||
add_type_embeddings=False, | ||
vocab_size=30524, | ||
encoder_width=768, | ||
add_cross_attention=True), | ||
), | ||
vision_neck=dict( | ||
type='Linear', | ||
in_features=768, | ||
out_features=256, | ||
), | ||
text_neck=dict( | ||
type='Linear', | ||
in_features=768, | ||
out_features=256, | ||
), | ||
head=dict( | ||
type='ITCHead', | ||
embed_dim=256, | ||
), | ||
multimodal_head=dict( | ||
type='ITMHead', | ||
hidden_size=768, | ||
with_pooler=False, | ||
), | ||
topk=256, | ||
max_txt_len=35, | ||
) | ||
|
||
# optimizer | ||
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.04) | ||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer) | ||
|
||
# learning rate scheduler | ||
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)] | ||
|
||
# runtime settings | ||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=6) | ||
val_cfg = dict(type='RetrievalValLoop') | ||
test_cfg = dict(type='RetrievalTestLoop') | ||
|
||
randomness = dict(seed=42) | ||
|
||
default_hooks = dict(logger=dict(interval=1)) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type='WarmupParamHook', | ||
param_name='alpha', | ||
module_name='head', | ||
warmup_epochs=2) | ||
] |
Oops, something went wrong.