-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
129 lines (97 loc) · 4.6 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 10 22:56:02 2019
@author: aneesh
"""
import os
import os.path as osp
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from helpers.utils import Metrics, AeroCLoader, parse_args
from networks.resnet6 import ResnetGenerator
from networks.segnet import segnet, segnetm
from networks.unet import unet, unetm
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = 'AeroRIT baseline evalutions')
### 0. Config file?
parser.add_argument('--config-file', default = None, help = 'Path to configuration file')
### 1. Data Loading
parser.add_argument('--bands', default = 51, help = 'Which bands category to load \
- 3: RGB, 4: RGB + 1 Infrared, 6: RGB + 3 Infrared, 31: Visible, 51: All', type = int)
parser.add_argument('--hsi_c', default = 'rad', help = 'Load HSI Radiance or Reflectance data?')
### 2. Network selections
### a. Which network?
parser.add_argument('--network_arch', default = 'unet', help = 'Network architecture?')
parser.add_argument('--use_mini', action = 'store_true', help = 'Use mini version of network?')
### b. ResNet config
parser.add_argument('--resnet_blocks', default = 6, help = 'How many blocks if ResNet architecture?', type = int)
### c. UNet configs
parser.add_argument('--use_SE', action = 'store_true', help = 'Network uses SE Layer?')
parser.add_argument('--use_preluSE', action = 'store_true', help = 'SE layer uses ReLU or PReLU activation?')
### Load weights post network config
parser.add_argument('--network_weights_path', default = None, help = 'Path to Saved Network weights')
### Use GPU or not
parser.add_argument('--use_cuda', action = 'store_true', help = 'use GPUs?')
args = parse_args(parser)
print(args)
args.use_mini = True
# args.use_SE = True
# args.use_preluSE = True
args.network_weights_path = 'savedmodels/unetm.pt'
if args.use_cuda and torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
perf = Metrics()
tx = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
if args.bands == 3 or args.bands == 4 or args.bands == 6:
testset = AeroCLoader(set_loc = 'right', set_type = 'test', size = 'small', hsi_sign = args.hsi_c, hsi_mode = '{}b'.format(args.bands), transforms = tx)
elif args.bands == 31:
testset = AeroCLoader(set_loc = 'right', set_type = 'test', size = 'small', hsi_sign = args.hsi_c, hsi_mode = 'visible', transforms = tx)
elif args.bands == 51:
testset = AeroCLoader(set_loc = 'right', set_type = 'test', size = 'small', hsi_sign = args.hsi_c, hsi_mode = 'all', transforms = tx)
else:
raise NotImplementedError('required parameter not found in dictionary')
print('Completed loading data...')
if args.network_arch == 'resnet':
net = ResnetGenerator(args.bands, 6, n_blocks=args.resnet_blocks)
elif args.network_arch == 'segnet':
if args.mini == True:
net = segnetm(args.bands, 6)
else:
net = segnet(args.bands, 6)
elif args.network_arch == 'unet':
if args.use_mini == True:
net = unetm(args.bands, 6, use_SE = args.use_SE, use_PReLU = args.use_preluSE)
else:
net = unet(args.bands, 6)
else:
raise NotImplementedError('required parameter not found in dictionary')
net.load_state_dict(torch.load(args.network_weights_path))
net.eval()
net.to(device)
print('Completed loading pretrained network weights...')
print('Calculating prediction accuracy...')
labels_gt = []
labels_pred = []
for img_idx in range(len(testset)):
_, hsi, label = testset[img_idx]
label = label.numpy()
label_pred = net(hsi.unsqueeze(0).to(device))
label_pred = label_pred.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
label = label.flatten()
label_pred = label_pred.flatten()
labels_gt = np.append(labels_gt, label)
labels_pred = np.append(labels_pred, label_pred)
scores = perf(labels_gt, labels_pred)
print('Statistics on Test set:\n')
print('Overall accuracy = {:.2f}%\nAverage Accuracy = {:.2f}%\nMean IOU is {:.2f}\
\nMean DICE score is {:.2f}'.format(scores[0]*100, scores[1]*100, scores[2]*100, scores[3]*100))