Skip to content

Commit

Permalink
fixing so that it works with all safety gym environments
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Mar 20, 2024
1 parent b2b518c commit e31d469
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def parse_args():
def make_env(cfg, idx, capture_video, run_name, gamma):
def thunk():
if capture_video:
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty)
else:
env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
Expand Down Expand Up @@ -512,7 +512,7 @@ def train(cfg):
cfg.use_risk = False if cfg.risk_model_path == "None" else True

import wandb
run = wandb.init(config=vars(cfg), entity="kaustubh95",
run = wandb.init(config=vars(cfg), entity="kaustubh_umontreal",
project="risk_aware_exploration",
monitor_gym=True,
sync_tensorboard=True, save_code=True)
Expand Down Expand Up @@ -574,7 +574,7 @@ def train(cfg):
agent = RiskAgent(envs=envs, risk_size=risk_size).to(device)
#else:
# agent = ContRiskAgent(envs=envs).to(device)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=120, batch_norm=True, out_size=risk_size)
risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=np.array(envs.single_observation_space.shape).prod(), batch_norm=True, out_size=risk_size)
if os.path.exists(cfg.risk_model_path):
risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device))
print("Pretrained risk model loaded successfully")
Expand Down Expand Up @@ -792,11 +792,16 @@ def train(cfg):
cum_cost += ep_cost
ep_len = info["episode"]["l"][0]
buffer_num += ep_len
goal_met += info["cum_goal_met"]

#print(f"global_step={global_step}, episodic_return={info['episode']['r']}, episode_cost={ep_cost}")
scores.append(info['episode']['r'])
goal_scores.append(info["cum_goal_met"])
writer.add_scalar("goals/Ep Goal Achieved ", info["cum_goal_met"], global_step)
try:
ep_goal_met = info["cum_goal_met"]
except:
ep_goal_met = 0
goal_met += ep_goal_met
goal_scores.append(ep_goal_met)
writer.add_scalar("goals/Ep Goal Achieved ", ep_goal_met, global_step)
writer.add_scalar("goals/Avg Ep Goal", np.mean(goal_scores[-100:]))
writer.add_scalar("goals/Total Goal Achieved", goal_met, global_step)
ep_goal_met = 0
Expand Down

0 comments on commit e31d469

Please sign in to comment.