-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
73 lines (61 loc) · 2.83 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
# @Project :mylearn
# @File :test
# @Date :2021/1/15 20:17
# @Author :Jay_Lee
# @Software :PyCharm
-------------------------------------------------
"""
import argparse
from matplotlib import pyplot as plt
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from conf import settings
from utils import get_network, get_test_dataloader
if __name__=='__main__':
parser=argparse.ArgumentParser()
parser.add_argument('-net', type=str, default='mobilenet', help='net type')
parser.add_argument('-gpu', action='store_true', default=True, help='ues gpu or not')
parser.add_argument('-v', type=str, default='cifar100', help='dataset type')
parser.add_argument('-b', type=int, default=16, help='batchsize for dataloader')
# parser.add_argument('-weights', type=str, default='~/shenzhen/lkq/test_cifar/mylearn/checkpoint/mobilenet/Saturday_16_January_2021_05h_11m_00s/mobilenet-165-best.pth',
# help='the weights file you want to test')
parser.add_argument('-weights', type=str,
default='./checkpoint/mobilenet/Saturday_16_January_2021_05h_11m_00s/mobilenet-165-best.pth',
help='the weights file you want to test')
args = parser.parse_args()
net=get_network(args)
cifar100_test_loader = get_test_dataloader(settings.CIFAR100_TRAIN_MEAN
, settings.CIFAR100_TRAIN_STD, num_workers=4, batch_size=args.b)
net.load_state_dict(torch.load(args.weights))
print(net)
net.eval()
correct_1 = 0.0
correct_5 = 0.0
total = 0
with torch.no_grad():
for n_iter,(image,label) in enumerate(cifar100_test_loader):
print('iteration: {}\ttotal:{}iterations'.format(n_iter+1,len(cifar100_test_loader)))
if args.gpu:
image=image.cuda()
label=label.cuda()
# print('GPU info')
# print(torch.cuda.memory_summary(),end='')
output = net(image)
_, pred = output.topk(5, 1, largest=True, sorted=True)
label = label.view(label.size(0), -1).expand_as(pred)
correct = pred.eq(label).float()
# compute top 5
correct_5 += correct[:, :5].sum()
# compute top1
correct_1 += correct[:, :1].sum()
# if args.gpu:
# print('GPU INFO.....')
# print(torch.cuda.memory_summary(), end='')
print(correct_1 / len(cifar100_test_loader.dataset))
print("Top 1 err: ", 1 - correct_1 / len(cifar100_test_loader.dataset))
print("Top 5 err: ", 1 - correct_5 / len(cifar100_test_loader.dataset))
print("Parameter numbers: {}".format(sum(p.numel() for p in net.parameters())))