-
Notifications
You must be signed in to change notification settings - Fork 5
/
denoise_eval.py
executable file
·317 lines (239 loc) · 10.6 KB
/
denoise_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import Denoisenet
import tensorflow as tf
from PIL import Image
from tqdm import tqdm
import time
import os
import argparse
PATCH_SHAPE = (128, 128, 3)
def _log10(x):
return np.log(x) / np.log(10)
def psnr(test, gt):
MSE = np.mean((test - gt) ** 2)
MAX = np.max(gt)
_psnr = 20 * _log10(MAX) - 10 * _log10(MSE)
return _psnr
def image_to_patches(img_pad, stride_x, stride_y, mini_batch_size=200):
"""Helper function to convert an image input to patches
Input:
img_pad: image of size (Hp, Wp, 3), numpy array, padded to multiples of PATCH_SHAPE
Output:
batches: batches of patches of shape (mini_batch_size, PATCH_SHAPE)
res: valid patches output in the last batch.
"""
Hp, Wp, C = img_pad.shape
batches = None
num_patches_x = (Wp - PATCH_SHAPE[1]) / stride_x + 1
num_patches_y = (Hp - PATCH_SHAPE[0]) / stride_y + 1
total_patches = int(num_patches_x * num_patches_y)
res = int(total_patches % mini_batch_size)
batches = np.zeros((total_patches, *PATCH_SHAPE))
x = y = 0
for n in range(total_patches):
# Extract patches from image
patch = np.array([img_pad[y:y+PATCH_SHAPE[1], x:x+PATCH_SHAPE[0], :]])
assert patch.shape == (1, *PATCH_SHAPE), "Shape mismatch, %s" % str(patch.shape)
# batches shape: (N, H, W, C)
batches[n, :, :, :] = patch
# Next patch along X
if x + PATCH_SHAPE[1] < Wp:
x += stride_x
# New line, start from X=0
elif x + PATCH_SHAPE[1] >= Wp and y + PATCH_SHAPE[0] < Hp:
y += stride_y
x = 0
# Last patch
elif x + PATCH_SHAPE[1] >= Wp and y + PATCH_SHAPE[0] >= Hp:
break
# Network takes fixed sized input (mini_batch, PATCH_SHAPE),
# append zeros to meet shape convention.
if not res == 0:
for i in range(mini_batch_size-res):
batches = np.concatenate((batches, patch), axis=0)
batches = batches.reshape(-1, mini_batch_size, PATCH_SHAPE[0], PATCH_SHAPE[1], PATCH_SHAPE[2])
return batches, res
def patches_to_image(img, batches_of_patches, stride_x, stride_y):
"""Takes an input of batches of patches, restore them back to original image
Input:
img: input image, used to define output shape (H, W, C)
batches_of_patches, of shape (mini_batch_size, PATCH_SHAPE)
stride_x: the stride of the patches on x axis
stride_y: the stride of the patches on y axis
Output:
image reconstructed from patches, with shape (H, W, C)
"""
H, W, C = img.shape
output = np.zeros(img.shape)
crop_in_x = (int)((PATCH_SHAPE[1] - stride_x) / 2)
crop_in_y = (int)((PATCH_SHAPE[0] - stride_y) / 2)
h, w, c = PATCH_SHAPE
x = y = 0
j = 0
for patches in batches_of_patches:
for i in range(patches.shape[0]):
# Get patch
p = patches[i, crop_in_y:h-crop_in_y, crop_in_x:w-crop_in_x, :]
pt = (p * 255).astype("uint8")
j += 1
# Stitching
# pylint: disable=line-too-long
output[y + crop_in_y : y + PATCH_SHAPE[0] - crop_in_y, x + crop_in_x : x + PATCH_SHAPE[1] - crop_in_x, :] = p
# Next patch along X
if x + PATCH_SHAPE[1] < W:
x += stride_x
# New line with X=0
elif x + PATCH_SHAPE[1] >= W and y + PATCH_SHAPE[0] < H:
y += stride_y
x = 0
# Last patch
elif x + PATCH_SHAPE[1] >= W and y + PATCH_SHAPE[0] >= H:
break
return output
def eval_patch(X, y, sess=None):
"""Evaluate some patches input using trained model
Inputs:
X: Noisy Image patches, size (N, PATCH_SHAPE), numpy array
y: Ground truth image patches, size (N, PATCH_SHAPE), numpy array
sess: session with model preloaded, passed in when denoising whole image.
Return:
output: evaluated output from the network.
loss: meaned squared error between output and ground truth
"""
assert sess is not None, "Session is NoneType when evaluating single patch. Check eval_image()."
if not X.shape == y.shape:
raise ValueError("Shape mismatch when evaluating single patch, shape X %s, shape Y %s" % (str(X.shape), str(y.shape)))
graph = tf.get_default_graph()
X_train = graph.get_tensor_by_name("input/mul_1:0")
y_train = graph.get_tensor_by_name("input/mul:0")
denoised = graph.get_tensor_by_name("denoised:0")
loss = Denoisenet.loss(denoised, y_train)
denoised, loss = sess.run([denoised, loss], feed_dict={X_train : X, y_train : y})
return denoised, loss
def eval_image(X, y, model=None, checkpoint=None, mini_batch_size=16, crop_in=5):
"""Evaluate a full image using a trained model
Achieved by dividing image into various patches and apply model individually.
Inputs:
X: Noisy Image Input, of size (H, W, 3), numpy array.
y: Ground truth Image (Long Exposure Image), of same size as X, numpy array.
checkpoint: Tensorflow checkpoint state object path
Console Outputs:
Loss: Summed loss of all patches across entire image
Returns:
Output: denoised image, of size (H, W, 3), numpy array.
Total_loss: Aggregated loss over entire image.
"""
assert model is not None, "Eval Image: Trained model location is not specified."
assert checkpoint is not None, "Eval Image: Tensorflow Checkpoint is not specified."
assert X.shape == y.shape, "Eval Image: X(%s) and y(%s) shape mismatch." % (X.shape, y.shape)
tf.reset_default_graph()
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
# Load previous model
saver = tf.train.import_meta_graph(model)
saver.restore(sess, tf.train.latest_checkpoint(checkpoint))
# Get Input shape
H, W, C = X.shape
if (C == 1):
raise ValueError("Monochromatic image not supported.")
# Evaluate a large image in patches, stride is length - 2*crop
# To escape convolution artifacts
stride_x = PATCH_SHAPE[0] - 2 * crop_in
stride_y = PATCH_SHAPE[1] - 2 * crop_in
# Pad
pad_h = stride_x - ((W - PATCH_SHAPE[1]) % stride_x)
pad_v = stride_y - ((H - PATCH_SHAPE[0]) % stride_y)
pad_right = pad_h // 2
pad_left = pad_h - pad_right
pad_bottom = pad_v // 2
pad_top = pad_v - pad_bottom
# Zero Pad Input to multiples of PATCH_SHAPE
X_pad = np.pad(X, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), 'constant', constant_values = 0)
y_pad = np.pad(y, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), 'constant', constant_values = 0)
Hp, Wp, Cp = X_pad.shape
x = y = 0
while x + PATCH_SHAPE[1] < Wp:
x += stride_x
assert x + PATCH_SHAPE[1] == Wp, "Padding on W is wrong."
while y + PATCH_SHAPE[0] < Hp:
y += stride_y
assert y + PATCH_SHAPE[0] == Hp, "Padding along H is wrong"
batches_X, resX = image_to_patches(X_pad, stride_x, stride_y, mini_batch_size = mini_batch_size)
batches_Y, resY = image_to_patches(y_pad, stride_x, stride_y, mini_batch_size = mini_batch_size)
assert batches_X.shape[0] == batches_Y.shape[0], "Batch num mismatch, X:%d, y:%d" % (batches_X.shape[0], batches_Y.shape[0])
assert resX == resY, "Residuals mismatch, X%d, y:%d" % (resX, resY)
total_batch_num = mini_batch_size * (batches_X.shape[0] - 1) + resX
total_loss = 0.
output = np.zeros(y_pad.shape)
output_patches = []
for i in tqdm(range(batches_X.shape[0]), desc = "Denoise with network"):
batch_X = batches_X[i]
batch_y = batches_Y[i]
denoised_batch, loss = eval_patch(batch_X, batch_y, sess)
output_patches.append(denoised_batch)
total_loss += np.sum(loss)
print ("Average MSE across image: %.6E" % (total_loss / total_batch_num))
output = patches_to_image(y_pad, output_patches, stride_x, stride_y)
"""Outputs effective regions"""
output = output[pad_top:output.shape[0]-pad_bottom, pad_left:output.shape[1]-pad_right, 0:3]
"""Pixel Clipping"""
output[output < 0] = 0
output[output > 1] = 1
sess.close()
return output, total_loss, [pad_top, pad_bottom, pad_left, pad_right]
def main(args):
t = time.time()
X_path = args.EvalX
y_path = args.EvalY
model = args.Model
ckpt = args.Checkpoint
if args.Output:
Output_path = args.Output
else:
Output_path = './Output/'
# Read Images, discard alpha channel
Ximg = Image.open(X_path)
Yimg = Image.open(y_path)
is_gray = len(np.asarray(Ximg).shape) == 2
X_ycbcr = Ximg.convert('YCbCr')
y_ycbcr = Yimg.convert('YCbCr')
X = np.asarray(X_ycbcr)
y = np.asarray(y_ycbcr)
if X.shape[2] == 4:
X = X[:, :, 0:3]
y = y[:, :, 0:3]
# Normalize:
X = X / 255.
y = y / 255.
# Process with eval_image
output, _, pads = eval_image(X, y, model, ckpt)
# Write output as image
name = X_path.split("/")
name = name[len(name) - 1]
print ("PSNR: ", psnr(output, y))
output = output * 255.
output = output.astype('uint8')
# Write output as image
name = X_path.split("/")
name = name[len(name) - 1]
if not os.path.exists(Output_path):
os.mkdir(Output_path)
if not is_gray:
Image.fromarray(output, mode = "YCbCr").convert("RGB").save(Output_path + name)
else:
output = output[:, :, 0].astype(np.uint8)
Image.fromarray(output, mode = "L").save(Output_path + name)
t = time.time() - t
print ("time used(s)", t)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "Evaluate a single noisy image and compute total loss.")
parser.add_argument('EvalX', type=str, help="Path of the noisy image to evaluate.")
parser.add_argument('EvalY', type=str, help="Path of the ground truth image")
parser.add_argument('Model', type=str, help="Path of the trained Tensorflow Model, this builds tensorflow graph.")
parser.add_argument('Checkpoint', type=str, help="Path of Tensorflow checkpoint, this restores parameters.")
parser.add_argument('--Output', type=str, help="Path of the output image")
args = parser.parse_args()
main(args)