-
Notifications
You must be signed in to change notification settings - Fork 7
/
dense_ed.py
220 lines (193 loc) · 9.75 KB
/
dense_ed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
"""
Convolutional Dense Encoder-Decoder Networks
Reference:
https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
Yinhao Zhu
Dec 21, 2017
Dec 30, 2017
Jan 03, 2018
Shaoxing Mo
May 07, 2019
"""
class _DenseLayer(nn.Sequential):
# bottleneck layer, bn_size: bottleneck size
def __init__(self, in_features, growth_rate, drop_rate=0, bn_size=4,
bottleneck=False):
# detect if the input features are more than bn_size x k,
# if yes, use bottleneck -- not much memory gain, but lose one relu
# I disabled the bottleneck for current implementation
super(_DenseLayer, self).__init__()
if bottleneck and in_features > bn_size * growth_rate:
self.add_module('norm1', nn.BatchNorm2d(in_features))
self.add_module('relu1', nn.ReLU(inplace=True))
self.add_module('conv1', nn.Conv2d(in_features, bn_size *
growth_rate, kernel_size=1, stride=1, bias=False))
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
self.add_module('relu2', nn.ReLU(inplace=True))
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False))
else:
self.add_module('norm1', nn.BatchNorm2d(in_features))
self.add_module('relu1', nn.ReLU(inplace=True))
self.add_module('conv1', nn.Conv2d(in_features, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False))
self.drop_rate = drop_rate
def forward(self, x):
y = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
y = F.dropout2d(y, p=self.drop_rate, training=self.training)
z = torch.cat([x, y], 1)
return z
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, in_features, growth_rate, drop_rate,
bn_size=4, bottleneck=False):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(in_features + i * growth_rate, growth_rate,
drop_rate=drop_rate, bn_size=bn_size,
bottleneck=bottleneck)
self.add_module('denselayer%d' % (i + 1), layer)
class _Transition(nn.Sequential):
def __init__(self, in_features, out_features, encoding=True, drop_rate=0.,
last=False, out_channels=3, outsize_even=True):
super(_Transition, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(in_features))
self.add_module('relu1', nn.ReLU(inplace=True))
if encoding:
# reduce feature maps; half image size (input feature size is even)
# bottleneck impl, save memory, add nonlinearity
self.add_module('conv1', nn.Conv2d(in_features, out_features,
kernel_size=1, stride=1,
padding=0, bias=False))
if drop_rate > 0:
self.add_module('dropout1', nn.Dropout2d(p=drop_rate))
self.add_module('norm2', nn.BatchNorm2d(out_features))
self.add_module('relu2', nn.ReLU(inplace=True))
self.add_module('conv2', nn.Conv2d(out_features, out_features,
kernel_size=3, stride=2,
padding=1, bias=False))
if drop_rate > 0:
self.add_module('dropout2', nn.Dropout2d(p=drop_rate))
else:
# decoding, transition up
if last:
ks = 6 if outsize_even else 5
out_convt = nn.ConvTranspose2d(out_features, out_channels,
kernel_size=ks, stride=2, padding=2, bias=False)
else:
out_convt = nn.ConvTranspose2d(
out_features, out_features, kernel_size=3, stride=2,
padding=1, output_padding=0, bias=False)
# bottleneck impl, save memory, add nonlinearity
self.add_module('conv1', nn.Conv2d(in_features, out_features,
kernel_size=1, stride=1,
padding=0, bias=False))
if drop_rate > 0:
self.add_module('dropout1', nn.Dropout2d(p=drop_rate))
self.add_module('norm2', nn.BatchNorm2d(out_features))
self.add_module('relu2', nn.ReLU(inplace=True))
self.add_module('convT2', out_convt)
if drop_rate > 0:
self.add_module('dropout2', nn.Dropout2d(p=drop_rate))
class DenseED(nn.Module):
def __init__(self, in_channels, out_channels, blocks, growth_rate=16,
num_init_features=64, bn_size=4, drop_rate=0, outsize_even=False,
bottleneck=False):
"""
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
blocks: list (of odd size) of integers
growth_rate (int): K
num_init_features (int): the number of feature maps after the first
conv layer
bn_size: bottleneck size for number of feature maps (not useful...)
bottleneck (bool): use bottleneck for dense block or not
drop_rate (float): dropout rate
outsize_even (bool): if the output size is even or odd (e.g.
65 x 65 is odd, 64 x 64 is even)
"""
super(DenseED, self).__init__()
self.out_channels = out_channels
if len(blocks) > 1 and len(blocks) % 2 == 0:
ValueError('length of blocks must be an odd number, but got {}'
.format(len(blocks)))
enc_block_layers = blocks[: len(blocks) // 2]
dec_block_layers = blocks[len(blocks) // 2:]
self.features = nn.Sequential()
# First convolution ================
# only conv, half image size
self.features.add_module('in_conv',
nn.Conv2d(in_channels, num_init_features,
kernel_size=7, stride=2, padding=3, bias=False))
# Encoding / transition down ================
# dense block --> encoding --> dense block --> encoding
num_features = num_init_features
for i, num_layers in enumerate(enc_block_layers):
block = _DenseBlock(num_layers=num_layers,
in_features=num_features,
bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate, bottleneck=bottleneck)
self.features.add_module('encblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
trans = _Transition(in_features=num_features,
out_features=num_features // 2,
encoding=True, drop_rate=drop_rate)
self.features.add_module('down%d' % (i + 1), trans)
num_features = num_features // 2
# Decoding / transition up ==============
# dense block --> decoding --> dense block --> decoding --> dense block
# if len(dec_block_layers) - len(enc_block_layers) == 1:
for i, num_layers in enumerate(dec_block_layers):
block = _DenseBlock(num_layers=num_layers,
in_features=num_features,
bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate, bottleneck=bottleneck)
self.features.add_module('decblock%d' % (i + 1), block)
num_features += num_layers * growth_rate
# if this is the last decoding layer is the output layer
last_layer = True if i == len(dec_block_layers) - 1 else False
trans = _Transition(in_features=num_features,
out_features=num_features // 2,
encoding=False, drop_rate=drop_rate,
last=last_layer, out_channels=out_channels,
outsize_even=outsize_even)
self.features.add_module('up%d' % (i + 1), trans)
num_features = num_features // 2
def forward(self, x):
y = self.features(x)
# use the softplus activation for concentration
y[:,:self.out_channels-1] = F.softplus(y[:,:self.out_channels-1].clone(), beta=5)
# in the example, pressure is the last output channel
# use the sigmoid activation for pressure
y[:,self.out_channels-1] = torch.sigmoid(y[:,self.out_channels-1])
return y
def _num_parameters_convlayers(self):
n_params, n_conv_layers = 0, 0
for name, param in self.named_parameters():
if 'conv' in name:
n_conv_layers += 1
n_params += param.numel()
return n_params, n_conv_layers
def _count_parameters(self):
n_params = 0
for name, param in self.named_parameters():
print(name)
print(param.size())
print(param.numel())
n_params += param.numel()
print('num of parameters so far: {}'.format(n_params))
def reset_parameters(self, verbose=False):
for module in self.modules():
# pass self, otherwise infinite loop
if isinstance(module, self.__class__):
continue
if 'reset_parameters' in dir(module):
if callable(module.reset_parameters):
module.reset_parameters()
if verbose:
print("Reset parameters in {}".format(module))