From ae6ebdc30a3a9857002cb4a35c5f91b74416d2cc Mon Sep 17 00:00:00 2001 From: kaustubh Date: Mon, 11 Mar 2024 17:57:03 -0400 Subject: [PATCH] adding termination conditgion --- cleanrl/ppo_rnd_envpool.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/cleanrl/ppo_rnd_envpool.py b/cleanrl/ppo_rnd_envpool.py index daa23960..b6840c43 100644 --- a/cleanrl/ppo_rnd_envpool.py +++ b/cleanrl/ppo_rnd_envpool.py @@ -37,6 +37,22 @@ def parse_args(): help="the wandb's project name") parser.add_argument("--wandb-entity", type=str, default=None, help="the entity (team) of wandb's project") + parser.add_argument("--reward-goal", type=float, default=1.0, + help="reward to give when the goal is achieved") + parser.add_argument("--reward-distance", type=float, default=1.0, + help="reward to give when the goal is achieved") + parser.add_argument("--early-termination", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="whether to terminate early i.e. when the catastrophe has happened") + parser.add_argument("--unifying-lidar", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="what kind of sensor is used (same for every environment?)") + parser.add_argument("--term-cost", type=int, default=1, + help="how many violations before you terminate") + parser.add_argument("--failure-penalty", type=float, default=0.0, + help="Reward Penalty when you fail") + parser.add_argument("--collect-data", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="store data while trianing") + parser.add_argument("--storage-path", type=str, default="./data/ppo/term_1", + help="the storage path for the data collected") # Algorithm specific arguments parser.add_argument("--env-id", type=str, default="MontezumaRevenge-v5", @@ -140,9 +156,9 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0): def make_env(cfg, idx, capture_video, run_name, gamma): def thunk(): if capture_video: - env = gym.make(cfg.env_id)#, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance) + env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance) else: - env = gym.make(cfg.env_id)#, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance) + env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance) env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: