-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTradeEnv.py
421 lines (404 loc) · 19.2 KB
/
TradeEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
import gym
import pandas as pd
import plotly as py
import plotly.graph_objs as go
from gym import spaces
from datetime import datetime
from Util.Util import *
import numpy as np
import wandb
"""
日间择时,开盘或收盘交易
"""
# noinspection PyAttributeOutsideInit
class TradeEnv(gym.Env):
def __init__(self, stock_data_path, start_episode=0, episode_len=720, obs_time_size='60 day',
obs_delta_frequency='1 day',
sim_delta_time='1 day', stock_codes='000938_XSHE',
result_path="E:/运行结果/train/", principal=1e5, origin_stock_amount=0, poundage_rate=5e-3,
time_format="%Y-%m-%d", auto_open_result=False, reward_verbose=1,
post_processor=None, start_index_bound=None, end_index_bound=None, trade_time='open', mode='test',
agent_state=True):
"""
:param start_episode: 起始episode
:param episode_len: episode长度
:param sim_delta_time: 最小交易频率('x min')
:param stock_codes: 股票代码
:param stock_data_path: 数据路径
:param result_path: 绘图结果保存路径
:param principal: 初始资金
:param origin_stock_amount:初始股票股数
:param poundage_rate: 手续费率
:param time_format: 数据时间格式 str
:param auto_open_result: 是否自动打开结果
:param reward_verbose: 0,1,2 不绘制reward,绘制单个reward(覆盖),绘制所有episode的reward
:param trade_time: 交易时间,open/close
:param mode: 环境模式,train/test, train模式下会使用wandb记录日志
:param agent_state: 是否添加agent状态(资金、头寸)到环境状态中
:return:
"""
super(TradeEnv, self).__init__()
self.delta_time = int(sim_delta_time[0:-4])
self.stock_codes = stock_codes
self.stock_data_path = stock_data_path
self.result_path = result_path
self.principal = principal
self.origin_stock_amount = origin_stock_amount
self.poundage_rate = poundage_rate
self.time_format = time_format
self.auto_open_result = auto_open_result
self.episode = start_episode
self.episode_len = episode_len
self.stock_datas = {stock_code: self.read_stock_data(stock_code) for stock_code in self.stock_codes}
self.reward_verbose = reward_verbose
self.obs_time = int(obs_time_size[0:-4])
self.obs_delta_frequency = int(obs_delta_frequency[0:-4])
self.action_space = spaces.Box(low=np.array([-1]), high=np.array([1]))
self.agent_state = agent_state
if agent_state:
self.observation_space = spaces.Box(
low=np.array([float('-inf') for _ in range(26 * (self.obs_time // self.obs_delta_frequency))] + [0, 0]),
high=np.array([float('inf') for _ in range(26 * (self.obs_time // self.obs_delta_frequency) + 2)]))
else:
self.observation_space = spaces.Box(
low=np.array([float('-inf') for _ in range(26 * (self.obs_time // self.obs_delta_frequency))]),
high=np.array([float('inf') for _ in range(26 * (self.obs_time // self.obs_delta_frequency))]))
self.step_ = 0
self.post_processor = post_processor
assert trade_time == "open" or trade_time == "close"
self.trade_time = trade_time
assert mode == "train" or mode == "test" or mode == "eval"
self.mode = mode
self.start_index_bound = self.obs_time
if start_index_bound is not None:
assert self.start_index_bound <= start_index_bound
self.start_index_bound = start_index_bound
self.customize_end_index_bound = end_index_bound
def seed(self, seed=None):
np.random.seed(seed)
def reset(self):
if self.mode == 'eval':
self.stock_code = self.stock_codes[self.episode%len(self.stock_codes)]
else:
# 随机选择一只股票
self.stock_code = np.random.choice(self.stock_codes)
# 读取预存的数据
self.stock_data, self.keys = self.stock_datas[self.stock_code]
# 设置终止边界
if self.customize_end_index_bound is not None:
self.end_index_bound = len(self.stock_data) + self.customize_end_index_bound
else:
self.end_index_bound = len(self.stock_data) - self.episode_len
assert self.start_index_bound < self.end_index_bound
# 随机初始化时间
self.current_time = \
np.random.choice(np.array(list(self.stock_data.keys()))[self.start_index_bound:self.end_index_bound], 1)[0]
self.index = self.keys.index(self.current_time)
self.episode_end_index = self.index + self.episode_len if self.index + self.episode_len < len(self.keys) else (
len(self.keys) - 1)
self.done = False
self.money = self.principal
# 持有股票数目(股)
self.stock_amount = self.origin_stock_amount
# 交易历史
self.trade_history = []
self.episode += 1
self.step_ = 0
self.start_time = self.current_time
return self.get_state()
def step(self, action):
if self.step_ > self.episode_len:
self.done = True
self.step_ += 1
quant = 0
# 记录交易时间
temp_time = self.current_time
# 惩罚标记
flag = False
# 交易标记
traded = False
# 当前(分钟)每股收盘/开盘价作为price
if self.trade_time == 'close':
price = self.stock_data[self.current_time][1]
elif self.trade_time == 'open':
price = self.stock_data[self.current_time][0]
# 买入
action = np.squeeze(action)
action = [action]
if action[0] > 0:
# 按钱数百分比买入
# 当前的钱可以买多少手
amount = self.money // (100 * price * (1 + self.poundage_rate))
# 实际买多少手
quant = int(action[0] * amount)
if quant == 0:
# print("钱数:" + str(self.money) + "不足以购买一手股票")
flag = True
else:
traded = True
# 卖出
elif action[0] < 0:
# 当前手中有多少手
amount = self.stock_amount / 100
if amount == 0:
flag = True
else:
# 实际卖出多少手
quant = int(action[0] * amount)
if quant == 0:
flag = True
else:
traded = True
# 记录交易前数据
old_money = self.money
old_amount = self.stock_amount
# 钱数-=每股价格*100*交易手数+手续费
self.money = self.money - price * 100 * quant - abs(price * 100 * quant * self.poundage_rate)
# 如果因为action + 随机数导致money<0 则取消交易
if self.money < 0:
self.money = old_money
quant = 0
traded = False
else:
# 股票数
self.stock_amount += 100 * quant
# 如果因为action + 随机数导致amount<0 则取消交易
if self.stock_amount < 0:
self.money = old_money
self.stock_amount = old_amount
quant = 0
traded = False
# 计算下一状态和奖励
# 如果采用t+1结算 and 交易了 则跳到下一天
self.set_next_day()
# 先添加到历史中,reward为空
self.trade_history.append(
[temp_time, price, quant, self.stock_amount, self.money, None, action[0]])
reward = self.get_reward()
# 修改历史记录中的reward
self.trade_history[-1][5] = reward
return self.get_state(), reward, self.done, {}
def get_value(self, last_hist):
last_value = last_hist[1] * last_hist[3] + last_hist[4]
return last_value
def get_reward(self):
# # 检验过去20个step是否一直交易量为0
# if len(self.trade_history) > 20:
# his = np.array(self.trade_history)
# if his[-21:-1, 2] == 0:
# return -1
now_hist = self.trade_history[-1]
now_value = self.get_value(now_hist)
now_price = now_hist[1]
if len(self.trade_history) >= 2:
last_hist = self.trade_history[-2]
last_value = self.get_value(last_hist)
last_price = last_hist[1]
else:
last_value = self.principal
if self.trade_time == 'close':
last_price = self.stock_data[self.start_time][1]
elif self.trade_time == 'open':
last_price = self.stock_data[self.start_time][0]
last_value += self.origin_stock_amount * last_price
if last_value == 0:
last_value = self.principal
reward = (((now_value - last_value) / last_value) - ((now_price - last_price) / last_price)) * 100
# if len(self.trade_history) > 10 and (np.array(self.trade_history[-10:])[:,2]==0).all():
# reward -= 1
return reward
def render(self, mode='auto'):
# if mode == "manual" or self.step_ >= self.episode_len or self.done:
if mode == 'manual' or (self.step_ != 0 and self.step_ % 20 == 0):
return self.draw()
def get_last_time(self):
assert self.index - 1 >= 0
index = self.index - 1
return self.keys[index], index
def set_next_day(self):
index = self.index
if index + self.delta_time <= self.keys.__len__() - 1:
self.current_time = self.keys[index + self.delta_time]
self.index += self.delta_time
else:
self.done = True
def read_stock_data(self, stock_code):
raw = pd.read_csv(self.stock_data_path + stock_code + '_with_indicator.csv', index_col=False)
raw = raw.dropna(axis=0, how='any')
data = np.array(raw)
data = fill_inf(data)
res = {}
time_list = []
for i in range(0, data.shape[0]):
line = data[i, :]
date = datetime.strptime(str(line[0]), self.time_format)
res[date] = line[1:]
time_list.append(date)
return res, time_list
def get_state(self):
stock_state = []
# 回溯历史股票状态
# 从上一时刻的价格开始回溯
time, index = self.get_last_time()
if time is None:
self.done = True
return None
for _ in range(self.obs_time // self.obs_delta_frequency):
stock_state.append(self.stock_data[time].tolist())
index -= self.obs_delta_frequency
if index >= 0:
time = self.keys[index]
else:
time = self.keys[0]
self.done = True
stock_state = np.flip(stock_state, axis=0)
state = stock_state.astype(np.float32)
state = state.flatten()
if self.agent_state:
state = np.append(state, np.array([self.money, self.stock_amount]))
if self.post_processor is not None:
state = self.post_processor(state, self.agent_state)
return state
def draw(self):
if self.trade_history.__len__() <= 1:
return
his = np.array(self.trade_history)
time_list = np.squeeze(his[:, 0]).tolist()
profit_list = np.squeeze(
(his[:, 4].astype(np.float32) + his[:, 1].astype(np.float32) * his[:, 3].astype(
np.float32) - self.principal) / self.principal).tolist()
price_list = np.squeeze(his[:, 1]).tolist()
quant_list = np.squeeze(his[:, 2]).tolist()
amount_list = np.squeeze(his[:, 3]).tolist()
reward_list = np.squeeze(his[:, 5]).tolist()
price_array = np.array(price_list).astype(np.float32)
base_list = ((price_array - price_array[0]) / price_array[0]).tolist()
dis = self.result_path
path = dis + ("episode_" + str(self.episode - 1) + ".html").replace(':', "_")
if not os.path.exists(dis):
os.makedirs(dis)
profit_scatter = go.Scatter(x=time_list,
y=profit_list,
name='RL',
line=dict(color='red'),
mode='lines')
base_scatter = go.Scatter(x=time_list,
y=base_list,
name='Base',
line=dict(color='blue'),
mode='lines')
price_scatter = go.Scatter(x=time_list,
y=price_list,
name='股价',
line=dict(color='orange'),
mode='lines',
xaxis='x',
yaxis='y2',
opacity=1)
trade_bar = go.Bar(x=time_list,
y=quant_list,
name='交易量(手)',
marker_color='#000099',
xaxis='x',
yaxis='y3',
opacity=0.5)
amount_scatter = go.Scatter(x=time_list,
y=amount_list,
name='持股数量(手)',
line=dict(color='rgba(0,204,255,0.4)'),
mode='lines',
fill='tozeroy',
fillcolor='rgba(0,204,255,0.2)',
xaxis='x',
yaxis='y4',
opacity=0.6)
conf = {
"data": [profit_scatter, base_scatter, price_scatter, trade_bar,
amount_scatter],
"layout": go.Layout(
title=self.stock_code + " 回测结果" + " 初始资金:" + str(
self.principal) + " 初始股票总量(股):" + str(
self.origin_stock_amount),
xaxis=dict(title='日期', type="category", showgrid=False, zeroline=False),
yaxis=dict(title='收益率', showgrid=False, zeroline=False, titlefont={'color': 'red'},
tickfont={'color': 'red'}),
yaxis2=dict(title='股价', overlaying='y', side='right',
titlefont={'color': 'orange'}, tickfont={'color': 'orange'},
showgrid=False,
zeroline=False),
yaxis3=dict(title='交易量', overlaying='y', side='right',
titlefont={'color': '#000099'}, tickfont={'color': '#000099'},
showgrid=False, position=0.97, zeroline=False, anchor='free'),
yaxis4=dict(title='持股量', overlaying='y', side='left',
titlefont={'color': '#00ccff'}, tickfont={'color': '#00ccff'},
showgrid=False, position=0.03, zeroline=False, anchor='free'),
paper_bgcolor='#000000',
plot_bgcolor='#000000'
)
}
if self.mode != "train":
conf = {
"data": [profit_scatter, base_scatter, price_scatter, trade_bar,
amount_scatter],
"layout": go.Layout(
title=self.stock_code + " 回测结果" + " 初始资金:" + str(
self.principal) + " 初始股票总量(股):" + str(
self.origin_stock_amount),
xaxis=dict(title='日期', type="category", showgrid=False, zeroline=False),
yaxis=dict(title='收益率', showgrid=False, zeroline=False, titlefont={'color': 'red'},
tickfont={'color': 'red'}),
yaxis2=dict(title='股价', overlaying='y', side='right',
titlefont={'color': 'orange'}, tickfont={'color': 'orange'},
showgrid=False,
zeroline=False),
yaxis3=dict(title='交易量', overlaying='y', side='right',
titlefont={'color': '#000099'}, tickfont={'color': '#000099'},
showgrid=False, position=0.97, zeroline=False, anchor='free'),
yaxis4=dict(title='持股量', overlaying='y', side='left',
titlefont={'color': '#00ccff'}, tickfont={'color': '#00ccff'},
showgrid=False, position=0.03, zeroline=False, anchor='free')
)
}
py.offline.plot(conf, auto_open=self.auto_open_result, filename=path)
episode_path = path
if self.reward_verbose != 0:
reward_scatter = go.Scatter(x=[i for i in range(len(reward_list))],
y=reward_list,
name='reward',
line=dict(color='orange'),
mode='lines',
opacity=1)
if self.reward_verbose == 1:
path = dis + "reward.html".format(self.episode - 1)
else:
path = dis + "reward_{}.html".format(self.episode - 1)
conf = {
"data": [reward_scatter],
"layout": go.Layout(
title="reward",
xaxis=dict(title='训练次数', showgrid=False, zeroline=False, titlefont={'color': 'white'},
tickfont={'color': 'white'}),
yaxis=dict(title='reward', showgrid=False, zeroline=False, titlefont={'color': 'orange'},
tickfont={'color': 'orange'}),
paper_bgcolor='#000000',
plot_bgcolor='#000000'
)
}
if self.mode != "train":
conf = {
"data": [reward_scatter],
"layout": go.Layout(
title="reward",
xaxis=dict(title='训练次数', showgrid=False, zeroline=False, titlefont={'color': 'white'},
tickfont={'color': 'white'}),
yaxis=dict(title='reward', showgrid=False, zeroline=False, titlefont={'color': 'orange'},
tickfont={'color': 'orange'})
)
}
py.offline.plot(conf, auto_open=self.auto_open_result, filename=path)
if self.mode == 'train':
wandb.log({"episode": wandb.Html(open(episode_path))}, sync=False)
os.remove(episode_path)
wandb.log({"episode_reward": wandb.Html(open(path))}, sync=False)
os.remove(path)
return profit_list, base_list