From 77ef3e0787e6fb2ea22777ed9dc61fd54ffaa41a Mon Sep 17 00:00:00 2001 From: cjcchen Date: Fri, 30 Mar 2018 20:31:11 -0700 Subject: [PATCH] Add TensorBoard Support Adds TensorBoard support for basic key-value pairs. Anything logged via `logger.record_tabular()` is also available via TensorBoard. --- environment.yml | 1 + rllab/misc/logger.py | 69 ++++++++++++++++-- scripts/run_experiment_lite.py | 127 ++++++++++++++++++++++++--------- 3 files changed, 155 insertions(+), 42 deletions(-) diff --git a/environment.yml b/environment.yml index 2623436e9..b0abb464e 100644 --- a/environment.yml +++ b/environment.yml @@ -62,3 +62,4 @@ dependencies: - pylru==1.0.9 - hyperopt - polling + - tensorboard diff --git a/rllab/misc/logger.py b/rllab/misc/logger.py index 5f25095aa..5a039c643 100644 --- a/rllab/misc/logger.py +++ b/rllab/misc/logger.py @@ -15,6 +15,7 @@ import json import pickle import base64 +import tensorflow as tf _prefixes = [] _prefix_str = '' @@ -31,6 +32,7 @@ _tabular_fds = {} _tabular_header_written = set() +_tensorboard_writer = None _snapshot_dir = None _snapshot_mode = 'all' _snapshot_gap = 1 @@ -38,6 +40,9 @@ _log_tabular_only = False _header_printed = False +_tensorboard_default_step = 0 +_tensorboard_step_key = None + def _add_output(file_name, arr, fds, mode='a'): if file_name not in arr: @@ -77,6 +82,20 @@ def remove_tabular_output(file_name): _remove_output(file_name, _tabular_outputs, _tabular_fds) +def set_tensorboard_dir(dir_name): + global _tensorboard_writer + if not dir_name: + if _tensorboard_writer: + _tensorboard_writer.close() + _tensorboard_writer = None + else: + mkdir_p(os.path.dirname(dir_name)) + _tensorboard_writer = tf.summary.FileWriter(dir_name) + _tensorboard_default_step = 0 + assert _tensorboard_writer is not None + print("tensorboard data will be logged into:", dir_name) + + def set_snapshot_dir(dir_name): global _snapshot_dir _snapshot_dir = dir_name @@ -94,18 +113,26 @@ def set_snapshot_mode(mode): global _snapshot_mode _snapshot_mode = mode + def get_snapshot_gap(): return _snapshot_gap + def set_snapshot_gap(gap): global _snapshot_gap _snapshot_gap = gap + def set_log_tabular_only(log_tabular_only): global _log_tabular_only _log_tabular_only = log_tabular_only +def set_tensorboard_step_key(key): + global _tensorboard_step_key + _tensorboard_step_key = key + + def get_log_tabular_only(): return _log_tabular_only @@ -186,6 +213,23 @@ def refresh(self): table_printer = TerminalTablePrinter() +def dump_tensorboard(*args, **kwargs): + if len(_tabular) > 0 and _tensorboard_writer: + tabular_dict = dict(_tabular) + if _tensorboard_step_key and _tensorboard_step_key in tabular_dict: + step = tabular_dict[_tensorboard_step_key] + else: + global _tensorboard_default_step + step = _tensorboard_default_step + _tensorboard_default_step += 1 + + summary = tf.Summary() + for k, v in tabular_dict.items(): + summary.value.add(tag=k, simple_value=float(v)) + _tensorboard_writer.add_summary(summary, int(step)) + _tensorboard_writer.flush() + + def dump_tabular(*args, **kwargs): wh = kwargs.pop("write_header", None) if len(_tabular) > 0: @@ -195,11 +239,18 @@ def dump_tabular(*args, **kwargs): for line in tabulate(_tabular).split('\n'): log(line, *args, **kwargs) tabular_dict = dict(_tabular) + + # write to the tensorboard folder + # This assumes that the keys in each iteration won't change! + dump_tensorboard(args, kwargs) + # Also write to the csv files # This assumes that the keys in each iteration won't change! for tabular_fd in list(_tabular_fds.values()): - writer = csv.DictWriter(tabular_fd, fieldnames=list(tabular_dict.keys())) - if wh or (wh is None and tabular_fd not in _tabular_header_written): + writer = csv.DictWriter( + tabular_fd, fieldnames=list(tabular_dict.keys())) + if wh or (wh is None + and tabular_fd not in _tabular_header_written): writer.writeheader() _tabular_header_written.add(tabular_fd) writer.writerow(tabular_dict) @@ -245,7 +296,8 @@ def log_parameters(log_file, args, classes): log_params[name] = params else: log_params[name] = getattr(cls, "__kwargs", dict()) - log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__ + log_params[name][ + "_name"] = cls.__module__ + "." + cls.__class__.__name__ mkdir_p(os.path.dirname(log_file)) with open(log_file, "w") as f: json.dump(log_params, f, indent=2, sort_keys=True) @@ -258,13 +310,13 @@ def stub_to_json(stub_sth): data = dict() for k, v in stub_sth.kwargs.items(): data[k] = stub_to_json(v) - data["_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__ + data[ + "_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__ return data elif isinstance(stub_sth, instrument.StubAttr): return dict( obj=stub_to_json(stub_sth.obj), - attr=stub_to_json(stub_sth.attr_name) - ) + attr=stub_to_json(stub_sth.attr_name)) elif isinstance(stub_sth, instrument.StubMethodCall): return dict( obj=stub_to_json(stub_sth.obj), @@ -294,7 +346,10 @@ def default(self, o): if isinstance(o, type): return {'$class': o.__module__ + "." + o.__name__} elif isinstance(o, Enum): - return {'$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name} + return { + '$enum': + o.__module__ + "." + o.__class__.__name__ + '.' + o.name + } return json.JSONEncoder.default(self, o) diff --git a/scripts/run_experiment_lite.py b/scripts/run_experiment_lite.py index 528eb1d2e..386ddf6d3 100644 --- a/scripts/run_experiment_lite.py +++ b/scripts/run_experiment_lite.py @@ -29,41 +29,95 @@ def run_experiment(argv): default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) parser = argparse.ArgumentParser() - parser.add_argument('--n_parallel', type=int, default=1, - help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers') - parser.add_argument( - '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') - parser.add_argument('--log_dir', type=str, default=None, - help='Path to save the log and iteration snapshot.') - parser.add_argument('--snapshot_mode', type=str, default='all', - help='Mode to save the snapshot. Can be either "all" ' - '(all iterations will be saved), "last" (only ' - 'the last iteration will be saved), "gap" (every' - '`snapshot_gap` iterations are saved), or "none" ' - '(do not save snapshots)') - parser.add_argument('--snapshot_gap', type=int, default=1, - help='Gap between snapshot iterations.') - parser.add_argument('--tabular_log_file', type=str, default='progress.csv', - help='Name of the tabular log file (in csv).') - parser.add_argument('--text_log_file', type=str, default='debug.log', - help='Name of the text log file (in pure text).') - parser.add_argument('--params_log_file', type=str, default='params.json', - help='Name of the parameter log file (in json).') - parser.add_argument('--variant_log_file', type=str, default='variant.json', - help='Name of the variant log file (in json).') - parser.add_argument('--resume_from', type=str, default=None, - help='Name of the pickle file to resume experiment from.') - parser.add_argument('--plot', type=ast.literal_eval, default=False, - help='Whether to plot the iteration results') - parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, - help='Whether to only print the tabular log information (in a horizontal format)') - parser.add_argument('--seed', type=int, - help='Random seed for numpy') - parser.add_argument('--args_data', type=str, - help='Pickled data for stub objects') - parser.add_argument('--variant_data', type=str, - help='Pickled data for variant configuration') - parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) + parser.add_argument( + '--n_parallel', + type=int, + default=1, + help= + 'Number of parallel workers to perform rollouts. 0 => don\'t start any workers' + ) + parser.add_argument( + '--exp_name', + type=str, + default=default_exp_name, + help='Name of the experiment.') + parser.add_argument( + '--log_dir', + type=str, + default=None, + help='Path to save the log and iteration snapshot.') + parser.add_argument( + '--snapshot_mode', + type=str, + default='all', + help='Mode to save the snapshot. Can be either "all" ' + '(all iterations will be saved), "last" (only ' + 'the last iteration will be saved), "gap" (every' + '`snapshot_gap` iterations are saved), or "none" ' + '(do not save snapshots)') + parser.add_argument( + '--snapshot_gap', + type=int, + default=1, + help='Gap between snapshot iterations.') + parser.add_argument( + '--tabular_log_file', + type=str, + default='progress.csv', + help='Name of the tabular log file (in csv).') + parser.add_argument( + '--text_log_file', + type=str, + default='debug.log', + help='Name of the text log file (in pure text).') + parser.add_argument( + '--tensorboard_log_dir', + type=str, + default='progress', + help='Name of the folder for tensorboard_summary.') + parser.add_argument( + '--tensorboard_step_key', + type=str, + default=None, + help= + 'Name of the step key in log data which shows the step in tensorboard_summary.' + ) + parser.add_argument( + '--params_log_file', + type=str, + default='params.json', + help='Name of the parameter log file (in json).') + parser.add_argument( + '--variant_log_file', + type=str, + default='variant.json', + help='Name of the variant log file (in json).') + parser.add_argument( + '--resume_from', + type=str, + default=None, + help='Name of the pickle file to resume experiment from.') + parser.add_argument( + '--plot', + type=ast.literal_eval, + default=False, + help='Whether to plot the iteration results') + parser.add_argument( + '--log_tabular_only', + type=ast.literal_eval, + default=False, + help= + 'Whether to only print the tabular log information (in a horizontal format)' + ) + parser.add_argument('--seed', type=int, help='Random seed for numpy') + parser.add_argument( + '--args_data', type=str, help='Pickled data for stub objects') + parser.add_argument( + '--variant_data', + type=str, + help='Pickled data for variant configuration') + parser.add_argument( + '--use_cloudpickle', type=ast.literal_eval, default=False) args = parser.parse_args(argv[1:]) @@ -87,6 +141,7 @@ def run_experiment(argv): tabular_log_file = osp.join(log_dir, args.tabular_log_file) text_log_file = osp.join(log_dir, args.text_log_file) params_log_file = osp.join(log_dir, args.params_log_file) + tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir) if args.variant_data is not None: variant_data = pickle.loads(base64.b64decode(args.variant_data)) @@ -100,12 +155,14 @@ def run_experiment(argv): logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) + logger.set_tensorboard_dir(tensorboard_log_dir) prev_snapshot_dir = logger.get_snapshot_dir() prev_mode = logger.get_snapshot_mode() logger.set_snapshot_dir(log_dir) logger.set_snapshot_mode(args.snapshot_mode) logger.set_snapshot_gap(args.snapshot_gap) logger.set_log_tabular_only(args.log_tabular_only) + logger.set_tensorboard_step_key(args.tensorboard_step_key) logger.push_prefix("[%s] " % args.exp_name) if args.resume_from is not None: