Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
2 parents d0065b9 + 76486a9 commit 72270ae
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def update(data, expert_data, num_network_updates=num_network_updates):
expert_data = expert_data.to(device)

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
metadata = update(data, expert_data)
d_loss = metadata["dloss"]
alpha = metadata["alpha"]
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def update(sampled_tensordict):
sampled_tensordict = replay_buffer.sample().to(device)

with timeit("training - update"):
torch.compiler.cudagraph_mark_step_begin()
metadata = update(sampled_tensordict)
# update priority
if prb:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/iql/iql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def update(data):
data = data.to(device)

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_info = update(data)

# evaluation
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def update(sampled_tensordict):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_info = update(sampled_tensordict)
# update priority
if prb:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def update(batch, num_network_updates):
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):

torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def update(batch, num_network_updates):
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):
torch.compiler.cudagraph_mark_step_begin()
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
Expand Down

0 comments on commit 72270ae

Please sign in to comment.