-
Notifications
You must be signed in to change notification settings - Fork 0
/
probabilistic_unet.py
273 lines (221 loc) · 12.1 KB
/
probabilistic_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
#This code is based on: https://github.com/SimonKohl/probabilistic_unet
from unet_blocks import *
from unet import *
import torch.nn.functional as F
from torch.distributions import Normal, Independent, kl
from utils import FocalLoss2d
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Encoder(nn.Module):
"""
A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers,
after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.
"""
def __init__(self, input_channels, num_filters, no_convs_per_block, padding=True, posterior=False):
super(Encoder, self).__init__()
self.contracting_path = nn.ModuleList()
self.input_channels = input_channels
self.num_filters = num_filters
if posterior:
#To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
self.input_channels += 1
layers = []
for i in range(len(self.num_filters)):
"""
Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
All the subsequent layers are output x output.
"""
input_dim = self.input_channels if i == 0 else output_dim
output_dim = num_filters[i]
if i != 0:
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
layers.append(nn.ReLU(inplace=True))
for _ in range(no_convs_per_block-1):
layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
def forward(self, x):
out = self.layers(x)
return out
class AxisAlignedConvGaussian(nn.Module):
"""
A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
"""
def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, posterior=False):
super(AxisAlignedConvGaussian, self).__init__()
self.input_channels = input_channels
self.channel_axis = 1
self.num_filters = num_filters
self.no_convs_per_block = no_convs_per_block
self.latent_dim = latent_dim
self.posterior = posterior
if self.posterior:
self.name = 'Posterior'
else:
self.name = 'Prior'
self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, posterior=self.posterior)
self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1)
self.show_img = 0
self.show_seg = 0
self.show_concat = 0
self.show_enc = 0
self.sum_input = 0
nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.normal_(self.conv_layer.bias)
def forward(self, inputs, segm=None):
#If segmentation is not none, concatenate the mask to the channel axis of the input
if segm is not None:
self.show_img = inputs
self.show_seg = segm
inputs = torch.cat((inputs, segm), dim=1)
self.show_concat = inputs
self.sum_input = torch.sum(inputs)
encoding = self.encoder(inputs)
self.show_enc = encoding
#We only want the mean of the resulting hxw image
encoding = torch.mean(encoding, dim=[2,3], keepdim=True)
# encoding = torch.mean(encoding, dim=3, keepdim=True)
#Convert encoding to 2 x latent dim and split up for mu and log_sigma
mu_log_sigma = self.conv_layer(encoding)
#We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
mu = mu_log_sigma[:,:self.latent_dim]
log_sigma = mu_log_sigma[:,self.latent_dim:]
#This is a multivariate normal with diagonal covariance matrix sigma
#https://github.com/pytorch/pytorch/pull/11178
dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1)
return dist
class Fcomb(nn.Module):
"""
A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
and output of the UNet (the feature map) by concatenating them along their channel axis.
"""
def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, use_tile=True):
super(Fcomb, self).__init__()
self.num_channels = num_output_channels #output channels
self.num_classes = num_classes
self.channel_axis = 1
self.spatial_axes = [2,3]
self.num_filters = num_filters
self.latent_dim = latent_dim
self.use_tile = use_tile
self.no_convs_fcomb = no_convs_fcomb
self.name = 'Fcomb'
if self.use_tile:
layers = []
#Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
layers.append(nn.ReLU(inplace=True))
for _ in range(no_convs_fcomb-2):
layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
def forward(self, feature_map, z):
"""
Z is batch_sizexlatent_dim and feature_map is batch_sizexno_channelsxHxW.
So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
"""
if self.use_tile:
z = z.view(-1, self.latent_dim, 1, 1) * torch.ones([z.shape[0], self.latent_dim, feature_map.shape[2], feature_map.shape[3]]).to('cuda')
#Concatenate the feature map (output of the UNet) and the sample taken from the latent space
feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
out = self.layers(feature_map)
return self.last_layer(out)
class ProbabilisticUnet(nn.Module):
"""
A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation.
input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
num_classes: the number of classes to predict
num_filters: is a list consisint of the amount of filters layer
latent_dim: dimension of the latent space
no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
"""
def __init__(self, input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=6, no_convs_fcomb=4, beta=10.0):
super(ProbabilisticUnet, self).__init__()
self.input_channels = input_channels
self.num_classes = num_classes
self.num_filters = num_filters
self.latent_dim = latent_dim
self.no_convs_per_block = 3
self.no_convs_fcomb = no_convs_fcomb
self.beta = beta
self.z_prior_sample = 0
self.unet = Unet(self.input_channels, self.num_classes, self.num_filters, apply_last_layer=False).to(device)
self.prior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim).to(device)
self.posterior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, posterior=True).to(device)
self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, use_tile=True).to(device)
self.criterion = FocalLoss2d(gamma=0.5)
# self.sigmoid = nn.Sigmoid()
def forward(self, patch, segm, training=True):
"""
Construct prior latent space for patch and run patch through UNet,
in case training is True also construct posterior latent space
"""
# if training:
self.posterior_latent_space = self.posterior.forward(patch, segm)
self.prior_latent_space = self.prior.forward(patch)
self.unet_features = self.unet.forward(patch,False)
def sample(self, testing=False):
"""
Sample a segmentation by reconstructing from a prior sample
and combining this with UNet features
"""
if testing == False:
z_prior = self.prior_latent_space.rsample()
self.z_prior_sample = z_prior
else:
#You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
#z_prior = self.prior_latent_space.base_dist.loc
z_prior = self.prior_latent_space.sample()
self.z_prior_sample = z_prior
# print(z_prior.shape)
self.z_prior_prob = torch.exp(self.prior_latent_space.log_prob(z_prior))
return self.fcomb.forward(self.unet_features,z_prior)
def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
"""
Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
use_posterior_mean: use posterior_mean instead of sampling z_q
calculate_posterior: use a provided sample or sample from posterior latent space
"""
if use_posterior_mean:
z_posterior = self.posterior_latent_space.loc
else:
if calculate_posterior:
z_posterior = self.posterior_latent_space.rsample()
self.recon_seg = self.fcomb.forward(self.unet_features, z_posterior)
return self.recon_seg
def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
"""
Calculate the KL divergence between the posterior and prior KL(Q||P)
analytic: calculate KL analytically or via sampling from the posterior
calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
"""
if analytic:
#Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
else:
if calculate_posterior:
z_posterior = self.posterior_latent_space.rsample()
log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
kl_div = log_posterior_prob - log_prior_prob
return kl_div
def elbo(self, segm, mask=None, analytic_kl=True, reconstruct_posterior_mean=False):
"""
Calculate the evidence lower bound of the log-likelihood of P(Y|X)
"""
z_posterior = self.posterior_latent_space.rsample()
self.kl = torch.mean(self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior))
#Here we use the posterior sample sampled above
self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior)
reconstruction_loss, _ = self.criterion(input=self.reconstruction, target=segm)
if mask is None:
self.reconstruction_loss = torch.sum(reconstruction_loss) / segm.shape[0] / self.reconstruction.shape[1]
else:
reconstruction_loss = reconstruction_loss * mask[:,0,...]
mask_sum = mask.mean()
self.reconstruction_loss = reconstruction_loss.sum() / (mask_sum + 1e-5)
# self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
return {'seg_loss':self.reconstruction_loss, 'kl':self.beta * self.kl}