-
Notifications
You must be signed in to change notification settings - Fork 14
/
model.py
131 lines (112 loc) · 6.11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn.functional as F
from torch import nn
from resnext import ResNeXt101
class DAF(nn.Module):
def __init__(self):
super(DAF, self).__init__()
self.resnext = ResNeXt101(32)
self.down4 = nn.Sequential(
nn.Conv2d(2048, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.down3 = nn.Sequential(
nn.Conv2d(1024, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.down2 = nn.Sequential(
nn.Conv2d(512, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.down1 = nn.Sequential(
nn.Conv2d(256, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.fuse1 = nn.Sequential(
nn.Conv2d(256, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.attention4 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.Softmax2d()
)
self.attention3 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.Softmax2d()
)
self.attention2 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.Softmax2d()
)
self.attention1 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.Softmax2d()
)
self.refine4 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.refine3 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.refine2 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.refine1 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=1), nn.BatchNorm2d(64), nn.PReLU()
)
self.predict4 = nn.Conv2d(64, 1, kernel_size=1)
self.predict3 = nn.Conv2d(64, 1, kernel_size=1)
self.predict2 = nn.Conv2d(64, 1, kernel_size=1)
self.predict1 = nn.Conv2d(64, 1, kernel_size=1)
self.predict4_2 = nn.Conv2d(64, 1, kernel_size=1)
self.predict3_2 = nn.Conv2d(64, 1, kernel_size=1)
self.predict2_2 = nn.Conv2d(64, 1, kernel_size=1)
self.predict1_2 = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, x):
layer0 = self.resnext.layer0(x)
layer1 = self.resnext.layer1(layer0)
layer2 = self.resnext.layer2(layer1)
layer3 = self.resnext.layer3(layer2)
layer4 = self.resnext.layer4(layer3)
down4 = F.upsample(self.down4(layer4), size=layer1.size()[2:], mode='bilinear')
down3 = F.upsample(self.down3(layer3), size=layer1.size()[2:], mode='bilinear')
down2 = F.upsample(self.down2(layer2), size=layer1.size()[2:], mode='bilinear')
down1 = self.down1(layer1)
predict4 = self.predict4(down4)
predict3 = self.predict3(down3)
predict2 = self.predict2(down2)
predict1 = self.predict1(down1)
fuse1 = self.fuse1(torch.cat((down4, down3, down2, down1), 1))
attention4 = self.attention4(torch.cat((down4, fuse1), 1))
attention3 = self.attention3(torch.cat((down3, fuse1), 1))
attention2 = self.attention2(torch.cat((down2, fuse1), 1))
attention1 = self.attention1(torch.cat((down1, fuse1), 1))
refine4 = self.refine4(torch.cat((down4, attention4 * fuse1), 1))
refine3 = self.refine3(torch.cat((down3, attention3 * fuse1), 1))
refine2 = self.refine2(torch.cat((down2, attention2 * fuse1), 1))
refine1 = self.refine1(torch.cat((down1, attention1 * fuse1), 1))
predict4_2 = self.predict4_2(refine4)
predict3_2 = self.predict3_2(refine3)
predict2_2 = self.predict2_2(refine2)
predict1_2 = self.predict1_2(refine1)
predict1 = F.upsample(predict1, size=x.size()[2:], mode='bilinear')
predict2 = F.upsample(predict2, size=x.size()[2:], mode='bilinear')
predict3 = F.upsample(predict3, size=x.size()[2:], mode='bilinear')
predict4 = F.upsample(predict4, size=x.size()[2:], mode='bilinear')
predict1_2 = F.upsample(predict1_2, size=x.size()[2:], mode='bilinear')
predict2_2 = F.upsample(predict2_2, size=x.size()[2:], mode='bilinear')
predict3_2 = F.upsample(predict3_2, size=x.size()[2:], mode='bilinear')
predict4_2 = F.upsample(predict4_2, size=x.size()[2:], mode='bilinear')
if self.training:
return predict1, predict2, predict3, predict4, predict1_2, predict2_2, predict3_2, predict4_2
else:
return F.sigmoid((predict1_2 + predict2_2 + predict3_2 + predict4_2) / 4)