diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 1fa3271da82..8b3efe15102 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -180,7 +180,7 @@ def update(sampled_tensordict): with timeit("update"): torch.compiler.cudagraph_mark_step_begin() sampled_tensordict = sampled_tensordict.to(device) - loss_out = update(sampled_tensordict) + loss_out = update(sampled_tensordict).clone() tds.append(loss_out)