Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/version-0.5' into version-0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 9, 2024
2 parents de73d67 + 52336a1 commit a8182b8
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,17 +406,16 @@ def _find_sync_values(self):
return _do_nothing, _do_nothing

if worker_device is None:
worker_not_main = [False]
worker_not_main = False

def find_all_worker_devices(item, worker_not_main=worker_not_main):
def find_all_worker_devices(item):
nonlocal worker_not_main
if hasattr(item, "device"):
worker_not_main[0] = worker_not_main[0] or (
item.device != self_device
)
worker_not_main = worker_not_main or (item.device != self_device)

for td in self.shared_tensordicts:
td.apply(find_all_worker_devices, filter_empty=True)
if worker_not_main[0]:
if worker_not_main:
if torch.cuda.is_available():
worker_device = (
torch.device("cuda")
Expand All @@ -431,6 +430,8 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main):
)
else:
raise RuntimeError("Did not find a valid worker device")
else:
worker_device = self_device

if (
worker_device is not None
Expand Down Expand Up @@ -460,6 +461,7 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main):
and self_device.type == "mps"
):
return _mps_sync(self_device), _mps_sync(self_device)
return _do_nothing, _do_nothing

def __getstate__(self):
out = copy(self.__dict__)
Expand Down

0 comments on commit a8182b8

Please sign in to comment.