From a9c3e5d90ba882f9bff51eebe939b018fe8a661c Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Thu, 12 Dec 2024 21:08:15 +0800 Subject: [PATCH] Explicit get instead of a wait --- daft/runners/ray_runner.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 9fe1db3062..c33e640e9a 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -809,14 +809,13 @@ def _await_tasks( runner_tracer.task_received_as_ready(task_id, inflight_tasks[task_id].stage_id) # Run a .wait on the metadatas to retrieve them locally so that subsequent accesses will be faster - ready_metadatas = list( - { - result.get_metadata_objref() - for ready in readies - for result in inflight_tasks[inflight_ref_to_task_id[ready]].get_results() - } - ) - ray.wait(ready_metadatas, fetch_local=True, num_returns=len(ready_metadatas), timeout=5) + ready_results = [ + result for ready in readies for result in inflight_tasks[inflight_ref_to_task_id[ready]].get_results() + ] + ready_results_mapping = {r._metadatas.get_objref(): r for r in ready_results} + retrieved_metadata = ray.get(list(ready_results_mapping.keys()), timeout=None) + for objref, retrieved in zip(list(ready_results_mapping.keys()), retrieved_metadata): + ready_results_mapping[objref]._metadatas.set_metadata(retrieved) return readies @@ -1374,9 +1373,6 @@ def metadata(self) -> PartitionMetadata: def cancel(self) -> None: return ray.cancel(self._partition) - def get_metadata_objref(self) -> ray.ObjectRef: - return self._metadatas.get_objref() - def _noop(self, _: ray.ObjectRef) -> None: return None @@ -1399,6 +1395,9 @@ def get_index(self, key) -> PartitionMetadata: def get_objref(self) -> ray.ObjectRef: return self._ref + def set_metadata(self, metadatas: list[PartitionMetadata]): + self._metadatas = metadatas + @classmethod def from_metadata_list(cls, meta: list[PartitionMetadata]) -> PartitionMetadataAccessor: ref = ray.put(meta)