forked from layumi/Person_reID_baseline_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
111 lines (98 loc) · 3.67 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import scipy.io
import torch
import numpy as np
import time
import os
from torchvision import datasets, models, transforms
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
#######################################################################
# Evaluate
parser = argparse.ArgumentParser(description='Demo')
parser.add_argument('--query_index', default=777, type=int, help='test_image_index')
parser.add_argument('--test_dir',default='../Market/pytorch',type=str, help='./test_data')
opts = parser.parse_args()
data_dir = opts.test_dir
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ) for x in ['gallery','query']}
#####################################################################
#Show result
def imshow(path, title=None):
"""Imshow for Tensor."""
im = plt.imread(path)
plt.imshow(im)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
######################################################################
result = scipy.io.loadmat('pytorch_result.mat')
query_feature = torch.FloatTensor(result['query_f'])
query_cam = result['query_cam'][0]
query_label = result['query_label'][0]
gallery_feature = torch.FloatTensor(result['gallery_f'])
gallery_cam = result['gallery_cam'][0]
gallery_label = result['gallery_label'][0]
multi = os.path.isfile('multi_query.mat')
if multi:
m_result = scipy.io.loadmat('multi_query.mat')
mquery_feature = torch.FloatTensor(m_result['mquery_f'])
mquery_cam = m_result['mquery_cam'][0]
mquery_label = m_result['mquery_label'][0]
mquery_feature = mquery_feature.cuda()
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()
#######################################################################
# sort the images
def sort_img(qf, ql, qc, gf, gl, gc):
query = qf.view(-1,1)
# print(query.shape)
score = torch.mm(gf,query)
score = score.squeeze(1).cpu()
score = score.numpy()
# predict index
index = np.argsort(score) #from small to large
index = index[::-1]
# index = index[0:2000]
# good index
query_index = np.argwhere(gl==ql)
#same camera
camera_index = np.argwhere(gc==qc)
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index1 = np.argwhere(gl==-1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1)
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
return index
i = opts.query_index
index = sort_img(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
########################################################################
# Visualize the rank result
query_path, _ = image_datasets['query'].imgs[i]
query_label = query_label[i]
print(query_path)
print('Top 10 images are as follow:')
try: # Visualize Ranking Result
# Graphical User Interface is needed
fig = plt.figure(figsize=(16,4))
ax = plt.subplot(1,11,1)
ax.axis('off')
imshow(query_path,'query')
for i in range(10):
ax = plt.subplot(1,11,i+2)
ax.axis('off')
img_path, _ = image_datasets['gallery'].imgs[index[i]]
label = gallery_label[index[i]]
imshow(img_path)
if label == query_label:
ax.set_title('%d'%(i+1), color='green')
else:
ax.set_title('%d'%(i+1), color='red')
print(img_path)
except RuntimeError:
for i in range(10):
img_path = image_datasets.imgs[index[i]]
print(img_path[0])
print('If you want to see the visualization of the ranking result, graphical user interface is needed.')
fig.savefig("show.png")