Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static and runtime dag info, API to fetch ancestor tasks #2124

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,154 @@ def _iter_filter(self, x):
# exclude private data artifacts
return x.id[0] != "_"

def _get_task_for_queried_step(self, flow_id, run_id, query_step):
"""
Returns a Task object corresponding to the queried step.
If the queried step has several tasks, the first task is returned.
"""
# Find any previous task for current step
step = Step(f"{flow_id}/{run_id}/{query_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
return task
raise MetaflowNotFound(f"No task found for the queried step {query_step}")

def _get_filter_query_value(
self, flow_id, run_id, cur_foreach_stack_len, query_steps, query_type
):
"""
For a given query type, returns the field name and value to be used for filtering tasks
based on the task's metadata.
"""
if len(query_steps) > 1:
# This is a static join, so there is no change in foreach stack length
query_foreach_stack_len = cur_foreach_stack_len
else:
query_task = self._get_task_for_queried_step(
flow_id, run_id, query_steps[0]
)
query_foreach_stack_len = len(
query_task.metadata_dict.get("foreach-stack", [])
)

if query_foreach_stack_len == cur_foreach_stack_len:
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
elif query_type == "ancestor":
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach join
# We will compare the foreach-indices-truncated value of ancestor task with the
# foreach-indices value of current task
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
else:
# This is a foreach split
# We will compare the foreach-indices value of ancestor task with the
# foreach-indices value of current task
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
else:
if query_foreach_stack_len > cur_foreach_stack_len:
# This is a foreach split
# We will compare the foreach-indices value of current task with the
# foreach-indices-truncated value of successor tasks
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
else:
# This is a foreach join
# We will compare the foreach-indices-truncated value of current task with the
# foreach-indices value of successor tasks
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")
return field_name, field_value

def _get_related_tasks(
self, steps_key: str, relation_type: str
) -> Dict[str, List[str]]:
flow_id, run_id, _, _ = self.path_components
query_steps = self.metadata_dict.get(steps_key)

if not query_steps:
return {}

field_name, field_value = self._get_filter_query_value(
flow_id,
run_id,
len(self.metadata_dict.get("foreach-stack", [])),
query_steps,
relation_type,
)

return {
query_step: self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, query_step, field_name, field_value
)
for query_step in query_steps
}

@property
def immediate_ancestors(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of immediate ancestors task ids of this task for each
previous step.

Returns
-------
Dict[str, List[str]]
Dictionary of immediate ancestors of this task. The keys are the
names of the ancestors steps and the values are the corresponding
task ids of the ancestors.
"""
return self._get_related_tasks("previous_steps", "ancestor")

@property
def immediate_successors(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of immediate successors task ids of this task for each
previous step.

Returns
-------
Dict[str, List[str]]
Dictionary of immediate successors of this task. The keys are the
names of the successors steps and the values are the corresponding
task ids of the successors.
"""
return self._get_related_tasks("successor_steps", "successor")

@property
def immediate_siblings(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of closest siblings of this task for each step.

Returns
-------
Dict[str, List[str]]
Dictionary of closest siblings of this task. The keys are the
names of the current step and the values are the corresponding
task ids of the siblings.
"""
flow_id, run_id, step_name, _ = self.path_components

foreach_stack = self.metadata_dict.get("foreach-stack", [])
foreach_step_names = self.metadata_dict.get("foreach-step-names", [])
if len(foreach_stack) == 0:
raise MetaflowInternalError("Task is not part of any foreach split")
if step_name != foreach_step_names[-1]:
raise MetaflowInternalError(
f"Step {step_name} does not have any direct siblings since it is not part "
f"of a new foreach split."
)

field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices-truncated")
# We find all tasks of the same step that have the same foreach-indices-truncated value
return {
step_name: self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, field_name, field_value
)
}

@property
def metadata(self) -> List[Metadata]:
"""
Expand Down
16 changes: 16 additions & 0 deletions metaflow/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,22 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt):
if metadata:
self.register_metadata(run_id, step_name, task_id, metadata)

@classmethod
def _filter_tasks_by_metadata(
cls, flow_id, run_id, query_step, field_name, field_value
):
raise NotImplementedError()

@classmethod
def filter_tasks_by_metadata(
cls, flow_id, run_id, query_step, field_name, field_value
):
# TODO: Do we need to do anything wrt to task attempt?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not, as the ancestors for task attempts should be identical, right? What about the immediate_siblings though, will they include or exclude attempts of the same task?

task_ids = cls._filter_tasks_by_metadata(
flow_id, run_id, query_step, field_name, field_value
)
return task_ids

Comment on lines +675 to +690
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a need for the private method, or could this simply be contained in the public-facing one? right now its not doing anything before calling the private one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, did you have an implementation of this for service.py yet?

@staticmethod
def _apply_filter(elts, filters):
if filters is None:
Expand Down
66 changes: 63 additions & 3 deletions metaflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
import os
import time
import json
import hashlib
import traceback

from types import MethodType, FunctionType
Expand Down Expand Up @@ -37,6 +39,22 @@ class MetaflowTask(object):
MetaflowTask prepares a Flow instance for execution of a single step.
"""

@staticmethod
def _dynamic_runtime_metadata(foreach_stack):
foreach_indices = [foreach_frame.index for foreach_frame in foreach_stack]
foreach_indices_truncated = foreach_indices[:-1]
foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack]
return foreach_indices, foreach_indices_truncated, foreach_step_names

def _static_runtime_metadata(self, graph_info, step_name):
prev_steps = [
node_name
for node_name, attributes in graph_info["steps"].items()
if step_name in attributes["next"]
]
succesor_steps = graph_info["steps"][step_name]["next"]
return prev_steps, succesor_steps

def __init__(
self,
flow,
Expand Down Expand Up @@ -493,6 +511,33 @@ def run_step(
)
)

# Add runtime dag info
foreach_indices, foreach_indices_truncated, foreach_step_names = (
self._dynamic_runtime_metadata(foreach_stack)
)
metadata.extend(
[
MetaDatum(
field="foreach-indices",
value=foreach_indices,
type="foreach-indices",
tags=metadata_tags,
),
MetaDatum(
field="foreach-indices-truncated",
value=foreach_indices_truncated,
type="foreach-indices-truncated",
tags=metadata_tags,
),
MetaDatum(
field="foreach-step-names",
value=foreach_step_names,
type="foreach-step-names",
tags=metadata_tags,
),
]
)

self.metadata.register_metadata(
run_id,
step_name,
Expand Down Expand Up @@ -559,6 +604,7 @@ def run_step(
self.flow._success = False
self.flow._task_ok = None
self.flow._exception = None

# Note: All internal flow attributes (ie: non-user artifacts)
# should either be set prior to running the user code or listed in
# FlowSpec._EPHEMERAL to allow for proper merging/importing of
Expand Down Expand Up @@ -616,7 +662,9 @@ def run_step(
"graph_info": self.flow._graph_info,
}
)

previous_steps, successor_steps = self._static_runtime_metadata(
self.flow._graph_info, step_name
)
for deco in decorators:
deco.task_pre_step(
step_name,
Expand Down Expand Up @@ -727,8 +775,20 @@ def run_step(
field="attempt_ok",
value=attempt_ok,
type="internal_attempt_status",
tags=["attempt_id:{0}".format(retry_count)],
)
tags=metadata_tags,
),
MetaDatum(
field="previous_steps",
value=previous_steps,
type="previous_steps",
tags=metadata_tags,
),
MetaDatum(
field="successor_steps",
value=successor_steps,
type="successor_steps",
tags=metadata_tags,
),
],
)

Expand Down
Loading