-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest.py
130 lines (115 loc) · 5.34 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
from model import DCTNet
from torch.utils.data import DataLoader
import warnings
from metrics import Rmse
import numpy as np
from scipy.io import savemat
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import torch.nn as nn
import torch
import time
from utils_gdsr import DRSRH5Dataset, DRSRDataset, save_param, output_img
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
warnings.filterwarnings('ignore')
def inference_net_eachDataset(dataset_name, net_Path, scale):
start = time.time()
# . Get your model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = nn.DataParallel(DCTNet()).to(device)
if dataset_name == 'Middlebury':
test_path = r'./DatasetsAfterProcessing/Middlebury_AfterProcessing_'+str(scale)+'X'
elif dataset_name == 'NYU':
test_path = r'./DatasetsAfterProcessing/NYU_Test_AfterProcessing_'+str(scale)+'X'
elif dataset_name == 'Lu':
test_path = r'./DatasetsAfterProcessing/Lu_AfterProcessing_'+str(scale)+'X'
elif dataset_name == 'RGBDD':
test_path = r'./DatasetsAfterProcessing/RGBDD_AfterProcessing_'+str(scale)+'X'
# # 1. Load the best weight and create the dataloader for testing
testloader = DataLoader(DRSRDataset(test_path, scale, dataset_name),
batch_size=1)
net.load_state_dict(torch.load(net_Path))
# 2. Compute the metrics
metrics = torch.zeros(1, testloader.__len__())
with torch.no_grad():
net.eval()
for i, (Depth, RGB, gt, D_min, D_max) in enumerate(testloader):
Depth, RGB, gt, D_min, D_max = Depth.cuda(
), RGB.cuda(), gt.cuda(), D_min.cuda(), D_max.cuda()
imgf_raw = net(Depth, RGB).clamp(min=0, max=1)
imgf = (imgf_raw * (D_max - D_min)) + D_min
filename = os.path.splitext(
testloader.dataset.DepthHR_files[i].split('/')[-1])[0]
if dataset_name == 'Middlebury' or dataset_name == 'Lu':
imgf2image = output_img(imgf).clip(min=0, max=255)
gt2image = output_img(gt).clip(min=0, max=255)
elif dataset_name == 'NYU':
imgf2image = output_img(imgf)[6:-6, 6:-6]
gt2image = output_img(gt)[6:-6, 6:-6]
else:
imgf2image = output_img(imgf)
gt2image = output_img(gt)
metrics[:, i] = Rmse(imgf2image, gt2image)
end = time.time()
return metrics.mean(dim=1)
def infrence_all_datasets(net_Path, scale):
if scale == 'RealScene':
Rmses = inference_net_eachDataset('RGBDD', net_Path, scale)
else:
Rmses = np.zeros(4)
Rmses[0] = inference_net_eachDataset('Middlebury', net_Path, scale)
Rmses[1] = inference_net_eachDataset('NYU', net_Path, scale)
Rmses[2] = inference_net_eachDataset('Lu', net_Path, scale)
Rmses[3] = inference_net_eachDataset('RGBDD', net_Path, scale)
return Rmses
def test():
'''Calculate RMSE value'''
rmseResults = np.zeros((4,3))
rmseResults[:, 0] = infrence_all_datasets('models/DCTNet_4X.pth', 4)
rmseResults[:, 1] = infrence_all_datasets('models/DCTNet_8X.pth', 8)
rmseResults[:, 2] = infrence_all_datasets('models/DCTNet_16X.pth', 16)
rmseResults_RealScene1 = infrence_all_datasets(
'models/DCTNet_4X.pth', 'RealScene')
rmseResults_RealScene2 = infrence_all_datasets(
'models/DCTNet_RealScene.pth', 'RealScene')
'''Output the final result'''
print('==============================================')
print('The testing RMSE results of Middlebury Dataset')
print(' X4 X8 X16')
print('----------------------------------------------')
print(rmseResults[0,:])
print('==============================================')
print('==============================================')
print('The testing RMSE results of NYU V2 Dataset')
print(' X4 X8 X16')
print('----------------------------------------------')
print(rmseResults[1,:])
print('==============================================')
print('==============================================')
print('The testing RMSE results of Lu Dataset')
print(' X4 X8 X16')
print('----------------------------------------------')
print(rmseResults[2,:])
print('==============================================')
print('==============================================')
print('The testing RMSE results of RGBDD Dataset')
print(' X4 X8 X16')
print('----------------------------------------------')
print(rmseResults[3,:])
print('==============================================')
print('==============================================')
print('The testing RMSE results in RealScene RGBDD')
print('DCTNet in real-world branch')
print('----------------------------------------------')
print(rmseResults_RealScene1)
print('==============================================')
print('==============================================')
print('The testing RMSE results in RealScene RGBDD')
print('DCTNet* in real-world branch')
print('----------------------------------------------')
print(rmseResults_RealScene2)
print('==============================================')
test()