Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 6b4284f commit 5721ce5
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
)


torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
device = cfg.network.device
Expand Down Expand Up @@ -196,9 +199,9 @@ def update(sampled_tensordict, update_actor):
torch.compiler.cudagraph_mark_step_begin()
q_loss, actor_loss = update(sampled_tensordict, update_actor)

q_losses.append(q_loss)
q_losses.append(q_loss.clone())
if update_actor:
actor_losses.append(actor_loss)
actor_losses.append(actor_loss.clone())

# Update priority
if prb:
Expand Down

0 comments on commit 5721ce5

Please sign in to comment.