-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 3d2ab2b
Showing
28 changed files
with
1,772 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Person-Attribute-Recognition-MarketDuke |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import os | ||
from PIL import Image | ||
import torch | ||
from torch.utils import data | ||
import numpy as np | ||
from torchvision import transforms as T | ||
from .reid_dataset import import_MarketDuke_nodistractors | ||
from .reid_dataset import import_Market1501Attribute_binary | ||
from .reid_dataset import import_DukeMTMCAttribute_binary | ||
|
||
|
||
class Train_Dataset(data.Dataset): | ||
|
||
def __init__(self, data_dir, dataset_name, transforms=None, train_val='train' ): | ||
|
||
train, query, gallery = import_MarketDuke_nodistractors(data_dir, dataset_name) | ||
|
||
if dataset_name == 'Market-1501': | ||
train_attr, test_attr, self.label = import_Market1501Attribute_binary(data_dir) | ||
elif dataset_name == 'DukeMTMC-reID': | ||
train_attr, test_attr, self.label = import_DukeMTMCAttribute_binary(data_dir) | ||
else: | ||
print('Input should only be Market1501 or DukeMTMC') | ||
|
||
self.num_ids = len(train['ids']) | ||
self.num_labels = len(self.label) | ||
|
||
# distribution:每个属性的正样本占比 | ||
distribution = np.zeros(self.num_labels) | ||
for k, v in train_attr.items(): | ||
distribution += np.array(v) | ||
self.distribution = distribution / len(train_attr) | ||
|
||
if train_val == 'train': | ||
self.train_data = train['data'] | ||
self.train_ids = train['ids'] | ||
self.train_attr = train_attr | ||
elif train_val == 'query': | ||
self.train_data = query['data'] | ||
self.train_ids = query['ids'] | ||
self.train_attr = test_attr | ||
elif train_val == 'gallery': | ||
self.train_data = gallery['data'] | ||
self.train_ids = gallery['ids'] | ||
self.train_attr = test_attr | ||
else: | ||
print('Input should only be train or val') | ||
|
||
self.num_ids = len(self.train_ids) | ||
|
||
if transforms is None: | ||
if train_val == 'train': | ||
self.transforms = T.Compose([ | ||
T.Resize(size=(288, 144)), | ||
T.RandomHorizontalFlip(), | ||
T.ToTensor(), | ||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
else: | ||
self.transforms = T.Compose([ | ||
T.Resize(size=(288, 144)), | ||
T.ToTensor(), | ||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
|
||
def __getitem__(self, index): | ||
''' | ||
一次返回一张图片的数据 | ||
''' | ||
img_path = self.train_data[index][0] | ||
i = self.train_data[index][1] | ||
id = self.train_data[index][2] | ||
cam = self.train_data[index][3] | ||
label = np.asarray(self.train_attr[id]) | ||
data = Image.open(img_path) | ||
data = self.transforms(data) | ||
name = self.train_data[index][4] | ||
return data, i, label, id, cam, name | ||
|
||
def __len__(self): | ||
return len(self.train_data) | ||
|
||
def num_label(self): | ||
return self.num_labels | ||
|
||
def num_id(self): | ||
return self.num_ids | ||
|
||
def labels(self): | ||
return self.label | ||
|
||
|
||
|
||
class Test_Dataset(data.Dataset): | ||
def __init__(self, data_dir, dataset_name, transforms=None, query_gallery='query' ): | ||
train, query, gallery = import_MarketDuke_nodistractors(data_dir, dataset_name) | ||
|
||
if dataset_name == 'Market-1501': | ||
self.train_attr, self.test_attr, self.label = import_Market1501Attribute_binary(data_dir) | ||
elif dataset_name == 'DukeMTMC-reID': | ||
self.train_attr, self.test_attr, self.label = import_DukeMTMCAttribute_binary(data_dir) | ||
else: | ||
print('Input should only be Market1501 or DukeMTMC') | ||
|
||
if query_gallery == 'query': | ||
self.test_data = query['data'] | ||
self.test_ids = query['ids'] | ||
elif query_gallery == 'gallery': | ||
self.test_data = gallery['data'] | ||
self.test_ids = gallery['ids'] | ||
elif query_gallery == 'all': | ||
self.test_data = gallery['data'] + query['data'] | ||
self.test_ids = gallery['ids'] | ||
else: | ||
print('Input shoud only be query or gallery;') | ||
|
||
if transforms is None: | ||
self.transforms = T.Compose([ | ||
T.Resize(size=(288, 144)), | ||
T.ToTensor(), | ||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
|
||
def __getitem__(self, index): | ||
''' | ||
一次返回一张图片的数据 | ||
''' | ||
img_path = self.test_data[index][0] | ||
id = self.test_data[index][2] | ||
label = np.asarray(self.test_attr[id]) | ||
data = Image.open(img_path) | ||
data = self.transforms(data) | ||
name = self.test_data[index][4] | ||
return data, label, id, name | ||
|
||
def __len__(self): | ||
return len(self.test_data) | ||
|
||
def labels(self): | ||
return self.label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from .reiddataset_downloader import reiddataset_downloader | ||
from .reiddataset_downloader import reiddataset_downloader_all | ||
from .import_VIPeR import import_VIPeR | ||
from .import_CUHK01 import import_CUHK01 | ||
from .import_CUHK03 import import_CUHK03 | ||
from .import_Market1501 import import_Market1501 | ||
from .import_Market1501Attribute import import_Market1501Attribute | ||
from .import_Market1501Attribute import import_Market1501Attribute_binary | ||
from .import_DukeMTMC import import_DukeMTMC | ||
from .import_DukeMTMCAttribute import import_DukeMTMCAttribute | ||
from .import_DukeMTMCAttribute import import_DukeMTMCAttribute_binary | ||
from .import_MarketDuke import import_MarketDuke | ||
from .import_MarketDuke_nodistractors import import_MarketDuke_nodistractors | ||
from .pytorch_prepare import pytorch_prepare | ||
from .pytorch_prepare import pytorch_prepare_all | ||
from .marketduke_to_hdf5 import marketduke_to_hdf5 | ||
from .cuhk03_to_image import cuhk03_to_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import warnings | ||
warnings.filterwarnings('ignore','.*conversion.*') | ||
|
||
import os | ||
import zipfile | ||
import shutil | ||
import requests | ||
import h5py | ||
import numpy as np | ||
from PIL import Image | ||
import argparse | ||
|
||
def cuhk03_to_image(CUHK03_dir): | ||
|
||
f = h5py.File(os.path.join(CUHK03_dir,'cuhk-03.mat')) | ||
|
||
detected_labeled = ['detected','labeled'] | ||
print('converting') | ||
for data_type in detected_labeled: | ||
|
||
datatype_dir = os.path.join(CUHK03_dir, data_type) | ||
if not os.path.exists(datatype_dir): | ||
os.makedirs(datatype_dir) | ||
|
||
for campair in range(len(f[data_type][0])): | ||
campair_dir = os.path.join(datatype_dir,'P%d'%(campair+1)) | ||
cam1_dir = os.path.join(campair_dir,'cam1') | ||
cam2_dir = os.path.join(campair_dir,'cam2') | ||
|
||
if not os.path.exists(campair_dir): | ||
os.makedirs(campair_dir) | ||
if not os.path.exists(cam1_dir): | ||
os.makedirs(cam1_dir) | ||
if not os.path.exists(cam2_dir): | ||
os.makedirs(cam2_dir) | ||
|
||
for img_no in range(f[f[data_type][0][campair]].shape[0]): | ||
if img_no < 5: | ||
cam_dir = 'cam1' | ||
else: | ||
cam_dir = 'cam2' | ||
for person_id in range(f[f[data_type][0][campair]].shape[1]): | ||
img = np.array(f[f[f[data_type][0][campair]][img_no][person_id]]) | ||
if img.shape[0] !=2: | ||
img = np.transpose(img, (2,1,0)) | ||
im = Image.fromarray(img) | ||
im.save(os.path.join(campair_dir, cam_dir, "%d-%d.jpg"%(person_id+1,img_no+1))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import requests | ||
|
||
def gdrive_downloader(destination, id): | ||
URL = "https://docs.google.com/uc?export=download" | ||
|
||
session = requests.Session() | ||
|
||
response = session.get(URL, params = { 'id' : id }, stream = True) | ||
token = get_confirm_token(response) | ||
|
||
if token: | ||
params = { 'id' : id, 'confirm' : token } | ||
response = session.get(URL, params = params, stream = True) | ||
|
||
save_response_content(response, destination) | ||
|
||
def get_confirm_token(response): | ||
for key, value in response.cookies.items(): | ||
if key.startswith('download_warning'): | ||
return value | ||
|
||
return None | ||
|
||
def save_response_content(response, destination): | ||
CHUNK_SIZE = 32768 | ||
|
||
with open(destination, "wb") as f: | ||
for chunk in response.iter_content(CHUNK_SIZE): | ||
if chunk: # filter out keep-alive new chunks | ||
f.write(chunk) | ||
|
||
if __name__ == "__main__": | ||
var = raw_input("Please enter public file id : ") | ||
file_id = str(var) | ||
name = raw_input("Please enter name with extension : ") | ||
destination = str(name) | ||
gdrive_downloader(file_id, destination) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
from .reiddataset_downloader import * | ||
def import_CUHK01(dataset_dir): | ||
cuhk01_dir = os.path.join(dataset_dir,'CUHK01') | ||
|
||
if not os.path.exists(cuhk01_dir): | ||
print('Please Download the CUHK01 Dataset') | ||
|
||
file_list=os.listdir(cuhk01_dir) | ||
name_dict={} | ||
for name in file_list: | ||
if name[-3:]=='png': | ||
id = name[:4] | ||
if id not in name_dict: | ||
name_dict[id]=[] | ||
name_dict[id].append([]) | ||
name_dict[id].append([]) | ||
if int(name[-7:-4])<3: | ||
name_dict[id][0].append(os.path.join(cuhk01_dir,name)) | ||
else: | ||
name_dict[id][1].append(os.path.join(cuhk01_dir,name)) | ||
return name_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import os | ||
from .reiddataset_downloader import * | ||
def import_CUHK03(dataset_dir, detected = False): | ||
|
||
cuhk03_dir = os.path.join(dataset_dir,'CUHK03') | ||
|
||
if not os.path.exists(cuhk03_dir): | ||
Print('Please Download the CUHK03 Dataset') | ||
|
||
if not detected: | ||
cuhk03_dir = os.path.join(cuhk03_dir , 'labeled') | ||
else: | ||
cuhk03_dir = os.path.join(cuhk03_dir , 'detected') | ||
|
||
campair_list = os.listdir(cuhk03_dir) | ||
#campair_list = ['P1','P2','P3'] | ||
name_dict={} | ||
for campair in campair_list: | ||
cam1_list = [] | ||
cam1_list=os.listdir(os.path.join(cuhk03_dir,campair,'cam1')) | ||
cam2_list=os.listdir(os.path.join(cuhk03_dir,campair,'cam2')) | ||
for file in cam1_list: | ||
id = campair[1:]+'-'+file.split('-')[0] | ||
if id not in name_dict: | ||
name_dict[id]=[] | ||
name_dict[id].append([]) | ||
name_dict[id].append([]) | ||
name_dict[id][0].append(os.path.join(cuhk03_dir,campair,'cam1',file)) | ||
for file in cam2_list: | ||
id = campair[1:]+'-'+file.split('-')[0] | ||
if id not in name_dict: | ||
name_dict[id]=[] | ||
name_dict[id].append([]) | ||
name_dict[id].append([]) | ||
name_dict[id][1].append(os.path.join(cuhk03_dir,campair,'cam2',file)) | ||
return name_dict | ||
|
||
def cuhk03_test(data_dir): | ||
CUHK03_dir = os.path.join(data_dir , 'CUHK03') | ||
f = h5py.File(os.path.join(CUHK03_dir,'cuhk-03.mat')) | ||
test = [] | ||
for i in range(20): | ||
test_set = (np.array(f[f['testsets'][0][i]],dtype='int').T).tolist() | ||
test.append(test_set) | ||
|
||
return test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import os | ||
from .reiddataset_downloader import* | ||
def import_DukeMTMC(dataset_dir): | ||
dukemtmc_dir = os.path.join(dataset_dir, 'DukeMTMC-reID') | ||
if not os.path.exists(dukemtmc_dir): | ||
print('Please Download the DukMTMC Dataset') | ||
data_group = ['train','query','gallery'] | ||
for group in data_group: | ||
if group == 'train': | ||
name_dir = os.path.join(dukemtmc_dir , 'bounding_box_train') | ||
elif group == 'query': | ||
name_dir = os.path.join(dukemtmc_dir, 'query') | ||
else: | ||
name_dir = os.path.join(dukemtmc_dir, 'bounding_box_test') | ||
file_list=os.listdir(name_dir) | ||
globals()[group]={} | ||
for name in file_list: | ||
if name[-3:]=='jpg': | ||
id = name.split('_')[0] | ||
if id not in globals()[group]: | ||
globals()[group][id]=[] | ||
globals()[group][id].append([]) | ||
globals()[group][id].append([]) | ||
globals()[group][id].append([]) | ||
globals()[group][id].append([]) | ||
globals()[group][id].append([]) | ||
globals()[group][id].append([]) | ||
globals()[group][id].append([]) | ||
globals()[group][id].append([]) | ||
cam_n = int(name.split('_')[1][1])-1 | ||
globals()[group][id][cam_n].append(os.path.join(name_dir,name)) | ||
return train, query, gallery |
Oops, something went wrong.