-
Notifications
You must be signed in to change notification settings - Fork 1
/
lapsrn.py
349 lines (283 loc) · 12.7 KB
/
lapsrn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import os
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from base_networks import *
from torch.utils.data import DataLoader
from torchvision.transforms import *
from data import get_training_set, get_test_set
import utils
from logger import Logger
from torchvision.transforms import *
def get_upsample_filter(size):
"""Make a 2D bilinear kernel suitable for upsampling"""
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
filter = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
return torch.from_numpy(filter).float()
class Net(torch.nn.Module):
def __init__(self, num_channels, base_filter, num_convs):
super(Net, self).__init__()
self.input_conv = ConvBlock(num_channels, base_filter, 3, 1, 1, activation='lrelu', norm=None, bias=False)
conv_blocks = []
for _ in range(num_convs):
conv_blocks.append(ConvBlock(base_filter, base_filter, 3, 1, 1, activation='lrelu', norm=None, bias=False))
conv_blocks.append(DeconvBlock(base_filter, base_filter, 4, 2, 1, activation='lrelu', norm=None, bias=False))
self.convt_I1 = DeconvBlock(num_channels, num_channels, 4, 2, 1, activation=None, norm=None, bias=False)
self.convt_R1 = ConvBlock(base_filter, num_channels, 3, 1, 1, activation=None, norm=None, bias=False)
self.convt_F1 = nn.Sequential(*conv_blocks)
self.convt_I2 = DeconvBlock(num_channels, num_channels, 4, 2, 1, activation=None, norm=None, bias=False)
self.convt_R2 = ConvBlock(base_filter, num_channels, 3, 1, 1, activation=None, norm=None, bias=False)
self.convt_F2 = nn.Sequential(*conv_blocks)
def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
c1, c2, h, w = m.weight.data.size()
weight = get_upsample_filter(h)
m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
out = self.input_conv(x)
convt_F1 = self.convt_F1(out)
convt_I1 = self.convt_I1(x)
convt_R1 = self.convt_R1(convt_F1)
x_coarse_ = convt_I1 + convt_R1
convt_F2 = self.convt_F2(convt_F1)
convt_I2 = self.convt_I2(x_coarse_)
convt_R2 = self.convt_R2(convt_F2)
x_finer_ = convt_I2 + convt_R2
return x_coarse_, x_finer_
class L1_Charbonnier_loss(torch.nn.Module):
"""L1 Charbonnierloss."""
def __init__(self):
super(L1_Charbonnier_loss, self).__init__()
self.eps = 1e-6
def forward(self, x, y):
diff = torch.add(x, -y)
error = torch.sqrt(diff * diff + self.eps)
loss = torch.mean(error)
return loss
class LapSRN(object):
def __init__(self, args):
# parameters
self.model_name = args.model_name
self.train_dataset = args.train_dataset
self.test_dataset = args.test_dataset
self.crop_size = args.crop_size
self.num_threads = args.num_threads
self.num_channels = args.num_channels
self.scale_factor = args.scale_factor
self.num_epochs = args.num_epochs
self.save_epochs = args.save_epochs
self.batch_size = args.batch_size
self.test_batch_size = args.test_batch_size
self.lr = args.lr
self.data_dir = args.data_dir
self.save_dir = args.save_dir
self.gpu_mode = args.gpu_mode
def load_dataset(self, dataset='train'):
if self.num_channels == 1:
is_gray = True
else:
is_gray = False
if dataset == 'train':
print('Loading train datasets...')
train_set = get_training_set(self.data_dir, self.train_dataset, self.crop_size, self.scale_factor, is_gray=is_gray,
normalize=False)
return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size,
shuffle=True)
elif dataset == 'test':
print('Loading test datasets...')
test_set = get_test_set(self.data_dir, self.test_dataset, self.scale_factor, is_gray=is_gray,
normalize=False)
return DataLoader(dataset=test_set, num_workers=self.num_threads,
batch_size=self.test_batch_size,
shuffle=False)
def train(self):
# networks
self.model = Net(num_channels=self.num_channels, base_filter=64, num_convs=10)
# weigh initialization
self.model.weight_init()
# optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
# loss function
if self.gpu_mode:
self.model.cuda()
self.loss = L1_Charbonnier_loss().cuda()
# self.loss = nn.L1Loss().cuda()
else:
self.loss = L1_Charbonnier_loss()
print('---------- Networks architecture -------------')
utils.print_network(self.model)
print('----------------------------------------------')
# load dataset
train_data_loader = self.load_dataset(dataset='train')
test_data_loader = self.load_dataset(dataset='test')
# set the logger
log_dir = os.path.join(self.save_dir, 'logs')
if not os.path.exists(log_dir):
os.mkdir(log_dir)
logger = Logger(log_dir)
################# Train #################
print('Training is started.')
avg_loss = []
step = 0
# test image
test_input, test_target = test_data_loader.dataset.__getitem__(2)
test_input = test_input.unsqueeze(0)
test_target = test_target.unsqueeze(0)
self.model.train()
for epoch in range(self.num_epochs):
# learning rate is decayed by a factor of 10 every 10 epochs
if (epoch+1) % 100 == 0:
for param_group in self.optimizer.param_groups:
param_group["lr"] /= 10.0
print("Learning rate decay: lr={}".format(self.optimizer.param_groups[0]["lr"]))
epoch_loss = 0
for iter, (input, target) in enumerate(train_data_loader):
# input data (low resolution image)
if self.gpu_mode:
x_finer_ = Variable(target.cuda())
x_coarse_ = Variable(utils.img_interp(target, 1/self.scale_factor*2).cuda())
y_ = Variable(input.cuda())
else:
x_finer_ = Variable(target)
x_coarse_ = Variable(utils.img_interp(target, 1/self.scale_factor*2))
y_ = Variable(input)
# update network
self.optimizer.zero_grad()
recon_coarse_, recon_finer_ = self.model(y_)
loss_coarse = self.loss(recon_coarse_, x_coarse_)
loss_finer = self.loss(recon_finer_, x_finer_)
loss = loss_coarse + loss_finer
loss_coarse.backward(retain_variables=True)
loss_finer.backward()
self.optimizer.step()
# log
epoch_loss += loss.data[0]
print("Epoch: [%2d] [%4d/%4d] loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0]))
# tensorboard logging
logger.scalar_summary('loss', loss.data[0], step + 1)
step += 1
# avg. loss per epoch
avg_loss.append(epoch_loss / len(train_data_loader))
# prediction
_, recon_imgs = self.model(Variable(test_input.cuda()))
recon_img = recon_imgs[0].cpu().data
gt_img = test_target[0]
lr_img = test_input[0]
bc_img = utils.img_interp(test_input[0], self.scale_factor)
# calculate psnrs
bc_psnr = utils.PSNR(bc_img, gt_img)
recon_psnr = utils.PSNR(recon_img, gt_img)
# save result images
result_imgs = [gt_img, lr_img, bc_img, recon_img]
psnrs = [None, None, bc_psnr, recon_psnr]
utils.plot_test_result(result_imgs, psnrs, epoch + 1, save_dir=self.save_dir, is_training=True)
print("Saving training result images at epoch %d" % (epoch + 1))
# Save trained parameters of model
if (epoch + 1) % self.save_epochs == 0:
self.save_model(epoch + 1)
# Plot avg. loss
utils.plot_loss([avg_loss], self.num_epochs, save_dir=self.save_dir)
print("Training is finished.")
# Save final trained parameters of model
self.save_model(epoch=None)
def test(self):
# networks
self.model = Net(num_channels=self.num_channels, base_filter=64, num_convs=10)
if self.gpu_mode:
self.model.cuda()
# load model
self.load_model()
# load dataset
test_data_loader = self.load_dataset(dataset='test')
# Test
print('Test is started.')
img_num = 0
self.model.eval()
for input, target in test_data_loader:
# input data (low resolution image)
if self.gpu_mode:
y_ = Variable(input.cuda())
else:
y_ = Variable(input)
# prediction
_, recon_imgs = self.model(y_)
for i, recon_img in enumerate(recon_imgs):
img_num += 1
recon_img = recon_imgs[i].cpu().data
gt_img = target[i]
lr_img = input[i]
bc_img = utils.img_interp(input[i], self.scale_factor)
# calculate psnrs
bc_psnr = utils.PSNR(bc_img, gt_img)
recon_psnr = utils.PSNR(recon_img, gt_img)
# save result images
result_imgs = [gt_img, lr_img, bc_img, recon_img]
psnrs = [None, None, bc_psnr, recon_psnr]
utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=self.save_dir)
print("Saving %d test result images..." % img_num)
def test_single(self, img_fn):
# networks
self.model = Net(num_channels=self.num_channels, base_filter=64, num_convs=10)
if self.gpu_mode:
self.model.cuda()
# load model
self.load_model()
# load data
img = Image.open(img_fn)
img = img.convert('YCbCr')
y, cb, cr = img.split()
input = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
if self.gpu_mode:
input = input.cuda()
self.model.eval()
recon_img = self.model(input)
# save result images
utils.save_img(recon_img.cpu().data, 1, save_dir=self.save_dir)
out = recon_img.cpu()
out_img_y = out.data[0]
out_img_y = (((out_img_y - out_img_y.min()) * 255) / (out_img_y.max() - out_img_y.min())).numpy()
# out_img_y *= 255.0
# out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
# save img
result_dir = os.path.join(self.save_dir, 'result')
if not os.path.exists(result_dir):
os.mkdir(result_dir)
save_fn = result_dir + '/SR_result.png'
out_img.save(save_fn)
def save_model(self, epoch=None):
model_dir = os.path.join(self.save_dir, 'model')
if not os.path.exists(model_dir):
os.mkdir(model_dir)
if epoch is not None:
torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param_epoch_%d.pkl' % epoch)
else:
torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param.pkl')
print('Trained model is saved.')
def load_model(self):
model_dir = os.path.join(self.save_dir, 'model')
model_name = model_dir + '/' + self.model_name + '_param.pkl'
if os.path.exists(model_name):
self.model.load_state_dict(torch.load(model_name))
print('Trained model is loaded.')
return True
else:
print('No model exists to load.')
return False