-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
112 lines (89 loc) · 4.2 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import argparse
import warnings
import os
from utils import calculate_mAP
from datasets import PascalVOCDataset
from pprint import PrettyPrinter
from model import SSD300
from utils import label_map
from tqdm import tqdm
def evaluate(test_loader, model, device):
"""
Evaluate.
:param test_loader: DataLoader for test data
:param model: model
"""
# Make sure it's in eval mode
model.eval()
# Lists to store detected and true boxes, labels, scores
det_boxes = list()
det_labels = list()
det_scores = list()
true_boxes = list()
true_labels = list()
true_difficulties = list() # it is necessary to know which objects are 'difficult', see 'calculate_mAP' in utils.py
with torch.no_grad():
# Batches
for i, (images, boxes, labels, difficulties) in enumerate(tqdm(test_loader, desc='evaluating model')):
images = images.to(device) # (N, 3, 300, 300)
# Forward prop.
predicted_locs, predicted_scores = model(images)
# Detect objects in SSD output
det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(predicted_locs, predicted_scores,
min_score=0.01, max_overlap=0.45,
top_k=200, device=device)
# Evaluation MUST be at min_score=0.01, max_overlap=0.45, top_k=200 for fair comparision with the paper's results and other repos
# Store this batch's results for mAP calculation
boxes = [b.to(device) for b in boxes]
labels = [l.to(device) for l in labels]
difficulties = [d.to(device) for d in difficulties]
det_boxes.extend(det_boxes_batch)
det_labels.extend(det_labels_batch)
det_scores.extend(det_scores_batch)
true_boxes.extend(boxes)
true_labels.extend(labels)
true_difficulties.extend(difficulties)
# Calculate mAP
print('calculating mAP...')
APs, mAP = calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties, device)
return APs, mAP
if __name__ == '__main__':
warnings.filterwarnings('ignore')
parser = argparse.ArgumentParser()
parser.add_argument('--data_folder', required=True)
parser.add_argument('--checkpoint', help='path of the pretrained model', required=True)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()
# Good formatting when printing the APs for each class and mAP
pp = PrettyPrinter()
# Parameters
keep_difficult = True # difficult ground truth objects must always be considered in mAP calculation, because these objects DO exist!
workers = 4
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # 防止预训练模型被加载到gpu0上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model checkpoint that is to be evaluated
if args.checkpoint == 'pretrained_ssd300.pth.tar':
checkpoint = torch.load(args.checkpoint, map_location=device)
model = checkpoint['model']
for m in model.modules():
if 'Conv' in str(type(m)):
setattr(m, 'padding_mode', 'zeros')
else:
checkpoint = torch.load(args.checkpoint, map_location=device)
model = SSD300(n_classes=len(label_map), device=device)
model.load_state_dict(checkpoint['model'])
model = model.to(device)
# Switch to eval mode
model.eval()
# Load test data
test_dataset = PascalVOCDataset(args.data_folder,
split='test',
keep_difficult=keep_difficult)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
collate_fn=test_dataset.collate_fn, num_workers=workers, pin_memory=True)
APs, mAP = evaluate(test_loader, model, device)
# Print AP for each class
pp.pprint(APs)
print('\nMean Average Precision (mAP): %.3f' % mAP)