-
Notifications
You must be signed in to change notification settings - Fork 141
/
dino_4scale_uniperceiver_adapter_base_6ep_gqa.py
231 lines (231 loc) · 8.69 KB
/
dino_4scale_uniperceiver_adapter_base_6ep_gqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
_base_ = [
'_base_/datasets/grounding_gqa.py',
'_base_/default_runtime.py'
]
# https://github.com/czczup/ViT-Adapter/releases/download/wsdm2023/uni-perceiver-base-L12-H768-224size-torch-pretrained_converted.pth
pretrained = 'pretrained/uni-perceiver-base-L12-H768-224size-torch-pretrained_converted.pth'
model = dict(
type='GroundingDINO',
with_aux_loss=True,
backbone=dict(
type='UniPerceiverAdapter',
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
drop_path_rate=0.2,
conv_inplane=64,
n_points=4,
deform_num_heads=12,
cffn_ratio=0.25,
deform_ratio=0.5,
num_classes=1,
with_cp=True,
out_indices=(1, 2, 3),
num_cross_attn=0,
interaction_indexes=[[0, 2], [3, 5], [6, 8], [9, 11]],
window_attn=[False, False, False, False, False, False,
False, False, False, False, False, False],
window_size=[None, None, None, None, None, None,
None, None, None, None, None, None],
pretrained=pretrained),
neck=dict(
type='ChannelMapper',
in_channels=[768, 768, 768],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
bbox_head=dict(
type='DINOHead',
num_query=100,
num_classes=1,
in_channels=2048, # TODO
sync_cls_avg_factor=True,
as_two_stage=True,
with_box_refine=True,
dn_cfg=dict(
type='CdnQueryGenerator',
noise_scale=dict(label=0.5, box=1.0), # 0.5, 0.4 for DN-DETR
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
transformer=dict(
type='DinoTransformer',
two_stage_num_proposals=100,
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
dropout=0.0), # 0.1 for DeformDETR
feedforward_channels=2048, # 1024 for DeformDETR
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
use_checkpoint=True,
act_cfg=dict(type='ReLU', inplace=True)),
ffn_dropout=0.0, # 0.1 for DeformDETR
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DinoTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
dropout=0.0), # 0.1 for DeformDETR
],
feedforward_channels=2048, # 1024 for DeformDETR
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
use_checkpoint=True,
act_cfg=dict(type='ReLU', inplace=True)),
ffn_dropout=0.0, # 0.1 for DeformDETR
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
temperature=20,
normalize=True),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))),
test_cfg=dict(max_per_img=1)) # TODO: Originally 100
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='LoadRefer'),
dict(type='RandomFlipWithRefer', flip_ratio=0.5),
dict(
type='AutoAugment',
policies=[
[
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(
type='Resize',
# The radio of all image in train dataset < 7
# follow the original impl
img_scale=[(400, 4200), (500, 4200), (600, 4200)],
multiscale_mode='value',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=False),
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='TokenizeRefer', max_sent_len=64),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'refer',
'r_mask', 'gt_bboxes', 'gt_labels'])
]
# test_pipeline, NOTE the Pad's size_divisor is different from the default
# setting (size_divisor=32). While there is little effect on the performance
# whether we use the default setting or use size_divisor=1.
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadRefer'),
dict(
type='MultiScaleFlipAug',
img_scale=[(1333, 600), (1333, 800), (1333, 1000)],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlipWithRefer'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='TokenizeRefer', max_sent_len=64),
dict(type='Collect', keys=['img', 'refer', 'r_mask'])
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(filter_empty_gt=True, pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='AdamW', lr=0.0001, weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[])
runner = dict(type='EpochBasedRunner', max_epochs=6)
checkpoint_config = dict(
interval=1,
max_keep_ckpts=3,
save_last=True,
)
evaluation = dict(save_best='auto')
custom_hooks = [
dict(
type='ExpMomentumEMAHook',
resume_from=None,
momentum=0.0001,
priority=49)
]