-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
24 lines (18 loc) · 1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class StockPredictor(nn.Module):
def __init__(self, input_size = 1, hidden_size = 200, output_size = 1):
super(StockPredictor, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, prices):
batch_size = len(prices)
self.hidden_cell = (torch.zeros(1, batch_size, self.hidden_size).to(device),
torch.zeros(1, batch_size, self.hidden_size).to(device))
new_prices = prices.permute(1, 0).unsqueeze(-1)
# NOTE: pytorch LSTM units take input in the form of [window_length, batch_size, num_features], which will end up being [WINDOW_SIZE, batch_size, 1] for our dataset
lstm_out, self.hidden_cell = self.lstm(new_prices, self.hidden_cell)
pred = self.linear(lstm_out[-1])
return pred[-1]