From c0c17483d1f692ad639144050f5f6fa0966e47a5 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Sat, 26 Oct 2024 16:30:52 -0700 Subject: [PATCH] [Jobs] Refactor: Extract task failure state update helper (#4185) refactor: a unified exception handling utility --- sky/jobs/controller.py | 61 +++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 1faa5dfbe31..73d509be9ef 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -340,48 +340,28 @@ def run(self): common_utils.format_exception(reason, use_bracket=True) for reason in e.reasons)) logger.error(failure_reason) - managed_job_state.set_failed( - self._job_id, - task_id=task_id, - failure_type=managed_job_state.ManagedJobStatus. - FAILED_PRECHECKS, - failure_reason=failure_reason, - callback_func=managed_job_utils.event_callback_func( - job_id=self._job_id, - task_id=task_id, - task=self._dag.tasks[task_id])) + self._update_failed_task_state( + task_id, managed_job_state.ManagedJobStatus.FAILED_PRECHECKS, + failure_reason) except exceptions.ManagedJobReachedMaxRetriesError as e: # Please refer to the docstring of self._run for the cases when # this exception can occur. - logger.error(common_utils.format_exception(e)) + failure_reason = common_utils.format_exception(e) + logger.error(failure_reason) # The managed job should be marked as FAILED_NO_RESOURCE, as the # managed job may be able to launch next time. - managed_job_state.set_failed( - self._job_id, - task_id=task_id, - failure_type=managed_job_state.ManagedJobStatus. - FAILED_NO_RESOURCE, - failure_reason=common_utils.format_exception(e), - callback_func=managed_job_utils.event_callback_func( - job_id=self._job_id, - task_id=task_id, - task=self._dag.tasks[task_id])) + self._update_failed_task_state( + task_id, managed_job_state.ManagedJobStatus.FAILED_NO_RESOURCE, + failure_reason) except (Exception, SystemExit) as e: # pylint: disable=broad-except with ux_utils.enable_traceback(): logger.error(traceback.format_exc()) - msg = ('Unexpected error occurred: ' - f'{common_utils.format_exception(e, use_bracket=True)}') + msg = ('Unexpected error occurred: ' + + common_utils.format_exception(e, use_bracket=True)) logger.error(msg) - managed_job_state.set_failed( - self._job_id, - task_id=task_id, - failure_type=managed_job_state.ManagedJobStatus. - FAILED_CONTROLLER, - failure_reason=msg, - callback_func=managed_job_utils.event_callback_func( - job_id=self._job_id, - task_id=task_id, - task=self._dag.tasks[task_id])) + self._update_failed_task_state( + task_id, managed_job_state.ManagedJobStatus.FAILED_CONTROLLER, + msg) finally: # This will set all unfinished tasks to CANCELLING, and will not # affect the jobs in terminal states. @@ -396,6 +376,21 @@ def run(self): managed_job_state.set_cancelled(job_id=self._job_id, callback_func=callback_func) + def _update_failed_task_state( + self, task_id: int, + failure_type: managed_job_state.ManagedJobStatus, + failure_reason: str): + """Update the state of the failed task.""" + managed_job_state.set_failed( + self._job_id, + task_id=task_id, + failure_type=failure_type, + failure_reason=failure_reason, + callback_func=managed_job_utils.event_callback_func( + job_id=self._job_id, + task_id=task_id, + task=self._dag.tasks[task_id])) + def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool): """Runs the controller in a remote process for interruption."""