Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated for tensorflow 1.0, support for Mnih 2015 network architecture #38

Open
wants to merge 1 commit into
base: gym
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def log_uniform(lo, hi, rate):
if USE_LSTM:
global_network = GameACLSTMNetwork(ACTION_SIZE, -1, device)
else:
global_network = GameACFFNetwork(ACTION_SIZE, device)
global_network = GameACFFNetwork(ACTION_SIZE, -1, device)


training_threads = []
Expand All @@ -74,15 +74,15 @@ def log_uniform(lo, hi, rate):
sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
allow_soft_placement=True))

init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
sess.run(init)

# summary for tensorboard
score_input = tf.placeholder(tf.int32)
tf.scalar_summary("score", score_input)
tf.summary.scalar("score", score_input)

summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(LOG_FILE, sess.graph_def)
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(LOG_FILE, sess.graph)

# init or load checkpoint with saver
saver = tf.train.Saver()
Expand All @@ -106,7 +106,7 @@ def log_uniform(lo, hi, rate):

def train_function(parallel_index):
global global_t

training_thread = training_threads[parallel_index]
# set start_time
start_time = time.time() - wall_t
Expand All @@ -121,17 +121,17 @@ def train_function(parallel_index):
diff_global_t = training_thread.process(sess, global_t, summary_writer,
summary_op, score_input)
global_t += diff_global_t


def signal_handler(signal, frame):
global stop_requested
print('You pressed Ctrl+C!')
stop_requested = True

train_threads = []
for i in range(PARALLEL_SIZE):
train_threads.append(threading.Thread(target=train_function, args=(i,)))

signal.signal(signal.SIGINT, signal_handler)

# set start time
Expand All @@ -144,12 +144,12 @@ def signal_handler(signal, frame):
signal.pause()

print('Now saving data. Please wait')

for t in train_threads:
t.join()

if not os.path.exists(CHECKPOINT_DIR):
os.mkdir(CHECKPOINT_DIR)
os.mkdir(CHECKPOINT_DIR)

# write wall time
wall_t = time.time() - start_time
Expand All @@ -158,5 +158,3 @@ def signal_handler(signal, frame):
f.write(str(wall_t))

saver.save(sess, CHECKPOINT_DIR + '/' + 'checkpoint', global_step = global_t)


30 changes: 4 additions & 26 deletions a3c_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,15 @@
from constants import USE_LSTM

def choose_action(pi_values):
values = []
sum = 0.0
for rate in pi_values:
sum = sum + rate
value = sum
values.append(value)

r = random.random() * sum
for i in range(len(values)):
if values[i] >= r:
return i;
#fail safe
return len(values)-1
return np.random.choice(range(len(pi_values)), p=pi_values)

# use CPU for display tool
device = "/cpu:0"

if USE_LSTM:
global_network = GameACLSTMNetwork(ACTION_SIZE, -1, device)
else:
global_network = GameACFFNetwork(ACTION_SIZE, device)
global_network = GameACFFNetwork(ACTION_SIZE, -1, device)

learning_rate_input = tf.placeholder("float")

Expand All @@ -49,17 +37,8 @@ def choose_action(pi_values):
clip_norm = GRAD_NORM_CLIP,
device = device)

# training_threads = []
# for i in range(PARALLEL_SIZE):
# training_thread = A3CTrainingThread(i, global_network, 1.0,
# learning_rate_input,
# grad_applier,
# 8000000,
# device = device)
# training_threads.append(training_thread)

sess = tf.Session()
init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
sess.run(init)

saver = tf.train.Saver()
Expand All @@ -70,7 +49,7 @@ def choose_action(pi_values):
else:
print("Could not find old checkpoint")

game_state = GameState(display=True, no_op_max=0)
game_state = GameState(0, display=True, no_op_max=0)

while True:
pi_values = global_network.run_policy(sess, game_state.s_t)
Expand All @@ -82,4 +61,3 @@ def choose_action(pi_values):
game_state.reset()
else:
game_state.update()

69 changes: 26 additions & 43 deletions a3c_training_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import sys

from accum_trainer import AccumTrainer
from game_state import GameState
from game_state import ACTION_SIZE
from game_ac_network import GameACFFNetwork, GameACLSTMNetwork
Expand Down Expand Up @@ -35,26 +34,26 @@ def __init__(self,
if USE_LSTM:
self.local_network = GameACLSTMNetwork(ACTION_SIZE, thread_index, device)
else:
self.local_network = GameACFFNetwork(ACTION_SIZE, device)
self.local_network = GameACFFNetwork(ACTION_SIZE, thread_index, device)

self.local_network.prepare_loss(ENTROPY_BETA)

# TODO: don't need accum trainer anymore with batch
self.trainer = AccumTrainer(device)
self.trainer.prepare_minimize( self.local_network.total_loss,
self.local_network.get_vars() )

self.accum_gradients = self.trainer.accumulate_gradients()
self.reset_gradients = self.trainer.reset_gradients()
with tf.device(device):
var_refs = [v._ref() for v in self.local_network.get_vars()]
self.gradients = tf.gradients(
self.local_network.total_loss, var_refs,
gate_gradients=False,
aggregation_method=None,
colocate_gradients_with_ops=False)

self.apply_gradients = grad_applier.apply_gradients(
global_network.get_vars(),
self.trainer.get_accum_grad_list() )
self.gradients )

self.sync = self.local_network.sync_from(global_network)

self.game_state = GameState()

self.local_t = 0

self.initial_learning_rate = initial_learning_rate
Expand All @@ -71,26 +70,15 @@ def _anneal_learning_rate(self, global_time_step):
return learning_rate

def choose_action(self, pi_values):
values = []
sum = 0.0
for rate in pi_values:
sum = sum + rate
value = sum
values.append(value)

r = random.random() * sum
for i in range(len(values)):
if values[i] >= r:
return i;
#fail safe
return len(values)-1
return np.random.choice(range(len(pi_values)), p=pi_values)

def _record_score(self, sess, summary_writer, summary_op, score_input, score, global_t):
summary_str = sess.run(summary_op, feed_dict={
score_input: score
})
summary_writer.add_summary(summary_str, global_t)

summary_writer.flush()

def set_start_time(self, start_time):
self.start_time = start_time

Expand All @@ -102,17 +90,14 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):

terminal_end = False

# reset accumulated gradients
sess.run( self.reset_gradients )

# copy weights from shared to local
sess.run( self.sync )

start_local_t = self.local_t

if USE_LSTM:
start_lstm_state = self.local_network.lstm_state_out

# t_max times loop
for i in range(LOCAL_T_MAX):
pi_, value_ = self.local_network.run_policy_and_value(sess, self.game_state.s_t)
Expand Down Expand Up @@ -142,14 +127,14 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):

# s_t1 -> s_t
self.game_state.update()

if terminal:
terminal_end = True
print("score={}".format(self.episode_reward))

self._record_score(sess, summary_writer, summary_op, score_input,
self.episode_reward, global_t)

self.episode_reward = 0
self.game_state.reset()
if USE_LSTM:
Expand Down Expand Up @@ -182,32 +167,31 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
batch_td.append(td)
batch_R.append(R)

cur_learning_rate = self._anneal_learning_rate(global_t)

if USE_LSTM:
batch_si.reverse()
batch_a.reverse()
batch_td.reverse()
batch_R.reverse()

sess.run( self.accum_gradients,
sess.run( self.apply_gradients,
feed_dict = {
self.local_network.s: batch_si,
self.local_network.a: batch_a,
self.local_network.td: batch_td,
self.local_network.r: batch_R,
self.local_network.initial_lstm_state: start_lstm_state,
self.local_network.step_size : [len(batch_a)] } )
self.local_network.step_size : [len(batch_a)],
self.learning_rate_input: cur_learning_rate } )
else:
sess.run( self.accum_gradients,
sess.run( self.apply_gradients,
feed_dict = {
self.local_network.s: batch_si,
self.local_network.a: batch_a,
self.local_network.td: batch_td,
self.local_network.r: batch_R} )

cur_learning_rate = self._anneal_learning_rate(global_t)

sess.run( self.apply_gradients,
feed_dict = { self.learning_rate_input: cur_learning_rate } )
self.local_network.r: batch_R,
self.learning_rate_input: cur_learning_rate} )

if (self.thread_index == 0) and (self.local_t - self.prev_local_t >= PERFORMANCE_LOG_INTERVAL):
self.prev_local_t += PERFORMANCE_LOG_INTERVAL
Expand All @@ -219,4 +203,3 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
# return advanced local step size
diff_local_t = self.local_t - start_local_t
return diff_local_t

14 changes: 3 additions & 11 deletions a3c_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if USE_LSTM:
global_network = GameACLSTMNetwork(ACTION_SIZE, -1, device)
else:
global_network = GameACFFNetwork(ACTION_SIZE, device)
global_network = GameACFFNetwork(ACTION_SIZE, -1, device)

training_threads = []

Expand All @@ -40,15 +40,8 @@
clip_norm = GRAD_NORM_CLIP,
device = device)

# for i in range(PARALLEL_SIZE):
# training_thread = A3CTrainingThread(i, global_network, 1.0,
# learning_rate_input,
# grad_applier, MAX_TIME_STEP,
# device = device)
# training_threads.append(training_thread)

sess = tf.Session()
init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
sess.run(init)

saver = tf.train.Saver()
Expand All @@ -58,7 +51,7 @@
print("checkpoint loaded:", checkpoint.model_checkpoint_path)
else:
print("Could not find old checkpoint")

W_conv1 = sess.run(global_network.W_conv1)

# show graph of W_conv1
Expand All @@ -74,4 +67,3 @@
ax.set_title(str(inch) + "," + str(outch))

plt.show()

Loading