Skip to content

Commit

Permalink
hopefully the final change
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaustubh Mani committed Aug 5, 2023
1 parent 44632a5 commit 859ff36
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions cleanrl/ppo_continuous_action_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def train(cfg):
next_obs, _ = envs.reset(seed=cfg.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(cfg.num_envs).to(device)
num_updates = cfg.total_timesteps // batch_size
num_updates = cfg.total_timesteps // cfg.batch_size

cum_cost, ep_cost, ep_risk_cost_int, cum_risk_cost_int, ep_risk, cum_risk = 0, 0, 0, 0, 0, 0
cost = 0
Expand Down Expand Up @@ -452,12 +452,12 @@ def train(cfg):
b_risks = risks.reshape((-1, ) + (2, ))

# Optimizing the policy and value network
b_inds = np.arange(batch_size)
b_inds = np.arange(cfg.batch_size)
clipfracs = []
for epoch in range(cfg.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, batch_size, minibatch_size):
end = start + minibatch_size
for start in range(0, cfg.batch_size, cfg.minibatch_size):
end = start + cfg.minibatch_size
mb_inds = b_inds[start:end]

if cfg.use_risk:
Expand Down Expand Up @@ -506,7 +506,7 @@ def train(cfg):
nn.utils.clip_grad_norm_(agent.parameters(), cfg.max_grad_norm)
optimizer.step()

if cfg.target_kl != "None":
if cfg.target_kl is not None:
if approx_kl > cfg.target_kl:
break

Expand Down Expand Up @@ -550,4 +550,4 @@ def train(cfg):

if __name__ == "__main__":
cfg = parse_args()
train(cfg)
train(cfg)

0 comments on commit 859ff36

Please sign in to comment.