-
Notifications
You must be signed in to change notification settings - Fork 70
/
scnet.py
342 lines (299 loc) · 13.9 KB
/
scnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Jiang-Jiang Liu
## Email: [email protected]
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""SCNet variants"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
__all__ = ['SCNet', 'scnet50', 'scnet101', 'scnet50_v1d', 'scnet101_v1d']
model_urls = {
'scnet50': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50-dc6a7e87.pth',
'scnet50_v1d': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet50_v1d-4109d1e1.pth',
'scnet101': 'https://backseason.oss-cn-beijing.aliyuncs.com/scnet/scnet101-44c5b751.pth',
# 'scnet101_v1d': coming soon...
}
class SCConv(nn.Module):
def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer):
super(SCConv, self).__init__()
self.k2 = nn.Sequential(
nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
nn.Conv2d(inplanes, planes, kernel_size=3, stride=1,
padding=padding, dilation=dilation,
groups=groups, bias=False),
norm_layer(planes),
)
self.k3 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=3, stride=1,
padding=padding, dilation=dilation,
groups=groups, bias=False),
norm_layer(planes),
)
self.k4 = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=padding, dilation=dilation,
groups=groups, bias=False),
norm_layer(planes),
)
def forward(self, x):
identity = x
out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2)
out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2)
out = self.k4(out) # k4
return out
class SCBottleneck(nn.Module):
"""SCNet SCBottleneck
"""
expansion = 4
pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv.
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, bottleneck_width=32,
avd=False, dilation=1, is_first=False,
norm_layer=None):
super(SCBottleneck, self).__init__()
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
self.conv1_a = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1_a = norm_layer(group_width)
self.conv1_b = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1_b = norm_layer(group_width)
self.avd = avd and (stride > 1 or is_first)
if self.avd:
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
stride = 1
self.k1 = nn.Sequential(
nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, bias=False),
norm_layer(group_width),
)
self.scconv = SCConv(
group_width, group_width, stride=stride,
padding=dilation, dilation=dilation,
groups=cardinality, pooling_r=self.pooling_r, norm_layer=norm_layer)
self.conv3 = nn.Conv2d(
group_width * 2, planes * 4, kernel_size=1, bias=False)
self.bn3 = norm_layer(planes*4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.dilation = dilation
self.stride = stride
def forward(self, x):
residual = x
out_a= self.conv1_a(x)
out_a = self.bn1_a(out_a)
out_b = self.conv1_b(x)
out_b = self.bn1_b(out_b)
out_a = self.relu(out_a)
out_b = self.relu(out_b)
out_a = self.k1(out_a)
out_b = self.scconv(out_b)
out_a = self.relu(out_a)
out_b = self.relu(out_b)
if self.avd:
out_a = self.avd_layer(out_a)
out_b = self.avd_layer(out_b)
out = self.conv3(torch.cat([out_a, out_b], dim=1))
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SCNet(nn.Module):
""" SCNet Variants Definations
Parameters
----------
block : Block
Class for the residual block.
layers : list of int
Numbers of layers in each block.
classes : int, default 1000
Number of classification classes.
dilated : bool, default False
Applying dilation strategy to pretrained SCNet yielding a stride-8 model.
deep_stem : bool, default False
Replace 7x7 conv in input stem with 3 3x3 conv.
avg_down : bool, default False
Use AvgPool instead of stride conv when
downsampling in the bottleneck.
norm_layer : object
Normalization layer used (default: :class:`torch.nn.BatchNorm2d`).
Reference:
- He, Kaiming, et al. "Deep residual learning for image recognition."
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
def __init__(self, block, layers, groups=1, bottleneck_width=32,
num_classes=1000, dilated=False, dilation=1,
deep_stem=False, stem_width=64, avg_down=False,
avd=False, norm_layer=nn.BatchNorm2d):
self.cardinality = groups
self.bottleneck_width = bottleneck_width
# ResNet-D params
self.inplanes = stem_width*2 if deep_stem else 64
self.avg_down = avg_down
self.avd = avd
super(SCNet, self).__init__()
conv_layer = nn.Conv2d
if deep_stem:
self.conv1 = nn.Sequential(
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(inplace=True),
conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False),
)
else:
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
if dilated or dilation == 4:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
dilation=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=4, norm_layer=norm_layer)
elif dilation==2:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilation=1, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=2, norm_layer=norm_layer)
else:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, norm_layer):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
is_first=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
down_layers = []
if self.avg_down:
if dilation == 1:
down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
ceil_mode=True, count_include_pad=False))
else:
down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
ceil_mode=True, count_include_pad=False))
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False))
else:
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False))
down_layers.append(norm_layer(planes * block.expansion))
downsample = nn.Sequential(*down_layers)
layers = []
if dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, dilation=1, is_first=is_first,
norm_layer=norm_layer))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, dilation=2, is_first=is_first,
norm_layer=norm_layer))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,
cardinality=self.cardinality,
bottleneck_width=self.bottleneck_width,
avd=self.avd, dilation=dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def scnet50(pretrained=False, **kwargs):
"""Constructs a SCNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SCNet(SCBottleneck, [3, 4, 6, 3],
deep_stem=False, stem_width=32, avg_down=False,
avd=False, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['scnet50']))
return model
def scnet50_v1d(pretrained=False, **kwargs):
"""Constructs a SCNet-50_v1d model described in
`Bag of Tricks <https://arxiv.org/pdf/1812.01187.pdf>`_.
`ResNeSt: Split-Attention Networks <https://arxiv.org/pdf/2004.08955.pdf>`_.
Compared with default SCNet(SCNetv1b), SCNetv1d replaces the 7x7 conv
in the input stem with three 3x3 convs. And in the downsampling block,
a 3x3 avg_pool with stride 2 is added before conv, whose stride is
changed to 1.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SCNet(SCBottleneck, [3, 4, 6, 3],
deep_stem=True, stem_width=32, avg_down=True,
avd=True, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['scnet50_v1d']))
return model
def scnet101(pretrained=False, **kwargs):
"""Constructs a SCNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SCNet(SCBottleneck, [3, 4, 23, 3],
deep_stem=False, stem_width=64, avg_down=False,
avd=False, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['scnet101']))
return model
def scnet101_v1d(pretrained=False, **kwargs):
"""Constructs a SCNet-101_v1d model described in
`Bag of Tricks <https://arxiv.org/pdf/1812.01187.pdf>`_.
`ResNeSt: Split-Attention Networks <https://arxiv.org/pdf/2004.08955.pdf>`_.
Compared with default SCNet(SCNetv1b), SCNetv1d replaces the 7x7 conv
in the input stem with three 3x3 convs. And in the downsampling block,
a 3x3 avg_pool with stride 2 is added before conv, whose stride is
changed to 1.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SCNet(SCBottleneck, [3, 4, 23, 3],
deep_stem=True, stem_width=64, avg_down=True,
avd=True, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['scnet101_v1d']))
return model
if __name__ == '__main__':
images = torch.rand(1, 3, 224, 224).cuda(0)
model = scnet101(pretrained=True)
model = model.cuda(0)
print(model(images).size())