forked from xahidbuffon/FUnIE-GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_funieGAN.py
80 lines (72 loc) · 2.37 KB
/
test_funieGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
# > Script for testing FUnIE-GAN
# > Notes and Usage:
# - set data_dir and model paths
# - python test_funieGAN.py
"""
import os
import time
import ntpath
import numpy as np
from PIL import Image
from os.path import join, exists
from keras.models import model_from_json
## local libs
from utils.data_utils import getPaths, read_and_resize, preprocess, deprocess
## for testing arbitrary local data
data_dir = "../data/test/A/"
from utils.data_utils import get_local_test_data
test_paths = getPaths(data_dir)
print ("{0} test images are loaded".format(len(test_paths)))
## create dir for log and (sampled) validation data
samples_dir = "../data/output/"
if not exists(samples_dir): os.makedirs(samples_dir)
## test funie-gan
checkpoint_dir = 'models/gen_p/'
model_name_by_epoch = "model_15320_"
## test funie-gan-up
#checkpoint_dir = 'models/gen_up/'
#model_name_by_epoch = "model_35442_"
model_h5 = checkpoint_dir + model_name_by_epoch + ".h5"
model_json = checkpoint_dir + model_name_by_epoch + ".json"
# sanity
assert (exists(model_h5) and exists(model_json))
# load model
with open(model_json, "r") as json_file:
loaded_model_json = json_file.read()
funie_gan_generator = model_from_json(loaded_model_json)
# load weights into new model
funie_gan_generator.load_weights(model_h5)
print("\nLoaded data and model")
# testing loop
times = []; s = time.time()
count=0
maxCount=10
for img_path in test_paths:
# prepare data
inp_img = read_and_resize(img_path, (256, 256))
im = preprocess(inp_img)
im = np.expand_dims(im, axis=0) # (1,256,256,3)
# generate enhanced image
s = time.time()
gen = funie_gan_generator.predict(im)
gen_img = deprocess(gen)[0]
tot = time.time()-s
times.append(tot)
# save output images
img_name = ntpath.basename(img_path)
out_img = np.hstack((inp_img, gen_img)).astype('uint8')
Image.fromarray(out_img).save(join(samples_dir, img_name))
count=count+1
if (count>maxCount):
break
# some statistics
num_test = len(test_paths)
if (num_test==0):
print ("\nFound no images for test")
else:
print ("\nTotal images: {0}".format(num_test))
# accumulate frame processing times (without bootstrap)
Ttime, Mtime = np.sum(times[1:]), np.mean(times[1:])
print ("Time taken: {0} sec at {1} fps".format(Ttime, 1./Mtime))
print("\nSaved generated images in in {0}\n".format(samples_dir))