-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtransxnet.py
751 lines (677 loc) · 28.9 KB
/
transxnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
import os
import math
import copy
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import checkpoint
from mmcv.runner.checkpoint import load_checkpoint
from timm.models.layers import DropPath, to_2tuple
from mmcv.cnn.bricks import ConvModule, build_activation_layer, build_norm_layer
try:
from mmseg.models.builder import BACKBONES as seg_BACKBONES
from mmseg.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmseg = True
except ImportError:
print("If for semantic segmentation, please install mmsegmentation first")
has_mmseg = False
try:
from mmdet.models.builder import BACKBONES as det_BACKBONES
from mmdet.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmdet = True
except ImportError:
print("If for detection, please install mmdetection first")
has_mmdet = False
class PatchEmbed(nn.Module):
"""Patch Embedding module implemented by a layer of convolution.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
Args:
patch_size (int): Patch size of the patch embedding. Defaults to 16.
stride (int): Stride of the patch embedding. Defaults to 16.
padding (int): Padding of the patch embedding. Defaults to 0.
in_chans (int): Input channels. Defaults to 3.
embed_dim (int): Output dimension of the patch embedding.
Defaults to 768.
norm_layer (module): Normalization module. Defaults to None (not use).
"""
def __init__(self,
patch_size=16,
stride=16,
padding=0,
in_chans=3,
embed_dim=768,
norm_layer=dict(type='BN2d'),
act_cfg=None,):
super().__init__()
self.proj = ConvModule(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding,
norm_cfg=norm_layer,
act_cfg=act_cfg,
)
def forward(self, x):
return self.proj(x)
class Attention(nn.Module): ### OSRA
def __init__(self, dim,
num_heads=1,
qk_scale=None,
attn_drop=0,
sr_ratio=1,):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.sr_ratio = sr_ratio
self.q = nn.Conv2d(dim, dim, kernel_size=1)
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1)
self.attn_drop = nn.Dropout(attn_drop)
if sr_ratio > 1:
self.sr = nn.Sequential(
ConvModule(dim, dim,
kernel_size=sr_ratio+3,
stride=sr_ratio,
padding=(sr_ratio+3)//2,
groups=dim,
bias=False,
norm_cfg=dict(type='BN2d'),
act_cfg=dict(type='GELU')),
ConvModule(dim, dim,
kernel_size=1,
groups=dim,
bias=False,
norm_cfg=dict(type='BN2d'),
act_cfg=None,),)
else:
self.sr = nn.Identity()
self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
def forward(self, x, relative_pos_enc=None):
B, C, H, W = x.shape
q = self.q(x).reshape(B, self.num_heads, C//self.num_heads, -1).transpose(-1, -2)
kv = self.sr(x)
kv = self.local_conv(kv) + kv
k, v = torch.chunk(self.kv(kv), chunks=2, dim=1)
k = k.reshape(B, self.num_heads, C//self.num_heads, -1)
v = v.reshape(B, self.num_heads, C//self.num_heads, -1).transpose(-1, -2)
attn = (q @ k) * self.scale
if relative_pos_enc is not None:
if attn.shape[2:] != relative_pos_enc.shape[2:]:
relative_pos_enc = F.interpolate(relative_pos_enc, size=attn.shape[2:],
mode='bicubic', align_corners=False)
attn = attn + relative_pos_enc
attn = torch.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(-1, -2)
return x.reshape(B, C, H, W)
class DynamicConv2d(nn.Module): ### IDConv
def __init__(self,
dim,
kernel_size=3,
reduction_ratio=4,
num_groups=1,
bias=True):
super().__init__()
assert num_groups > 1, f"num_groups {num_groups} should > 1."
self.num_groups = num_groups
self.K = kernel_size
self.bias_type = bias
self.weight = nn.Parameter(torch.empty(num_groups, dim, kernel_size, kernel_size), requires_grad=True)
self.pool = nn.AdaptiveAvgPool2d(output_size=(kernel_size, kernel_size))
self.proj = nn.Sequential(
ConvModule(dim,
dim//reduction_ratio,
kernel_size=1,
norm_cfg=dict(type='BN2d'),
act_cfg=dict(type='GELU'),),
nn.Conv2d(dim//reduction_ratio, dim*num_groups, kernel_size=1),)
if bias:
self.bias = nn.Parameter(torch.empty(num_groups, dim), requires_grad=True)
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.trunc_normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.trunc_normal_(self.bias, std=0.02)
def forward(self, x):
B, C, H, W = x.shape
scale = self.proj(self.pool(x)).reshape(B, self.num_groups, C, self.K, self.K)
scale = torch.softmax(scale, dim=1)
weight = scale * self.weight.unsqueeze(0)
weight = torch.sum(weight, dim=1, keepdim=False)
weight = weight.reshape(-1, 1, self.K, self.K)
if self.bias is not None:
scale = self.proj(torch.mean(x, dim=[-2, -1], keepdim=True))
scale = torch.softmax(scale.reshape(B, self.num_groups, C), dim=1)
bias = scale * self.bias.unsqueeze(0)
bias = torch.sum(bias, dim=1).flatten(0)
else:
bias = None
x = F.conv2d(x.reshape(1, -1, H, W),
weight=weight,
padding=self.K//2,
groups=B*C,
bias=bias)
return x.reshape(B, C, H, W)
class HybridTokenMixer(nn.Module): ### D-Mixer
def __init__(self,
dim,
kernel_size=3,
num_groups=2,
num_heads=1,
sr_ratio=1,
reduction_ratio=8):
super().__init__()
assert dim % 2 == 0, f"dim {dim} should be divided by 2."
self.local_unit = DynamicConv2d(
dim=dim//2, kernel_size=kernel_size, num_groups=num_groups)
self.global_unit = Attention(
dim=dim//2, num_heads=num_heads, sr_ratio=sr_ratio)
inner_dim = max(16, dim//reduction_ratio)
self.proj = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim),
nn.GELU(),
nn.BatchNorm2d(dim),
nn.Conv2d(dim, inner_dim, kernel_size=1),
nn.GELU(),
nn.BatchNorm2d(inner_dim),
nn.Conv2d(inner_dim, dim, kernel_size=1),
nn.BatchNorm2d(dim),)
def forward(self, x, relative_pos_enc=None):
x1, x2 = torch.chunk(x, chunks=2, dim=1)
x1 = self.local_unit(x1)
x2 = self.global_unit(x2, relative_pos_enc)
x = torch.cat([x1, x2], dim=1)
x = self.proj(x) + x ## STE
return x
class MultiScaleDWConv(nn.Module):
def __init__(self, dim, scale=(1, 3, 5, 7)):
super().__init__()
self.scale = scale
self.channels = []
self.proj = nn.ModuleList()
for i in range(len(scale)):
if i == 0:
channels = dim - dim // len(scale) * (len(scale) - 1)
else:
channels = dim // len(scale)
conv = nn.Conv2d(channels, channels,
kernel_size=scale[i],
padding=scale[i]//2,
groups=channels)
self.channels.append(channels)
self.proj.append(conv)
def forward(self, x):
x = torch.split(x, split_size_or_sections=self.channels, dim=1)
out = []
for i, feat in enumerate(x):
out.append(self.proj[i](feat))
x = torch.cat(out, dim=1)
return x
class Mlp(nn.Module): ### MS-FFN
"""
Mlp implemented by with 1x1 convolutions.
Input: Tensor with shape [B, C, H, W].
Output: Tensor with shape [B, C, H, W].
Args:
in_features (int): Dimension of input features.
hidden_features (int): Dimension of hidden features.
out_features (int): Dimension of output features.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop=0,):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Sequential(
nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=False),
build_activation_layer(act_cfg),
nn.BatchNorm2d(hidden_features),
)
self.dwconv = MultiScaleDWConv(hidden_features)
self.act = build_activation_layer(act_cfg)
self.norm = nn.BatchNorm2d(hidden_features)
self.fc2 = nn.Sequential(
nn.Conv2d(hidden_features, in_features, kernel_size=1, bias=False),
nn.BatchNorm2d(in_features),
)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x) + x
x = self.norm(self.act(x))
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_value=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim, 1, 1, 1)*init_value,
requires_grad=True)
self.bias = nn.Parameter(torch.zeros(dim), requires_grad=True)
def forward(self, x):
x = F.conv2d(x, weight=self.weight, bias=self.bias, groups=x.shape[1])
return x
class Block(nn.Module):
"""
Network Block.
Args:
dim (int): Embedding dim.
kernel_size (int): kernel size of dynamic conv. Defaults to 3.
num_groups (int): num_groups of dynamic conv. Defaults to 2.
num_heads (int): num_groups of self-attention. Defaults to 1.
mlp_ratio (float): Mlp expansion ratio. Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='GN', num_groups=1)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-5.
"""
def __init__(self,
dim=64,
kernel_size=3,
sr_ratio=1,
num_groups=2,
num_heads=1,
mlp_ratio=4,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
drop=0,
drop_path=0,
layer_scale_init_value=1e-5,
grad_checkpoint=False):
super().__init__()
self.grad_checkpoint = grad_checkpoint
mlp_hidden_dim = int(dim * mlp_ratio)
self.pos_embed = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.token_mixer = HybridTokenMixer(dim,
kernel_size=kernel_size,
num_groups=num_groups,
num_heads=num_heads,
sr_ratio=sr_ratio)
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop,)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
if layer_scale_init_value is not None:
self.layer_scale_1 = LayerScale(dim, layer_scale_init_value)
self.layer_scale_2 = LayerScale(dim, layer_scale_init_value)
else:
self.layer_scale_1 = nn.Identity()
self.layer_scale_2 = nn.Identity()
def _forward_impl(self, x, relative_pos_enc=None):
x = x + self.pos_embed(x)
x = x + self.drop_path(self.layer_scale_1(
self.token_mixer(self.norm1(x), relative_pos_enc)))
x = x + self.drop_path(self.layer_scale_2(self.mlp(self.norm2(x))))
return x
def forward(self, x, relative_pos_enc=None):
if self.grad_checkpoint and x.requires_grad:
x = checkpoint.checkpoint(self._forward_impl, x, relative_pos_enc)
else:
x = self._forward_impl(x, relative_pos_enc)
return x
def basic_blocks(dim,
index,
layers,
kernel_size=3,
num_groups=2,
num_heads=1,
sr_ratio=1,
mlp_ratio=4,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
drop_rate=0,
drop_path_rate=0,
layer_scale_init_value=1e-5,
grad_checkpoint=False):
blocks = nn.ModuleList()
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (
block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(
Block(
dim,
kernel_size=kernel_size,
num_groups=num_groups,
num_heads=num_heads,
sr_ratio=sr_ratio,
mlp_ratio=mlp_ratio,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop=drop_rate,
drop_path=block_dpr,
layer_scale_init_value=layer_scale_init_value,
grad_checkpoint=grad_checkpoint,
))
return blocks
class TransXNet(nn.Module):
"""
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``arch_settings``. And if dict, it
should include the following two keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
- mlp_ratios (list[int]): Expansion ratio of MLPs.
- layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 'tiny'.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
in_patch_size (int): The patch size of input image patch embedding.
Defaults to 7.
in_stride (int): The stride of input image patch embedding.
Defaults to 4.
in_pad (int): The padding of input image patch embedding.
Defaults to 2.
down_patch_size (int): The patch size of downsampling patch embedding.
Defaults to 3.
down_stride (int): The stride of downsampling patch embedding.
Defaults to 2.
down_pad (int): The padding of downsampling patch embedding.
Defaults to 1.
drop_rate (float): Dropout rate. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
grad_checkpoint (bool): Using grad checkpointing for saving memory.
checkpoint_stage (Sequence | bool): Decide which layer uses grad checkpointing.
For example, checkpoint_stage=[0,0,1,1] means that stage3 and stage4 use gd
out_indices (Sequence | int): Output from which network position.
Index 0-6 respectively corresponds to
[stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4]
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): Initialization config dict
""" # noqa: E501
# --layers: [x,x,x,x], numbers of layers for the four stages
# --embed_dims, --mlp_ratios:
# embedding dims and mlp ratios for the four stages
# --downsamples: flags to apply downsampling or not in four blocks
arch_settings = {
**dict.fromkeys(['t', 'tiny', 'T'],
{'layers': [3, 3, 9, 3],
'embed_dims': [48, 96, 224, 448],
'kernel_size': [7, 7, 7, 7],
'num_groups': [2, 2, 2, 2],
'sr_ratio': [8, 4, 2, 1],
'num_heads': [1, 2, 4, 8],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-5,}),
**dict.fromkeys(['s', 'small', 'S'],
{'layers': [4, 4, 12, 4],
'embed_dims': [64, 128, 320, 512],
'kernel_size': [7, 7, 7, 7],
'num_groups': [2, 2, 3, 4],
'sr_ratio': [8, 4, 2, 1],
'num_heads': [1, 2, 5, 8],
'mlp_ratios': [6, 6, 4, 4],
'layer_scale_init_value': 1e-5,}),
**dict.fromkeys(['b', 'base', 'B'],
{'layers': [4, 4, 21, 4],
'embed_dims': [76, 152, 336, 672],
'kernel_size': [7, 7, 7, 7],
'num_groups': [2, 2, 4, 4],
'sr_ratio': [8, 4, 2, 1],
'num_heads': [2, 4, 8, 16],
'mlp_ratios': [8, 8, 4, 4],
'layer_scale_init_value': 1e-5,}),}
def __init__(self,
image_size=224,
arch='tiny',
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
in_chans=3,
in_patch_size=7,
in_stride=4,
in_pad=3,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0,
drop_path_rate=0,
grad_checkpoint=False,
checkpoint_stage=[0] * 4,
num_classes=1000,
fork_feat=False,
start_level=0,
init_cfg=None,
pretrained=None,
**kwargs):
super().__init__()
'''
The above image_size does not need to be adjusted,
even if the image input size is not 224x224,
unless you want to change the size of the relative positional encoding
'''
if not fork_feat:
self.num_classes = num_classes
self.fork_feat = fork_feat
self.grad_checkpoint = grad_checkpoint
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'layers' in arch and 'embed_dims' in arch, \
f'The arch dict must have "layers" and "embed_dims", ' \
f'but got {list(arch.keys())}.'
layers = arch['layers']
embed_dims = arch['embed_dims']
kernel_size = arch['kernel_size']
num_groups = arch['num_groups']
sr_ratio = arch['sr_ratio']
num_heads = arch['num_heads']
if not grad_checkpoint:
checkpoint_stage = [0] * 4
mlp_ratios = arch['mlp_ratios'] if 'mlp_ratios' in arch else [4, 4, 4, 4]
layer_scale_init_value = arch['layer_scale_init_value'] if 'layer_scale_init_value' in arch else 1e-5
self.patch_embed = PatchEmbed(patch_size=in_patch_size,
stride=in_stride,
padding=in_pad,
in_chans=in_chans,
embed_dim=embed_dims[0])
self.relative_pos_enc = []
self.pos_enc_record = []
image_size = to_2tuple(image_size)
image_size = [math.ceil(image_size[0]/in_stride),
math.ceil(image_size[1]/in_stride)]
for i in range(4):
num_patches = image_size[0]*image_size[1]
sr_patches = math.ceil(
image_size[0]/sr_ratio[i])*math.ceil(image_size[1]/sr_ratio[i])
self.relative_pos_enc.append(
nn.Parameter(torch.zeros(1, num_heads[i], num_patches, sr_patches), requires_grad=True))
self.pos_enc_record.append([image_size[0], image_size[1],
math.ceil(image_size[0]/sr_ratio[i]),
math.ceil(image_size[1]/sr_ratio[i]),])
image_size = [math.ceil(image_size[0]/2),
math.ceil(image_size[1]/2)]
self.relative_pos_enc = nn.ParameterList(self.relative_pos_enc)
# self.relative_pos_enc = [None] * 4
# set the main block in network
network = []
for i in range(len(layers)):
stage = basic_blocks(
embed_dims[i],
i,
layers,
kernel_size=kernel_size[i],
num_groups=num_groups[i],
num_heads=num_heads[i],
sr_ratio=sr_ratio[i],
mlp_ratio=mlp_ratios[i],
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
grad_checkpoint=checkpoint_stage[i],)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
network.append(
PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i],
embed_dim=embed_dims[i+1]))
self.network = nn.ModuleList(network)
if self.fork_feat:
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb < start_level:
layer = nn.Identity()
else:
layer = build_norm_layer(norm_cfg, embed_dims[(i_layer + 1) // 2])[1]
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
# Classifier
self.classifier = nn.Sequential(
build_norm_layer(norm_cfg, embed_dims[-1])[1],
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(embed_dims[-1], num_classes, kernel_size=1),
) if num_classes > 0 else nn.Identity()
self.apply(self._init_model_weights)
self.init_cfg = copy.deepcopy(init_cfg)
# load pre-trained model
if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
self.init_weights()
self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
self.train()
# init for image classification
def _init_model_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.GroupNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
'''
init for mmdetection or mmsegmentation
by loading imagenet pre-trained weights
'''
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
missing_keys, unexpected_keys = self.load_state_dict(state_dict, False)
# show for debug
print('missing_keys: ', missing_keys)
print('unexpected_keys: ', unexpected_keys)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes):
self.num_classes = num_classes
if num_classes > 0:
self.classifier[-1].out_channels = num_classes
else:
self.classifier = nn.Identity()
def forward_embeddings(self, x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
outs = []
pos_idx = 0
for idx in range(len(self.network)):
if idx in [0, 2, 4, 6]:
for blk in self.network[idx]:
x = blk(x, self.relative_pos_enc[pos_idx])
pos_idx += 1
else:
x = self.network[idx](x)
if self.fork_feat and (idx in self.out_indices):
x_out = getattr(self, f'norm{idx}')(x)
outs.append(x_out)
if self.fork_feat:
# output the features of four stages for dense prediction
return outs
# output only the features of last layer for image classification
return x
def forward(self, x):
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
if self.fork_feat:
# features of four stages for dense prediction
return x
else:
# for image classification
x = self.classifier(x).flatten(1)
return x
@det_BACKBONES.register_module()
def transxnet_t(pretrained=False, init_cfg=None, **kwargs):
if pretrained:
init_cfg=dict(type='Pretrained',
checkpoint='https://github.com/LMMMEng/TransXNet/releases/download/v1.0/transx-t.pth.tar',)
model = TransXNet(arch='t', fork_feat=True, init_cfg=init_cfg, **kwargs)
return model
@det_BACKBONES.register_module()
def transxnet_s(pretrained=False, init_cfg=None, **kwargs):
if pretrained:
init_cfg=dict(type='Pretrained',
checkpoint='https://github.com/LMMMEng/TransXNet/releases/download/v1.0/transx-s.pth.tar',)
model = TransXNet(arch='s', fork_feat=True, init_cfg=init_cfg, **kwargs)
return model
@det_BACKBONES.register_module()
def transxnet_b(pretrained=False, init_cfg=None, **kwargs):
if pretrained:
init_cfg=dict(type='Pretrained',
checkpoint='https://github.com/LMMMEng/TransXNet/releases/download/v1.0/transx-b.pth.tar',)
model = TransXNet(arch='b', fork_feat=True, init_cfg=init_cfg, **kwargs)
return model