forked from layumi/Person_reID_baseline_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_mydata_rand.py
102 lines (90 loc) · 3.15 KB
/
prepare_mydata_rand.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import shutil
from shutil import copyfile
import re
import numpy as np
import math
download_path = '../my_data'
train_path = download_path + '/train_set'
tag_file = download_path + '/train_list.txt'
query_path_org = download_path + '/query_a'
gallery_path_org = download_path + '/gallery_a'
query_list = download_path + '/query_a_list.txt'
pytorch_path = download_path + '/pytorch'
train_save_path = download_path + '/pytorch/train_all'
val_save_path = download_path + '/pytorch/val'
query_save_path = download_path + '/pytorch/query'
gallery_save_path = download_path + '/pytorch/gallery'
if not os.path.isdir( pytorch_path ):
os.mkdir( pytorch_path )
else:
shutil.rmtree(pytorch_path)
os.mkdir(pytorch_path)
if not os.path.isdir( train_save_path ):
os.mkdir( train_save_path )
if not os.path.isdir( val_save_path ):
os.mkdir( val_save_path )
if not os.path.isdir( query_save_path ):
os.mkdir( query_save_path )
if not os.path.isdir( gallery_save_path ):
os.mkdir( gallery_save_path )
MAX_LINE =25000
N1 = 300
# train_all & val & query & gallery
train_stat = {}
img_list = []
ID_last = '-1'
num = 0
drop_limit = 3
tag_file = download_path + '/train_list.txt'
fp = open( tag_file, 'r' )
for i in range(MAX_LINE):
line = fp.readline()
line = line.strip().replace( '\r', '' ).replace( '\n', '' )
if (line != ""):
name, ID = list( re.findall( r"train/(.+?) (.*)", line )[0] )
# name, ID = line.split( ' ' )
if name == "" or ID == "" or not name[-3:] == 'png':
continue
if ID != ID_last:
if ID_last!= '-1':
if num >= drop_limit:
train_stat.update( {ID_last: [num, img_list]} )
num = 1
ID_last = ID
img_list=[]
img_list.append( name )
else:
num += 1
img_list.append( name )
if num >= drop_limit:
train_stat.update({ ID: [num, img_list]} )
#print(train_stat)
i = 0
for ID in train_stat:
i += 1
c_num = train_stat[ID][0]
#if c_num > 1 and c_num <= N1:
if i%5 != 0: #train
for name in train_stat[ID][1]:
src_path = train_path + '/' + name
dst_path = train_save_path + '/' + ID
val_dst_path = val_save_path + '/' + ID
if not os.path.isdir(dst_path):
os.mkdir(dst_path)
copyfile(src_path, dst_path + '/' + name) # train set
if not os.path.isdir(val_dst_path): # val set, only one image
os.mkdir(val_dst_path)
copyfile(src_path, val_dst_path + '/' + name)
#elif c_num > N1:
else:
for name in train_stat[ID][1]:
src_path = train_path + '/' + name
gallery_dst_path = gallery_save_path + '/' + ID
query_dst_path = query_save_path + '/' + ID
if not os.path.isdir(gallery_dst_path):
os.mkdir(gallery_dst_path)
copyfile(src_path, gallery_dst_path + '/' + name) # gallery set
if not os.path.isdir(query_dst_path): # query set, only one image
os.mkdir(query_dst_path)
copyfile(src_path, query_dst_path + '/' + name)