Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Always propagate pytorch task worker process exception timestamp to task exception #3057

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
@@ -15,14 +15,15 @@
class FlyteUserRuntimeException(_FlyteException):
_ERROR_CODE = "USER:RuntimeError"

def __init__(self, exc_value: Exception):
def __init__(self, exc_value: Exception, timestamp: typing.Optional[float] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding timestamp parameter validation

Consider adding validation for the timestamp parameter to ensure it's a valid timestamp value when provided.

Code suggestion
Check the AI-generated fix before applying
 @@ -18,2 +18,5 @@
      def __init__(self, exc_value: Exception, timestamp: typing.Optional[float] = None):
 +        if timestamp is not None and (not isinstance(timestamp, (int, float)) or timestamp < 0):
 +            raise ValueError("timestamp must be a non-negative number representing seconds since epoch")
          super().__init__(str(exc_value), timestamp=timestamp)

Code Review Run #3d1353


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Sorry, something went wrong.

"""
FlyteUserRuntimeException is thrown when a user code raises an exception.

:param exc_value: The exception that was raised from user code.
:param timestamp: The timestamp as fractional seconds since epoch when the exception was raised.
"""
self._exc_value = exc_value
super().__init__(str(exc_value))
super().__init__(str(exc_value), timestamp=timestamp)

Check warning on line 26 in flytekit/exceptions/user.py

Codecov / codecov/patch

flytekit/exceptions/user.py#L26

Added line #L26 was not covered by tests

@property
def value(self):
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException
from flytekit.extend import IgnoreOutputs, TaskPlugins
from flytekit.loggers import logger

@@ -475,7 +475,7 @@ def fn_partial():
# the automatically assigned timestamp based on exception creation time
raise FlyteRecoverableException(e.format_msg(), timestamp=first_failure.timestamp)
else:
raise RuntimeError(e.format_msg())
raise FlyteUserRuntimeException(e, timestamp=first_failure.timestamp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider preserving original error message

Consider preserving the original error message from e.format_msg() when raising FlyteUserRuntimeException. The error message could provide valuable debugging context that is currently being lost.

Code suggestion
Check the AI-generated fix before applying
Suggested change
raise FlyteUserRuntimeException(e, timestamp=first_failure.timestamp)
raise FlyteUserRuntimeException(e.format_msg(), timestamp=first_failure.timestamp)

Code Review Run #3d1353


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Sorry, something went wrong.

except SignalException as e:
logger.exception(f"Elastic launch agent process terminating: {e}")
raise IgnoreOutputs()
34 changes: 34 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
@@ -276,3 +276,37 @@ def test_task_omp_set():
assert os.environ["OMP_NUM_THREADS"] == "42"

test_task_omp_set()


def test_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise Exception("Test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None


def test_recoverable_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise FlyteRecoverableException("Recoverable test exception")

with pytest.raises(Exception) as e:
test_task()

assert e.value.timestamp is not None
Comment on lines +281 to +312
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider consolidating duplicate test functions

Consider consolidating the two test functions test_exception_timestamp() and test_recoverable_exception_timestamp() into a single parameterized test since they follow the same pattern and only differ in the exception type.

Code suggestion
Check the AI-generated fix before applying
Suggested change
def test_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise Exception("Test exception")
with pytest.raises(Exception) as e:
test_task()
assert e.value.timestamp is not None
def test_recoverable_exception_timestamp() -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
@task(
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise FlyteRecoverableException("Recoverable test exception")
with pytest.raises(Exception) as e:
test_task()
assert e.value.timestamp is not None
def test_recoverable_exception_timestamp() -> None:
@pytest.mark.parametrize("exception_cls,message", [
(Exception, "Test exception"),
(FlyteRecoverableException, "Recoverable test exception")
])
def test_exception_timestamp(exception_cls, message) -> None:
"""Test that the timestamp of the worker process exception is propagated to the task exception."""
task_config=Elastic(
nnodes=1,
nproc_per_node=2,
)
)
def test_task():
raise exception_cls(message)
with pytest.raises(Exception) as e:
test_task()
assert e.value.timestamp is not None

Code Review Run #3d1353


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Sorry, something went wrong.