Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Move task retry logic to correct branch in stream_logs_by_id #4407

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
11 changes: 5 additions & 6 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
task.num_nodes == 1):
continue

if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
# Add a grace period before the check of preemption to avoid
# false alarm for job failure.
time.sleep(5)
Expand Down Expand Up @@ -268,9 +266,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
if job_status is not None and not job_status.is_terminal():
# The multi-node job is still running, continue monitoring.
continue
elif job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
elif job_status in job_lib.JobStatus.user_code_failure_states():
# The user code has probably crashed, fail immediately.
end_time = managed_job_utils.get_job_timestamp(
self._backend, cluster_name, get_end_time=True)
Expand Down Expand Up @@ -473,6 +469,7 @@ def start(job_id, dag_yaml, retry_until_up):
"""Start the controller."""
controller_process = None
cancelling = False
task_id = None
try:
_handle_signal(job_id)
# TODO(suquark): In theory, we should make controller process a
Expand All @@ -491,6 +488,7 @@ def start(job_id, dag_yaml, retry_until_up):
except exceptions.ManagedJobUserCancelledError:
dag, _ = _get_dag_and_name(dag_yaml)
task_id, _ = managed_job_state.get_latest_task_id_status(job_id)
assert task_id is not None, job_id
logger.info(
f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}')
managed_job_state.set_cancelling(
Expand Down Expand Up @@ -522,6 +520,7 @@ def start(job_id, dag_yaml, retry_until_up):
logger.info(f'Cluster of managed job {job_id} has been cleaned up.')

if cancelling:
assert task_id is not None, job_id # Since it's set with cancelling
managed_job_state.set_cancelled(
job_id=job_id,
callback_func=managed_job_utils.event_callback_func(
Expand Down
10 changes: 6 additions & 4 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,12 @@ def get_latest_task_id_status(
id_statuses = _get_all_task_ids_statuses(job_id)
if len(id_statuses) == 0:
return None, None
task_id, status = id_statuses[-1]
for task_id, status in id_statuses:
if not status.is_terminal():
break
task_id, status = next(
((tid, st) for tid, st in id_statuses if not st.is_terminal()),
id_statuses[-1],
)
# Unpack the tuple first, or it triggers a Pylint's bug on recognizing
# the return type.
return task_id, status


Expand Down
83 changes: 51 additions & 32 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,32 +384,15 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
job_statuses = backend.get_job_status(handle, stream_logs=False)
job_status = list(job_statuses.values())[0]
assert job_status is not None, 'No job found.'
assert task_id is not None, job_id

if job_status != job_lib.JobStatus.CANCELLED:
assert task_id is not None, job_id
if task_id < num_tasks - 1 and follow:
# The log for the current job is finished. We need to
# wait until next job to be started.
logger.debug(
f'INFO: Log for the current task ({task_id}) '
'is finished. Waiting for the next task\'s log '
'to be started.')
# Add a newline to avoid the status display below
# removing the last line of the task output.
print()
status_display.update(
ux_utils.spinner_message(
f'Waiting for the next task: {task_id + 1}'))
status_display.start()
original_task_id = task_id
while True:
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if original_task_id != task_id:
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
else:
if not follow:
break

# Logs for retrying failed tasks.
if (job_status
in job_lib.JobStatus.user_code_failure_states()):
task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
Expand All @@ -422,15 +405,51 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()
while True:
_, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if (managed_job_status !=
managed_job_state.ManagedJobStatus.RUNNING):
break

def is_managed_job_status_updated(
status: Optional[managed_job_state.ManagedJobStatus]
) -> bool:
"""Check if local managed job status reflects remote
job failure.

Ensures synchronization between remote cluster
failure detection (JobStatus.FAILED) and controller
retry logic.
"""
return (status !=
managed_job_state.ManagedJobStatus.RUNNING)

while not is_managed_job_status_updated(
managed_job_status :=
managed_job_state.get_status(job_id)):
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue

if task_id == num_tasks - 1:
break

# The log for the current job is finished. We need to
# wait until next job to be started.
logger.debug(
f'INFO: Log for the current task ({task_id}) '
'is finished. Waiting for the next task\'s log '
'to be started.')
# Add a newline to avoid the status display below
# removing the last line of the task output.
print()
status_display.update(
ux_utils.spinner_message(
f'Waiting for the next task: {task_id + 1}'))
status_display.start()
original_task_id = task_id
while True:
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(job_id))
if original_task_id != task_id:
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue

# The job can be cancelled by the user or the controller (when
# the cluster is partially preempted).
logger.debug(
Expand Down
4 changes: 1 addition & 3 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,9 +998,7 @@ def _fetch_job_status(self) -> None:
# Re-raise the exception if it is not preempted.
raise
job_status = list(job_statuses.values())[0]
if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
info.status_property.user_app_failed = True
serve_state.add_or_update_replica(self._service_name,
info.replica_id, info)
Expand Down
12 changes: 8 additions & 4 deletions sky/skylet/job_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sqlite3
import subprocess
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

import colorama
import filelock
Expand Down Expand Up @@ -162,13 +162,17 @@ class JobStatus(enum.Enum):
def nonterminal_statuses(cls) -> List['JobStatus']:
return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING]

def is_terminal(self):
def is_terminal(self) -> bool:
return self not in self.nonterminal_statuses()

def __lt__(self, other):
@classmethod
def user_code_failure_states(cls) -> Sequence['JobStatus']:
return (cls.FAILED, cls.FAILED_SETUP)

def __lt__(self, other: 'JobStatus') -> bool:
return list(JobStatus).index(self) < list(JobStatus).index(other)

def colored_str(self):
def colored_str(self) -> str:
color = _JOB_STATUS_TO_COLOR[self]
return f'{color}{self.value}{colorama.Style.RESET_ALL}'

Expand Down
27 changes: 27 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -3283,6 +3283,33 @@ def test_managed_jobs_recovery_multi_node_gcp():
run_one_test(test)


@pytest.mark.managed_jobs
def test_managed_jobs_retry_logs():
"""Test managed job retry logs are properly displayed when a task fails."""
name = _get_cluster_name()
yaml_path = 'tests/test_yamls/test_managed_jobs_retry.yaml'

with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file:
test = Test(
'managed_jobs_retry_logs',
[
f'sky jobs launch -n {name} {yaml_path} -y -d',
f'sky jobs logs -n {name} | tee {log_file.name}',
# First attempt
f'cat {log_file.name} | grep "Job started. Streaming logs..."',
f'cat {log_file.name} | grep "Job 1 failed"',
# Second attempt
f'cat {log_file.name} | grep "Job started. Streaming logs..." | wc -l | grep 2',
f'cat {log_file.name} | grep "Job 1 failed" | wc -l | grep 2',
# Task 2 is not reached
f'! cat {log_file.name} | grep "Job 2"',
],
f'sky jobs cancel -y -n {name}',
timeout=7 * 60, # 7 mins
)
run_one_test(test)


@pytest.mark.aws
@pytest.mark.managed_jobs
def test_managed_jobs_cancellation_aws(aws_config_region):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_yamls/test_managed_jobs_retry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
resources:
cpus: 2+
job_recovery:
max_restarts_on_errors: 1

# Task 1: Always fails
run: |
echo "Task 1 starting"
exit 1
---
# Task 2: Never reached due to Task 1 failure
run: |
echo "Task 2 starting"
exit 0