Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Move task retry logic to correct branch in stream_logs_by_id #4407

Merged
merged 17 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
task.num_nodes == 1):
continue

if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
# Add a grace period before the check of preemption to avoid
# false alarm for job failure.
time.sleep(5)
Expand Down Expand Up @@ -268,9 +266,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
if job_status is not None and not job_status.is_terminal():
# The multi-node job is still running, continue monitoring.
continue
elif job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
elif job_status in job_lib.JobStatus.user_code_failure_states():
# The user code has probably crashed, fail immediately.
end_time = managed_job_utils.get_job_timestamp(
self._backend, cluster_name, get_end_time=True)
Expand Down Expand Up @@ -473,6 +469,7 @@ def start(job_id, dag_yaml, retry_until_up):
"""Start the controller."""
controller_process = None
cancelling = False
task_id = None
try:
_handle_signal(job_id)
# TODO(suquark): In theory, we should make controller process a
Expand All @@ -491,6 +488,7 @@ def start(job_id, dag_yaml, retry_until_up):
except exceptions.ManagedJobUserCancelledError:
dag, _ = _get_dag_and_name(dag_yaml)
task_id, _ = managed_job_state.get_latest_task_id_status(job_id)
assert task_id is not None, job_id
logger.info(
f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}')
managed_job_state.set_cancelling(
Expand Down Expand Up @@ -522,6 +520,7 @@ def start(job_id, dag_yaml, retry_until_up):
logger.info(f'Cluster of managed job {job_id} has been cleaned up.')

if cancelling:
assert task_id is not None, job_id # Since it's set with cancelling
managed_job_state.set_cancelled(
job_id=job_id,
callback_func=managed_job_utils.event_callback_func(
Expand Down
10 changes: 6 additions & 4 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,12 @@ def get_latest_task_id_status(
id_statuses = _get_all_task_ids_statuses(job_id)
if len(id_statuses) == 0:
return None, None
task_id, status = id_statuses[-1]
for task_id, status in id_statuses:
if not status.is_terminal():
break
task_id, status = next(
((tid, st) for tid, st in id_statuses if not st.is_terminal()),
id_statuses[-1],
)
# Unpack the tuple first, or it triggers a Pylint's bug on recognizing
# the return type.
return task_id, status


Expand Down
82 changes: 49 additions & 33 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import textwrap
import time
import typing
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import colorama
import filelock
Expand Down Expand Up @@ -384,32 +384,15 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
job_statuses = backend.get_job_status(handle, stream_logs=False)
job_status = list(job_statuses.values())[0]
assert job_status is not None, 'No job found.'
assert task_id is not None, job_id

if job_status != job_lib.JobStatus.CANCELLED:
assert task_id is not None, job_id
if task_id < num_tasks - 1 and follow:
# The log for the current job is finished. We need to
# wait until next job to be started.
logger.debug(
f'INFO: Log for the current task ({task_id}) '
'is finished. Waiting for the next task\'s log '
'to be started.')
# Add a newline to avoid the status display below
# removing the last line of the task output.
print()
status_display.update(
ux_utils.spinner_message(
f'Waiting for the next task: {task_id + 1}'))
status_display.start()
original_task_id = task_id
while True:
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if original_task_id != task_id:
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
else:
if not follow:
break

# Logs for retrying failed tasks.
if job_status in job_lib.JobStatus.user_code_failure_states(
):
andylizf marked this conversation as resolved.
Show resolved Hide resolved
task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
Expand All @@ -422,15 +405,48 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()
while True:
_, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if (managed_job_status !=
managed_job_state.ManagedJobStatus.RUNNING):
break

# Check if local managed job status reflects remote job
# failure.
# Ensures synchronization between remote cluster failure
# detection (JobStatus.FAILED) and controller retry
# logic.
is_managed_job_status_updated: Callable[
[Optional[managed_job_state.ManagedJobStatus]],
bool] = (lambda status: status != managed_job_state.
ManagedJobStatus.RUNNING)

while not is_managed_job_status_updated(
managed_job_status :=
managed_job_state.get_status(job_id)):
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue

if task_id == num_tasks - 1:
break

# The log for the current job is finished. We need to
# wait until next job to be started.
logger.debug(
f'INFO: Log for the current task ({task_id}) '
'is finished. Waiting for the next task\'s log '
'to be started.')
# Add a newline to avoid the status display below
# removing the last line of the task output.
print()
status_display.update(
ux_utils.spinner_message(
f'Waiting for the next task: {task_id + 1}'))
status_display.start()
original_task_id = task_id
while True:
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(job_id))
if original_task_id != task_id:
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue

# The job can be cancelled by the user or the controller (when
# the cluster is partially preempted).
logger.debug(
Expand Down
4 changes: 1 addition & 3 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,9 +998,7 @@ def _fetch_job_status(self) -> None:
# Re-raise the exception if it is not preempted.
raise
job_status = list(job_statuses.values())[0]
if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
info.status_property.user_app_failed = True
serve_state.add_or_update_replica(self._service_name,
info.replica_id, info)
Expand Down
12 changes: 8 additions & 4 deletions sky/skylet/job_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sqlite3
import subprocess
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

import colorama
import filelock
Expand Down Expand Up @@ -162,13 +162,17 @@ class JobStatus(enum.Enum):
def nonterminal_statuses(cls) -> List['JobStatus']:
return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING]

def is_terminal(self):
def is_terminal(self) -> bool:
return self not in self.nonterminal_statuses()

def __lt__(self, other):
@classmethod
def user_code_failure_states(cls) -> Sequence['JobStatus']:
return (cls.FAILED, cls.FAILED_SETUP)

def __lt__(self, other: 'JobStatus') -> bool:
return list(JobStatus).index(self) < list(JobStatus).index(other)

def colored_str(self):
def colored_str(self) -> str:
color = _JOB_STATUS_TO_COLOR[self]
return f'{color}{self.value}{colorama.Style.RESET_ALL}'

Expand Down
46 changes: 46 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -3283,6 +3283,52 @@ def test_managed_jobs_recovery_multi_node_gcp():
run_one_test(test)


@pytest.mark.managed_jobs
def test_managed_jobs_retry_logs():
"""Test managed job retry logs are properly displayed when a task fails."""
name = _get_cluster_name()
# Create a temporary YAML file with two tasks - first one fails, second succeeds
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml') as f:
yaml_content = textwrap.dedent("""
resources:
cpus: 2+
job_recovery:
max_restarts_on_errors: 1

# Task 1: Always fails
run: |
echo "Task 1 starting"
exit 1
andylizf marked this conversation as resolved.
Show resolved Hide resolved
---
# Task 2: Never reached due to Task 1 failure
run: |
echo "Task 2 starting"
exit 0
""")
f.write(yaml_content)
f.flush()

with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file:
test = Test(
'managed_jobs_retry_logs',
[
f'sky jobs launch -n {name} {f.name} -y -d',
f'sky jobs logs -n {name} | tee {log_file.name}',
# First attempt
f'cat {log_file.name} | grep "Job started. Streaming logs..."',
f'cat {log_file.name} | grep "Job 1 failed"',
# Second attempt
f'cat {log_file.name} | grep "Job started. Streaming logs..." | wc -l | grep 2',
f'cat {log_file.name} | grep "Job 1 failed" | wc -l | grep 2',
# Task 2 is not reached
f'! cat {log_file.name} | grep "Job 2"',
],
f'sky jobs cancel -y -n {name}',
timeout=7 * 60, # 5 mins
)
run_one_test(test)


@pytest.mark.aws
@pytest.mark.managed_jobs
def test_managed_jobs_cancellation_aws(aws_config_region):
Expand Down