-
Notifications
You must be signed in to change notification settings - Fork 21
/
shmetro.py
129 lines (120 loc) · 5.56 KB
/
shmetro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# link: https://github.com/ivechan/PVCGN
import util
import pickle
import pandas as pd
import json
outputdir = 'output/SHMETRO'
util.ensure_dir(outputdir)
dataurl = 'input/SHMetro/'
dataname = outputdir + '/SHMETRO'
data = {}
for category in ['train', 'val', 'test']:
cat_data = pickle.load(open(dataurl + category + '.pkl', "rb"))
data['x_' + category] = cat_data['x']
data['xtime_' + category] = cat_data['xtime']
data['y_' + category] = cat_data['y']
data['ytime_' + category] = cat_data['ytime']
cor = pickle.load(open(dataurl + 'graph_sh_cor.pkl', "rb"))
conn = pickle.load(open(dataurl + 'graph_sh_conn.pkl', "rb"))
sml = pickle.load(open(dataurl + 'graph_sh_sml.pkl', "rb"))
geo = []
for i in range(conn.shape[0]):
geo.append([i, 'Point', '[]'])
geo = pd.DataFrame(geo, columns=['geo_id', 'type', 'coordinates'])
geo.to_csv(dataname + '.geo', index=False)
rel = []
reldict = dict()
rel_id = 0
for i in range(conn.shape[0]):
for j in range(conn.shape[1]):
rel.append([rel_id, 'geo', i, j, conn[i][j], sml[i][j], cor[i][j]])
rel_id += 1
rel = pd.DataFrame(rel,
columns=['rel_id', 'type', 'origin_id', 'destination_id', 'connection', 'similarity', 'correlation'])
rel.to_csv(dataname + '.rel', index=False)
dyna_id = 0
dyna_file = open(dataname + '.dyna', 'w')
dyna_file.write('dyna_id' + ',' + 'type' + ',' + 'time' + ',' + 'entity_id' + ',' + 'inflow' + ',' + 'outflow' + '\n')
date_set = {}
for i in range(data['x_train'].shape[2]):
for date in range(62):
for j in range(66):
t = date * 66 + j
time = str(data["xtime_train"][t][0]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['x_train'][t][0][i][0]) + ',' + str(data['x_train'][t][0][i][1]) + '\n')
dyna_id += 1
t = date * 66 + 62
for k in range(4):
time = str(data["ytime_train"][t][k]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['y_train'][t][k][i][0]) + ',' + str(data['y_train'][t][k][i][0]) + '\n')
dyna_id += 1
t = date * 66 + 65
for k in range(1, 4):
time = str(data["ytime_train"][t][k]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['y_train'][t][k][i][0]) + ',' + str(data['y_train'][t][k][i][0]) + '\n')
dyna_id += 1
for date in range(9):
for j in range(66):
t = date * 66 + j
time = str(data["xtime_val"][t][0]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['x_val'][t][0][i][0]) + ',' + str(data['x_val'][t][0][i][1]) + '\n')
dyna_id += 1
t = date * 66 + 62
for k in range(4):
time = str(data["ytime_val"][t][k]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['y_val'][t][k][i][0]) + ',' + str(data['y_val'][t][k][i][0]) + '\n')
dyna_id += 1
t = date * 66 + 65
for k in range(1, 4):
time = str(data["ytime_val"][t][k]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['y_val'][t][k][i][0]) + ',' + str(data['y_val'][t][k][i][0]) + '\n')
dyna_id += 1
for date in range(21):
for j in range(66):
t = date * 66 + j
time = str(data["xtime_test"][t][0]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['x_test'][t][0][i][0]) + ',' + str(data['x_test'][t][0][i][1]) + '\n')
dyna_id += 1
t = date * 66 + 62
for k in range(4):
time = str(data["ytime_test"][t][k]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['y_test'][t][k][i][0]) + ',' + str(data['y_test'][t][k][i][0]) + '\n')
dyna_id += 1
t = date * 66 + 65
for k in range(1, 4):
time = str(data["ytime_test"][t][k]).split('.')[0] + 'Z'
dyna_file.write(str(dyna_id) + ',' + 'state' + ',' + time + ',' + str(i) + ',' + str(
data['y_test'][t][k][i][0]) + ',' + str(data['y_test'][t][k][i][0]) + '\n')
dyna_id += 1
dyna_file.close()
config = dict()
config['geo'] = dict()
config['geo']['including_types'] = ['Point']
config['geo']['Point'] = {}
config['rel'] = dict()
config['rel']['including_types'] = ['geo']
config['rel']['geo'] = {'connection': 'num', 'similarity': 'num', 'correlation': 'num'}
config['dyna'] = dict()
config['dyna']['including_types'] = ['state']
config['dyna']['state'] = {'entity_id': 'geo_id', 'inflow': 'num', 'outflow': 'num'}
config['info'] = dict()
config['info']['data_col'] = ['inflow', 'outflow']
config['info']['weight_col'] = ['connection']
config['info']['data_files'] = ['SHMETRO']
config['info']['geo_file'] = 'SHMETRO'
config['info']['rel_file'] = 'SHMETRO'
config['info']['output_dim'] = 2
config['info']['time_intervals'] = 900
config['info']['init_weight_inf_or_zero'] = 'inf'
config['info']['set_weight_link_or_dist'] = 'dist'
config['info']['calculate_weight_adj'] = False
config['info']['weight_adj_epsilon'] = 0.1
json.dump(config, open(outputdir + '/config.json', 'w', encoding='utf-8'), ensure_ascii=False)