diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index a9253059438..45d3acbb85f 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -276,6 +276,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): expert_data = expert_data.to(device) with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() metadata = update(data, expert_data) d_loss = metadata["dloss"] alpha = metadata["alpha"] diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index e03d5930aa6..2fd990d0bb3 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -187,6 +187,7 @@ def update(sampled_tensordict): sampled_tensordict = replay_buffer.sample().to(device) with timeit("training - update"): + torch.compiler.cudagraph_mark_step_begin() metadata = update(sampled_tensordict) # update priority if prb: diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index afb2d0cfd55..1ae559a04f8 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -141,6 +141,7 @@ def update(data): data = data.to(device) with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() loss_info = update(data) # evaluation diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index c20d4074342..e933db0e836 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -181,6 +181,7 @@ def update(sampled_tensordict): # sample from replay buffer sampled_tensordict = replay_buffer.sample().to(device) with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() loss_info = update(sampled_tensordict) # update priority if prb: diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index fa45e856efd..8b97f227490 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -238,7 +238,7 @@ def update(batch, num_network_updates): data_buffer.extend(data_reshape) for k, batch in enumerate(data_buffer): - + torch.compiler.cudagraph_mark_step_begin() loss, num_network_updates = update( batch, num_network_updates=num_network_updates ) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 0e6184b8e1b..162b8e701df 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -227,6 +227,7 @@ def update(batch, num_network_updates): data_buffer.extend(data_reshape) for k, batch in enumerate(data_buffer): + torch.compiler.cudagraph_mark_step_begin() loss, num_network_updates = update( batch, num_network_updates=num_network_updates )