Skip to content

Commit

Permalink
old changes
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Dec 20, 2023
1 parent b2b518c commit 750f33a
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ def get_action_and_value(self, x, risk, action=None):
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk)
candidates = probs.sample_n(5)
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.get_value(x, risk), candidates, probs.log_prob(candidates.squeeze()).sum(1)


class RiskAgent1(nn.Module):
def __init__(self, envs, linear_size=64, risk_size=2):
Expand Down Expand Up @@ -281,7 +283,8 @@ def get_action_and_value(self, x, risk, action=None):
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
candidates = probs.sample(5)
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x), candidates, probs.log_prob(candidates.squeeze())


class Agent(nn.Module):
Expand Down Expand Up @@ -705,14 +708,27 @@ def train(cfg):
# ALGO LOGIC: action logic
with torch.no_grad():
if cfg.use_risk:
action, logprob, _, value = agent.get_action_and_value(next_obs, next_risk)
action, logprob, _, value, candidates, logprop_cand = agent.get_action_and_value(next_obs, next_risk)
else:
action, logprob, _, value = agent.get_action_and_value(next_obs)

values[step] = value.flatten()
#actions[step] = action
#logprobs[step] = logprob

#print(logprop_cand.size())
if cfg.use_risk:
with torch.no_grad():
candidates = candidates.squeeze()
logprop_cand = logprop_cand.squeeze()
# print(next_obs_risk.repeat(5, 1).size(), candidates.size())
candidates_risk = torch.sum(torch.exp(risk_model(next_obs_risk.repeat(5, 1).to(device), candidates))[:, :2], -1)
best_cand = torch.argmin(candidates_risk)
action = candidates[best_cand]
logprop = logprop_cand[best_cand]

actions[step] = action
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminated, truncated)
Expand Down Expand Up @@ -918,7 +934,7 @@ def train(cfg):
mb_inds = b_inds[start:end]

if cfg.use_risk:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
_, newlogprob, entropy, newvalue, _, _ = agent.get_action_and_value(b_obs[mb_inds], b_risks[mb_inds], b_actions[mb_inds])
else:
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])

Expand Down

0 comments on commit 750f33a

Please sign in to comment.