-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
159 lines (123 loc) · 5.93 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
# GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
import argparse
import cv2
import glob
import time
import numpy as np
from collections import OrderedDict
import torch
from utils.util_calculate_psnr_ssim import to_y_channel, calculate_psnr, calculate_ssim
from utils.tools import read_yaml, Logger
from utils.train import Trainer
from utils.niqe import niqe as calculate_niqe
from lpips import LPIPS, im2tensor
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--scale', type=int, default=4, help='scale factor: 4, 8') # 1 for dn and jpeg car ####################
parser.add_argument('--model', type=str, default='tiny')
parser.add_argument('--train_type', type=str, default='pre')
parser.add_argument('--dataset_lr', type=str, default='/media/Datasets/super-resolution/Set5/LR_bicubic/',
help='input low-quality test image folder')
parser.add_argument('--dataset_hr', type=str, default='/media/Datasets/super-resolution/Set5/HR',
help='input ground-truth test image folder')
parser.add_argument('--config', default='config.yaml', type=str, help='Config path')
parser.add_argument('--predict', help='Predict', default=True, action=argparse.BooleanOptionalAction)
parser.add_argument('--logfile', default='result_log.txt', type=str, help='Log file path')
args = parser.parse_args()
DATASET = args.dataset_hr.split('/')[-2]
print(f"\n\nTesting on {DATASET}")
logger = Logger(args.logfile)
calculate_lpips = LPIPS(net='alex', verbose=False).cuda()
config = read_yaml(args.config)
weights = f'weights/srgan/{args.model}'
config['MODE'] = 'TEST'
if args.predict:
model = get_model(config, weights)
# resolve(model, np.random.rand(1,200,200,3))
# setup folder and path
# save_dir = f'results/{args.model}'
# os.makedirs(save_dir, exist_ok=True)
border = config['SCALE']
test_results = OrderedDict()
for metric in ['psnr', 'ssim', 'lpips', 'niqe', 'timing']:
test_results[metric] = []
for idx, path in enumerate(sorted(glob.glob(os.path.join(args.dataset_hr, '*')))):
# read image
imgname, img_lr, img_hr = get_image_pair(path, args.dataset_lr, args.scale) # image to HWC-BGR, float32
# inference
h_lr, w_lr = img_lr.shape[:-1]
if args.predict:
output, timing = resolve(model, img_lr[None,...,::-1])
output = output[0]
else:
h_lr, w_lr = h_lr//args.scale, w_lr//args.scale
output = img_lr[...,::-1]
timing = -1
test_results['timing'].append(timing)
output = output[:h_lr * args.scale, :w_lr * args.scale][:,:,::-1] # RGB to BGR
img_hr = img_hr[:h_lr * args.scale, :w_lr * args.scale]
# # save image
# output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
# if output.ndim == 3:
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
# output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
# cv2.imwrite(f'{save_dir}/{imgname}_SwinIR.png', output)
# evaluate psnr/ssim/nique/lpips
output_y = to_y_channel(output)
img_hr_y = to_y_channel(img_hr)
psnr = calculate_psnr(output_y, img_hr_y, crop_border=border)
ssim = calculate_ssim(output_y, img_hr_y, crop_border=border)
with torch.no_grad():
lpips = calculate_lpips(im2tensor(output[...,::-1]).cuda(), im2tensor(img_hr[...,::-1]).cuda()).item()
niqe = calculate_niqe(output_y)
test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)
test_results['lpips'].append(lpips)
test_results['niqe'].append(niqe)
print('Testing {:d} {:10s} - PSNR: {:.4f} dB; SSIM: {:.4f}; '
'LPIPS: {:.4f}; NIQE: {:.4f}; \ttiming: {:.4f} ms '.format(idx, imgname, psnr, ssim, lpips, niqe, timing))
# summarize psnr/ssim/lpips/niqe
ave_psnr = np.mean(test_results['psnr'])
ave_ssim = np.mean(test_results['ssim'])
ave_lpips = np.mean(test_results['lpips'])
ave_niqe = np.mean(test_results['niqe'])
ave_timing = np.mean(test_results['timing'])
print('\nAverage PSNR/SSIM/LPIPS/NIQE: {:.4f} dB; {:.4f}; {:.4f}; {:.4f}'.format(ave_psnr, ave_ssim, ave_lpips, ave_niqe))
print('Average timing: {:.4f} ms'.format(ave_timing))
text = '{:s}: PSNR/SSIM/LPIPS/NIQE {:.4f} dB; {:.4f}; {:.4f}; {:.4f}'.format(DATASET, ave_psnr, ave_ssim, ave_lpips, ave_niqe)
logger.save_log(text)
def get_model(config, model_weights):
trainer = Trainer(config=config)
model = trainer.generator
model.load_weights(model_weights, by_name=False, skip_mismatch=False)
return model
def get_image_pair(path, dataset_lr, scale=4, normalize=False):
(imgname, imgext) = os.path.splitext(os.path.basename(path))
img_gt = cv2.imread(path, cv2.IMREAD_COLOR)
img_lq = cv2.imread(f'{dataset_lr}/X{scale}/{imgname}x{scale}{imgext}', cv2.IMREAD_COLOR)
if normalize:
img_gt = img_gt/255.
img_lq = img_lq/255.
return imgname, img_lq, img_gt
def resolve(model, lr_batch, to_numpy=True):
lr_batch = tf.cast(lr_batch, tf.float32)
t0 = time.time()
sr_batch = model(lr_batch, training=False)[0]
t1 = time.time()
timing = (t1 - t0)*1000
sr_batch = tf.clip_by_value(sr_batch, 0, 255)
sr_batch = tf.round(sr_batch)
sr_batch = tf.cast(sr_batch, tf.uint8)
if to_numpy:
sr_batch = sr_batch.numpy()
return sr_batch, timing
if __name__ == '__main__':
main()