From 1b4adfd981b1f2b033f1ebec498cee8ffe1a461e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 13:47:59 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/ppo/ppo_atari.py | 1 + sota-implementations/ppo/ppo_mujoco.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 70c36b4a82f..2b0b7ec5e98 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -235,6 +235,7 @@ def update(batch, num_network_updates): # Compute GAE with torch.no_grad(), timeit("adv"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) with timeit("rb - extend"): # Update the data buffer diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 05713c44dde..1f284fc7634 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -224,7 +224,8 @@ def update(batch, num_network_updates): # Compute GAE with torch.no_grad(), timeit("adv"): - data = adv_module(data.to(device)) + torch.compiler.cudagraph_mark_step_begin() + data = adv_module(data) with timeit("rb - extend"): # Update the data buffer data_reshape = data.reshape(-1)