diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 20957ec7bee..a400ade0e68 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -371,10 +371,10 @@ def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: ) dist_class = TanhNormal dist_kwargs = { - "low": -1.0, - "high": 1.0, + "low": -torch.ones((), device=device), + "high": torch.ones((), device=device), "tanh_loc": False, - "upscale": 5.0, + "upscale": torch.full((), 5, device=device), # "safe_tanh": not cfg.compile.compile, } diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index cad4065f54a..87879ff70c3 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1558,6 +1558,7 @@ def __init__( state_dim=state_dim, action_dim=action_dim, config=transformer_config, + device=device, ) self.action_layer_mean = nn.Linear( transformer_config["n_embd"], action_dim, device=device