From 9119439f28a5dd4755c8188751af4af374a342f7 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Thu, 15 Feb 2024 14:38:30 +0100 Subject: [PATCH] [BugFix] Fix Ray collector example error (#1908) --- examples/distributed/collectors/multi_nodes/lol.py | 3 --- torchrl/collectors/distributed/ray.py | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) delete mode 100644 examples/distributed/collectors/multi_nodes/lol.py diff --git a/examples/distributed/collectors/multi_nodes/lol.py b/examples/distributed/collectors/multi_nodes/lol.py deleted file mode 100644 index 89d5e66b487..00000000000 --- a/examples/distributed/collectors/multi_nodes/lol.py +++ /dev/null @@ -1,3 +0,0 @@ -from torchrl.envs.libs.gym import GymEnv - -env = GymEnv("ALE/Pong-v5") diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index a467c763fa5..faf4d4a6cce 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -472,7 +472,7 @@ def check_list_length_consistency(*lists): pending_samples = [ e.print_remote_collector_info.remote() for e in self.remote_collectors() ] - ray.wait(object_refs=pending_samples) + ray.wait(pending_samples) @property def num_workers(self): @@ -602,7 +602,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: samples_ready = [] while len(samples_ready) < self.num_collectors: samples_ready, samples_not_ready = ray.wait( - object_refs=pending_tasks, num_returns=len(pending_tasks) + pending_tasks, num_returns=len(pending_tasks) ) # Retrieve and concatenate Tensordicts @@ -645,7 +645,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: raise RuntimeError("Missing pending tasks, something went wrong") # Wait for first worker to finish - wait_results = ray.wait(object_refs=list(pending_tasks.keys())) + wait_results = ray.wait(list(pending_tasks.keys())) future = wait_results[0][0] collector_index = pending_tasks.pop(future) collector = self.remote_collectors()[collector_index] @@ -678,7 +678,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: # Wait for the in-process collections tasks to finish. refs = list(pending_tasks.keys()) - ray.wait(object_refs=refs, num_returns=len(refs)) + ray.wait(refs, num_returns=len(refs)) # Cancel the in-process collections tasks # for ref in refs: