diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index dc2a40cf8..c10ae40f5 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -248,7 +248,9 @@ def get_action_and_value(self, x, risk, action=None): probs = Normal(action_mean, action_std) if action is None: action = probs.sample() - return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk) + candidates = probs.sample_n(5) + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk), candidates, probs.log_prob(candidates.squeeze()).sum(1) + class RiskAgent1(nn.Module): def __init__(self, envs, linear_size=64, risk_size=2): @@ -281,7 +283,8 @@ def get_action_and_value(self, x, risk, action=None): probs = Normal(action_mean, action_std) if action is None: action = probs.sample() - return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + candidates = probs.sample(5) + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates, probs.log_prob(candidates.squeeze()) class Agent(nn.Module): @@ -705,14 +708,27 @@ def train(cfg): # ALGO LOGIC: action logic with torch.no_grad(): if cfg.use_risk: - action, logprob, _, value = agent.get_action_and_value(next_obs, next_risk) + action, logprob, _, value, candidates, logprop_cand = agent.get_action_and_value(next_obs, next_risk) else: action, logprob, _, value = agent.get_action_and_value(next_obs) values[step] = value.flatten() + #actions[step] = action + #logprobs[step] = logprob + + #print(logprop_cand.size()) + if cfg.use_risk: + with torch.no_grad(): + candidates = candidates.squeeze() + logprop_cand = logprop_cand.squeeze() + # print(next_obs_risk.repeat(5, 1).size(), candidates.size()) + candidates_risk = torch.sum(torch.exp(risk_model(next_obs_risk.repeat(5, 1).to(device), candidates))[:, :2], -1) + best_cand = torch.argmin(candidates_risk) + action = candidates[best_cand] + logprop = logprop_cand[best_cand] + actions[step] = action logprobs[step] = logprob - # TRY NOT TO MODIFY: execute the game and log data. next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy()) done = np.logical_or(terminated, truncated) @@ -918,7 +934,7 @@ def train(cfg): mb_inds = b_inds[start:end] if cfg.use_risk: - _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds]) + _, newlogprob, entropy, newvalue, _, _ = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds]) else: _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])