Skip to content

Commit

Permalink
Asynchronous plotting for Theano (#124)
Browse files Browse the repository at this point in the history
Added support for asynchronous plotting for Theano. The main caveat is that Linux machines need to use multiprocessing.Process and Mac OS X machines need to use threading.Thread. If Linux machines use Threads, the program slows to a crawl (like as a result of Python's Global Interpreter Lock GIL); conversely, if Mac machines use Process, the glfw will throw a segmentation fault and fail to draw the window.
  • Loading branch information
jonashen authored Jun 7, 2018
1 parent 5c5d3a2 commit 00ec862
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
9 changes: 3 additions & 6 deletions rllab/algos/batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __init__(self,

def start_worker(self):
self.sampler.start_worker()
if self.plot:
plotter.init_plot(self.env, self.policy)

def shutdown_worker(self):
self.sampler.shutdown_worker()
Expand All @@ -130,12 +132,7 @@ def train(self):
logger.log("saved")
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(
self.env,
self.policy,
max_path_length=self.max_path_length,
animated=True,
speedup=5)
self.update_plot()
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
Expand Down
19 changes: 16 additions & 3 deletions rllab/plotter/plotter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import atexit
from queue import Empty
from multiprocessing import Process, Queue
from rllab.sampler.utils import rollout
import numpy as np
import platform
from threading import Thread
from queue import Empty

from rllab.sampler.utils import rollout

__all__ = ['init_worker', 'init_plot', 'update_plot']

Expand Down Expand Up @@ -60,12 +63,22 @@ def _shutdown_worker():
def init_worker():
global process, queue
queue = Queue()
process = Process(target=_worker_start)
process = Thread(target=_worker_start) if (
'Darwin' in platform.platform()) else Process(target=_worker_start)
process.daemon = True
process.start()
atexit.register(_shutdown_worker)


def init_plot(env, policy):
global process, queue
if not (process and queue):
init_worker()

# Needed in order to draw glfw window on the main thread
if ('Darwin' in platform.platform()):
rollout(env, policy, max_path_length=np.inf, animated=True, speedup=5)

queue.put(['update', env, policy])


Expand Down

0 comments on commit 00ec862

Please sign in to comment.