-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_model.py
executable file
·65 lines (51 loc) · 2.24 KB
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
import argparse
import numpy as np
from lib import environ, data, models
import torch
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.pyplot as plt
EPSILON = 0.02
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data", required=True, help="CSV file with quotes to run the model")
parser.add_argument("-m", "--model", required=True, help="Model file to load")
parser.add_argument("-b", "--bars", type=int, default=50, help="Count of bars to feed into the model")
parser.add_argument("-n", "--name", required=True, help="Name to use in output images")
parser.add_argument("--commission", type=float, default=0.1, help="Commission size in percent, default=0.1")
parser.add_argument("--conv", default=False, action="store_true", help="Use convolution model instead of FF")
args = parser.parse_args()
prices = data.load_relative(args.data)
env = environ.StocksEnv({"TEST": prices}, bars_count=args.bars, reset_on_close=False, commission=args.commission,
state_1d=args.conv, random_ofs_on_reset=False, reward_on_close=False, volumes=False)
if args.conv:
net = models.DQNConv1D(env.observation_space.shape, env.action_space.n)
else:
net = models.SimpleFFDQN(env.observation_space.shape[0], env.action_space.n)
net.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage))
obs = env.reset()
start_price = env._state._cur_close()
total_reward = 0.0
step_idx = 0
rewards = []
while True:
step_idx += 1
obs_v = torch.tensor([obs])
out_v = net(obs_v)
action_idx = out_v.max(dim=1)[1].item()
if np.random.random() < EPSILON:
action_idx = env.action_space.sample()
action = environ.Actions(action_idx)
obs, reward, done, _ = env.step(action_idx)
total_reward += reward
rewards.append(total_reward)
if step_idx % 100 == 0:
print("%d: reward=%.3f" % (step_idx, total_reward))
if done:
break
plt.clf()
plt.plot(rewards)
plt.title("Total reward, data=%s" % args.name)
plt.ylabel("Reward, %")
plt.savefig("rewards-%s.png" % args.name)