-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupernet.py
116 lines (97 loc) · 4.81 KB
/
upernet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import torch.nn.functional as F
#from torchvision import models
#from base import BaseModel
#from utils.helpers import initialize_weights
from itertools import chain
from swin_transformer_pytorch import SwinTransformer
class PSPModule(nn.Module):
# In the original inmplementation they use precise RoI pooling
# Instead of using adaptative average pooling
def __init__(self, in_channels, bin_sizes=[1, 2, 3, 6]):
super(PSPModule, self).__init__()
out_channels = in_channels // len(bin_sizes)
self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
for b_s in bin_sizes])
self.bottleneck = nn.Sequential(
nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(0.1)
)
def _make_stages(self, in_channels, out_channels, bin_sz):
prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
bn = nn.BatchNorm2d(out_channels)
relu = nn.ReLU(inplace=True)
return nn.Sequential(prior, conv, bn, relu)
def forward(self, features):
h, w = features.size()[2], features.size()[3]
pyramids = [features]
pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
align_corners=True) for stage in self.stages])
output = self.bottleneck(torch.cat(pyramids, dim=1))
return output
def up_and_add(x, y):
return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
class FPN_fuse(nn.Module):
def __init__(self, feature_channels=[48, 96, 192, 384], fpn_out=48):
super(FPN_fuse, self).__init__()
assert feature_channels[0] == fpn_out
self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
for ft_size in feature_channels[1:]])
self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
* (len(feature_channels)-1))
self.conv_fusion = nn.Sequential(
nn.Conv2d(len(feature_channels)*fpn_out, fpn_out, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(fpn_out),
nn.ReLU(inplace=True)
)
def forward(self, features):
features[1:] = [conv1x1(feature) for feature, conv1x1 in zip(features[1:], self.conv1x1)]
P = [up_and_add(features[i], features[i-1]) for i in reversed(range(1, len(features)))]
P = [smooth_conv(x) for smooth_conv, x in zip(self.smooth_conv, P)]
P = list(reversed(P))
P.append(features[-1]) #P = [P1, P2, P3, P4]
H, W = P[0].size(2), P[0].size(3)
P[1:] = [F.interpolate(feature, size=(H, W), mode='bilinear', align_corners=True) for feature in P[1:]]
x = self.conv_fusion(torch.cat((P), dim=1))
return x
class UperNet(nn.Module):
# Implementing only the object path
def __init__(self, num_classes=6, use_aux=True, fpn_out=48, freeze_bn=False, **_):
super(UperNet, self).__init__()
feature_channels = [48, 96, 192, 384]
self.backbone = SwinTransformer(
hidden_dim=48, # channel
layers=(2, 2, 6, 2), # depth
heads=(3, 6, 12, 24), # in layer
channels=3,
num_classes=6,
head_dim=32, # in layer
window_size=8,
downscaling_factors=(2, 2, 2, 2),
relative_pos_embedding=True
)
self.PPN = PSPModule(feature_channels[-1])
self.FPN = FPN_fuse(feature_channels, fpn_out=fpn_out)
self.head = nn.Conv2d(fpn_out, num_classes, kernel_size=3, padding=1)
if freeze_bn: self.freeze_bn()
#if freeze_backbone:
# set_trainable([self.backbone], False)
def forward(self, x):
input_size = (x.size()[-2], x.size()[-1])
features = self.backbone(x)
features[-1] = self.PPN(features[-1])
x = self.head(self.FPN(features))
x = F.interpolate(x, size=input_size, mode='bilinear')
return x
def get_backbone_params(self):
return self.backbone.parameters()
def get_decoder_params(self):
return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters())
def freeze_bn(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d): module.eval()