diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index bb0d2005a13..cc42ef38f9d 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -236,6 +236,8 @@ def update(batch, num_network_updates): with torch.no_grad(), timeit("adv"): torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) + if compile_mode: + data = data.clone() with timeit("rb - extend"): # Update the data buffer data_reshape = data.reshape(-1) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 8615fb47084..5c3c621c5e6 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -225,6 +225,9 @@ def update(batch, num_network_updates): with torch.no_grad(), timeit("adv"): torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) + if compile_mode: + data = data.clone() + with timeit("rb - extend"): # Update the data buffer data_reshape = data.reshape(-1)