Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorBoard Support #226

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ dependencies:
- pylru==1.0.9
- hyperopt
- polling
- tensorboard
69 changes: 62 additions & 7 deletions rllab/misc/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import pickle
import base64
import tensorflow as tf

_prefixes = []
_prefix_str = ''
Expand All @@ -31,13 +32,17 @@
_tabular_fds = {}
_tabular_header_written = set()

_tensorboard_writer = None
_snapshot_dir = None
_snapshot_mode = 'all'
_snapshot_gap = 1

_log_tabular_only = False
_header_printed = False

_tensorboard_default_step = 0
_tensorboard_step_key = None


def _add_output(file_name, arr, fds, mode='a'):
if file_name not in arr:
Expand Down Expand Up @@ -77,6 +82,20 @@ def remove_tabular_output(file_name):
_remove_output(file_name, _tabular_outputs, _tabular_fds)


def set_tensorboard_dir(dir_name):
global _tensorboard_writer
if not dir_name:
if _tensorboard_writer:
_tensorboard_writer.close()
_tensorboard_writer = None
else:
mkdir_p(os.path.dirname(dir_name))
_tensorboard_writer = tf.summary.FileWriter(dir_name)
_tensorboard_default_step = 0
assert _tensorboard_writer is not None
print("tensorboard data will be logged into:", dir_name)


def set_snapshot_dir(dir_name):
global _snapshot_dir
_snapshot_dir = dir_name
Expand All @@ -94,18 +113,26 @@ def set_snapshot_mode(mode):
global _snapshot_mode
_snapshot_mode = mode


def get_snapshot_gap():
return _snapshot_gap


def set_snapshot_gap(gap):
global _snapshot_gap
_snapshot_gap = gap


def set_log_tabular_only(log_tabular_only):
global _log_tabular_only
_log_tabular_only = log_tabular_only


def set_tensorboard_step_key(key):
global _tensorboard_step_key
_tensorboard_step_key = key


def get_log_tabular_only():
return _log_tabular_only

Expand Down Expand Up @@ -186,6 +213,23 @@ def refresh(self):
table_printer = TerminalTablePrinter()


def dump_tensorboard(*args, **kwargs):
if len(_tabular) > 0 and _tensorboard_writer:
tabular_dict = dict(_tabular)
if _tensorboard_step_key and _tensorboard_step_key in tabular_dict:
step = tabular_dict[_tensorboard_step_key]
else:
global _tensorboard_default_step
step = _tensorboard_default_step
_tensorboard_default_step += 1

summary = tf.Summary()
for k, v in tabular_dict.items():
summary.value.add(tag=k, simple_value=float(v))
_tensorboard_writer.add_summary(summary, int(step))
_tensorboard_writer.flush()


def dump_tabular(*args, **kwargs):
wh = kwargs.pop("write_header", None)
if len(_tabular) > 0:
Expand All @@ -195,11 +239,18 @@ def dump_tabular(*args, **kwargs):
for line in tabulate(_tabular).split('\n'):
log(line, *args, **kwargs)
tabular_dict = dict(_tabular)

# write to the tensorboard folder
# This assumes that the keys in each iteration won't change!
dump_tensorboard(args, kwargs)

# Also write to the csv files
# This assumes that the keys in each iteration won't change!
for tabular_fd in list(_tabular_fds.values()):
writer = csv.DictWriter(tabular_fd, fieldnames=list(tabular_dict.keys()))
if wh or (wh is None and tabular_fd not in _tabular_header_written):
writer = csv.DictWriter(
tabular_fd, fieldnames=list(tabular_dict.keys()))
if wh or (wh is None
and tabular_fd not in _tabular_header_written):
writer.writeheader()
_tabular_header_written.add(tabular_fd)
writer.writerow(tabular_dict)
Expand Down Expand Up @@ -245,7 +296,8 @@ def log_parameters(log_file, args, classes):
log_params[name] = params
else:
log_params[name] = getattr(cls, "__kwargs", dict())
log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__
log_params[name][
"_name"] = cls.__module__ + "." + cls.__class__.__name__
mkdir_p(os.path.dirname(log_file))
with open(log_file, "w") as f:
json.dump(log_params, f, indent=2, sort_keys=True)
Expand All @@ -258,13 +310,13 @@ def stub_to_json(stub_sth):
data = dict()
for k, v in stub_sth.kwargs.items():
data[k] = stub_to_json(v)
data["_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__
data[
"_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__
return data
elif isinstance(stub_sth, instrument.StubAttr):
return dict(
obj=stub_to_json(stub_sth.obj),
attr=stub_to_json(stub_sth.attr_name)
)
attr=stub_to_json(stub_sth.attr_name))
elif isinstance(stub_sth, instrument.StubMethodCall):
return dict(
obj=stub_to_json(stub_sth.obj),
Expand Down Expand Up @@ -294,7 +346,10 @@ def default(self, o):
if isinstance(o, type):
return {'$class': o.__module__ + "." + o.__name__}
elif isinstance(o, Enum):
return {'$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name}
return {
'$enum':
o.__module__ + "." + o.__class__.__name__ + '.' + o.name
}
return json.JSONEncoder.default(self, o)


Expand Down
127 changes: 92 additions & 35 deletions scripts/run_experiment_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,95 @@ def run_experiment(argv):

default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
parser = argparse.ArgumentParser()
parser.add_argument('--n_parallel', type=int, default=1,
help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
parser.add_argument(
'--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
parser.add_argument('--log_dir', type=str, default=None,
help='Path to save the log and iteration snapshot.')
parser.add_argument('--snapshot_mode', type=str, default='all',
help='Mode to save the snapshot. Can be either "all" '
'(all iterations will be saved), "last" (only '
'the last iteration will be saved), "gap" (every'
'`snapshot_gap` iterations are saved), or "none" '
'(do not save snapshots)')
parser.add_argument('--snapshot_gap', type=int, default=1,
help='Gap between snapshot iterations.')
parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
help='Name of the tabular log file (in csv).')
parser.add_argument('--text_log_file', type=str, default='debug.log',
help='Name of the text log file (in pure text).')
parser.add_argument('--params_log_file', type=str, default='params.json',
help='Name of the parameter log file (in json).')
parser.add_argument('--variant_log_file', type=str, default='variant.json',
help='Name of the variant log file (in json).')
parser.add_argument('--resume_from', type=str, default=None,
help='Name of the pickle file to resume experiment from.')
parser.add_argument('--plot', type=ast.literal_eval, default=False,
help='Whether to plot the iteration results')
parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
help='Whether to only print the tabular log information (in a horizontal format)')
parser.add_argument('--seed', type=int,
help='Random seed for numpy')
parser.add_argument('--args_data', type=str,
help='Pickled data for stub objects')
parser.add_argument('--variant_data', type=str,
help='Pickled data for variant configuration')
parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)
parser.add_argument(
'--n_parallel',
type=int,
default=1,
help=
'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
)
parser.add_argument(
'--exp_name',
type=str,
default=default_exp_name,
help='Name of the experiment.')
parser.add_argument(
'--log_dir',
type=str,
default=None,
help='Path to save the log and iteration snapshot.')
parser.add_argument(
'--snapshot_mode',
type=str,
default='all',
help='Mode to save the snapshot. Can be either "all" '
'(all iterations will be saved), "last" (only '
'the last iteration will be saved), "gap" (every'
'`snapshot_gap` iterations are saved), or "none" '
'(do not save snapshots)')
parser.add_argument(
'--snapshot_gap',
type=int,
default=1,
help='Gap between snapshot iterations.')
parser.add_argument(
'--tabular_log_file',
type=str,
default='progress.csv',
help='Name of the tabular log file (in csv).')
parser.add_argument(
'--text_log_file',
type=str,
default='debug.log',
help='Name of the text log file (in pure text).')
parser.add_argument(
'--tensorboard_log_dir',
type=str,
default='progress',
help='Name of the folder for tensorboard_summary.')
parser.add_argument(
'--tensorboard_step_key',
type=str,
default=None,
help=
'Name of the step key in log data which shows the step in tensorboard_summary.'
)
parser.add_argument(
'--params_log_file',
type=str,
default='params.json',
help='Name of the parameter log file (in json).')
parser.add_argument(
'--variant_log_file',
type=str,
default='variant.json',
help='Name of the variant log file (in json).')
parser.add_argument(
'--resume_from',
type=str,
default=None,
help='Name of the pickle file to resume experiment from.')
parser.add_argument(
'--plot',
type=ast.literal_eval,
default=False,
help='Whether to plot the iteration results')
parser.add_argument(
'--log_tabular_only',
type=ast.literal_eval,
default=False,
help=
'Whether to only print the tabular log information (in a horizontal format)'
)
parser.add_argument('--seed', type=int, help='Random seed for numpy')
parser.add_argument(
'--args_data', type=str, help='Pickled data for stub objects')
parser.add_argument(
'--variant_data',
type=str,
help='Pickled data for variant configuration')
parser.add_argument(
'--use_cloudpickle', type=ast.literal_eval, default=False)

args = parser.parse_args(argv[1:])

Expand All @@ -87,6 +141,7 @@ def run_experiment(argv):
tabular_log_file = osp.join(log_dir, args.tabular_log_file)
text_log_file = osp.join(log_dir, args.text_log_file)
params_log_file = osp.join(log_dir, args.params_log_file)
tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir)

if args.variant_data is not None:
variant_data = pickle.loads(base64.b64decode(args.variant_data))
Expand All @@ -100,12 +155,14 @@ def run_experiment(argv):

logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
logger.set_tensorboard_dir(tensorboard_log_dir)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode(args.snapshot_mode)
logger.set_snapshot_gap(args.snapshot_gap)
logger.set_log_tabular_only(args.log_tabular_only)
logger.set_tensorboard_step_key(args.tensorboard_step_key)
logger.push_prefix("[%s] " % args.exp_name)

if args.resume_from is not None:
Expand Down