Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
2 parents e88ca8d + 785a856 commit 5064558
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def main(cfg: "DictConfig"): # noqa: F821

prb = cfg.replay_buffer.prb

def update(update_actor, prb=prb):
sampled_tensordict = replay_buffer.sample()
def update(sampled_tensordict, update_actor, prb=prb):

# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)
Expand All @@ -138,10 +137,6 @@ def update(update_actor, prb=prb):
else:
actor_loss = q_loss.new_zeros(())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

return q_loss.detach(), actor_loss.detach()

if cfg.compile.compile:
Expand Down Expand Up @@ -204,9 +199,16 @@ def update(update_actor, prb=prb):
update_counter += 1
update_actor = update_counter % delayed_updates == 0

with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample()
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
q_loss, actor_loss = update(update_actor)
q_loss, actor_loss = update(sampled_tensordict, update_actor)

# Update priority
if prb:
with timeit("rb - priority"):
replay_buffer.update_priority(sampled_tensordict)

q_losses.append(q_loss.clone())
if update_actor:
Expand Down

0 comments on commit 5064558

Please sign in to comment.