Skip to content

Commit

Permalink
[Jobs] Move task retry logic to correct branch in stream_logs_by_id (
Browse files Browse the repository at this point in the history
…#4407)

* fix(jobs): move task retry logic to correct branch in `stream_logs_by_id`

* refactor: use `next` for better readibility

* refactor: add some comments for why it's wait until not RUNNING

* refactor: a pylint's bug

* fix: also include failed_setup

* refactor: a extracted `user_code_failure_states`

* refactor: remove `nonlocal`

* fix: stop logging retry for no-follow

* tests: smoke tests for managed jobs retrying

* format

* format

Co-authored-by: Tian Xia <[email protected]>

* chore: extract yaml file to test_yamls/

* refactor: use `def` rather than lambda

* revert: restore lost test during merging

* style: format

---------

Co-authored-by: Tian Xia <[email protected]>
  • Loading branch information
andylizf and cblmemo authored Jan 15, 2025
1 parent f31c732 commit 2db9ae0
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 49 deletions.
11 changes: 5 additions & 6 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
task.num_nodes == 1):
continue

if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
# Add a grace period before the check of preemption to avoid
# false alarm for job failure.
time.sleep(5)
Expand Down Expand Up @@ -288,9 +286,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
if job_status is not None and not job_status.is_terminal():
# The multi-node job is still running, continue monitoring.
continue
elif job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
elif job_status in job_lib.JobStatus.user_code_failure_states():
# The user code has probably crashed, fail immediately.
end_time = managed_job_utils.get_job_timestamp(
self._backend, cluster_name, get_end_time=True)
Expand Down Expand Up @@ -493,6 +489,7 @@ def start(job_id, dag_yaml, retry_until_up):
"""Start the controller."""
controller_process = None
cancelling = False
task_id = None
try:
_handle_signal(job_id)
# TODO(suquark): In theory, we should make controller process a
Expand All @@ -511,6 +508,7 @@ def start(job_id, dag_yaml, retry_until_up):
except exceptions.ManagedJobUserCancelledError:
dag, _ = _get_dag_and_name(dag_yaml)
task_id, _ = managed_job_state.get_latest_task_id_status(job_id)
assert task_id is not None, job_id
logger.info(
f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}')
managed_job_state.set_cancelling(
Expand Down Expand Up @@ -542,6 +540,7 @@ def start(job_id, dag_yaml, retry_until_up):
logger.info(f'Cluster of managed job {job_id} has been cleaned up.')

if cancelling:
assert task_id is not None, job_id # Since it's set with cancelling
managed_job_state.set_cancelled(
job_id=job_id,
callback_func=managed_job_utils.event_callback_func(
Expand Down
10 changes: 6 additions & 4 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,12 @@ def get_latest_task_id_status(
id_statuses = _get_all_task_ids_statuses(job_id)
if not id_statuses:
return None, None
task_id, status = id_statuses[-1]
for task_id, status in id_statuses:
if not status.is_terminal():
break
task_id, status = next(
((tid, st) for tid, st in id_statuses if not st.is_terminal()),
id_statuses[-1],
)
# Unpack the tuple first, or it triggers a Pylint's bug on recognizing
# the return type.
return task_id, status


Expand Down
83 changes: 51 additions & 32 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,32 +398,15 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
job_statuses = backend.get_job_status(handle, stream_logs=False)
job_status = list(job_statuses.values())[0]
assert job_status is not None, 'No job found.'
assert task_id is not None, job_id

if job_status != job_lib.JobStatus.CANCELLED:
assert task_id is not None, job_id
if task_id < num_tasks - 1 and follow:
# The log for the current job is finished. We need to
# wait until next job to be started.
logger.debug(
f'INFO: Log for the current task ({task_id}) '
'is finished. Waiting for the next task\'s log '
'to be started.')
# Add a newline to avoid the status display below
# removing the last line of the task output.
print()
status_display.update(
ux_utils.spinner_message(
f'Waiting for the next task: {task_id + 1}'))
status_display.start()
original_task_id = task_id
while True:
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if original_task_id != task_id:
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
else:
if not follow:
break

# Logs for retrying failed tasks.
if (job_status
in job_lib.JobStatus.user_code_failure_states()):
task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
Expand All @@ -436,15 +419,51 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()
while True:
_, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if (managed_job_status !=
managed_job_state.ManagedJobStatus.RUNNING):
break

def is_managed_job_status_updated(
status: Optional[managed_job_state.ManagedJobStatus]
) -> bool:
"""Check if local managed job status reflects remote
job failure.
Ensures synchronization between remote cluster
failure detection (JobStatus.FAILED) and controller
retry logic.
"""
return (status !=
managed_job_state.ManagedJobStatus.RUNNING)

while not is_managed_job_status_updated(
managed_job_status :=
managed_job_state.get_status(job_id)):
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue

if task_id == num_tasks - 1:
break

# The log for the current job is finished. We need to
# wait until next job to be started.
logger.debug(
f'INFO: Log for the current task ({task_id}) '
'is finished. Waiting for the next task\'s log '
'to be started.')
# Add a newline to avoid the status display below
# removing the last line of the task output.
print()
status_display.update(
ux_utils.spinner_message(
f'Waiting for the next task: {task_id + 1}'))
status_display.start()
original_task_id = task_id
while True:
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(job_id))
if original_task_id != task_id:
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue

# The job can be cancelled by the user or the controller (when
# the cluster is partially preempted).
logger.debug(
Expand Down
4 changes: 1 addition & 3 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,9 +998,7 @@ def _fetch_job_status(self) -> None:
# Re-raise the exception if it is not preempted.
raise
job_status = list(job_statuses.values())[0]
if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
info.status_property.user_app_failed = True
serve_state.add_or_update_replica(self._service_name,
info.replica_id, info)
Expand Down
12 changes: 8 additions & 4 deletions sky/skylet/job_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sqlite3
import subprocess
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

import colorama
import filelock
Expand Down Expand Up @@ -162,13 +162,17 @@ class JobStatus(enum.Enum):
def nonterminal_statuses(cls) -> List['JobStatus']:
return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING]

def is_terminal(self):
def is_terminal(self) -> bool:
return self not in self.nonterminal_statuses()

def __lt__(self, other):
@classmethod
def user_code_failure_states(cls) -> Sequence['JobStatus']:
return (cls.FAILED, cls.FAILED_SETUP)

def __lt__(self, other: 'JobStatus') -> bool:
return list(JobStatus).index(self) < list(JobStatus).index(other)

def colored_str(self):
def colored_str(self) -> str:
color = _JOB_STATUS_TO_COLOR[self]
return f'{color}{self.value}{colorama.Style.RESET_ALL}'

Expand Down
27 changes: 27 additions & 0 deletions tests/smoke_tests/test_managed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,33 @@ def test_managed_jobs_cancellation_gcp():
smoke_tests_utils.run_one_test(test)


@pytest.mark.managed_jobs
def test_managed_jobs_retry_logs():
"""Test managed job retry logs are properly displayed when a task fails."""
name = smoke_tests_utils.get_cluster_name()
yaml_path = 'tests/test_yamls/test_managed_jobs_retry.yaml'

with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file:
test = smoke_tests_utils.Test(
'managed_jobs_retry_logs',
[
f'sky jobs launch -n {name} {yaml_path} -y -d',
f'sky jobs logs -n {name} | tee {log_file.name}',
# First attempt
f'cat {log_file.name} | grep "Job started. Streaming logs..."',
f'cat {log_file.name} | grep "Job 1 failed"',
# Second attempt
f'cat {log_file.name} | grep "Job started. Streaming logs..." | wc -l | grep 2',
f'cat {log_file.name} | grep "Job 1 failed" | wc -l | grep 2',
# Task 2 is not reached
f'! cat {log_file.name} | grep "Job 2"',
],
f'sky jobs cancel -y -n {name}',
timeout=7 * 60, # 7 mins
)
smoke_tests_utils.run_one_test(test)


# ---------- Testing storage for managed job ----------
@pytest.mark.no_fluidstack # Fluidstack does not support spot instances
@pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances
Expand Down
14 changes: 14 additions & 0 deletions tests/test_yamls/test_managed_jobs_retry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
resources:
cpus: 2+
job_recovery:
max_restarts_on_errors: 1

# Task 1: Always fails
run: |
echo "Task 1 starting"
exit 1
---
# Task 2: Never reached due to Task 1 failure
run: |
echo "Task 2 starting"
exit 0

0 comments on commit 2db9ae0

Please sign in to comment.