Skip to content

Commit

Permalink
adding termination conditgion
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Mar 11, 2024
1 parent ca5a5e4 commit ae6ebdc
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions cleanrl/ppo_rnd_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ def parse_args():
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--reward-goal", type=float, default=1.0,
help="reward to give when the goal is achieved")
parser.add_argument("--reward-distance", type=float, default=1.0,
help="reward to give when the goal is achieved")
parser.add_argument("--early-termination", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="whether to terminate early i.e. when the catastrophe has happened")
parser.add_argument("--unifying-lidar", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="what kind of sensor is used (same for every environment?)")
parser.add_argument("--term-cost", type=int, default=1,
help="how many violations before you terminate")
parser.add_argument("--failure-penalty", type=float, default=0.0,
help="Reward Penalty when you fail")
parser.add_argument("--collect-data", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="store data while trianing")
parser.add_argument("--storage-path", type=str, default="./data/ppo/term_1",
help="the storage path for the data collected")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="MontezumaRevenge-v5",
Expand Down Expand Up @@ -140,9 +156,9 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
def make_env(cfg, idx, capture_video, run_name, gamma):
def thunk():
if capture_video:
env = gym.make(cfg.env_id)#, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
else:
env = gym.make(cfg.env_id)#, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
Expand Down

0 comments on commit ae6ebdc

Please sign in to comment.