-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathdataset.py
125 lines (108 loc) · 4.49 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import glob
class Dataset():
def __init__(self, root='/home/paul/datasets', dataset='market1501'):
self.dataset = dataset
self.root = root
def train_path(self):
if self.dataset == 'market1501' or self.dataset == 'duke':
return os.path.join(self.root, self.dataset, 'bounding_box_train')
elif self.dataset == 'cuhk03':
return os.path.join(self.root, self.dataset, 'bounding_box_train')
elif self.dataset == 'viper':
return os.path.join(self.root, self.dataset, 'bounding_box_train')
else:
raise ValueError('Unknown train set for %s' % self.dataset)
def test_path(self):
if self.dataset == 'market1501' or self.dataset == 'duke':
return os.path.join(self.root, self.dataset, 'bounding_box_test')
elif self.dataset == 'cuhk03' or self.dataset == 'viper':
return os.path.join(self.root, self.dataset, 'bounding_box_test')
else:
raise ValueError('Unknown test set for %s' % self.dataset)
def gallery_path(self):
return self.testset()
def query_path(self):
if self.dataset == 'market1501' or self.dataset == 'duke':
return os.path.join(self.root, self.dataset, 'query')
elif self.dataset == 'cuhk03' or self.dataset == 'viper':
return os.path.join(self.root, self.dataset, 'query')
else:
raise ValueError('Unknown query set for %s' % self.dataset)
def gan_path(self):
return os.path.join('/home/paul/generated', self.dataset)
def dataset_path(self):
return os.path.join(self.root, self.dataset)
def n_classe(self):
if self.dataset == 'market1501':
return 751
elif self.dataset == 'duke':
return 702
elif self.dataset == 'cuhk03':
return 767
elif self.dataset == 'viper':
return 316
else:
raise ValueError('Unknown n_classe set for %s' % self.dataset)
def root_path(self):
return self.root
def gt_set(self):
if self.dataset == 'market1501':
return os.path.join(self.root, self.dataset, 'gt_bbox')
else:
raise ValueError('Unknown hand-drawn bounding boxes for %s' % self.dataset)
def train_list(self):
if self.dataset == 'market1501' or self.dataset == 'duke' or self.dataset == 'cuhk03':
train_list = os.path.join(self.root, self.dataset, 'train.list')
elif self.dataset == 'viper':
train_list = os.path.join(self.root, self.dataset, 'train.list')
else:
raise ValueError('Unknown train bounding boxes for %s' % self.dataset)
if not os.path.exists(train_list):
raise FileNotFoundError('%s not found' % train_list)
return train_list
def cluster_path(self):
if self.dataset == 'market1501' or self.dataset == 'duke' or \
self.dataset == 'cuhk03' or self.dataset == 'viper':
return os.path.join('/home/paul', 'clustering', self.dataset)
else:
raise ValueError('Unknown cluster path for %s' % self.dataset)
def n_training_set(self):
if self.dataset == 'market1501':
data_list = glob.glob(os.path.join(self.train_path(), '*.jpg'))
n = len(data_list)
assert n == 12936
elif self.dataset == 'duke':
n = 16522
else:
raise ValueError("Unknow training set size for %s" % self.dataset)
return n
def n_gan_set(self):
if self.dataset == 'market1501':
data_list = glob.glob(os.path.join(self.gan_path(), '*.jpg'))
n = len(data_list)
else:
raise ValueError('Unknow generated set size for %s' % self.dataset)
return n
def test_num(self):
if self.dataset == 'market1501':
return 19732
elif self.dataset == 'duke':
return 17661
elif self.dataset == 'cuhk03':
return 6751
elif self.dataset == 'viper':
return 316
else:
raise ValueError('Unknown test num for % dataset' % self.dataset)
def query_num(self):
if self.dataset == 'market1501':
return 3368
elif self.dataset == 'duke':
return 2228
elif self.dataset == 'cuhk03':
return 6751
elif self.dataset == 'viper':
return 316
else:
raise ValueError('Unknown query num for % dataset' % self.dataset)