Skip to content

Commit

Permalink
[BugFix] Fix get_default_device calls in older PT versions
Browse files Browse the repository at this point in the history
ghstack-source-id: fd3a739d38feba075073801dda362be598822a94
Pull Request resolved: #2586
  • Loading branch information
vmoens committed Nov 19, 2024
1 parent 236d38f commit 705ecc2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

@pytest.fixture(scope="module", autouse=True)
def set_default_device():
cur_device = torch.get_default_device()
cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
yield
Expand Down
4 changes: 3 additions & 1 deletion torchrl/envs/libs/vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,9 @@ def _build_env(
num_envs=num_envs,
device=self.device
if self.device is not None
else torch.get_default_device(),
else getattr(
torch, "get_default_device", lambda: torch.device("cpu")
)(),
continuous_actions=continuous_actions,
max_steps=max_steps,
seed=seed,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def _get_default_device(net):
for p in net.parameters():
return p.device
else:
return torch.get_default_device()
return getattr(torch, "get_default_device", lambda: torch.device("cpu"))()


def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
):
super().__init__()
if device is None:
device = torch.get_default_device()
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
# this is saved for tracking only and should not be used to cast anything else than buffers during
# init.
self._device = device
Expand Down

0 comments on commit 705ecc2

Please sign in to comment.