From 653ab9aa8ac06488adea3de8e439e63927fda144 Mon Sep 17 00:00:00 2001 From: Yonathan Efroni Date: Fri, 2 Feb 2024 12:03:12 -0800 Subject: [PATCH] Update title and font in benchmarkplots Summary: -updated title and font size in the plot generation of benchmarks. -added config files that are used by the safety reward constraint module -remove alpha=0.8 in reward constraint module (gives almost similar results as the baseline) Reviewed By: yiwan-rl Differential Revision: D53355512 fbshipit-source-id: 8575677a2e1ee328e20f63e3cc95d8939ab1e6ab --- pearl/utils/scripts/benchmark.py | 28 +++++++++++++++++++++---- pearl/utils/scripts/benchmark_config.py | 24 ++++++++++----------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/pearl/utils/scripts/benchmark.py b/pearl/utils/scripts/benchmark.py index 99856259..e46f361a 100644 --- a/pearl/utils/scripts/benchmark.py +++ b/pearl/utils/scripts/benchmark.py @@ -35,11 +35,28 @@ benchmark_hopper_v4, # noqa benchmark_pendulum_v1_lstm, benchmark_walker2d_v4, # noqa - get_env, - test_dynamic_action_space, + get_env, # noqa + rccsac_ant, # noqa + rccsac_half_cheetah, # noqa + rccsac_hopper, # noqa + rccsac_walker, # noqa + rcddpg_ant, # noqa + rcddpg_half_cheetah, # noqa + rcddpg_hopper, # noqa + rcddpg_walker, # noqa + rctd3_ant, # noqa + rctd3_half_cheetah, # noqa + rctd3_hopper, # noqa + rctd3_walker, # noqa + test_dynamic_action_space, # noqa ) warnings.filterwarnings("ignore") +attr_to_title = { + "return": "return", + "return_cost": "cummulative cost", + "risk_sa": "risk_sa", +} def run(experiments) -> None: @@ -230,6 +247,7 @@ def generate_plots(experiments, attributes) -> None: def generate_one_plot(experiment, attributes): """Generating learning curves for all tested methods in one environment.""" + plt.rcParams.update({"font.size": 15}) env_name = experiment["env_name"] exp_name = experiment["exp_name"] num_runs = experiment["num_runs"] @@ -262,12 +280,14 @@ def generate_one_plot(experiment, attributes): mean + std_error, alpha=0.2, ) - plt.title(env_name) + plt.title(env_name.replace("_", "-")) + plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) + plt.ticklabel_format(style="sci", axis="x", scilimits=(0, 0)) if "num_steps" in experiment: plt.xlabel("Steps") else: plt.xlabel("Episodes") - plt.ylabel(attr) + plt.ylabel(attr_to_title[attr]) plt.legend() plt.savefig(f"outputs/{exp_name}_{env_name}_{attr}.png") plt.close() diff --git a/pearl/utils/scripts/benchmark_config.py b/pearl/utils/scripts/benchmark_config.py index 9e12561c..7a7903d8 100644 --- a/pearl/utils/scripts/benchmark_config.py +++ b/pearl/utils/scripts/benchmark_config.py @@ -995,7 +995,7 @@ RCTD3_method_const_0_1, RCTD3_method_const_0_2, RCTD3_method_const_0_4, - RCTD3_method_const_0_8, + # RCTD3_method_const_0_8, TD3_method, ], "device_id": 3, @@ -1018,7 +1018,7 @@ RCDDPG_method_const_0_1, RCDDPG_method_const_0_2, RCDDPG_method_const_0_4, - RCDDPG_method_const_0_8, + # RCDDPG_method_const_0_8, DDPG_method, ], "device_id": 3, @@ -1041,7 +1041,7 @@ RCCSAC_method_const_0_1, RCCSAC_method_const_0_2, RCCSAC_method_const_0_4, - RCCSAC_method_const_0_8, + # RCCSAC_method_const_0_8, CSAC_method, ], "device_id": 3, @@ -1065,7 +1065,7 @@ RCDDPG_method_const_0_1, RCDDPG_method_const_0_2, RCDDPG_method_const_0_4, - RCDDPG_method_const_0_8, + # RCDDPG_method_const_0_8, DDPG_method, ], "device_id": 2, @@ -1088,7 +1088,7 @@ RCTD3_method_const_0_1, RCTD3_method_const_0_2, RCTD3_method_const_0_4, - RCTD3_method_const_0_8, + # RCTD3_method_const_0_8, TD3_method, ], "device_id": 2, @@ -1111,7 +1111,7 @@ RCCSAC_method_const_0_1, RCCSAC_method_const_0_2, RCCSAC_method_const_0_4, - RCCSAC_method_const_0_8, + # RCCSAC_method_const_0_8, CSAC_method, ], "device_id": 2, @@ -1135,7 +1135,7 @@ RCTD3_method_const_0_1, RCTD3_method_const_0_2, RCTD3_method_const_0_4, - RCTD3_method_const_0_8, + # RCTD3_method_const_0_8, TD3_method, ], "device_id": 0, @@ -1158,7 +1158,7 @@ RCDDPG_method_const_0_1, RCDDPG_method_const_0_2, RCDDPG_method_const_0_4, - RCDDPG_method_const_0_8, + # RCDDPG_method_const_0_8, DDPG_method, ], "device_id": 0, @@ -1181,7 +1181,7 @@ RCCSAC_method_const_0_1, RCCSAC_method_const_0_2, RCCSAC_method_const_0_4, - RCCSAC_method_const_0_8, + # RCCSAC_method_const_0_8, CSAC_method, ], "device_id": 0, @@ -1204,7 +1204,7 @@ RCTD3_method_const_0_1, RCTD3_method_const_0_2, RCTD3_method_const_0_4, - RCTD3_method_const_0_8, + # RCTD3_method_const_0_8, TD3_method, ], "device_id": 1, @@ -1228,7 +1228,7 @@ RCDDPG_method_const_0_1, RCDDPG_method_const_0_2, RCDDPG_method_const_0_4, - RCDDPG_method_const_0_8, + # RCDDPG_method_const_0_8, DDPG_method, ], "device_id": 1, @@ -1251,7 +1251,7 @@ RCCSAC_method_const_0_1, RCCSAC_method_const_0_2, RCCSAC_method_const_0_4, - RCCSAC_method_const_0_8, + # RCCSAC_method_const_0_8, CSAC_method, ], "device_id": 1,