Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added asynchronous plotting to TensorFlow #126

Open
wants to merge 7 commits into
base: integration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ blackbox.zip
blackbox
rllab/config_personal.py
*.swp
sandbox
70 changes: 38 additions & 32 deletions sandbox/rocky/tf/algos/batch_polopt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import time

import tensorflow as tf

from rllab.algos import RLAlgorithm
import rllab.misc.logger as logger

from sandbox.rocky.tf.plotter import Plotter
from sandbox.rocky.tf.policies.base import Policy
import tensorflow as tf
from sandbox.rocky.tf.samplers import BatchSampler
from sandbox.rocky.tf.samplers import VectorizedSampler
from rllab.sampler.utils import rollout


class BatchPolopt(RLAlgorithm):
Expand All @@ -14,30 +17,28 @@ class BatchPolopt(RLAlgorithm):
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
"""

def __init__(
self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs
):
def __init__(self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs):
"""
:param env: Environment
:param policy: Policy
Expand Down Expand Up @@ -86,11 +87,16 @@ def __init__(
self.sampler = sampler_cls(self, **sampler_args)
self.init_opt()

def start_worker(self):
def start_worker(self, sess):
self.sampler.start_worker()
if self.plot:
self.plotter = Plotter(self.env, self.policy, sess)
self.plotter.start()

def shutdown_worker(self):
self.sampler.shutdown_worker()
if self.plot:
self.plotter.shutdown()

def obtain_samples(self, itr):
return self.sampler.obtain_samples(itr)
Expand All @@ -103,9 +109,9 @@ def train(self, sess=None):
if sess is None:
sess = tf.Session()
sess.__enter__()

sess.run(tf.global_variables_initializer())
self.start_worker()
self.start_worker(sess)
start_time = time.time()
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
Expand All @@ -119,7 +125,7 @@ def train(self, sess=None):
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
params = self.get_itr_snapshot(itr, samples_data)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
Expand All @@ -128,10 +134,11 @@ def train(self, sess=None):
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
self.plotter.update_plot(self.policy, self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")

self.shutdown_worker()
if created_session:
sess.close()
Expand All @@ -157,4 +164,3 @@ def get_itr_snapshot(self, itr, samples_data):

def optimize_policy(self, itr, samples_data):
raise NotImplementedError

1 change: 1 addition & 0 deletions sandbox/rocky/tf/plotter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sandbox.rocky.tf.plotter.plotter import Plotter
89 changes: 89 additions & 0 deletions sandbox/rocky/tf/plotter/plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import atexit
from multiprocessing import Process
import numpy as np
import platform
from queue import Empty, Queue
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8: import grouping

import tensorflow as tf
from threading import Thread

from rllab.sampler.utils import rollout

__all__ = ["Plotter"]


class Plotter(object):
def __init__(self, env, policy, sess=None, graph=None):
self.env = env
self.policy = policy
self.sess = tf.get_default_session() if sess is None else sess
self.graph = tf.get_default_graph() if graph is None else graph
self.worker_thread = Thread(target=self._start_worker, daemon=True)
self.queue = Queue()

def _start_worker(self):
env = None
policy = None
max_length = None
initial_rollout = True
try:
with self.sess.as_default(), self.sess.graph.as_default():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here explaining that the worker processes all messages in the queue per loop, not one message per loop?

while True:
msgs = {}
# If true, block and yield processor
if initial_rollout:
msg = self.queue.get()
msgs[msg[0]] = msg[1:]
# Only fetch the last message of each type
while not self.queue.empty():
msg = self.queue.get()
msgs[msg[0]] = msg[1:]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

msg[0] and msg[1:] is very difficult to read. What if we used a namedtuple instead?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from collections import namedtuple
import enum
from enum import Enum

class Op(Enum):
    STOP = enum.auto()
    UPDATE = enum.auto()
    DEMO = enum.auto()

Message = namedtuple("Message", ["op", "args", "kwargs"])

class Plotter:

    def _start_worker(self):
        while True:

            if initial_rollout:
                msg = self.queue.get()
                msgs[msg.op] = msg

            if Op.STOP in msgs:
                break
            elif Op.DEMO in msgs:
                env, policy = msgs[Op.DEMO].args

    def update_plot(self, policy, max_length=np.inf):
        if self.worker_thread.is_alive():
            self.queue.put(Message(op=Op.DEMO, args=(policy.get_param_values(), max_length))
            self.queue.task_done()

else:
# Only fetch the last message of each type
while not self.queue.empty():
msg = self.queue.get_nowait()
msgs[msg[0]] = msg[1:]

if 'stop' in msgs:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's replace the strings with enums.

break
elif 'update' in msgs:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be if and not elif?

env, policy = msgs['update']
elif 'demo' in msgs:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be if and not elif?

param_values, max_length = msgs['demo']
policy.set_param_values(param_values)
initial_rollout = False
rollout(
env,
policy,
max_path_length=max_length,
animated=True,
speedup=5)
else:
if max_length:
rollout(
env,
policy,
max_path_length=max_length,
animated=True,
speedup=5)
except KeyboardInterrupt:
pass

def shutdown(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that the worker thread dies correctly at the end of the simulation and when user interrupts the simulation with keyboard interruption.

if self.worker_thread.is_alive():
self.queue.put(['stop'])
self.queue.task_done()
self.queue.join()
self.worker_thread.join()

def start(self):
if not self.worker_thread.is_alive():
tf.get_variable_scope().reuse_variables()
self.worker_thread.start()
self.queue.put(['update', self.env, self.policy])
self.queue.task_done()
atexit.register(self.shutdown)

def update_plot(self, policy, max_length=np.inf):
if self.worker_thread.is_alive():
self.queue.put(['demo', policy.get_param_values(), max_length])
self.queue.task_done()