From e4481ef5cc56a984cece7010a0f3bd4c10619c0e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Dec 2024 12:55:31 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/decision_transformer/utils.py | 6 +++--- torchrl/modules/models/models.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 20957ec7bee..a400ade0e68 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -371,10 +371,10 @@ def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: ) dist_class = TanhNormal dist_kwargs = { - "low": -1.0, - "high": 1.0, + "low": -torch.ones((), device=device), + "high": torch.ones((), device=device), "tanh_loc": False, - "upscale": 5.0, + "upscale": torch.full((), 5, device=device), # "safe_tanh": not cfg.compile.compile, } diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index cad4065f54a..87879ff70c3 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1558,6 +1558,7 @@ def __init__( state_dim=state_dim, action_dim=action_dim, config=transformer_config, + device=device, ) self.action_layer_mean = nn.Linear( transformer_config["n_embd"], action_dim, device=device