diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 0540293e1f7..5f14734addd 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -395,7 +395,7 @@ def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] - actor(td) + actor(td.to(device)) return actor