-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstyle_transfer.py
283 lines (221 loc) · 8.89 KB
/
style_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def img_loader(img, device):
"""
This function loads and pre-porcesses images for
style transfer needs.
"""
PIL_img = Image.open(img)
max_dim = max(PIL_img.size)
max_allow_dim = 256
if max_dim > max_allow_dim:
scale_ratio = max_allow_dim / max_dim
loader = transforms.Compose([transforms.Resize(
(int(PIL_img.size[1] * scale_ratio),
int(PIL_img.size[0] * scale_ratio))),
transforms.ToTensor()])
else:
loader = transforms.ToTensor()
tensor_img = loader(PIL_img).unsqueeze(0).to(device)
return tensor_img
def save_output_img(img_tensor, user_id):
"""
This function saves style transfer resulted tensor
to JPEG file.
"""
PIL_transform = transforms.ToPILImage()
img_tensor = img_tensor.cpu().squeeze(0)
img = PIL_transform(img_tensor)
output_img_name = user_id + '_output_img.jpg'
img.save(output_img_name)
class VGG_Head(nn.Module):
"""
This class reproduces head of VGG19 which contains
only first 5 convolutions.
"""
def __init__(self, cfg):
super(VGG_Head, self).__init__()
self.features = self.make_layers(cfg)
def make_layers(self, cfg):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers[:-1])
def forward(self, x):
x = self.features(x)
return x
class Normalization(nn.Module):
"""
This class represent normalization layer for VGG head
which has been trained on IMAGENET.
"""
def __init__(self):
super(Normalization, self).__init__()
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
self.std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
def forward(self, img):
return (img - self.mean) / self.std
class ContentLoss(nn.Module):
"""
Content loss layer
"""
def __init__(self, target):
super(ContentLoss, self).__init__()
self.target = target.detach()
self.loss = F.mse_loss(self.target, self.target)
def forward(self, input):
if input.shape == self.target.shape:
self.loss = F.mse_loss(input, self.target)
return input
def gram_matrix(input):
"""
This function computes gramm matrix which
is used to determine style loss.
"""
batch_size, ch, h, w = input.shape
features = input.view(batch_size * ch, h * w)
G = torch.mm(features, features.t())
return G.div(batch_size * ch * h * w)
class StyleLoss(nn.Module):
"""
Style Loss layer
"""
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()
self.loss = F.mse_loss(self.target, self.target)
def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input
class NST_Assy():
"""
Assembly of CNN network for image style transfer
"""
def __init__(self, content_img, style_img, input_img, gamma, device):
self.input_img = input_img
self.gamma = max(gamma, 1e-5)
self.model, self.style_loss, self.content_loss = self.assemble_model(
content_img, style_img, device)
self.optimizer = optim.LBFGS([self.input_img.requires_grad_()])
def assemble_model(self, content_img, style_img, device):
"""
Built VGG19 head and load pretrained weights.
"""
cfg = [64, 64, 'M', 128, 128, 'M', 256]
cnn = VGG_Head(cfg).features
# load pretrained weights:
cnn.load_state_dict(torch.load('VGG19_Head.pth'))
cnn.to(device)
# normalization module
normalization = Normalization().to(device)
# Lists of content/style losses:
content_losses = []
style_losses = []
# for the TG Bot we are not planning to change model architecture
# Therefore, we will hardcode locations of Style and Content loss
# layers within NST model class:
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
# start to assemble a new model of NST:
model = nn.Sequential(normalization)
i = 0 # increment every time we see a conv
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
name = 'relu_{}'.format(i)
# Redefine ReLU to make it out-of-place:
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module("content_loss_{}".format(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module("style_loss_{}".format(i), style_loss)
style_losses.append(style_loss)
# trim off the layers after the last content and style losses
# In fact, this loop will break at the first iteration since we are
# using modified CNN that represents only required layers from VGG19.
# But this functionality will be usefull for possible modificaitons.
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
model = model[:(i + 1)]
return model, style_losses, content_losses
def fit(self, num_steps=300):
"""Run the style transfer."""
# Set weights for content and style losses:
content_weight = 1
style_weight = int(self.gamma * 1e5)
run = [0]
while run[0] <= num_steps:
def closure():
# values should be in range [0;1]
self.input_img.data.clamp_(0, 1)
self.optimizer.zero_grad()
self.model(self.input_img)
style_score = 0
content_score = 0
for sl in self.style_loss:
style_score += sl.loss
for cl in self.content_loss:
content_score += cl.loss
# apply weights to losses:
style_score *= style_weight
content_score *= content_weight
# total loss:
loss = style_score + content_score
loss.backward()
run[0] += 1
# uncomment if you'd like to see progress output
# in the python terminal
# if run[0] % 50 == 0:
# print('Run {} : Style Loss = {:4f} Content Loss = {:4f}'.format(
# run[0], style_score.item(), content_score.item()))
return style_score + content_score
self.optimizer.step(closure)
# a last correction...
self.input_img.data.clamp_(0, 1)
return self.input_img
async def image_style_transfer(gamma, user_id):
# set device for model run:
device = get_device()
# load content and style images
content_img_name = user_id + '_content_img.jpg'
style_img_name = user_id + '_style_img.jpg'
content_img = img_loader(content_img_name, device)
style_img = img_loader(style_img_name, device)
# initialize input image:
input_img = content_img.clone()
# built nst cnn:
nst = NST_Assy(content_img, style_img, input_img, gamma, device)
# run style transfer:
output_img = nst.fit()
# save result:
save_output_img(output_img, user_id)