-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
74 lines (63 loc) · 3.34 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import numpy as np
import torchvision
from torchvision import transforms
import argparse
import time
from tqdm import tqdm
from dataloader import RUE_Net_DataSet
from metrics_calculation import *
__all__ = [
"test",
"setup",
"testing",
]
@torch.no_grad()
def test(config, test_dataloader, test_model):
test_model.eval()
for img, _, gray, gx, gy, name in test_dataloader:
with torch.no_grad():
device = config.device
img = img.to(device)
gray = gray.to(device)
gx = gx.to(device)
gy = gy.to(device)
generate_img, _ = test_model(img, gray, gx, gy)
torchvision.utils.save_image(generate_img, config.output_images_path + name[0])
def setup(config):
if torch.cuda.is_available():
config.device = "cuda"
else:
config.device = "cpu"
model = torch.load(config.snapshot_path).to(config.device)
transform = transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])
test_dataset = RUE_Net_DataSet(config.test_images_path,None,transform, False)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size = config.batch_size,shuffle = False)
print("Test Dataset Reading Completed.")
return test_dataloader, model
def testing(config):
ds_test, model = setup(config)
test(config, ds_test, model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--snapshot_path',type=str,default='',help='snapshot path,such as :xxx/snapshots/model.ckpt')
parser.add_argument('--test_images_path', type=str, default="",help='path of input images(underwater images) for testing default:./data/input/')
parser.add_argument('--output_images_path',type=str,default='./output/',help='path to save generated image.')
parser.add_argument('--batch_size', type=int, default=1,help="default : 1")
parser.add_argument('--resize', type=int, default=512,help="resize images, default:resize images to 256*256")
parser.add_argument('--calculate_metrics', type=bool, default=True, help="calculate PSNR, SSIM and UIQM on test images")
parser.add_argument('--label_images_path', type=str, default="",help='path of label images(clear images) default:./data/label/')
print("-------------------testing---------------------")
config = parser.parse_args()
if not os.path.exists(config.output_images_path):
os.mkdir(config.output_images_path)
start_time = time.time()
testing(config)
print("total testing time" , time.time() - start_time)
if config.calculate_metrics:
print("-------------------calculating performance metrics---------------------")
SSIM_measures, PSNR_measures = calculate_metrics_ssim_psnr(config.output_images_path, config.label_images_path, (config.resize, config.resize))
UIQM_measures = calculate_UIQM(config.output_images_path, (config.resize, config.resize))
print("SSIM on {0} samples {1} ± {2}".format(len(SSIM_measures), np.round(np.mean(SSIM_measures), 3), np.round(np.std(SSIM_measures), 3)))
print("PSNR on {0} samples {1} ± {2}".format(len(PSNR_measures), np.round(np.mean(PSNR_measures), 3), np.round(np.std(PSNR_measures), 3)))
print("UIQM on {0} samples {1} ± {2}".format(len(UIQM_measures), np.round(np.mean(UIQM_measures), 3), np.round(np.std(UIQM_measures), 3)))