From a974cd685733a3837327c0099c7ee6632a333b68 Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Mon, 18 Dec 2023 18:10:34 -0500 Subject: [PATCH] state action risk --- cleanrl/ppo_continuous_action_wandb.py | 68 +++++++++++++------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index b9efecb5b..71d3616df 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -9,6 +9,7 @@ import gymnasium as gym import numpy as np import torch +import tqdm import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset @@ -247,7 +248,9 @@ def get_action_and_value(self, x, risk, action=None): probs = Normal(action_mean, action_std) if action is None: action = probs.sample() - return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk) + candidates = probs.sample_n(5) + + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk), candidates class RiskAgent1(nn.Module): def __init__(self, envs, linear_size=64, risk_size=2): @@ -280,7 +283,9 @@ def get_action_and_value(self, x, risk, action=None): probs = Normal(action_mean, action_std) if action is None: action = probs.sample() - return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + candidates = probs.sample_n(5) + + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates class Agent(nn.Module): @@ -312,7 +317,8 @@ def get_action_and_value(self, x, action=None): probs = Normal(action_mean, action_std) if action is None: action = probs.sample() - return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + candidates = probs.sample_n(5) + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates class ContRiskAgent(nn.Module): @@ -412,28 +418,6 @@ def risk_sgd_step(cfg, model, data, criterion, opt, device): return loss -def train_risk(cfg, model, data, criterion, opt, device): - model.train() - dataset = RiskyDataset(data["next_obs"].to('cpu'), None, data["risks"].to('cpu'), False, risk_type=cfg.risk_type, - fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num) - dataloader = DataLoader(dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=10, generator=torch.Generator(device='cpu')) - net_loss = 0 - for batch in dataloader: - pred = model(get_risk_obs(cfg, batch[0]).to(device)) - if cfg.model_type == "mlp": - loss = criterion(pred, batch[1].squeeze().to(device)) - else: - loss = criterion(pred, torch.argmax(batch[1].squeeze(), axis=1).to(device)) - opt.zero_grad() - loss.backward() - opt.step() - - net_loss += loss.item() - torch.save(model.state_dict(), os.path.join(wandb.run.dir, "risk_model.pt")) - wandb.save("risk_model.pt") - model.eval() - print("risk_loss:", net_loss) - return net_loss def test_policy(cfg, agent, envs, device, risk_model=None): envs = gym.vector.SyncVectorEnv( @@ -574,7 +558,7 @@ def train(cfg): agent = RiskAgent(envs=envs, risk_size=risk_size).to(device) #else: # agent = ContRiskAgent(envs=envs).to(device) - risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size) + risk_model = risk_model_class[cfg.model_type][cfg.risk_type](obs_size=96, batch_norm=True, out_size=risk_size, action_size=envs.single_action_space.shape[0], model_type="state_action_risk") if os.path.exists(cfg.risk_model_path): risk_model.load_state_dict(torch.load(cfg.risk_model_path, map_location=device)) print("Pretrained risk model loaded successfully") @@ -702,14 +686,20 @@ def train(cfg): # ALGO LOGIC: action logic with torch.no_grad(): if cfg.use_risk: - action, logprob, _, value = agent.get_action_and_value(next_obs, next_risk) + action, logprob, _, value, candidates = agent.get_action_and_value(next_obs, next_risk) else: - action, logprob, _, value = agent.get_action_and_value(next_obs) + action, logprob, _, value, candidates = agent.get_action_and_value(next_obs) values[step] = value.flatten() actions[step] = action logprobs[step] = logprob + if cfg.use_risk: + with torch.no_grad(): + candidates = candidates.squeeze() + # print(next_obs_risk.repeat(5, 1).size(), candidates.size()) + candidates_risk = torch.sum(torch.exp(risk_model(next_obs_risk.repeat(5, 1).to(device), candidates))[:, :2], -1) + action = candidates[torch.argmin(candidates_risk)] # TRY NOT TO MODIFY: execute the game and log data. next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy()) done = np.logical_or(terminated, truncated) @@ -738,12 +728,13 @@ def train(cfg): for i in range(cfg.num_envs): f_obs[i] = obs_[i].unsqueeze(0).to(device) if f_obs[i] is None else torch.concat([f_obs[i], obs_[i].unsqueeze(0).to(device)], axis=0) f_next_obs[i] = next_obs[i].unsqueeze(0).to(device) if f_next_obs[i] is None else torch.concat([f_next_obs[i], next_obs[i].unsqueeze(0).to(device)], axis=0) - f_actions[i] = action[i].unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], action[i].unsqueeze(0).to(device)], axis=0) + f_actions[i] = action.unsqueeze(0).to(device) if f_actions[i] is None else torch.concat([f_actions[i], action.unsqueeze(0).to(device)], axis=0) f_rewards[i] = reward[i].unsqueeze(0).to(device) if f_rewards[i] is None else torch.concat([f_rewards[i], reward[i].unsqueeze(0).to(device)], axis=0) # f_risks = risk_ if f_risks is None else torch.concat([f_risks, risk_], axis=0) f_costs[i] = cost[i].unsqueeze(0).to(device) if f_costs[i] is None else torch.concat([f_costs[i], cost[i].unsqueeze(0).to(device)], axis=0) f_dones[i] = next_done[i].unsqueeze(0).to(device) if f_dones[i] is None else torch.concat([f_dones[i], next_done[i].unsqueeze(0).to(device)], axis=0) + # print(f_actions[0].size()) obs_ = next_obs # if global_step % cfg.update_risk_model == 0 and cfg.fine_tune_risk: # if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0: @@ -767,7 +758,7 @@ def train(cfg): writer.add_scalar("risk/risk_loss", risk_loss, global_step) elif cfg.fine_tune_risk == "off" and cfg.use_risk: if cfg.use_risk and (global_step > cfg.start_risk_update and cfg.fine_tune_risk) and global_step % cfg.risk_update_period == 0: - for epoch in range(cfg.num_risk_epochs): + for epoch in tqdm.tqdm(range(cfg.num_risk_epochs)): total_risk_updates +=1 print(total_risk_updates) if cfg.finetune_risk_online: @@ -775,7 +766,16 @@ def train(cfg): data = rb.slice_data(-cfg.risk_batch_size*cfg.num_update_risk, 0) else: data = rb.sample(cfg.risk_batch_size*cfg.num_update_risk) - risk_loss = train_risk(cfg, risk_model, data, criterion, opt_risk, device) + state = torch.cat([data["obs"], data["next_obs"]], axis=0) + actions = torch.cat([data["actions"], torch.zeros_like(data["actions"])], axis=0) + dist_to_fail = torch.cat([data["dist_to_fail"], data["dist_to_fail"]], axis=0) + print(state.size(), actions.size(), dist_to_fail.size()) + risk_dataset = RiskyDataset(state.to(device), actions.to(device), dist_to_fail.to(device), True, risk_type=cfg.risk_type, + fear_clip=None, fear_radius=cfg.fear_radius, one_hot=True, quantile_size=cfg.quantile_size, quantile_num=cfg.quantile_num) + risk_dataloader = DataLoader(risk_dataset, batch_size=cfg.risk_batch_size, shuffle=True, num_workers=4, generator=torch.Generator(device=device)) + + risk_loss = train_risk(risk_model, risk_dataloader, criterion, opt_risk, 1, device, train_mode="state_action") + writer.add_scalar("risk/risk_loss", risk_loss, global_step) # Only print when at least 1 env is done @@ -853,7 +853,7 @@ def train(cfg): if cfg.risk_type == "binary": rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], (f_risks <= cfg.fear_radius).float(), e_risks.unsqueeze(1)) else: - rb.add(f_obs[i], f_next_obs[i], f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks) + rb.add(get_risk_obs(cfg, f_obs[i]), get_risk_obs(cfg, f_next_obs[i]), f_actions[i], f_rewards[i], f_dones[i], f_costs[i], f_risks, f_risks) f_obs[i] = None f_next_obs[i] = None @@ -915,9 +915,9 @@ def train(cfg): mb_inds = b_inds[start:end] if cfg.use_risk: - _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds]) + _, newlogprob, entropy, newvalue, cands = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds]) else: - _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) + _, newlogprob, entropy, newvalue, cands = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) logratio = newlogprob - b_logprobs[mb_inds] ratio = logratio.exp()