Skip to content

Commit

Permalink
[easy] Show failed_node_id in failure node local execution (#2334)
Browse files Browse the repository at this point in the history
  • Loading branch information
Future-Outlier authored Apr 7, 2024
1 parent b475c87 commit bf38b8e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
3 changes: 2 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
except Exception as exc:
if self.on_failure:
if self.on_failure.python_interface and "err" in self.on_failure.python_interface.inputs:
input_kwargs["err"] = FlyteError(failed_node_id="", message=str(exc))
id = self.failure_node.id if self.failure_node else ""
input_kwargs["err"] = FlyteError(failed_node_id=id, message=str(exc))
self.on_failure(**input_kwargs)
raise exc

Expand Down
48 changes: 47 additions & 1 deletion tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import typing
from collections import OrderedDict
from unittest.mock import patch

import pytest
from typing_extensions import Annotated # type: ignore
Expand All @@ -15,6 +16,7 @@
from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.tools.translator import get_serializable
from flytekit.types.error.error import FlyteError

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -51,7 +53,7 @@ def t1(a: int) -> typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c",
@workflow(interruptible=True, failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)
def wf(a: int) -> typing.Tuple[str, str]:
x, y = t1(a=a)
u, v = t1(a=x)
_, v = t1(a=x)
return y, v

wf_spec = get_serializable(OrderedDict(), serialization_settings, wf)
Expand Down Expand Up @@ -435,3 +437,47 @@ def wf():
t4()

assert ctx.compilation_state is None


@patch("builtins.print")
def test_failure_node_local_execution(mock_print):
@task
def clean_up(name: str, err: typing.Optional[FlyteError] = None):
print(f"Deleting cluster {name} due to {err}")
print("This is err:", str(err))

@task
def create_cluster(name: str):
print(f"Creating cluster: {name}")

@task
def delete_cluster(name: str, err: typing.Optional[FlyteError] = None):
print(f"Deleting cluster {name}")
print(err)

@task
def t1(a: int, b: str):
print(f"{a} {b}")
raise ValueError("Fail!")

@workflow(on_failure=clean_up)
def wf(name: str = "flyteorg"):
c = create_cluster(name=name)
t = t1(a=1, b="2")
d = delete_cluster(name=name)
c >> t >> d

with pytest.raises(ValueError):
wf()

# Adjusted the error message to match the one in the failure
expected_error_message = str(
FlyteError(message="Error encountered while executing 'wf':\n Fail!", failed_node_id="fn0")
)

assert mock_print.call_count > 0

mock_print.assert_any_call("Creating cluster: flyteorg")
mock_print.assert_any_call("1 2")
mock_print.assert_any_call(f"Deleting cluster flyteorg due to {expected_error_message}")
mock_print.assert_any_call("This is err:", expected_error_message)

0 comments on commit bf38b8e

Please sign in to comment.