From 6b6c0b354feb794b9e4e8ab926adbf6edc0c6f56 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 21 Oct 2024 15:35:17 -0700 Subject: [PATCH 1/6] Add static and runtime dag info, API to fetch ancestor tasks --- metaflow/client/core.py | 80 +++++++++++++++++++++++++++++++++++ metaflow/metadata/metadata.py | 15 +++++++ metaflow/task.py | 65 ++++++++++++++++++++++++++-- 3 files changed, 157 insertions(+), 3 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 9534ffcca2c..0f287ccb177 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,6 +1123,86 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" + def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]: + """ + Returns a dictionary with iterators over the immediate ancestors of this task. + + Returns + ------- + Dict[str, Iterator[Task]] + Dictionary of immediate ancestors of this task. The keys are the + names of the ancestors steps and the values are iterators over the + tasks of the corresponding steps. + """ + + def _prev_task(flow_id, run_id, previous_step): + # Find any previous task for current step + + step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No previous task found for step {previous_step}") + + flow_id, run_id, step_name, task_id = self.path_components + previous_steps = self.metadata_dict.get("previous_steps", None) + print( + f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" + ) + print(f"previous_steps: {previous_steps}") + + if not previous_steps or len(previous_steps) == 0: + return + + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) + ancestor_iters = {} + if len(previous_steps) > 1: + # This is a static join, so there is no change in foreach stack length + prev_foreach_stack_len = cur_foreach_stack_len + else: + prev_task = _prev_task(flow_id, run_id, previous_steps[0]) + prev_foreach_stack_len = len( + prev_task.metadata_dict.get("foreach-stack", []) + ) + + print( + f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}" + ) + if prev_foreach_stack_len == cur_foreach_stack_len: + field_name = "foreach-indices" + field_value = self.metadata_dict.get(field_name) + elif prev_foreach_stack_len > cur_foreach_stack_len: + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # We will compare the foreach-stack-truncated value of current task with the + # foreach-stack value of tasks in previous steps + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + + for prev_step in previous_steps: + # print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and " + # f"{field_value}") + ancestor_iters[prev_step] = ( + self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, prev_step, field_name, field_value + ) + ) + + return ancestor_iters + + # def closest_siblings(self) -> Iterator["Task"]: + # """ + # Returns an iterator over the closest siblings of this task. + # + # Returns + # ------- + # Iterator[Task] + # Iterator over the closest siblings of this task + # """ + # flow_id, run_id, step_name, task_id = self.path_components + # print(f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}") + @property def metadata(self) -> List[Metadata]: """ diff --git a/metaflow/metadata/metadata.py b/metaflow/metadata/metadata.py index 11c3873a85e..a2d607eff7a 100644 --- a/metaflow/metadata/metadata.py +++ b/metaflow/metadata/metadata.py @@ -672,6 +672,21 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): if metadata: self.register_metadata(run_id, step_name, task_id, metadata) + @classmethod + def _filter_tasks_by_metadata( + cls, flow_id, run_id, step_name, prev_step, field_name, field_value + ): + raise NotImplementedError() + + @classmethod + def filter_tasks_by_metadata( + cls, flow_id, run_id, step_name, prev_step, field_name, field_value + ): + task_ids = cls._filter_tasks_by_metadata( + flow_id, run_id, step_name, prev_step, field_name, field_value + ) + return task_ids + @staticmethod def _apply_filter(elts, filters): if filters is None: diff --git a/metaflow/task.py b/metaflow/task.py index bba15c45471..3f628572bf2 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -4,6 +4,8 @@ import sys import os import time +import json +import hashlib import traceback from types import MethodType, FunctionType @@ -37,6 +39,23 @@ class MetaflowTask(object): MetaflowTask prepares a Flow instance for execution of a single step. """ + @staticmethod + def _dynamic_runtime_metadata(foreach_stack): + foreach_indices = [foreach_frame.index for foreach_frame in foreach_stack] + foreach_indices_truncated = foreach_indices[:-1] + foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack] + return foreach_indices, foreach_indices_truncated, foreach_step_names + + def _static_runtime_metadata(self, graph_info, step_name): + if step_name == "start": + return [] + + return [ + node_name + for node_name, attributes in graph_info["steps"].items() + if step_name in attributes["next"] + ] + def __init__( self, flow, @@ -493,6 +512,33 @@ def run_step( ) ) + # Add runtime dag info + foreach_indices, foreach_indices_truncated, foreach_step_names = ( + self._dynamic_runtime_metadata(foreach_stack) + ) + metadata.extend( + [ + MetaDatum( + field="foreach-indices", + value=foreach_indices, + type="foreach-indices", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-indices-truncated", + value=foreach_indices_truncated, + type="foreach-indices-truncated", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-step-names", + value=foreach_step_names, + type="foreach-step-names", + tags=metadata_tags, + ), + ] + ) + self.metadata.register_metadata( run_id, step_name, @@ -559,12 +605,17 @@ def run_step( self.flow._success = False self.flow._task_ok = None self.flow._exception = None + # Note: All internal flow attributes (ie: non-user artifacts) # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. if join_type: + if join_type == "foreach": + # We only want to persist one of the input paths + self.flow._input_paths = str(input_paths[0]) + # Join step: # Ensure that we have the right number of inputs. The @@ -616,7 +667,9 @@ def run_step( "graph_info": self.flow._graph_info, } ) - + previous_steps = self._static_runtime_metadata( + self.flow._graph_info, step_name + ) for deco in decorators: deco.task_pre_step( step_name, @@ -727,8 +780,14 @@ def run_step( field="attempt_ok", value=attempt_ok, type="internal_attempt_status", - tags=["attempt_id:{0}".format(retry_count)], - ) + tags=metadata_tags, + ), + MetaDatum( + field="previous_steps", + value=previous_steps, + type="previous_steps", + tags=metadata_tags, + ), ], ) From ed8a0008d9df2f6af1c230d28bdad9d0fe4c22a3 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 31 Oct 2024 02:53:31 -0700 Subject: [PATCH 2/6] Add API to get immediate successors --- metaflow/client/core.py | 82 +++++++++++++++++++++++++++++++++-------- metaflow/task.py | 15 +++++--- 2 files changed, 76 insertions(+), 21 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 0f287ccb177..2f330e99219 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,21 +1123,21 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]: + def immediate_ancestors(self) -> Dict[str, List[str]]: """ - Returns a dictionary with iterators over the immediate ancestors of this task. + Returns a dictionary of immediate ancestors task ids of this task for each + previous step. Returns ------- - Dict[str, Iterator[Task]] + Dict[str, List[str]] Dictionary of immediate ancestors of this task. The keys are the - names of the ancestors steps and the values are iterators over the - tasks of the corresponding steps. + names of the ancestors steps and the values are the corresponding + task ids of the ancestors. """ def _prev_task(flow_id, run_id, previous_step): # Find any previous task for current step - step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) task = next(iter(step.tasks()), None) if task: @@ -1146,10 +1146,6 @@ def _prev_task(flow_id, run_id, previous_step): flow_id, run_id, step_name, task_id = self.path_components previous_steps = self.metadata_dict.get("previous_steps", None) - print( - f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" - ) - print(f"previous_steps: {previous_steps}") if not previous_steps or len(previous_steps) == 0: return @@ -1165,9 +1161,6 @@ def _prev_task(flow_id, run_id, previous_step): prev_task.metadata_dict.get("foreach-stack", []) ) - print( - f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}" - ) if prev_foreach_stack_len == cur_foreach_stack_len: field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) @@ -1181,16 +1174,73 @@ def _prev_task(flow_id, run_id, previous_step): field_value = self.metadata_dict.get("foreach-indices-truncated") for prev_step in previous_steps: - # print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and " - # f"{field_value}") ancestor_iters[prev_step] = ( self._metaflow.metadata.filter_tasks_by_metadata( flow_id, run_id, step_name, prev_step, field_name, field_value ) ) - return ancestor_iters + def immediate_successors(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of immediate successors task ids of this task for each + previous step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of immediate successors of this task. The keys are the + names of the successors steps and the values are the corresponding + task ids of the successors. + """ + + def _successor_task(flow_id, run_id, successor_step): + # Find any previous task for current step + step = Step(f"{flow_id}/{run_id}/{successor_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No successor task found for step {successor_step}") + + flow_id, run_id, step_name, task_id = self.path_components + successor_steps = self.metadata_dict.get("successor_steps", None) + + if not successor_steps or len(successor_steps) == 0: + return + + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) + successor_iters = {} + if len(successor_steps) > 1: + # This is a static split, so there is no change in foreach stack length + successor_foreach_stack_len = cur_foreach_stack_len + else: + successor_task = _successor_task(flow_id, run_id, successor_steps[0]) + successor_foreach_stack_len = len( + successor_task.metadata_dict.get("foreach-stack", []) + ) + + if successor_foreach_stack_len == cur_foreach_stack_len: + field_name = "foreach-indices" + field_value = self.metadata_dict.get(field_name) + elif successor_foreach_stack_len > cur_foreach_stack_len: + # We will compare the foreach-indices value of current task with the + # foreach-indices-truncated value of tasks in successor steps + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # We will compare the foreach-stack-truncated value of current task with the + # foreach-stack value of tasks in successor steps + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + + for successor_step in successor_steps: + successor_iters[successor_step] = ( + self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, successor_step, field_name, field_value + ) + ) + return successor_iters + # def closest_siblings(self) -> Iterator["Task"]: # """ # Returns an iterator over the closest siblings of this task. diff --git a/metaflow/task.py b/metaflow/task.py index 3f628572bf2..ef946d021d4 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -47,14 +47,13 @@ def _dynamic_runtime_metadata(foreach_stack): return foreach_indices, foreach_indices_truncated, foreach_step_names def _static_runtime_metadata(self, graph_info, step_name): - if step_name == "start": - return [] - - return [ + prev_steps = [ node_name for node_name, attributes in graph_info["steps"].items() if step_name in attributes["next"] ] + succesor_steps = graph_info["steps"][step_name]["next"] + return prev_steps, succesor_steps def __init__( self, @@ -667,7 +666,7 @@ def run_step( "graph_info": self.flow._graph_info, } ) - previous_steps = self._static_runtime_metadata( + previous_steps, successor_steps = self._static_runtime_metadata( self.flow._graph_info, step_name ) for deco in decorators: @@ -788,6 +787,12 @@ def run_step( type="previous_steps", tags=metadata_tags, ), + MetaDatum( + field="successor_steps", + value=successor_steps, + type="successor_steps", + tags=metadata_tags, + ), ], ) From 0ef1087f0156e15f2dd314cb65dd5285296a780b Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 31 Oct 2024 11:13:24 -0700 Subject: [PATCH 3/6] Add API for getting closest siblings --- metaflow/client/core.py | 42 ++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 2f330e99219..8676e99573b 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1241,17 +1241,37 @@ def _successor_task(flow_id, run_id, successor_step): ) return successor_iters - # def closest_siblings(self) -> Iterator["Task"]: - # """ - # Returns an iterator over the closest siblings of this task. - # - # Returns - # ------- - # Iterator[Task] - # Iterator over the closest siblings of this task - # """ - # flow_id, run_id, step_name, task_id = self.path_components - # print(f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}") + def closest_siblings(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of closest siblings of this task for each step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of closest siblings of this task. The keys are the + names of the current step and the values are the corresponding + task ids of the siblings. + """ + flow_id, run_id, step_name, task_id = self.path_components + + foreach_stack = self.metadata_dict.get("foreach-stack", []) + foreach_step_names = self.metadata_dict.get("foreach-step-names", []) + if len(foreach_stack) == 0: + raise MetaflowInternalError("Task is not part of any foreach split") + elif step_name != foreach_step_names[-1]: + raise MetaflowInternalError( + f"Step {step_name} does not have any direct siblings since it is not part " + f"of a new foreach split." + ) + + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices-truncated") + # We find all tasks of the same step that have the same foreach-indices-truncated value + return { + step_name: self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, step_name, field_name, field_value + ) + } @property def metadata(self) -> List[Metadata]: From ec43f14e6131b5173f015d520966362c567b501c Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 1 Nov 2024 11:34:18 -0700 Subject: [PATCH 4/6] Update metadata API params --- metaflow/metadata/metadata.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metaflow/metadata/metadata.py b/metaflow/metadata/metadata.py index a2d607eff7a..9e11515abde 100644 --- a/metaflow/metadata/metadata.py +++ b/metaflow/metadata/metadata.py @@ -674,16 +674,16 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): @classmethod def _filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, prev_step, field_name, field_value + cls, flow_id, run_id, step_name, query_step, field_name, field_value ): raise NotImplementedError() @classmethod def filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, prev_step, field_name, field_value + cls, flow_id, run_id, step_name, query_step, field_name, field_value ): task_ids = cls._filter_tasks_by_metadata( - flow_id, run_id, step_name, prev_step, field_name, field_value + flow_id, run_id, step_name, query_step, field_name, field_value ) return task_ids From a84f463cd41b5eb2e55fcbf6e066c11a910bc887 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 1 Nov 2024 15:13:53 -0700 Subject: [PATCH 5/6] Refactor ancestor and successor client code --- metaflow/client/core.py | 191 +++++++++++++++++----------------- metaflow/metadata/metadata.py | 7 +- metaflow/task.py | 4 - 3 files changed, 99 insertions(+), 103 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 8676e99573b..ba944bf2dd2 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,64 +1123,108 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def immediate_ancestors(self) -> Dict[str, List[str]]: + def _get_task_for_queried_step(self, flow_id, run_id, query_step): """ - Returns a dictionary of immediate ancestors task ids of this task for each - previous step. - - Returns - ------- - Dict[str, List[str]] - Dictionary of immediate ancestors of this task. The keys are the - names of the ancestors steps and the values are the corresponding - task ids of the ancestors. + Returns a Task object corresponding to the queried step. + If the queried step has several tasks, the first task is returned. """ + # Find any previous task for current step + step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No task found for the queried step {query_step}") - def _prev_task(flow_id, run_id, previous_step): - # Find any previous task for current step - step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) - task = next(iter(step.tasks()), None) - if task: - return task - raise MetaflowNotFound(f"No previous task found for step {previous_step}") - - flow_id, run_id, step_name, task_id = self.path_components - previous_steps = self.metadata_dict.get("previous_steps", None) - - if not previous_steps or len(previous_steps) == 0: - return - - cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) - ancestor_iters = {} - if len(previous_steps) > 1: + def _get_filter_query_value( + self, flow_id, run_id, cur_foreach_stack_len, query_steps, query_type + ): + """ + For a given query type, returns the field name and value to be used for filtering tasks + based on the task's metadata. + """ + if len(query_steps) > 1: # This is a static join, so there is no change in foreach stack length - prev_foreach_stack_len = cur_foreach_stack_len + query_foreach_stack_len = cur_foreach_stack_len else: - prev_task = _prev_task(flow_id, run_id, previous_steps[0]) - prev_foreach_stack_len = len( - prev_task.metadata_dict.get("foreach-stack", []) + query_task = self._get_task_for_queried_step( + flow_id, run_id, query_steps[0] + ) + query_foreach_stack_len = len( + query_task.metadata_dict.get("foreach-stack", []) ) - if prev_foreach_stack_len == cur_foreach_stack_len: + # print(f"query_foreach_stack_len: {query_foreach_stack_len} cur_foreach_stack_len: {cur_foreach_stack_len}") + if query_foreach_stack_len == cur_foreach_stack_len: field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) - elif prev_foreach_stack_len > cur_foreach_stack_len: - field_name = "foreach-indices-truncated" - field_value = self.metadata_dict.get("foreach-indices") + elif query_type == "ancestor": + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach join + # We will compare the foreach-indices-truncated value of ancestor task with the + # foreach-indices value of current task + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # This is a foreach split + # We will compare the foreach-indices value of ancestor task with the + # foreach-indices value of current task + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") else: - # We will compare the foreach-stack-truncated value of current task with the - # foreach-stack value of tasks in previous steps - field_name = "foreach-indices" - field_value = self.metadata_dict.get("foreach-indices-truncated") + if query_foreach_stack_len > cur_foreach_stack_len: + # This is a foreach split + # We will compare the foreach-indices value of current task with the + # foreach-indices-truncated value of successor tasks + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # This is a foreach join + # We will compare the foreach-indices-truncated value of current task with the + # foreach-indices value of successor tasks + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + return field_name, field_value + + def _get_related_tasks( + self, steps_key: str, relation_type: str + ) -> Dict[str, List[str]]: + flow_id, run_id, _, _ = self.path_components + query_steps = self.metadata_dict.get(steps_key) + + if not query_steps: + return {} + + field_name, field_value = self._get_filter_query_value( + flow_id, + run_id, + len(self.metadata_dict.get("foreach-stack", [])), + query_steps, + relation_type, + ) - for prev_step in previous_steps: - ancestor_iters[prev_step] = ( - self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, prev_step, field_name, field_value - ) + return { + query_step: self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, query_step, field_name, field_value ) - return ancestor_iters + for query_step in query_steps + } + + @property + def immediate_ancestors(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of immediate ancestors task ids of this task for each + previous step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of immediate ancestors of this task. The keys are the + names of the ancestors steps and the values are the corresponding + task ids of the ancestors. + """ + return self._get_related_tasks("previous_steps", "ancestor") + @property def immediate_successors(self) -> Dict[str, List[str]]: """ Returns a dictionary of immediate successors task ids of this task for each @@ -1193,55 +1237,10 @@ def immediate_successors(self) -> Dict[str, List[str]]: names of the successors steps and the values are the corresponding task ids of the successors. """ + return self._get_related_tasks("successor_steps", "successor") - def _successor_task(flow_id, run_id, successor_step): - # Find any previous task for current step - step = Step(f"{flow_id}/{run_id}/{successor_step}", _namespace_check=False) - task = next(iter(step.tasks()), None) - if task: - return task - raise MetaflowNotFound(f"No successor task found for step {successor_step}") - - flow_id, run_id, step_name, task_id = self.path_components - successor_steps = self.metadata_dict.get("successor_steps", None) - - if not successor_steps or len(successor_steps) == 0: - return - - cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) - successor_iters = {} - if len(successor_steps) > 1: - # This is a static split, so there is no change in foreach stack length - successor_foreach_stack_len = cur_foreach_stack_len - else: - successor_task = _successor_task(flow_id, run_id, successor_steps[0]) - successor_foreach_stack_len = len( - successor_task.metadata_dict.get("foreach-stack", []) - ) - - if successor_foreach_stack_len == cur_foreach_stack_len: - field_name = "foreach-indices" - field_value = self.metadata_dict.get(field_name) - elif successor_foreach_stack_len > cur_foreach_stack_len: - # We will compare the foreach-indices value of current task with the - # foreach-indices-truncated value of tasks in successor steps - field_name = "foreach-indices-truncated" - field_value = self.metadata_dict.get("foreach-indices") - else: - # We will compare the foreach-stack-truncated value of current task with the - # foreach-stack value of tasks in successor steps - field_name = "foreach-indices" - field_value = self.metadata_dict.get("foreach-indices-truncated") - - for successor_step in successor_steps: - successor_iters[successor_step] = ( - self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, successor_step, field_name, field_value - ) - ) - return successor_iters - - def closest_siblings(self) -> Dict[str, List[str]]: + @property + def immediate_siblings(self) -> Dict[str, List[str]]: """ Returns a dictionary of closest siblings of this task for each step. @@ -1252,13 +1251,13 @@ def closest_siblings(self) -> Dict[str, List[str]]: names of the current step and the values are the corresponding task ids of the siblings. """ - flow_id, run_id, step_name, task_id = self.path_components + flow_id, run_id, step_name, _ = self.path_components foreach_stack = self.metadata_dict.get("foreach-stack", []) foreach_step_names = self.metadata_dict.get("foreach-step-names", []) if len(foreach_stack) == 0: raise MetaflowInternalError("Task is not part of any foreach split") - elif step_name != foreach_step_names[-1]: + if step_name != foreach_step_names[-1]: raise MetaflowInternalError( f"Step {step_name} does not have any direct siblings since it is not part " f"of a new foreach split." @@ -1269,7 +1268,7 @@ def closest_siblings(self) -> Dict[str, List[str]]: # We find all tasks of the same step that have the same foreach-indices-truncated value return { step_name: self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step_name, step_name, field_name, field_value + flow_id, run_id, step_name, field_name, field_value ) } diff --git a/metaflow/metadata/metadata.py b/metaflow/metadata/metadata.py index 9e11515abde..ac713505099 100644 --- a/metaflow/metadata/metadata.py +++ b/metaflow/metadata/metadata.py @@ -674,16 +674,17 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): @classmethod def _filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, query_step, field_name, field_value + cls, flow_id, run_id, query_step, field_name, field_value ): raise NotImplementedError() @classmethod def filter_tasks_by_metadata( - cls, flow_id, run_id, step_name, query_step, field_name, field_value + cls, flow_id, run_id, query_step, field_name, field_value ): + # TODO: Do we need to do anything wrt to task attempt? task_ids = cls._filter_tasks_by_metadata( - flow_id, run_id, step_name, query_step, field_name, field_value + flow_id, run_id, query_step, field_name, field_value ) return task_ids diff --git a/metaflow/task.py b/metaflow/task.py index ef946d021d4..c4e78e00660 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -611,10 +611,6 @@ def run_step( # user artifacts in the user's step code. if join_type: - if join_type == "foreach": - # We only want to persist one of the input paths - self.flow._input_paths = str(input_paths[0]) - # Join step: # Ensure that we have the right number of inputs. The From ffbf68ad120f2c44dd1d77e041a720221d283940 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Fri, 1 Nov 2024 15:36:34 -0700 Subject: [PATCH 6/6] Remove unneccessary prints --- metaflow/client/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index ba944bf2dd2..cb2ca382f4b 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1153,7 +1153,6 @@ def _get_filter_query_value( query_task.metadata_dict.get("foreach-stack", []) ) - # print(f"query_foreach_stack_len: {query_foreach_stack_len} cur_foreach_stack_len: {cur_foreach_stack_len}") if query_foreach_stack_len == cur_foreach_stack_len: field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name)