-
Notifications
You must be signed in to change notification settings - Fork 30
/
modules.py
145 lines (117 loc) · 5.33 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import numpy as np
from PIL import Image
from torch.utils import data
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
import torchvision.transforms as transforms
class Dataset(data.Dataset):
"Characterizes a dataset for PyTorch"
def __init__(self, filenames, labels, transform=None):
"Initialization"
self.filenames = filenames
self.labels = labels
self.transform = transform
def __len__(self):
"Denotes the total number of samples"
return len(self.filenames)
def __getitem__(self, index):
"Generates one sample of data"
# Select sample
filename = self.filenames[index]
X = Image.open(filename)
if self.transform:
X = self.transform(X) # transform
y = torch.LongTensor([self.labels[index]])
return X, y
## ---------------------- end of Dataloaders ---------------------- ##
def conv2D_output_size(img_size, padding, kernel_size, stride):
# compute output shape of conv2D
outshape = (np.floor((img_size[0] + 2 * padding[0] - (kernel_size[0] - 1) - 1) / stride[0] + 1).astype(int),
np.floor((img_size[1] + 2 * padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1).astype(int))
return outshape
def convtrans2D_output_size(img_size, padding, kernel_size, stride):
# compute output shape of conv2D
outshape = ((img_size[0] - 1) * stride[0] - 2 * padding[0] + kernel_size[0],
(img_size[1] - 1) * stride[1] - 2 * padding[1] + kernel_size[1])
return outshape
## ---------------------- ResNet VAE ---------------------- ##
class ResNet_VAE(nn.Module):
def __init__(self, fc_hidden1=1024, fc_hidden2=768, drop_p=0.3, CNN_embed_dim=256):
super(ResNet_VAE, self).__init__()
self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim
# CNN architechtures
self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3) # 2d kernal size
self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2) # 2d strides
self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0) # 2d padding
# encoding components
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-1] # delete the last fc layer.
self.resnet = nn.Sequential(*modules)
self.fc1 = nn.Linear(resnet.fc.in_features, self.fc_hidden1)
self.bn1 = nn.BatchNorm1d(self.fc_hidden1, momentum=0.01)
self.fc2 = nn.Linear(self.fc_hidden1, self.fc_hidden2)
self.bn2 = nn.BatchNorm1d(self.fc_hidden2, momentum=0.01)
# Latent vectors mu and sigma
self.fc3_mu = nn.Linear(self.fc_hidden2, self.CNN_embed_dim) # output = CNN embedding latent variables
self.fc3_logvar = nn.Linear(self.fc_hidden2, self.CNN_embed_dim) # output = CNN embedding latent variables
# Sampling vector
self.fc4 = nn.Linear(self.CNN_embed_dim, self.fc_hidden2)
self.fc_bn4 = nn.BatchNorm1d(self.fc_hidden2)
self.fc5 = nn.Linear(self.fc_hidden2, 64 * 4 * 4)
self.fc_bn5 = nn.BatchNorm1d(64 * 4 * 4)
self.relu = nn.ReLU(inplace=True)
# Decoder
self.convTrans6 = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=self.k4, stride=self.s4,
padding=self.pd4),
nn.BatchNorm2d(32, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans7 = nn.Sequential(
nn.ConvTranspose2d(in_channels=32, out_channels=8, kernel_size=self.k3, stride=self.s3,
padding=self.pd3),
nn.BatchNorm2d(8, momentum=0.01),
nn.ReLU(inplace=True),
)
self.convTrans8 = nn.Sequential(
nn.ConvTranspose2d(in_channels=8, out_channels=3, kernel_size=self.k2, stride=self.s2,
padding=self.pd2),
nn.BatchNorm2d(3, momentum=0.01),
nn.Sigmoid() # y = (y1, y2, y3) \in [0 ,1]^3
)
def encode(self, x):
x = self.resnet(x) # ResNet
x = x.view(x.size(0), -1) # flatten output of conv
# FC layers
x = self.bn1(self.fc1(x))
x = self.relu(x)
x = self.bn2(self.fc2(x))
x = self.relu(x)
# x = F.dropout(x, p=self.drop_p, training=self.training)
mu, logvar = self.fc3_mu(x), self.fc3_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
if self.training:
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
return eps.mul(std).add_(mu)
else:
return mu
def decode(self, z):
x = self.relu(self.fc_bn4(self.fc4(z)))
x = self.relu(self.fc_bn5(self.fc5(x))).view(-1, 64, 4, 4)
x = self.convTrans6(x)
x = self.convTrans7(x)
x = self.convTrans8(x)
x = F.interpolate(x, size=(224, 224), mode='bilinear')
return x
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_reconst = self.decode(z)
return x_reconst, z, mu, logvar