diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index b9efecb5b..a6d2a3c3e 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -159,8 +159,10 @@ def parse_args(): parser.add_argument("--quantile-num", type=int, default=5, help="number of quantiles to make") parser.add_argument("--risk-penalty", type=float, default=0., help="penalty to impose for entering risky states") parser.add_argument("--risk-penalty-start", type=float, default=20., help="penalty to impose for entering risky states") + parser.add_argument("--risk-gamma", type=float, default=0.99, help="penalty to impose for entering risky states") args = parser.parse_args() + args.quantile_size = 1 / float(args.quantile_num) args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) # fmt: on @@ -510,7 +512,7 @@ def train(cfg): cfg.use_risk = False if cfg.risk_model_path == "None" else True import wandb - run = wandb.init(config=vars(cfg), entity="kaustubh95", + run = wandb.init(config=vars(cfg), entity="kaustubh_umontreal", project="risk_aware_exploration", monitor_gym=True, sync_tensorboard=True, save_code=True) @@ -624,7 +626,6 @@ def train(cfg): global_step = 0 start_time = time.time() next_obs, _ = envs.reset(seed=cfg.seed) -# next_obs = torch.Tensor(next_obs).to(device) next_done = torch.zeros(cfg.num_envs).to(device) num_updates = cfg.total_timesteps // cfg.batch_size @@ -658,6 +659,7 @@ def train(cfg): buffer_num = 0 goal_met = 0; ep_goal_met = 0 #update = 0 + risk_penalty, ep_risk_penalty, total_risk_updates = 0, 0, 0 for update in range(1, num_updates + 1): # Annealing the rate if instructed to do so. @@ -718,18 +720,7 @@ def train(cfg): rewards[step] = torch.tensor(reward).to(device).view(-1) - risk_penalty info_dict = {'reward': reward, 'done': done, 'cost': cost, 'obs': obs} - # if cfg.collect_data: - # store_data(next_obs, info_dict, storage_path, episode, step_log) - step_log += 1 - # for i in range(cfg.num_envs): - # if not done[i]: - # cost = torch.Tensor(infos["cost"]).to(device).view(-1) - # ep_cost += infos["cost"]; cum_cost += infos["cost"] - # else: - # cost = torch.Tensor(np.array([infos["final_info"][i]["cost"]])).to(device).view(-1) - # ep_cost += np.array([infos["final_info"][i]["cost"]]); cum_cost += np.array([infos["final_info"][i]["cost"]]) - next_obs, next_done, reward = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device), torch.Tensor(reward).to(device) # print(obs_.size()) cost = torch.Tensor(np.zeros(cfg.num_envs)).to(device) @@ -745,16 +736,6 @@ def train(cfg): f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0) obs_ = next_obs - # if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk: - # if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0: - # for epoch in range(cfg.num_risk_epochs): - # if cfg.finetune_risk_online: - # print("I am online") - # data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0) - # else: - # data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk) - # risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device) - # writer.add_scalar("risk/risk_loss", risk_loss, global_step) if cfg.fine_tune_risk == "sync" and cfg.use_risk: if cfg.use_risk and buffer_num > cfg.risk_batch_size and cfg.fine_tune_risk: @@ -823,46 +804,47 @@ def train(cfg): step_log = 0 ep_risk_penalty = 0 # f_dist_to_fail = torch.Tensor(np.array(list(reversed(range(f_obs.size()[0]))))).to(device) if cost > 0 else torch.Tensor(np.array([f_obs.size()[0]]*f_obs.shape[0])).to(device) - e_risks = np.array(list(reversed(range(int(ep_len))))) if cum_cost > 0 else np.array([int(ep_len)]*int(ep_len)) - # print(risks.size()) - e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1))) - e_risks = torch.Tensor(e_risks) - - print(e_risks_quant.size()) - if cfg.fine_tune_risk != "None" and cfg.use_risk: - f_risks = e_risks.unsqueeze(1) - f_risks_quant = e_risks_quant - elif cfg.collect_data: - f_risks = e_risks.unsqueeze(1) if f_risks is None else torch.concat([f_risks, e_risks.unsqueeze(1)], axis=0) - - if cfg.fine_tune_risk in ["off", "sync"] and cfg.use_risk: - f_dist_to_fail = e_risks - if cfg.rb_type == "balanced": - idx_risky = (f_dist_to_fail<=cfg.fear_radius) - idx_safe = (f_dist_to_fail>cfg.fear_radius) - risk_ones = torch.ones_like(f_risks) - risk_zeros = torch.zeros_like(f_risks) - - if cfg.risk_type == "binary": - rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], risk_ones[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky]) - rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], risk_zeros[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe]) - else: - rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky]) - rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks_quant[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe]) - else: - if cfg.risk_type == "binary": - rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1)) + if cum_cost > 0: + e_risks = cfg.risk_gamma**(np.array(list(reversed(range(int(ep_len))))) if cum_cost > 0 else np.array([int(ep_len)]*int(ep_len))) + # print(risks.size()) + e_risks_quant = torch.Tensor(np.apply_along_axis(lambda x: np.histogram(x, bins=risk_bins)[0], 1, np.expand_dims(e_risks, 1))) + e_risks = torch.Tensor(e_risks) + + print(e_risks_quant.size()) + if cfg.fine_tune_risk != "None" and cfg.use_risk: + f_risks = e_risks.unsqueeze(1) + f_risks_quant = e_risks_quant + elif cfg.collect_data: + f_risks = e_risks.unsqueeze(1) if f_risks is None else torch.concat([f_risks, e_risks.unsqueeze(1)], axis=0) + + if cfg.fine_tune_risk in ["off", "sync"] and cfg.use_risk: + f_dist_to_fail = e_risks + if cfg.rb_type == "balanced": + idx_risky = (f_dist_to_fail<=cfg.fear_radius) + idx_safe = (f_dist_to_fail>cfg.fear_radius) + risk_ones = torch.ones_like(f_risks) + risk_zeros = torch.zeros_like(f_risks) + + if cfg.risk_type == "binary": + rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], risk_ones[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky]) + rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], risk_zeros[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe]) + else: + rb.add_risky(f_obs[i][idx_risky], f_next_obs[i][idx_risky], f_actions[i][idx_risky], f_rewards[i][idx_risky], f_dones[i][idx_risky], f_costs[i][idx_risky], f_risks_quant[idx_risky], f_dist_to_fail.unsqueeze(1)[idx_risky]) + rb.add_safe(f_obs[i][idx_safe], f_next_obs[i][idx_safe], f_actions[i][idx_safe], f_rewards[i][idx_safe], f_dones[i][idx_safe], f_costs[i][idx_safe], f_risks_quant[idx_safe], f_dist_to_fail.unsqueeze(1)[idx_safe]) else: - rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks) - - f_obs[i] = None - f_next_obs[i] = None - f_risks = None - #f_ep_len = None - f_actions[i] = None - f_rewards[i] = None - f_dones[i] = None - f_costs[i] = None + if cfg.risk_type == "binary": + rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1)) + else: + rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks) + + f_obs[i] = None + f_next_obs[i] = None + f_risks = None + #f_ep_len = None + f_actions[i] = None + f_rewards[i] = None + f_dones[i] = None + f_costs[i] = None ## Save all the data if cfg.collect_data: