-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added asynchronous plotting to TensorFlow #126
base: integration
Are you sure you want to change the base?
Changes from 6 commits
c07843e
edfc757
c7bc9d6
5483eef
3c26089
e0ca0b8
47a9604
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,4 +36,3 @@ blackbox.zip | |
blackbox | ||
rllab/config_personal.py | ||
*.swp | ||
sandbox |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from sandbox.rocky.tf.plotter.plotter import Plotter |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import atexit | ||
from multiprocessing import Process | ||
import numpy as np | ||
import platform | ||
from queue import Empty, Queue | ||
import tensorflow as tf | ||
from threading import Thread | ||
|
||
from rllab.sampler.utils import rollout | ||
|
||
__all__ = ["Plotter"] | ||
|
||
|
||
class Plotter(object): | ||
def __init__(self, env, policy, sess=None, graph=None): | ||
self.env = env | ||
self.policy = policy | ||
self.sess = tf.get_default_session() if sess is None else sess | ||
self.graph = tf.get_default_graph() if graph is None else graph | ||
self.worker_thread = Thread(target=self._start_worker, daemon=True) | ||
self.queue = Queue() | ||
|
||
def _start_worker(self): | ||
env = None | ||
policy = None | ||
max_length = None | ||
initial_rollout = True | ||
try: | ||
with self.sess.as_default(), self.sess.graph.as_default(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment here explaining that the worker processes all messages in the queue per loop, not one message per loop? |
||
while True: | ||
msgs = {} | ||
# If true, block and yield processor | ||
if initial_rollout: | ||
msg = self.queue.get() | ||
msgs[msg[0]] = msg[1:] | ||
# Only fetch the last message of each type | ||
while not self.queue.empty(): | ||
msg = self.queue.get() | ||
msgs[msg[0]] = msg[1:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from collections import namedtuple
import enum
from enum import Enum
class Op(Enum):
STOP = enum.auto()
UPDATE = enum.auto()
DEMO = enum.auto()
Message = namedtuple("Message", ["op", "args", "kwargs"])
class Plotter:
def _start_worker(self):
while True:
if initial_rollout:
msg = self.queue.get()
msgs[msg.op] = msg
if Op.STOP in msgs:
break
elif Op.DEMO in msgs:
env, policy = msgs[Op.DEMO].args
def update_plot(self, policy, max_length=np.inf):
if self.worker_thread.is_alive():
self.queue.put(Message(op=Op.DEMO, args=(policy.get_param_values(), max_length))
self.queue.task_done() |
||
else: | ||
# Only fetch the last message of each type | ||
while not self.queue.empty(): | ||
msg = self.queue.get_nowait() | ||
msgs[msg[0]] = msg[1:] | ||
|
||
if 'stop' in msgs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's replace the strings with enums. |
||
break | ||
elif 'update' in msgs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be |
||
env, policy = msgs['update'] | ||
elif 'demo' in msgs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be |
||
param_values, max_length = msgs['demo'] | ||
policy.set_param_values(param_values) | ||
initial_rollout = False | ||
rollout( | ||
env, | ||
policy, | ||
max_path_length=max_length, | ||
animated=True, | ||
speedup=5) | ||
else: | ||
if max_length: | ||
rollout( | ||
env, | ||
policy, | ||
max_path_length=max_length, | ||
animated=True, | ||
speedup=5) | ||
except KeyboardInterrupt: | ||
pass | ||
|
||
def shutdown(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check that the worker thread dies correctly at the end of the simulation and when user interrupts the simulation with keyboard interruption. |
||
if self.worker_thread.is_alive(): | ||
self.queue.put(['stop']) | ||
self.queue.task_done() | ||
self.queue.join() | ||
self.worker_thread.join() | ||
|
||
def start(self): | ||
if not self.worker_thread.is_alive(): | ||
tf.get_variable_scope().reuse_variables() | ||
self.worker_thread.start() | ||
self.queue.put(['update', self.env, self.policy]) | ||
self.queue.task_done() | ||
atexit.register(self.shutdown) | ||
|
||
def update_plot(self, policy, max_length=np.inf): | ||
if self.worker_thread.is_alive(): | ||
self.queue.put(['demo', policy.get_param_values(), max_length]) | ||
self.queue.task_done() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8: import grouping