Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents 05e7d88 + 0c6081d commit 16d934c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
5 changes: 3 additions & 2 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def main(cfg: "DictConfig"): # noqa: F821
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device="cpu",
storing_device="cpu",
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode else False,
cudagraph_policy=cfg.compile.cudagraphs,
Expand All @@ -96,6 +96,7 @@ def main(cfg: "DictConfig"): # noqa: F821
value_network=critic,
average_gae=False,
device=device,
vectorized=not cfg.compile.compile,
)
loss_module = ClipPPOLoss(
actor_network=actor,
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def main(cfg: "DictConfig"): # noqa: F821
value_network=critic,
average_gae=False,
device=device,
vectorized=not cfg.compile.compile,
)

loss_module = ClipPPOLoss(
Expand Down

0 comments on commit 16d934c

Please sign in to comment.