-
Notifications
You must be signed in to change notification settings - Fork 69
/
inference.py
72 lines (58 loc) · 2.17 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding:utf-8 -*-
import sys
sys.path.append('/home/gfx/Projects/remote_sensing_image_classification')
import os, argparse, time
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torch.nn.parallel.data_parallel import data_parallel
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from dataset.dataset import *
from networks.network import *
from networks.lr_schedule import *
from metrics.metric import *
from utils.plot import *
from config import config
def inference():
# model
# load checkpoint
model = torch.load(os.path.join('./checkpoints', config.checkpoint))
print model
# model = torch.nn.DataParallel(model)
model.cuda()
# validation data
transform = transforms.Compose([transforms.Resize((config.width, config.height)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
dst_valid = RSDataset('./data/valid.txt', width=config.width,
height=config.height, transform=transform)
dataloader_valid = DataLoader(dst_valid, shuffle=False, batch_size=config.batch_size/2, num_workers=config.num_workers)
sum = 0
val_top1_sum = 0
labels = []
preds = []
model.eval()
for ims, label in dataloader_valid:
labels += label.numpy().tolist()
input = Variable(ims).cuda()
target = Variable(label).cuda()
output = model(input)
_, pred = output.topk(1, 1, True, True)
preds += pred.t().cpu().numpy().tolist()[0]
top1_val = accuracy(output.data, target.data, topk=(1,))
sum += 1
val_top1_sum += top1_val[0]
avg_top1 = val_top1_sum / sum
print 'acc: {}'.format(avg_top1.data)
labels_=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
plot_confusion_matrix(labels, preds, labels_)
if __name__ == '__main__':
inference()