Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents 16d934c + b776b63 commit f872d5c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode else False,
compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

Expand Down Expand Up @@ -166,7 +166,7 @@ def update(batch, num_network_updates):
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1
num_network_updates = num_network_updates + 1
# Get a data batch
batch = batch.to(device, non_blocking=True)

Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode else False,
compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

Expand Down Expand Up @@ -153,7 +153,7 @@ def update(batch, num_network_updates):
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
num_network_updates += 1
num_network_updates = num_network_updates + 1

# Forward pass PPO loss
loss = loss_module(batch)
Expand Down
6 changes: 4 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_ProcessNoWarn,
_replace_last,
accept_remote_rref_udf_invocation,
compile_with_warmup,
logger as torchrl_logger,
prod,
RL_WARNINGS,
Expand All @@ -67,7 +68,6 @@
set_exploration_type,
)


try:
from torch.compiler import cudagraph_mark_step_begin
except ImportError:
Expand Down Expand Up @@ -661,7 +661,9 @@ def __init__(
self.policy_weights = TensorDict()

if self.compiled_policy:
self.policy = torch.compile(self.policy, **self.compiled_policy_kwargs)
self.policy = compile_with_warmup(
self.policy, **self.compiled_policy_kwargs
)
if self.cudagraphed_policy:
self.policy = CudaGraphModule(self.policy, **self.cudagraphed_policy_kwargs)

Expand Down

0 comments on commit f872d5c

Please sign in to comment.