From 5721ce5475511373df75266112da0b2dccaaad11 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 18:15:06 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/td3/td3.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 9ab29894a50..36183655fe0 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -38,6 +38,9 @@ ) +torch.set_float32_matmul_precision("high") + + @hydra.main(version_base="1.1", config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = cfg.network.device @@ -196,9 +199,9 @@ def update(sampled_tensordict, update_actor): torch.compiler.cudagraph_mark_step_begin() q_loss, actor_loss = update(sampled_tensordict, update_actor) - q_losses.append(q_loss) + q_losses.append(q_loss.clone()) if update_actor: - actor_losses.append(actor_loss) + actor_losses.append(actor_loss.clone()) # Update priority if prb: