-
Notifications
You must be signed in to change notification settings - Fork 0
/
deep_q_network.py
89 lines (76 loc) · 3.31 KB
/
deep_q_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
This code is the modified code from https://github.com/hunkim/ReinforcementZeroToAll/
DQN Class
DQN(NIPS-2013)
"Playing Atari with Deep Reinforcement Learning"
https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
DQN(Nature-2015)
"Human-level control through deep reinforcement learning"
http://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf
"""
import os
import numpy as np
import tensorflow as tf
class DQN:
def __init__(self, session: tf.Session, input_size: int, output_size: int, name: str="main") -> None:
"""DQN Agent can
1) Build network
2) Predict Q_value given state
3) Train parameters
Args:
session (tf.Session): Tensorflow session
input_size (int): Input dimension
output_size (int): Number of discrete actions
name (str, optional): TF Graph will be built under this name scope
"""
self.session = session
self.input_size = input_size
self.output_size = output_size
self.net_name = name
self._build_network()
def _build_network(self, l_rate=0.0001) -> None:
"""DQN Network architecture (simple MLP)
Args:
l_rate (float, optional): Learning rate
"""
with tf.variable_scope(self.net_name):
self._X = tf.placeholder(tf.float32, [None, self.input_size], name="input_x")
net = self._X
hidden = tf.layers.dense(net, 512, activation=tf.nn.relu)
net = tf.layers.dense(hidden, self.output_size)
self._Qpred = net
self._Y = tf.placeholder(tf.float32, shape=[None, self.output_size])
self._loss = tf.losses.mean_squared_error(self._Y, self._Qpred)
optimizer = tf.train.AdamOptimizer(learning_rate=l_rate)
self._train = optimizer.minimize(self._loss)
def predict(self, state: np.ndarray) -> np.ndarray:
"""Returns Q(s, a)
Args:
state (np.ndarray): State array, shape (n, input_dim)
Returns:
np.ndarray: Q value array, shape (n, output_dim)
"""
x = np.reshape(state, [-1, self.input_size])
return self.session.run(self._Qpred, feed_dict={self._X: x})
def update(self, x_stack: np.ndarray, y_stack: np.ndarray) -> list:
"""Performs updates on given X and y and returns a result
Args:
x_stack (np.ndarray): State array, shape (n, input_dim)
y_stack (np.ndarray): Target Q array, shape (n, output_dim)
Returns:
list: First element is loss, second element is a result from train step
"""
feed = {
self._X: x_stack,
self._Y: y_stack
}
return self.session.run([self._loss, self._train], feed)
def load(self, name):
saver = tf.train.import_meta_graph('model.meta')
saver.restore(self.session, tf.train.latest_checkpoint('./'))
def save(self, run, score_mean, loss_mean):
if not os.path.exists('./models'):
os.mkdir('./models')
filename = "models/" + str(run) + "_" + str(round(score_mean, 3)) + "_" + str(round(loss_mean, 6)) + "/model"
saver = tf.train.Saver()
saver.save(self.session, filename)