diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index ecf379a6b57..cb35521f26c 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -7,6 +7,7 @@ import dataclasses import importlib +from contextlib import nullcontext from dataclasses import dataclass from typing import Any @@ -92,9 +93,6 @@ def __init__( config: dict | DTConfig = None, device: torch.device | None = None, ): - if device is not None: - with torch.device(device): - return self.__init__(state_dim, action_dim, config) if not _has_transformers: raise ImportError( @@ -117,28 +115,29 @@ def __init__( super(DecisionTransformer, self).__init__() - gpt_config = transformers.GPT2Config( - n_embd=config["n_embd"], - n_layer=config["n_layer"], - n_head=config["n_head"], - n_inner=config["n_inner"], - activation_function=config["activation"], - n_positions=config["n_positions"], - resid_pdrop=config["resid_pdrop"], - attn_pdrop=config["attn_pdrop"], - vocab_size=1, - ) - self.state_dim = state_dim - self.action_dim = action_dim - self.hidden_size = config["n_embd"] + with torch.device(device) if device is not None else nullcontext(): + gpt_config = transformers.GPT2Config( + n_embd=config["n_embd"], + n_layer=config["n_layer"], + n_head=config["n_head"], + n_inner=config["n_inner"], + activation_function=config["activation"], + n_positions=config["n_positions"], + resid_pdrop=config["resid_pdrop"], + attn_pdrop=config["attn_pdrop"], + vocab_size=1, + ) + self.state_dim = state_dim + self.action_dim = action_dim + self.hidden_size = config["n_embd"] - self.transformer = GPT2Model(config=gpt_config) + self.transformer = GPT2Model(config=gpt_config) - self.embed_return = torch.nn.Linear(1, self.hidden_size) - self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) - self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size) + self.embed_return = torch.nn.Linear(1, self.hidden_size) + self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) + self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size) - self.embed_ln = nn.LayerNorm(self.hidden_size) + self.embed_ln = nn.LayerNorm(self.hidden_size) def forward( self,