Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added asynchronous plotting to TensorFlow #126

Open
wants to merge 7 commits into
base: integration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ blackbox.zip
blackbox
rllab/config_personal.py
*.swp
sandbox
70 changes: 38 additions & 32 deletions sandbox/rocky/tf/algos/batch_polopt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import time

import tensorflow as tf

from rllab.algos import RLAlgorithm
import rllab.misc.logger as logger

from sandbox.rocky.tf.plotter import Plotter
from sandbox.rocky.tf.policies.base import Policy
import tensorflow as tf
from sandbox.rocky.tf.samplers import BatchSampler
from sandbox.rocky.tf.samplers import VectorizedSampler
from rllab.sampler.utils import rollout


class BatchPolopt(RLAlgorithm):
Expand All @@ -14,30 +17,28 @@ class BatchPolopt(RLAlgorithm):
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
"""

def __init__(
self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs
):
def __init__(self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs):
"""
:param env: Environment
:param policy: Policy
Expand Down Expand Up @@ -86,11 +87,16 @@ def __init__(
self.sampler = sampler_cls(self, **sampler_args)
self.init_opt()

def start_worker(self):
def start_worker(self, sess):
self.sampler.start_worker()
if self.plot:
self.plotter = Plotter(self.env, self.policy, sess)
self.plotter.start()

def shutdown_worker(self):
self.sampler.shutdown_worker()
if self.plot:
self.plotter.shutdown()

def obtain_samples(self, itr):
return self.sampler.obtain_samples(itr)
Expand All @@ -103,9 +109,9 @@ def train(self, sess=None):
if sess is None:
sess = tf.Session()
sess.__enter__()

sess.run(tf.global_variables_initializer())
self.start_worker()
self.start_worker(sess)
start_time = time.time()
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
Expand All @@ -119,7 +125,7 @@ def train(self, sess=None):
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
params = self.get_itr_snapshot(itr, samples_data)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
Expand All @@ -128,10 +134,11 @@ def train(self, sess=None):
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
self.plotter.update_plot(self.policy, self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")

self.shutdown_worker()
if created_session:
sess.close()
Expand All @@ -157,4 +164,3 @@ def get_itr_snapshot(self, itr, samples_data):

def optimize_policy(self, itr, samples_data):
raise NotImplementedError

1 change: 1 addition & 0 deletions sandbox/rocky/tf/plotter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sandbox.rocky.tf.plotter.plotter import Plotter
108 changes: 108 additions & 0 deletions sandbox/rocky/tf/plotter/plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import atexit
from collections import namedtuple
import enum
from enum import Enum
from multiprocessing import Process
import platform
from queue import Empty, Queue
from threading import Thread

import numpy as np
import tensorflow as tf

from rllab.sampler.utils import rollout

__all__ = ["Plotter"]


class Op(Enum):
STOP = 0
UPDATE = 1
DEMO = 2


Message = namedtuple("Message", ["op", "args", "info"])


class Plotter(object):
def __init__(self, env, policy, sess=None, graph=None):
self.env = env
self.policy = policy
self.sess = tf.get_default_session() if sess is None else sess
self.graph = tf.get_default_graph() if graph is None else graph
self.worker_thread = Thread(target=self._start_worker, daemon=True)
self.queue = Queue()

def _start_worker(self):
env = None
policy = None
max_length = None
initial_rollout = True
try:
with self.sess.as_default(), self.sess.graph.as_default():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here explaining that the worker processes all messages in the queue per loop, not one message per loop?

# Each iteration will process ALL messages currently in the queue
while True:
msgs = {}
# If true, block and yield processor
if initial_rollout:
msg = self.queue.get()
msgs[msg.op] = msg
# Only fetch the last message of each type
while not self.queue.empty():
msg = self.queue.get()
msgs[msg.op] = msg
else:
# Only fetch the last message of each type
while not self.queue.empty():
msg = self.queue.get_nowait()
msgs[msg.op] = msg

if Op.STOP in msgs:
break
if Op.UPDATE in msgs:
env, policy = msgs[Op.UPDATE].args
if Op.DEMO in msgs:
param_values, max_length = msgs[Op.DEMO].args
policy.set_param_values(param_values)
initial_rollout = False
rollout(
env,
policy,
max_path_length=max_length,
animated=True,
speedup=5)
else:
if max_length:
rollout(
env,
policy,
max_path_length=max_length,
animated=True,
speedup=5)
except KeyboardInterrupt:
pass

def shutdown(self):
if self.worker_thread.is_alive():
self.queue.put(Message(op=Op.STOP, args=None, info=None))
self.queue.task_done()
self.queue.join()
self.worker_thread.join()

def start(self):
if not self.worker_thread.is_alive():
tf.get_variable_scope().reuse_variables()
self.worker_thread.start()
self.queue.put(
Message(op=Op.UPDATE, args=(self.env, self.policy), info=None))
self.queue.task_done()
atexit.register(self.shutdown)

def update_plot(self, policy, max_length=np.inf):
if self.worker_thread.is_alive():
self.queue.put(
Message(
op=Op.DEMO,
args=(policy.get_param_values(), max_length),
info=None))
self.queue.task_done()