diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 41f2d4c775..9670ff42b7 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -51,6 +51,9 @@ class PartitionTask(Generic[PartitionT]): # This is used when a specific executor (e.g. an Actor pool) must be provisioned and used for the task actor_pool_id: str | None + # Indicates that the metadata of the result partition should be cached when the task is done + cache_metadata_on_done: bool = True + # Indicates if the PartitionTask is "done" or not is_done: bool = False @@ -70,11 +73,17 @@ def set_done(self): """Sets the PartitionTask as done.""" assert not self.is_done, "Cannot set PartitionTask as done more than once" self.is_done = True + if self.cache_metadata_on_done: + self.cache_metadata() def cancel(self) -> None: """If possible, cancel the execution of this PartitionTask.""" raise NotImplementedError() + def cache_metadata(self) -> None: + """Cache the metadata of the result partition.""" + raise NotImplementedError() + def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: """Set the result of this Task. For use by the Task executor. @@ -140,7 +149,9 @@ def is_empty(self) -> bool: """Whether this partition task is guaranteed to result in an empty partition.""" return len(self.partial_metadatas) > 0 and all(meta.num_rows == 0 for meta in self.partial_metadatas) - def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPartitionTask[PartitionT]: + def finalize_partition_task_single_output( + self, stage_id: int, cache_metadata_on_done: bool = True + ) -> SingleOutputPartitionTask[PartitionT]: """Create a SingleOutputPartitionTask from this PartitionTaskBuilder. Returns a "frozen" version of this PartitionTask that cannot have instructions added. @@ -162,9 +173,12 @@ def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPa partial_metadatas=self.partial_metadatas, actor_pool_id=self.actor_pool_id, node_id=self.node_id, + cache_metadata_on_done=cache_metadata_on_done, ) - def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPartitionTask[PartitionT]: + def finalize_partition_task_multi_output( + self, stage_id: int, cache_metadata_on_done: bool = True + ) -> MultiOutputPartitionTask[PartitionT]: """Create a MultiOutputPartitionTask from this PartitionTaskBuilder. Same as finalize_partition_task_single_output, except the output of this PartitionTask is a list of partitions. @@ -184,6 +198,7 @@ def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPart partial_metadatas=self.partial_metadatas, actor_pool_id=self.actor_pool_id, node_id=self.node_id, + cache_metadata_on_done=cache_metadata_on_done, ) def __str__(self) -> str: @@ -201,6 +216,7 @@ class SingleOutputPartitionTask(PartitionTask[PartitionT]): # When available, the partition created from running the PartitionTask. _result: None | MaterializedResult[PartitionT] = None + _partition_metadata: None | PartitionMetadata = None def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: assert self._result is None, f"Cannot set result twice. Result is already {self._result}" @@ -220,13 +236,22 @@ def partition(self) -> PartitionT: """Get the PartitionT resulting from running this PartitionTask.""" return self.result().partition() + def cache_metadata(self) -> None: + assert self._result is not None, "Cannot cache metadata without a result" + if self._partition_metadata is not None: + return + + [partial_metadata] = self.partial_metadatas + self._partition_metadata = self.result().metadata().merge_with_partial(partial_metadata) + def partition_metadata(self) -> PartitionMetadata: """Get the metadata of the result partition. (Avoids retrieving the actual partition itself if possible.) """ - [partial_metadata] = self.partial_metadatas - return self.result().metadata().merge_with_partial(partial_metadata) + self.cache_metadata() + assert self._partition_metadata is not None + return self._partition_metadata def micropartition(self) -> MicroPartition: """Get the raw vPartition of the result.""" @@ -249,6 +274,7 @@ class MultiOutputPartitionTask(PartitionTask[PartitionT]): # When available, the partitions created from running the PartitionTask. _results: None | list[MaterializedResult[PartitionT]] = None + _partition_metadatas: None | list[PartitionMetadata] = None def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: assert self._results is None, f"Cannot set result twice. Result is already {self._results}" @@ -264,16 +290,24 @@ def partitions(self) -> list[PartitionT]: assert self._results is not None return [result.partition() for result in self._results] + def cache_metadata(self) -> None: + assert self._results is not None, "Cannot cache metadata without a result" + if self._partition_metadatas is not None: + return + + self._partition_metadatas = [ + result.metadata().merge_with_partial(partial_metadata) + for result, partial_metadata in zip(self._results, self.partial_metadatas) + ] + def partition_metadatas(self) -> list[PartitionMetadata]: """Get the metadata of the result partitions. (Avoids retrieving the actual partition itself if possible.) """ - assert self._results is not None - return [ - result.metadata().merge_with_partial(partial_metadata) - for result, partial_metadata in zip(self._results, self.partial_metadatas) - ] + self.cache_metadata() + assert self._partition_metadatas is not None + return self._partition_metadatas def micropartition(self, index: int) -> MicroPartition: """Get the raw vPartition of the result.""" diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index be3f6739c3..7d1ae3cbe2 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -1681,7 +1681,7 @@ def _best_effort_next_step( return (None, False) else: if isinstance(step, PartitionTaskBuilder): - step = step.finalize_partition_task_single_output(stage_id=stage_id) + step = step.finalize_partition_task_single_output(stage_id=stage_id, cache_metadata_on_done=False) return (step, True) elif isinstance(step, PartitionTask): return (step, False) @@ -1771,7 +1771,7 @@ def __iter__(self) -> MaterializedPhysicalPlan: try: step = next(self.child_plan) if isinstance(step, PartitionTaskBuilder): - step = step.finalize_partition_task_single_output(stage_id=stage_id) + step = step.finalize_partition_task_single_output(stage_id=stage_id, cache_metadata_on_done=False) self.materializations.append(step) num_final_yielded += 1 logger.debug("[plan-%s] YIELDING final task (%s so far)", stage_id, num_final_yielded)