Skip to content

Commit

Permalink
added the first step for test policy
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Sep 19, 2023
1 parent 5ccc53d commit 2d0a4bc
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,29 @@ def train_risk(cfg, model, data, criterion, opt, device):
model.eval()
return net_loss



def test_policy(cfg, agent, envs, device, risk_model=None):
next_obs, _ = envs.reset()
total_reward = 0
risk = None
step = 0
while True:
step+= 1
next_obs = torch.from_numpy(next_obs).to(device)
with torch.no_grad():
if cfg.use_risk:
#next_obs = torch.from_numpy(next_obs).to(device)
risk = risk_model(get_risk_obs(cfg, next_obs).to(device))
action, logprob, _, value = agent.get_action_and_value(next_obs, next_risk)
else:
action, logprob, _, value = agent.get_action_and_value(next_obs)
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminated, truncated)
total_reward += reward
if done:
break
return total_reward


def get_risk_obs(cfg, next_obs):
if "goal" in cfg.risk_model_path.lower():
if "push" in cfg.env_id.lower():
Expand Down Expand Up @@ -671,6 +692,9 @@ def train(cfg):

print(f"global_step={global_step}, episodic_return={info['episode']['r']}, episode_cost={ep_cost}")


avg_total_reward = np.mean([test_policy(cfg, agent, envs, device=device, risk_model=risk) for i in range(20)])
writer.add_scalar("Results/Test_Avg_Total_Reward", avg_total_reward, global_step)
if cfg.use_risk:
ep_risk = torch.sum(all_risks[last_step:global_step]).item()
cum_risk += ep_risk
Expand Down

0 comments on commit 2d0a4bc

Please sign in to comment.