From bcf40fa878189c545f70061338b1c4faaceb53af Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Thu, 7 Nov 2024 14:19:35 -0800 Subject: [PATCH] Fix `stream_logs` Duplicate Job Handling and TypeError (#4274) fix: multiple `job_id` --- sky/jobs/utils.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 981f6d8286f..896740f6ed6 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -14,7 +14,7 @@ import textwrap import time import typing -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import colorama import filelock @@ -487,6 +487,7 @@ def stream_logs(job_id: Optional[int], job_id = managed_job_state.get_latest_job_id() if job_id is None: return 'No managed job found.' + if controller: if job_id is None: assert job_name is not None @@ -494,16 +495,22 @@ def stream_logs(job_id: Optional[int], # We manually filter the jobs by name, instead of using # get_nonterminal_job_ids_by_name, as with `controller=True`, we # should be able to show the logs for jobs in terminal states. - managed_jobs = list( - filter(lambda job: job['job_name'] == job_name, managed_jobs)) - if len(managed_jobs) == 0: + managed_job_ids: Set[int] = { + job['job_id'] + for job in managed_jobs + if job['job_name'] == job_name + } + if len(managed_job_ids) == 0: return f'No managed job found with name {job_name!r}.' - if len(managed_jobs) > 1: - job_ids_str = ', '.join(job['job_id'] for job in managed_jobs) - raise ValueError( - f'Multiple managed jobs found with name {job_name!r} (Job ' - f'IDs: {job_ids_str}). Please specify the job_id instead.') - job_id = managed_jobs[0]['job_id'] + if len(managed_job_ids) > 1: + job_ids_str = ', '.join( + str(job_id) for job_id in managed_job_ids) + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Multiple managed jobs found with name {job_name!r} ' + f'(Job IDs: {job_ids_str}). Please specify the job_id ' + 'instead.') + job_id = managed_job_ids.pop() assert job_id is not None, (job_id, job_name) # TODO: keep the following code sync with # job_lib.JobLibCodeGen.tail_logs, we do not directly call that function @@ -849,6 +856,7 @@ def stream_logs(cls, from sky.skylet import job_lib, log_lib from sky.skylet import constants + from sky.utils import ux_utils try: from sky.jobs.utils import stream_logs_by_id except ImportError: