Skip to content

Commit

Permalink
Update title and font in benchmarkplots
Browse files Browse the repository at this point in the history
Summary:
-updated title and font size in the plot generation of benchmarks.
-added config files that are used by the safety reward constraint module
-remove alpha=0.8 in reward constraint module (gives almost similar results as the baseline)

Reviewed By: yiwan-rl

Differential Revision: D53355512

fbshipit-source-id: 8575677a2e1ee328e20f63e3cc95d8939ab1e6ab
  • Loading branch information
Yonathan Efroni authored and facebook-github-bot committed Feb 2, 2024
1 parent a8d96bb commit 653ab9a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
28 changes: 24 additions & 4 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,28 @@
benchmark_hopper_v4, # noqa
benchmark_pendulum_v1_lstm,
benchmark_walker2d_v4, # noqa
get_env,
test_dynamic_action_space,
get_env, # noqa
rccsac_ant, # noqa
rccsac_half_cheetah, # noqa
rccsac_hopper, # noqa
rccsac_walker, # noqa
rcddpg_ant, # noqa
rcddpg_half_cheetah, # noqa
rcddpg_hopper, # noqa
rcddpg_walker, # noqa
rctd3_ant, # noqa
rctd3_half_cheetah, # noqa
rctd3_hopper, # noqa
rctd3_walker, # noqa
test_dynamic_action_space, # noqa
)

warnings.filterwarnings("ignore")
attr_to_title = {
"return": "return",
"return_cost": "cummulative cost",
"risk_sa": "risk_sa",
}


def run(experiments) -> None:
Expand Down Expand Up @@ -230,6 +247,7 @@ def generate_plots(experiments, attributes) -> None:

def generate_one_plot(experiment, attributes):
"""Generating learning curves for all tested methods in one environment."""
plt.rcParams.update({"font.size": 15})
env_name = experiment["env_name"]
exp_name = experiment["exp_name"]
num_runs = experiment["num_runs"]
Expand Down Expand Up @@ -262,12 +280,14 @@ def generate_one_plot(experiment, attributes):
mean + std_error,
alpha=0.2,
)
plt.title(env_name)
plt.title(env_name.replace("_", "-"))
plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
plt.ticklabel_format(style="sci", axis="x", scilimits=(0, 0))
if "num_steps" in experiment:
plt.xlabel("Steps")
else:
plt.xlabel("Episodes")
plt.ylabel(attr)
plt.ylabel(attr_to_title[attr])
plt.legend()
plt.savefig(f"outputs/{exp_name}_{env_name}_{attr}.png")
plt.close()
Expand Down
24 changes: 12 additions & 12 deletions pearl/utils/scripts/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@
RCTD3_method_const_0_1,
RCTD3_method_const_0_2,
RCTD3_method_const_0_4,
RCTD3_method_const_0_8,
# RCTD3_method_const_0_8,
TD3_method,
],
"device_id": 3,
Expand All @@ -1018,7 +1018,7 @@
RCDDPG_method_const_0_1,
RCDDPG_method_const_0_2,
RCDDPG_method_const_0_4,
RCDDPG_method_const_0_8,
# RCDDPG_method_const_0_8,
DDPG_method,
],
"device_id": 3,
Expand All @@ -1041,7 +1041,7 @@
RCCSAC_method_const_0_1,
RCCSAC_method_const_0_2,
RCCSAC_method_const_0_4,
RCCSAC_method_const_0_8,
# RCCSAC_method_const_0_8,
CSAC_method,
],
"device_id": 3,
Expand All @@ -1065,7 +1065,7 @@
RCDDPG_method_const_0_1,
RCDDPG_method_const_0_2,
RCDDPG_method_const_0_4,
RCDDPG_method_const_0_8,
# RCDDPG_method_const_0_8,
DDPG_method,
],
"device_id": 2,
Expand All @@ -1088,7 +1088,7 @@
RCTD3_method_const_0_1,
RCTD3_method_const_0_2,
RCTD3_method_const_0_4,
RCTD3_method_const_0_8,
# RCTD3_method_const_0_8,
TD3_method,
],
"device_id": 2,
Expand All @@ -1111,7 +1111,7 @@
RCCSAC_method_const_0_1,
RCCSAC_method_const_0_2,
RCCSAC_method_const_0_4,
RCCSAC_method_const_0_8,
# RCCSAC_method_const_0_8,
CSAC_method,
],
"device_id": 2,
Expand All @@ -1135,7 +1135,7 @@
RCTD3_method_const_0_1,
RCTD3_method_const_0_2,
RCTD3_method_const_0_4,
RCTD3_method_const_0_8,
# RCTD3_method_const_0_8,
TD3_method,
],
"device_id": 0,
Expand All @@ -1158,7 +1158,7 @@
RCDDPG_method_const_0_1,
RCDDPG_method_const_0_2,
RCDDPG_method_const_0_4,
RCDDPG_method_const_0_8,
# RCDDPG_method_const_0_8,
DDPG_method,
],
"device_id": 0,
Expand All @@ -1181,7 +1181,7 @@
RCCSAC_method_const_0_1,
RCCSAC_method_const_0_2,
RCCSAC_method_const_0_4,
RCCSAC_method_const_0_8,
# RCCSAC_method_const_0_8,
CSAC_method,
],
"device_id": 0,
Expand All @@ -1204,7 +1204,7 @@
RCTD3_method_const_0_1,
RCTD3_method_const_0_2,
RCTD3_method_const_0_4,
RCTD3_method_const_0_8,
# RCTD3_method_const_0_8,
TD3_method,
],
"device_id": 1,
Expand All @@ -1228,7 +1228,7 @@
RCDDPG_method_const_0_1,
RCDDPG_method_const_0_2,
RCDDPG_method_const_0_4,
RCDDPG_method_const_0_8,
# RCDDPG_method_const_0_8,
DDPG_method,
],
"device_id": 1,
Expand All @@ -1251,7 +1251,7 @@
RCCSAC_method_const_0_1,
RCCSAC_method_const_0_2,
RCCSAC_method_const_0_4,
RCCSAC_method_const_0_8,
# RCCSAC_method_const_0_8,
CSAC_method,
],
"device_id": 1,
Expand Down

0 comments on commit 653ab9a

Please sign in to comment.