Skip to content

Commit 7575e96

Browse files
author
Vincent Moens
committed
[Refactor] Use default device instead of CPU in losses
ghstack-source-id: 96ab850 Pull Request resolved: #2687
1 parent ad6c994 commit 7575e96

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

torchrl/objectives/cql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(
323323
try:
324324
device = next(self.parameters()).device
325325
except AttributeError:
326-
device = torch.device("cpu")
326+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
327327
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
328328
if bool(min_alpha) ^ bool(max_alpha):
329329
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/crossq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __init__(
306306
try:
307307
device = next(self.parameters()).device
308308
except AttributeError:
309-
device = torch.device("cpu")
309+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
310310
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
311311
if bool(min_alpha) ^ bool(max_alpha):
312312
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/decision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
try:
104104
device = next(self.parameters()).device
105105
except AttributeError:
106-
device = torch.device("cpu")
106+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
107107

108108
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
109109
if bool(min_alpha) ^ bool(max_alpha):

torchrl/objectives/deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(
195195
try:
196196
device = next(self.parameters()).device
197197
except AttributeError:
198-
device = torch.device("cpu")
198+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
199199

200200
self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
201201
self.register_buffer(

torchrl/objectives/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def __init__(
388388
try:
389389
device = next(self.parameters()).device
390390
except (AttributeError, StopIteration):
391-
device = torch.device("cpu")
391+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
392392

393393
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
394394
if critic_coef is not None:

torchrl/objectives/redq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def __init__(
318318
try:
319319
device = next(self.parameters()).device
320320
except AttributeError:
321-
device = torch.device("cpu")
321+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
322322

323323
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
324324
self.register_buffer(

torchrl/objectives/sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def __init__(
393393
try:
394394
device = next(self.parameters()).device
395395
except AttributeError:
396-
device = torch.device("cpu")
396+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
397397
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
398398
if bool(min_alpha) ^ bool(max_alpha):
399399
min_alpha = min_alpha if min_alpha else 0.0
@@ -1119,7 +1119,7 @@ def __init__(
11191119
try:
11201120
device = next(self.parameters()).device
11211121
except AttributeError:
1122-
device = torch.device("cpu")
1122+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
11231123

11241124
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
11251125
if bool(min_alpha) ^ bool(max_alpha):

0 commit comments

Comments
 (0)