Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents 401d58b + d6210cb commit acfe898
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def update(sampled_tensordict):
sampled_tensordict = replay_buffer.sample().to(device)

with timeit("training - update"):
torch.compiler.cudagraph_mark_step_begin()
metadata = update(sampled_tensordict)
# update priority
if prb:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def update(data):
data = data.to(device)

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_info = update(data)

# evaluation
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def update(sampled_tensordict):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_info = update(sampled_tensordict)
# update priority
if prb:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def update(batch, num_network_updates):
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):

torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def update(batch, num_network_updates):
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):
torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
Expand Down

0 comments on commit acfe898

Please sign in to comment.