-
Notifications
You must be signed in to change notification settings - Fork 21
/
sz_taxi.py
84 lines (70 loc) · 2.78 KB
/
sz_taxi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# link: https://github.com/lehaifeng/T-GCN/tree/master/data
import numpy as np
import pandas as pd
import json
import util
outputdir = 'output/SZ_TAXI'
util.ensure_dir(outputdir)
dataurl = 'input/SZ-TAXI/'
dataname = outputdir+'/SZ_TAXI'
geo_id_list = pd.read_csv(dataurl+'sz_speed.csv', header=None, nrows=1)
geo_id_list = np.array(geo_id_list)[0]
geo = []
for i in range(len(geo_id_list)):
geo.append([geo_id_list[i], 'LineString', '[[]]'])
geo = pd.DataFrame(geo, columns=['geo_id', 'type', 'coordinates'])
geo.to_csv(dataname+'.geo', index=False)
sz_adj = pd.read_csv(dataurl+'sz_adj.csv', header=None)
adj = np.mat(sz_adj)
rel = []
rel_id_counter = 0
for i in range(adj.shape[0]):
for j in range(adj.shape[1]):
rel.append([rel_id_counter, 'geo', geo_id_list[i], geo_id_list[j], adj[i, j]])
rel_id_counter += 1
rel = pd.DataFrame(rel, columns=['rel_id', 'type', 'origin_id', 'destination_id', 'link_weight'])
rel.to_csv(dataname + '.rel', index=False)
sz_speed = pd.read_csv(dataurl+'sz_speed.csv')
speed = np.mat(sz_speed)
dyna = []
dyna_id_counter = 0
def num2time(num):
day = num // 96 + 1
hour = (num % 96) // 4
quarter = num % 4
day = str(day) if day > 9 else '0' + str(day)
hour = str(hour) if hour > 9 else '0' + str(hour)
minute = str(15*quarter) if 15*quarter > 9 else '0' + str(15*quarter)
time = '2015-01-' + day + 'T' + hour + ':' + minute + ':' + '00Z'
return time
for j in range(speed.shape[1]):
for i in range(speed.shape[0]):
time = num2time(i)
# dyna_id, type, time, entity_id, traffic_speed
dyna.append([dyna_id_counter, 'state', time, geo_id_list[j], speed[i, j]])
dyna_id_counter += 1
dyna = pd.DataFrame(dyna, columns=['dyna_id', 'type', 'time', 'entity_id', 'traffic_speed'])
dyna.to_csv(dataname + '.dyna', index=False)
config = dict()
config['geo'] = dict()
config['geo']['including_types'] = ['LineString']
config['geo']['LineString'] = {}
config['rel'] = dict()
config['rel']['including_types'] = ['geo']
config['rel']['geo'] = {'link_weight': 'num'}
config['dyna'] = dict()
config['dyna']['including_types'] = ['state']
config['dyna']['state'] = {'entity_id': 'geo_id', 'traffic_speed': 'num'}
config['info'] = dict()
config['info']['data_col'] = ['traffic_speed']
config['info']['weight_col'] = 'link_weight'
config['info']['data_files'] = ['SZ_TAXI']
config['info']['geo_file'] = 'SZ_TAXI'
config['info']['rel_file'] = 'SZ_TAXI'
config['info']['output_dim'] = 1
config['info']['time_intervals'] = 900
config['info']['init_weight_inf_or_zero'] = 'inf'
config['info']['set_weight_link_or_dist'] = 'dist'
config['info']['calculate_weight_adj'] = False
config['info']['weight_adj_epsilon'] = 0.1
json.dump(config, open(outputdir+'/config.json', 'w', encoding='utf-8'), ensure_ascii=False)