From 705ecc2bb780b9d23052fe16838714f74890d2c0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 19 Nov 2024 17:55:35 +0000 Subject: [PATCH] [BugFix] Fix get_default_device calls in older PT versions ghstack-source-id: fd3a739d38feba075073801dda362be598822a94 Pull Request resolved: https://github.com/pytorch/rl/pull/2586 --- benchmarks/test_objectives_benchmarks.py | 2 +- torchrl/envs/libs/vmas.py | 4 +++- torchrl/objectives/utils.py | 2 +- torchrl/objectives/value/advantages.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 7ff3f23b7a5..9932c8ba8b7 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -52,7 +52,7 @@ @pytest.fixture(scope="module", autouse=True) def set_default_device(): - cur_device = torch.get_default_device() + cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.set_default_device(device) yield diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 22f9835303b..8d2e3387e3c 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -803,7 +803,9 @@ def _build_env( num_envs=num_envs, device=self.device if self.device is not None - else torch.get_default_device(), + else getattr( + torch, "get_default_device", lambda: torch.device("cpu") + )(), continuous_actions=continuous_actions, max_steps=max_steps, seed=seed, diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index b0ba254d2b3..4dfed60e5a9 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -596,7 +596,7 @@ def _get_default_device(net): for p in net.parameters(): return p.device else: - return torch.get_default_device() + return getattr(torch, "get_default_device", lambda: torch.device("cpu"))() def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 8ac64bf3d21..fadfe932c50 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -219,7 +219,7 @@ def __init__( ): super().__init__() if device is None: - device = torch.get_default_device() + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() # this is saved for tracking only and should not be used to cast anything else than buffers during # init. self._device = device