Skip to content

Commit

Permalink
Merge branch 'integration' into tf_plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
jonashen authored Jun 7, 2018
2 parents faa064e + 00ec862 commit 89d5d08
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
9 changes: 3 additions & 6 deletions rllab/algos/batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __init__(self,

def start_worker(self):
self.sampler.start_worker()
if self.plot:
plotter.init_plot(self.env, self.policy)

def shutdown_worker(self):
self.sampler.shutdown_worker()
Expand All @@ -130,12 +132,7 @@ def train(self):
logger.log("saved")
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(
self.env,
self.policy,
max_path_length=self.max_path_length,
animated=True,
speedup=5)
self.update_plot()
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
Expand Down
19 changes: 16 additions & 3 deletions rllab/plotter/plotter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import atexit
from queue import Empty
from multiprocessing import Process, Queue
from rllab.sampler.utils import rollout
import numpy as np
import platform
from threading import Thread
from queue import Empty

from rllab.sampler.utils import rollout

__all__ = ['init_worker', 'init_plot', 'update_plot']

Expand Down Expand Up @@ -60,12 +63,22 @@ def _shutdown_worker():
def init_worker():
global process, queue
queue = Queue()
process = Process(target=_worker_start)
process = Thread(target=_worker_start) if (
'Darwin' in platform.platform()) else Process(target=_worker_start)
process.daemon = True
process.start()
atexit.register(_shutdown_worker)


def init_plot(env, policy):
global process, queue
if not (process and queue):
init_worker()

# Needed in order to draw glfw window on the main thread
if ('Darwin' in platform.platform()):
rollout(env, policy, max_path_length=np.inf, animated=True, speedup=5)

queue.put(['update', env, policy])


Expand Down

0 comments on commit 89d5d08

Please sign in to comment.