Skip to content

Commit

Permalink
for velocity environments
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Mar 8, 2024
1 parent b65d637 commit d8164a2
Showing 1 changed file with 31 additions and 29 deletions.
60 changes: 31 additions & 29 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def parse_args():
help="the id of the environment")
parser.add_argument("--binary-risk", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Use risk model in the critic or not ")
parser.add_argument("--model-type", type=str, default="mlp",
parser.add_argument("--model-type", type=str, default="bayesian",
help="specify the NN to use for the risk model")
parser.add_argument("--risk-bnorm", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--risk-type", type=str, default="binary",
parser.add_argument("--risk-type", type=str, default="quantile",
help="whether the risk is binary or continuous")
parser.add_argument("--fear-radius", type=int, default=5,
help="fear radius for training the risk model")
Expand All @@ -147,7 +147,7 @@ def parse_args():
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
parser.add_argument("--start-risk-update", type=int, default=10000,
help="number of epochs to update the risk model")
parser.add_argument("--rb-type", type=str, default="balanced",
parser.add_argument("--rb-type", type=str, default="simple",
help="which type of replay buffer to use for ")
parser.add_argument("--freeze-risk-layers", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
Expand All @@ -168,9 +168,9 @@ def parse_args():
def make_env(cfg, idx, capture_video, run_name, gamma):
def thunk():
if capture_video:
env = gym.make(cfg.env_id, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.make(cfg.env_id)#, render_mode="rgb_array", early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
else:
env = gym.make(cfg.env_id, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.make(cfg.env_id)#, early_termination=cfg.early_termination, term_cost=cfg.term_cost, failure_penalty=cfg.failure_penalty, reward_goal=cfg.reward_goal, reward_distance=cfg.reward_distance)
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
Expand Down Expand Up @@ -665,9 +665,9 @@ def train(cfg):
# frac = 1.0 - (update - 1.0) / num_updates
# lrnow = frac * cfg.learning_rate
# optimizer.param_groups[0]["lr"] = lrnow
print(episode, cfg.max_episodes)
if episode > cfg.max_episodes:
break
# print(episode, cfg.max_episodes)
# if episode > cfg.max_episodes:
# break

for step in range(0, cfg.num_steps):
risk = torch.Tensor([[0.]]).to(device)
Expand Down Expand Up @@ -753,6 +753,7 @@ def train(cfg):
# writer.add_scalar("risk/risk_loss", risk_loss, global_step)

if cfg.fine_tune_risk == "sync" and cfg.use_risk:
# print(len(rb))
if cfg.use_risk and len(rb) > cfg.start_risk_update and cfg.fine_tune_risk:
if cfg.finetune_risk_online:
print("I am online")
Expand Down Expand Up @@ -781,15 +782,15 @@ def train(cfg):
# Skip the envs that are not done
if info is None:
continue
ep_cost = info["cost_sum"]
ep_cost = info["cost"]
cum_cost += ep_cost
ep_len = info["episode"]["l"][0]
buffer_num += ep_len
goal_met += info["cum_goal_met"]
goal_met += 0 #info["cum_goal_met"]
#print(f"global_step={global_step}, episodic_return={info['episode']['r']}, episode_cost={ep_cost}")
scores.append(info['episode']['r'])
goal_scores.append(info["cum_goal_met"])
writer.add_scalar("goals/Ep Goal Achieved ", info["cum_goal_met"], global_step)
goal_scores.append(0)
writer.add_scalar("goals/Ep Goal Achieved ", 0, global_step)
writer.add_scalar("goals/Avg Ep Goal", np.mean(goal_scores[-100:]))
writer.add_scalar("goals/Total Goal Achieved", goal_met, global_step)
ep_goal_met = 0
Expand Down Expand Up @@ -819,7 +820,7 @@ def train(cfg):
episode += 1
step_log = 0
# f_dist_to_fail = torch.Tensor(np.array(list(reversed(range(f_obs.size()[0]))))).to(device) if cost > 0 else torch.Tensor(np.array([f_obs.size()[0]]*f_obs.shape[0])).to(device)
e_risks = np.array(list(reversed(range(int(ep_len))))) if cum_cost > 0 else np.array([int(ep_len)]*int(ep_len))
e_risks = np.array(list(reversed(range(int(ep_len))))) if ep_cost > 0 else np.array([int(ep_len)]*int(ep_len))
# print(risks.size())
e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1)))
e_risks = torch.Tensor(e_risks)
Expand All @@ -833,23 +834,24 @@ def train(cfg):

if cfg.fine_tune_risk in ["off", "sync"] and cfg.use_risk:
f_dist_to_fail = e_risks
if cfg.rb_type == "balanced":
idx_risky = (f_dist_to_fail<=cfg.fear_radius)
idx_safe = (f_dist_to_fail>cfg.fear_radius)
risk_ones = torch.ones_like(f_risks)
risk_zeros = torch.zeros_like(f_risks)

if cfg.risk_type == "binary":
rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], risk_ones[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky])
rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], risk_zeros[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe])
if terminated:
if cfg.rb_type == "balanced":
idx_risky = (f_dist_to_fail<=cfg.fear_radius)
idx_safe = (f_dist_to_fail>cfg.fear_radius)
risk_ones = torch.ones_like(f_risks)
risk_zeros = torch.zeros_like(f_risks)

if cfg.risk_type == "binary":
rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], risk_ones[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky])
rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], risk_zeros[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe])
else:
rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky])
rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks_quant[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe])
else:
rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky])
rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks_quant[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe])
else:
if cfg.risk_type == "binary":
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1))
else:
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks_quant, f_risks)
if cfg.risk_type == "binary":
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1))
else:
rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks_quant, f_risks)

f_obs[i] = None
f_next_obs[i] = None
Expand Down

0 comments on commit d8164a2

Please sign in to comment.