From 859ff3649a2fb7a0444809077e5fd183bd1c6c48 Mon Sep 17 00:00:00 2001 From: Kaustubh Mani Date: Sat, 5 Aug 2023 15:34:22 -0400 Subject: [PATCH] hopefully the final change --- cleanrl/ppo_continuous_action_wandb.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cleanrl/ppo_continuous_action_wandb.py b/cleanrl/ppo_continuous_action_wandb.py index f859d1a35..074955460 100644 --- a/cleanrl/ppo_continuous_action_wandb.py +++ b/cleanrl/ppo_continuous_action_wandb.py @@ -302,7 +302,7 @@ def train(cfg): next_obs, _ = envs.reset(seed=cfg.seed) next_obs = torch.Tensor(next_obs).to(device) next_done = torch.zeros(cfg.num_envs).to(device) - num_updates = cfg.total_timesteps // batch_size + num_updates = cfg.total_timesteps // cfg.batch_size cum_cost, ep_cost, ep_risk_cost_int, cum_risk_cost_int, ep_risk, cum_risk = 0, 0, 0, 0, 0, 0 cost = 0 @@ -452,12 +452,12 @@ def train(cfg): b_risks = risks.reshape((-1, ) + (2, )) # Optimizing the policy and value network - b_inds = np.arange(batch_size) + b_inds = np.arange(cfg.batch_size) clipfracs = [] for epoch in range(cfg.update_epochs): np.random.shuffle(b_inds) - for start in range(0, batch_size, minibatch_size): - end = start + minibatch_size + for start in range(0, cfg.batch_size, cfg.minibatch_size): + end = start + cfg.minibatch_size mb_inds = b_inds[start:end] if cfg.use_risk: @@ -506,7 +506,7 @@ def train(cfg): nn.utils.clip_grad_norm_(agent.parameters(), cfg.max_grad_norm) optimizer.step() - if cfg.target_kl != "None": + if cfg.target_kl is not None: if approx_kl > cfg.target_kl: break @@ -550,4 +550,4 @@ def train(cfg): if __name__ == "__main__": cfg = parse_args() - train(cfg) \ No newline at end of file + train(cfg)