diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index 3413d172ff..05a783d034 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -15,14 +15,15 @@ class FlyteUserException(_FlyteException): class FlyteUserRuntimeException(_FlyteException): _ERROR_CODE = "USER:RuntimeError" - def __init__(self, exc_value: Exception): + def __init__(self, exc_value: Exception, timestamp: typing.Optional[float] = None): """ FlyteUserRuntimeException is thrown when a user code raises an exception. :param exc_value: The exception that was raised from user code. + :param timestamp: The timestamp as fractional seconds since epoch when the exception was raised. """ self._exc_value = exc_value - super().__init__(str(exc_value)) + super().__init__(str(exc_value), timestamp=timestamp) @property def value(self): diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index a951bea0a5..4bbcb814a4 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -18,7 +18,7 @@ from flytekit.core.context_manager import FlyteContextManager, OutputMetadata from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import convert_resources_to_resource_model -from flytekit.exceptions.user import FlyteRecoverableException +from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException from flytekit.extend import IgnoreOutputs, TaskPlugins from flytekit.loggers import logger @@ -475,7 +475,7 @@ def fn_partial(): # the automatically assigned timestamp based on exception creation time raise FlyteRecoverableException(e.format_msg(), timestamp=first_failure.timestamp) else: - raise RuntimeError(e.format_msg()) + raise FlyteUserRuntimeException(e, timestamp=first_failure.timestamp) except SignalException as e: logger.exception(f"Elastic launch agent process terminating: {e}") raise IgnoreOutputs() diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index faadc1019f..f8742d1fe9 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -17,7 +17,7 @@ from flytekit import task, workflow from flytekit.core.context_manager import FlyteContext, FlyteContextManager, ExecutionState, ExecutionParameters, OutputMetadataTracker from flytekit.configuration import SerializationSettings -from flytekit.exceptions.user import FlyteRecoverableException +from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException @pytest.fixture(autouse=True, scope="function") def restore_env(): @@ -223,7 +223,7 @@ def wf(recoverable: bool): with pytest.raises(FlyteRecoverableException): wf(recoverable=recoverable) else: - with pytest.raises(RuntimeError): + with pytest.raises(FlyteUserRuntimeException): wf(recoverable=recoverable) @@ -276,3 +276,37 @@ def test_task_omp_set(): assert os.environ["OMP_NUM_THREADS"] == "42" test_task_omp_set() + + +def test_exception_timestamp() -> None: + """Test that the timestamp of the worker process exception is propagated to the task exception.""" + @task( + task_config=Elastic( + nnodes=1, + nproc_per_node=2, + ) + ) + def test_task(): + raise Exception("Test exception") + + with pytest.raises(Exception) as e: + test_task() + + assert e.value.timestamp is not None + + +def test_recoverable_exception_timestamp() -> None: + """Test that the timestamp of the worker process exception is propagated to the task exception.""" + @task( + task_config=Elastic( + nnodes=1, + nproc_per_node=2, + ) + ) + def test_task(): + raise FlyteRecoverableException("Recoverable test exception") + + with pytest.raises(Exception) as e: + test_task() + + assert e.value.timestamp is not None