-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
529 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.git | ||
__pycache__ | ||
.ipynb_checkpoints | ||
|
||
models | ||
runs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Q-Trader | ||
|
||
** Use in your own risk ** | ||
|
||
Pytorch implmentation from q-trader(https://github.com/edwardhdlu/q-trader) | ||
|
||
## Results | ||
|
||
Some examples of results on test sets: | ||
|
||
![HSI2018](images/%5EHSI_2018.png) | ||
Starting Capital: $100,000. | ||
HSI, 2017-2018. Profit of $10702.13. | ||
|
||
## Running the Code | ||
|
||
To train the model, download a training and test csv files from [Yahoo! Finance](https://ca.finance.yahoo.com/quote/%5EGSPC/history?p=%5EGSPC) into `data/` | ||
``` | ||
mkdir models | ||
python train ^GSPC 10 1000 | ||
``` | ||
|
||
Then when training finishes (minimum 200 episodes for results): | ||
``` | ||
jupyter notebook -> visualize.ipynb | ||
``` | ||
|
||
## References | ||
|
||
[Deep Q-Learning with Keras and Gym](https://keon.io/deep-q-learning/) - Q-learning overview and Agent skeleton code |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from agent.memory import Transition, ReplayMemory | ||
from agent.model import DQN | ||
|
||
import numpy as np | ||
import random | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
import os | ||
|
||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
device = torch.device("cpu") | ||
# print(device) | ||
|
||
class Agent: | ||
def __init__(self, state_size, is_eval=False): | ||
self.state_size = state_size # normalized previous days | ||
self.action_size = 3 # sit, buy, sell | ||
self.memory = ReplayMemory(10000) | ||
self.inventory = [] | ||
self.is_eval = is_eval | ||
|
||
self.gamma = 0.95 | ||
self.epsilon = 1.0 | ||
self.epsilon_min = 0.01 | ||
self.epsilon_decay = 0.99995 | ||
self.batch_size = 32 | ||
if os.path.exists('models/target_model'): | ||
self.policy_net = torch.load('models/policy_model', map_location=device) | ||
self.target_net = torch.load('models/target_model', map_location=device) | ||
else: | ||
self.policy_net = DQN(state_size, self.action_size) | ||
self.target_net = DQN(state_size, self.action_size) | ||
self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=0.005, momentum=0.9) | ||
|
||
def act(self, state): | ||
if not self.is_eval and np.random.rand() <= self.epsilon: | ||
return random.randrange(self.action_size) | ||
|
||
tensor = torch.FloatTensor(state).to(device) | ||
options = self.target_net(tensor) | ||
return np.argmax(options[0].detach().numpy()) | ||
|
||
def optimize(self): | ||
if len(self.memory) < self.batch_size: | ||
return | ||
transitions = self.memory.sample(self.batch_size) | ||
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for | ||
# detailed explanation). This converts batch-array of Transitions | ||
# to Transition of batch-arrays. | ||
batch = Transition(*zip(*transitions)) | ||
|
||
# Compute a mask of non-final states and concatenate the batch elements | ||
# (a final state would've been the one after which simulation ended) | ||
next_state = torch.FloatTensor(batch.next_state).to(device) | ||
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, next_state))) | ||
non_final_next_states = torch.cat([s for s in next_state if s is not None]) | ||
state_batch = torch.FloatTensor(batch.state).to(device) | ||
action_batch = torch.LongTensor(batch.action).to(device) | ||
reward_batch = torch.FloatTensor(batch.reward).to(device) | ||
|
||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the | ||
# columns of actions taken. These are the actions which would've been taken | ||
# for each batch state according to policy_net | ||
|
||
# print(state_batch.shape, action_batch.shape, self.batch_size) | ||
state_action_values = self.policy_net(state_batch).reshape((self.batch_size, 3)).gather(1, action_batch.reshape((self.batch_size, 1))) | ||
|
||
# Compute V(s_{t+1}) for all next states. | ||
# Expected values of actions for non_final_next_states are computed based | ||
# on the "older" target_net; selecting their best reward with max(1)[0]. | ||
# This is merged based on the mask, such that we'll have either the expected | ||
# state value or 0 in case the state was final. | ||
next_state_values = torch.zeros(self.batch_size, device=device) | ||
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach() | ||
# Compute the expected Q values | ||
expected_state_action_values = (next_state_values * self.gamma) + reward_batch | ||
|
||
# Compute Huber loss | ||
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) | ||
|
||
# Optimize the model | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
for param in self.policy_net.parameters(): | ||
param.grad.data.clamp_(-1, 1) | ||
self.optimizer.step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import random | ||
from collections import namedtuple | ||
|
||
Transition = namedtuple('Transition', | ||
('state', 'action', 'next_state', 'reward')) | ||
|
||
|
||
class ReplayMemory(object): | ||
|
||
def __init__(self, capacity): | ||
self.capacity = capacity | ||
self.memory = [] | ||
self.position = 0 | ||
|
||
def push(self, *args): | ||
"""Saves a transition.""" | ||
if len(self.memory) < self.capacity: | ||
self.memory.append(None) | ||
self.memory[self.position] = Transition(*args) | ||
self.position = (self.position + 1) % self.capacity | ||
|
||
def sample(self, batch_size): | ||
return random.sample(self.memory, batch_size) | ||
|
||
def __len__(self): | ||
return len(self.memory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
class DQN(nn.Module): | ||
def __init__(self, state_size, action_size): | ||
super(DQN, self).__init__() | ||
self.main = nn.Sequential( | ||
nn.Linear(state_size, 64), | ||
nn.LeakyReLU(0.01, inplace=True), | ||
nn.Linear(64, 32), | ||
nn.LeakyReLU(0.01, inplace=True), | ||
nn.Linear(32, 8), | ||
nn.LeakyReLU(0.01, inplace=True), | ||
nn.Linear(8, action_size), | ||
) | ||
|
||
def forward(self, input): | ||
return self.main(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import torch | ||
import numpy as np | ||
import matplotlib | ||
matplotlib.use('TkAgg') | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from agent.agent import Agent | ||
from functions import * | ||
import pandas as pd | ||
|
||
# stock_name = '^HSI_2018' | ||
window_size = 5 | ||
|
||
''' | ||
agent = Agent(window_size, True) | ||
data, adj_close = getStockDataVec(stock_name) | ||
l = len(data) - 1 | ||
batch_size = 32 | ||
episode_count = 1000 | ||
''' | ||
action_map = {0:"HOLD", 1: "BUY", 2:"SELL"} | ||
def DQN(stock_table=None, money= None, inc= None, original_shares= None, commission= None): | ||
|
||
window_size=3 | ||
cash, num_shares = 100_000, 0 | ||
sh = 50 | ||
agent = Agent(window_size) | ||
data, adj_close, date = getStockDataVec(path = "../data/test_dqn_data.csv") | ||
l = len(adj_close) - 1 | ||
batch_size = 32 | ||
episode_count = 5 | ||
|
||
|
||
|
||
closes = [] | ||
buys = [] | ||
sells = [] | ||
|
||
|
||
final_vals, actions, shares, cashes, dates = [], [], [], [], [] | ||
|
||
episode_count=1 | ||
for e in range(episode_count): | ||
closes = [] | ||
buys = [] | ||
sells = [] | ||
|
||
|
||
state = getState(data, 0, window_size + 1) | ||
total_profit = 0 | ||
agent.inventory = [] | ||
|
||
|
||
# capital = 100000 | ||
for t in range(l): | ||
#action = agent.act(state) | ||
action = np.random.randint(0, 3) | ||
closes.append(data[t]) | ||
|
||
# sit | ||
next_state = getState(data, t + 1, window_size + 1) | ||
reward = 0 | ||
''' | ||
if action == 1: # buy | ||
if capital > adj_close[t]: | ||
agent.inventory.append(adj_close[t]) | ||
buys.append(adj_close[t]) | ||
sells.append(None) | ||
capital -= adj_close[t] | ||
else: | ||
buys.append(None) | ||
sells.append(None) | ||
elif action == 2: # sell | ||
if len(agent.inventory) > 0: | ||
bought_price = agent.inventory.pop(0) | ||
reward = max(adj_close[t] - bought_price, 0) | ||
total_profit += adj_close[t] - bought_price | ||
buys.append(None) | ||
sells.append(adj_close[t]) | ||
capital += adj_close[t] | ||
else: | ||
buys.append(None) | ||
sells.append(None) | ||
elif action == 0: | ||
buys.append(None) | ||
sells.append(None) | ||
''' | ||
|
||
next_adj_close = adj_close[t+1] | ||
current_adj_close = adj_close[t] | ||
|
||
# get reward | ||
if action == 0: # hold | ||
if num_shares > 0: | ||
next_cash = cash # no change | ||
reward = (cash + num_shares * next_adj_close) - (cash + num_shares * current_adj_close) | ||
else: | ||
reward = 0 | ||
|
||
if action == 1: # buy | ||
if cash > sh * current_adj_close: | ||
next_cash = cash - sh * current_adj_close | ||
# reward = (cash - current_adj_close + ((num_shares+1)*next_adj_close)) - (cash + num_shares*current_adj_close) | ||
reward = (next_cash + ((num_shares + sh) * next_adj_close)) - (cash + num_shares * current_adj_close) | ||
num_shares += sh | ||
cash = next_cash | ||
else: | ||
reward = 0 | ||
|
||
if action == 2: # sell | ||
if num_shares > 0: | ||
next_cash = cash + sh * current_adj_close | ||
# reward = (cash + current_adj_close + ((num_shares-1)*next_adj_close)) - (cash + num_shares*current_adj_close) | ||
reward = (next_cash + ((num_shares - sh) * next_adj_close)) - (cash + num_shares * current_adj_close) | ||
num_shares -= sh | ||
cash = next_cash | ||
else: | ||
reward = 0 | ||
|
||
|
||
|
||
done = True if t == l - 1 else False | ||
agent.memory.push(state, action, next_state, reward) | ||
state = next_state | ||
|
||
''' | ||
if done: | ||
print("--------------------------------") | ||
print(" Total Profit: " + formatPrice(total_profit)) | ||
print(" Total Shares: ", ) | ||
print("--------------------------------") | ||
''' | ||
if done: | ||
print("--------------------------------") | ||
print("Total Profit: " + formatPrice(total_profit)) | ||
print("Total Reward: ", reward) | ||
print("Total shares: ", num_shares) | ||
print("Total cash: ", cash) | ||
print("--------------------------------") | ||
|
||
|
||
cur_cash, cur_shares = cash, num_shares | ||
final_vals.append(cur_cash + (cur_shares * adj_close[t])) | ||
cashes.append(cur_cash) | ||
actions.append(action_map[action]) | ||
shares.append(num_shares) | ||
dates.append(date[t]) | ||
|
||
cashes = pd.Series(cashes,index = pd.to_datetime(dates)) | ||
shares = pd.Series(shares,index = pd.to_datetime(dates)) | ||
actions = pd.Series(actions,index = pd.to_datetime(dates)) | ||
final_vals = pd.Series(final_vals,index = pd.to_datetime(dates)) | ||
|
||
results = {'final_vals': final_vals, 'actions': actions, 'shares': shares, 'cash': cashes} | ||
|
||
|
||
return results | ||
|
||
if __name__=="__main__": | ||
res = DQN() | ||
dic = pd.DataFrame(res) | ||
print(dic) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import numpy as np | ||
import math | ||
|
||
# prints formatted price | ||
def formatPrice(n): | ||
return ("-$" if n < 0 else "$") + "{0:.2f}".format(abs(n)) | ||
|
||
def getStockDataVec(key=None,path=None): | ||
vec, states, date = [], [], [] | ||
lines = open(path, "r").read().splitlines() | ||
for line in lines[5:]: | ||
row = line.split(",") | ||
close = row[6] | ||
if close != 'null': | ||
date.append(row[0]) | ||
vec.append(float(row[6])) | ||
states.append(list(map(float, row[11:13]))) | ||
|
||
return states, vec, date | ||
|
||
# returns an an n-day state representation ending at time t | ||
def getState(data, t, n): | ||
return np.array([data[t]]) |
Oops, something went wrong.