diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 87b6a88c37c..2403181156c 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -379,7 +379,7 @@ def __iter__(self) -> Iterator["MetaflowObject"]: _CLASSES[self._CHILD_CLASS]._NAME, query_filter, self._attempt, - *self.path_components + *self.path_components, ) unfiltered_children = unfiltered_children if unfiltered_children else [] children = filter( @@ -1123,6 +1123,143 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" + def _iter_matching_tasks(self, steps, metadata_key, metadata_pattern): + """ + Yield tasks from specified steps matching a foreach path pattern. + + Parameters + ---------- + steps : List[str] + List of step names to search for tasks + pattern : str + Regex pattern to match foreach-indices metadata + + Returns + ------- + Iterator[Task] + Tasks matching the foreach path pattern + """ + flow_id, run_id, _, _ = self.path_components + + for step in steps: + task_pathspecs = self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step.id, metadata_key, metadata_pattern + ) + for task_pathspec in task_pathspecs: + yield Task(pathspec=task_pathspec, _namespace_check=False) + + @property + def parent_tasks(self) -> Iterator["Task"]: + """ + Yields all parent tasks of the current task if one exists. + + Yields + ------ + Task + Parent task of the current task + + """ + flow_id, run_id, _, _ = self.path_components + + steps = list(self.parent.parent_steps) + if not steps: + return [] + + current_path = self.metadata_dict.get("foreach-execution-path", "") + + if len(steps) > 1: + # Static join - use exact path matching + pattern = current_path or ".*" + yield from self._iter_matching_tasks( + steps, "foreach-execution-path", pattern + ) + return + + # Handle single step case + target_task = Step( + f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False + ).task + target_path = target_task.metadata_dict.get("foreach-execution-path") + + if not target_path or not current_path: + # (Current task, "A:10") and (Parent task, "") + # Pattern: ".*" + pattern = ".*" + else: + current_depth = len(current_path.split(",")) + target_depth = len(target_path.split(",")) + + if current_depth < target_depth: + # Foreach join + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach split or linear step + # Option 1: + # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") + # Option 2: + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") + # Pattern: "A:10,B:13" + pattern = ",".join(current_path.split(",")[:target_depth]) + + yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + + @property + def child_tasks(self) -> Iterator["Task"]: + """ + Yield all child tasks of the current task if one exists. + + Yields + ------ + Task + Child task of the current task + """ + flow_id, run_id, _, _ = self.path_components + steps = list(self.parent.child_steps) + if not steps: + return [] + + current_path = self.metadata_dict.get("foreach-execution-path", "") + + if len(steps) > 1: + # Static split - use exact path matching + pattern = current_path or ".*" + yield from self._iter_matching_tasks( + steps, "foreach-execution-path", pattern + ) + return + + # Handle single step case + target_task = Step( + f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False + ).task + target_path = target_task.metadata_dict.get("foreach-execution-path") + + if not target_path or not current_path: + # (Current task, "A:10") and (Child task, "") + # Pattern: ".*" + pattern = ".*" + else: + current_depth = len(current_path.split(",")) + target_depth = len(target_path.split(",")) + + if current_depth < target_depth: + # Foreach split + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach join or linear step + # Option 1: + # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") + # Option 2: + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") + # Pattern: "A:10,B:13" + pattern = ",".join(current_path.split(",")[:target_depth]) + + yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + @property def metadata(self) -> List[Metadata]: """ @@ -1837,6 +1974,41 @@ def environment_info(self) -> Optional[Dict[str, Any]]: for t in self: return t.environment_info + @property + def parent_steps(self) -> Iterator["Step"]: + """ + Yields parent steps for the current step. + + Yields + ------ + Step + Parent step + """ + graph_info = self.task["_graph_info"].data + + if self.id != "start": + flow, run, _ = self.path_components + for node_name, attributes in graph_info["steps"].items(): + if self.id in attributes["next"]: + yield Step(f"{flow}/{run}/{node_name}", _namespace_check=False) + + @property + def child_steps(self) -> Iterator["Step"]: + """ + Yields child steps for the current step. + + Yields + ------ + Step + Child step + """ + graph_info = self.task["_graph_info"].data + + if self.id != "end": + flow, run, _ = self.path_components + for next_step in graph_info["steps"][self.id]["next"]: + yield Step(f"{flow}/{run}/{next_step}", _namespace_check=False) + class Run(MetaflowObject): """ diff --git a/metaflow/metadata_provider/metadata.py b/metaflow/metadata_provider/metadata.py index 11c3873a85e..11a80462b8c 100644 --- a/metaflow/metadata_provider/metadata.py +++ b/metaflow/metadata_provider/metadata.py @@ -5,6 +5,7 @@ from collections import namedtuple from itertools import chain +from typing import List from metaflow.exception import MetaflowInternalError, MetaflowTaggingError from metaflow.tagging_util import validate_tag from metaflow.util import get_username, resolve_identity_as_tuple, is_stringish @@ -672,6 +673,38 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): if metadata: self.register_metadata(run_id, step_name, task_id, metadata) + @classmethod + def filter_tasks_by_metadata( + cls, + flow_name: str, + run_id: str, + step_name: str, + field_name: str, + pattern: str, + ) -> List[str]: + """ + Filter tasks by metadata field and pattern, returning task pathspecs that match criteria. + + Parameters + ---------- + flow_name : str + Flow name, that the run belongs to. + run_id: str + Run id, together with flow_id, that identifies the specific Run whose tasks to query + step_name: str + Step name to query tasks from + field_name: str + Metadata field name to query + pattern: str + Pattern to match in metadata field value + + Returns + ------- + List[str] + List of task pathspecs that satisfy the query + """ + raise NotImplementedError() + @staticmethod def _apply_filter(elts, filters): if filters is None: diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index ea7754cac5f..74de40a61e8 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -2,10 +2,12 @@ import glob import json import os +import re import random import tempfile import time from collections import namedtuple +from typing import List from metaflow.exception import MetaflowInternalError, MetaflowTaggingError from metaflow.metadata_provider.metadata import ObjectOrder @@ -202,6 +204,70 @@ def _optimistically_mutate(): "Tagging failed due to too many conflicting updates from other processes" ) + @classmethod + def filter_tasks_by_metadata( + cls, + flow_name: str, + run_id: str, + step_name: str, + field_name: str, + pattern: str, + ) -> List[str]: + """ + Filter tasks by metadata field and pattern, returning task pathspecs that match criteria. + + Parameters + ---------- + flow_name : str + Identifier for the flow + run_id : str + Identifier for the run + step_name : str + Name of the step to query tasks from + field_name : str + Name of metadata field to query + pattern : str + Pattern to match in metadata field value + + Returns + ------- + List[str] + List of task pathspecs that match the query criteria + """ + tasks = cls.get_object("step", "task", {}, None, flow_name, run_id, step_name) + if not tasks: + return [] + + regex = re.compile(pattern) + matching_task_pathspecs = [] + + for task in tasks: + task_id = task.get("task_id") + if not task_id: + continue + + if pattern == ".*": + # If the pattern is ".*", we can match all tasks without reading metadata + matching_task_pathspecs.append( + f"{flow_name}/{run_id}/{step_name}/{task_id}" + ) + continue + + metadata = cls.get_object( + "task", "metadata", {}, None, flow_name, run_id, step_name, task_id + ) + + if any( + meta.get("field_name") == field_name + and regex.match(meta.get("value", "")) + for meta in metadata + ): + matching_task_pathspecs.append( + f"{flow_name}/{run_id}/{step_name}/{task_id}" + ) + + return matching_task_pathspecs + @classmethod def _get_object_internal( cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args diff --git a/metaflow/plugins/metadata_providers/service.py b/metaflow/plugins/metadata_providers/service.py index 3c1fd588e82..5fd6ce25848 100644 --- a/metaflow/plugins/metadata_providers/service.py +++ b/metaflow/plugins/metadata_providers/service.py @@ -4,6 +4,7 @@ import requests +from typing import List from metaflow.exception import ( MetaflowException, MetaflowInternalError, @@ -13,6 +14,7 @@ from metaflow.metadata_provider.heartbeat import HB_URL_KEY from metaflow.metaflow_config import SERVICE_HEADERS, SERVICE_RETRY_COUNT, SERVICE_URL from metaflow.sidecar import Message, MessageTypes, Sidecar +from urllib.parse import urlencode from metaflow.util import version_parse @@ -318,6 +320,55 @@ def _new_task( self._register_system_metadata(run_id, step_name, task["task_id"], attempt) return task["task_id"], did_create + @classmethod + def filter_tasks_by_metadata( + cls, + flow_name: str, + run_id: str, + step_name: str, + field_name: str, + pattern: str, + ) -> List[str]: + """ + Filter tasks by metadata field and pattern, returning task pathspecs that match criteria. + + Parameters + ---------- + flow_name : str + Flow name, that the run belongs to. + run_id: str + Run id, together with flow_id, that identifies the specific Run whose tasks to query + step_name: str + Step name to query tasks from + field_name: str + Metadata field name to query + pattern: str + Pattern to match in metadata field value + + Returns + ------- + List[str] + List of task pathspecs that satisfy the query + """ + query_params = { + "metadata_field_name": field_name, + "pattern": pattern, + "step_name": step_name, + } + url = ServiceMetadataProvider._obj_path(flow_name, run_id, step_name) + url = f"{url}/filtered_tasks?{urlencode(query_params)}" + try: + resp = cls._request(None, url, "GET") + except Exception as e: + if e.http_code == 404: + # filter_tasks_by_metadata endpoint does not exist in the version of metadata service + # deployed currently. Raise a more informative error message. + raise MetaflowInternalError( + "The version of metadata service deployed currently does not support filtering tasks by metadata. " + "Upgrade Metadata service to version 2.15 or greater to use this feature." + ) from e + return resp + @staticmethod def _obj_path( flow_name, diff --git a/metaflow/task.py b/metaflow/task.py index 6b73302652b..414b7e54710 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -493,6 +493,25 @@ def run_step( ) ) + # Add runtime dag information to the metadata of the task + foreach_execution_path = ",".join( + [ + "{}:{}".format(foreach_frame.step, foreach_frame.index) + for foreach_frame in foreach_stack + ] + ) + if foreach_execution_path: + metadata.extend( + [ + MetaDatum( + field="foreach-execution-path", + value=foreach_execution_path, + type="foreach-execution-path", + tags=metadata_tags, + ), + ] + ) + self.metadata.register_metadata( run_id, step_name, @@ -559,6 +578,7 @@ def run_step( self.flow._success = False self.flow._task_ok = None self.flow._exception = None + # Note: All internal flow attributes (ie: non-user artifacts) # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of @@ -616,7 +636,6 @@ def run_step( "graph_info": self.flow._graph_info, } ) - for deco in decorators: deco.task_pre_step( step_name, @@ -728,7 +747,7 @@ def run_step( value=attempt_ok, type="internal_attempt_status", tags=["attempt_id:{0}".format(retry_count)], - ) + ), ], ) diff --git a/test/core/tests/runtime_dag.py b/test/core/tests/runtime_dag.py new file mode 100644 index 00000000000..da623b967f7 --- /dev/null +++ b/test/core/tests/runtime_dag.py @@ -0,0 +1,183 @@ +from metaflow_test import MetaflowTest, ExpectationFailed, steps + + +class RuntimeDagTest(MetaflowTest): + """ + Test that `parent_tasks` and `child_tasks` API returns correct parent and child tasks + respectively by comparing task ids stored during step execution. + """ + + PRIORITY = 1 + + @steps(0, ["start"]) + def step_start(self): + from metaflow import current + + self.step_name = current.step_name + self.task_pathspec = current.pathspec + self.parent_pathspecs = set() + + @steps(1, ["join"]) + def step_join(self): + from metaflow import current + + self.step_name = current.step_name + + # Store the parent task ids + # Store the task pathspec for all the parent tasks + self.parent_pathspecs = set(inp.task_pathspec for inp in inputs) + + # Set the current task id + self.task_pathspec = current.pathspec + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + @steps(2, ["all"]) + def step_all(self): + from metaflow import current + + self.step_name = current.step_name + # Store the parent task ids + # Task only has one parent, so we store the parent task id + self.parent_pathspecs = set([self.task_pathspec]) + + # Set the current task id + self.task_pathspec = current.pathspec + + print( + f"Task Pathspec: {self.task_pathspec} and parent_pathspecs: {self.parent_pathspecs}" + ) + + def check_results(self, flow, checker): + def _equals_task(task1, task2): + # Verify that two task instances are equal + # by comparing all their properties + properties = [ + name + for name, value in type(task1).__dict__.items() + if isinstance(value, property) + if name + not in ["parent_tasks", "child_tasks", "metadata", "data", "artifacts"] + ] + + for prop_name in properties: + value1 = getattr(task1, prop_name) + value2 = getattr(task2, prop_name) + if value1 != value2: + raise Exception( + f"Value {value1} of property {prop_name} of task {task1} does not match the expected" + f" value {value2} of task {task2}" + ) + return True + + def _verify_parent_tasks(task): + # Verify that the parent tasks are correct + from metaflow import Task + + parent_tasks = list(task.parent_tasks) + expected_parent_pathspecs = task.data.parent_pathspecs + actual_parent_pathspecs = set([task.pathspec for task in parent_tasks]) + assert actual_parent_pathspecs == expected_parent_pathspecs, ( + f"Mismatch in ancestor task pathspecs for task {task.pathspec}: Expected {expected_parent_pathspecs}, " + f"got {actual_parent_pathspecs}." + ) + + # Verify that all attributes of the parent tasks match the expected values + expected_parent_pathspecs_dict = { + pathspec: Task(pathspec, _namespace_check=False) + for pathspec in expected_parent_pathspecs + } + for parent_task in parent_tasks: + expected_parent_task = expected_parent_pathspecs_dict[ + parent_task.pathspec + ] + + try: + assert _equals_task(parent_task, expected_parent_task), ( + f"Expected parent task {expected_parent_task} does not match " + f"the actual parent task {parent_task}." + ) + except Exception as e: + raise AssertionError( + f"Comparison failed with error: {str(e)}\n" + f"Expected parent task: {expected_parent_task}\n" + f"Actual parent task: {parent_task}" + ) from e + + def _verify_child_tasks(task): + # Verify that the child tasks are correct + from metaflow import Task + + cur_task_pathspec = task.pathspec + child_tasks = task.child_tasks + actual_children_pathspecs_set = set([task.pathspec for task in child_tasks]) + expected_children_pathspecs_set = set() + + # Get child steps for the current task + child_steps = task.parent.child_steps + + # Verify that the current task pathspec is in the parent_pathspecs of the child tasks + for child_task in child_tasks: + assert task.pathspec in child_task.data.parent_pathspecs, ( + f"Task {task.pathspec} is not in the `parent_pathspecs` of the successor task " + f"{child_task.pathspec}" + ) + + # Identify all the expected children pathspecs by iterating over all the tasks + # in the child steps + for child_step in child_steps: + for child_task in child_step: + if cur_task_pathspec in child_task.data.parent_pathspecs: + expected_children_pathspecs_set.add(child_task.pathspec) + + # Assert that None of the tasks in the successor steps have the current task in their + # parent_pathspecs + assert actual_children_pathspecs_set == expected_children_pathspecs_set, ( + f"Expected children pathspecs: {expected_children_pathspecs_set}, got " + f"{actual_children_pathspecs_set}" + ) + + # Verify that all attributes of the child tasks match the expected values + expected_children_pathspecs_dict = { + pathspec: Task(pathspec, _namespace_check=False) + for pathspec in expected_children_pathspecs_set + } + for child_task in child_tasks: + expected_child_task = expected_children_pathspecs_dict[ + child_task.pathspec + ] + + try: + assert _equals_task(child_task, expected_child_task), ( + f"Expected child task {expected_child_task} does not match " + f"the actual child task {child_task}." + ) + except Exception as e: + raise AssertionError( + f"Comparison failed with error: {str(e)}\n" + f"Expected child task: {expected_child_task}\n" + f"Actual child task: {child_task}" + ) from e + + from itertools import chain + + run = checker.get_run() + + if run is None: + print("Run is None") + # very basic sanity check for CLI checker + for step in flow: + checker.assert_artifact(step.name, "step_name", step.name) + return + + # For each step in the flow + for step in run: + # For each task in the step + for task in step: + # Verify that the parent tasks are correct + _verify_parent_tasks(task) + + # Verify that the child tasks are correct + _verify_child_tasks(task)