forked from quantylab/rltrader
-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy_network_dnn.py
47 lines (37 loc) · 1.55 KB
/
policy_network_dnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
from keras.models import Sequential
from keras.layers import Activation, LSTM, Dense, BatchNormalization, Dropout, Flatten
from keras.optimizers import sgd
class PolicyNetwork:
def __init__(self, input_dim=0, output_dim=0, lr=0.01):
self.input_dim = input_dim
self.lr = lr
# DNN 신경망
self.model = Sequential()
self.model.add(Dense(128, input_shape=(1, input_dim)))
self.model.add(Dropout(0.5))
self.model.add(BatchNormalization())
self.model.add(Dense(128))
self.model.add(Dropout(0.5))
self.model.add(BatchNormalization())
self.model.add(Dense(128))
self.model.add(Dropout(0.5))
self.model.add(BatchNormalization())
self.model.add(Dense(output_dim))
self.model.add(Flatten())
self.model.add(Activation('sigmoid'))
self.model.compile(optimizer=sgd(lr=lr), loss='mse')
self.prob = None
def reset(self):
self.prob = None
def predict(self, sample):
self.prob = self.model.predict(np.array(sample).reshape((1, -1, self.input_dim)))[0]
return self.prob
def train_on_batch(self, x, y):
return self.model.train_on_batch(x, y)
def save_model(self, model_path):
if model_path is not None and self.model is not None:
self.model.save_weights(model_path, overwrite=True)
def load_model(self, model_path):
if model_path is not None:
self.model.load_weights(model_path)