From 10138839f069965817ef645584b39495bd188ab8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Dec 2024 16:16:19 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/decision_transformer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 0540293e1f7..5f14734addd 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -395,7 +395,7 @@ def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] - actor(td) + actor(td.to(device)) return actor