diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index df32bb51b61..0c67ec6b328 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3329,8 +3329,8 @@ def _exec_code_on_head( handle: CloudVmRayResourceHandle, codegen: str, job_id: int, - task: task_lib.Task, detach_run: bool = False, + managed_job_dag: Optional['dag.Dag'] = None, ) -> None: """Executes generated code on the head node.""" style = colorama.Style @@ -3378,11 +3378,11 @@ def _dump_code_to_file(codegen: str) -> None: _dump_code_to_file(codegen) job_submit_cmd = f'{mkdir_code} && {code}' - if task.managed_job_dag is not None: + if managed_job_dag is not None: # Add the managed job to job queue database. managed_job_codegen = managed_jobs.ManagedJobCodeGen() managed_job_code = managed_job_codegen.set_pending( - job_id, task.managed_job_dag, task.envs['DAG_YAML_PATH']) + job_id, managed_job_dag) # Set the managed job to PENDING state to make sure that this # managed job appears in the `sky jobs queue`, when there are # already 2x vCPU controller processes running on the controller VM, @@ -4896,7 +4896,7 @@ def _execute_task_one_node(self, handle: CloudVmRayResourceHandle, codegen.build(), job_id, detach_run=detach_run, - task=task) + managed_job_dag=task.managed_job_dag) def _execute_task_n_nodes(self, handle: CloudVmRayResourceHandle, task: task_lib.Task, job_id: int, @@ -4952,4 +4952,4 @@ def _execute_task_n_nodes(self, handle: CloudVmRayResourceHandle, codegen.build(), job_id, detach_run=detach_run, - task=task) + managed_job_dag=task.managed_job_dag) diff --git a/sky/cli.py b/sky/cli.py index 1faf0003ff9..edc60d38f01 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -3583,18 +3583,6 @@ def jobs(): is_flag=True, help=('If True, as soon as a job is submitted, return from this call ' 'and do not stream execution logs.')) -@click.option( - '--retry-until-up/--no-retry-until-up', - '-r/-no-r', - default=None, - is_flag=True, - required=False, - help=( - '(Default: True; this flag is deprecated and will be removed in a ' - 'future release.) Whether to retry provisioning infinitely until the ' - 'cluster is up, if unavailability errors are encountered. This ' # pylint: disable=bad-docstring-quotes - 'applies to launching all managed jobs (both the initial and ' - 'any recovery attempts), not the jobs controller.')) @click.option('--yes', '-y', is_flag=True, @@ -3634,7 +3622,6 @@ def jobs_launch( disk_tier: Optional[str], ports: Tuple[str], detach_run: bool, - retry_until_up: bool, yes: bool, fast: bool, ): @@ -3678,19 +3665,6 @@ def jobs_launch( ports=ports, job_recovery=job_recovery, ) - # Deprecation. We set the default behavior to be retry until up, and the - # flag `--retry-until-up` is deprecated. We can remove the flag in 0.8.0. - if retry_until_up is not None: - flag_str = '--retry-until-up' - if not retry_until_up: - flag_str = '--no-retry-until-up' - click.secho( - f'Flag {flag_str} is deprecated and will be removed in a ' - 'future release (managed jobs will always be retried). ' - 'Please file an issue if this does not work for you.', - fg='yellow') - else: - retry_until_up = True if not isinstance(task_or_dag, sky.Dag): assert isinstance(task_or_dag, sky.Task), task_or_dag @@ -3730,11 +3704,7 @@ def jobs_launch( common_utils.check_cluster_name_is_valid(name) - managed_jobs.launch(dag, - name, - detach_run=detach_run, - retry_until_up=retry_until_up, - fast=fast) + managed_jobs.launch(dag, name, detach_run=detach_run, fast=fast) @jobs.command('queue', cls=_DocumentedCodeCommand) diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 0caf2fc2ae1..4cbbac34ae4 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -1,6 +1,5 @@ """Controller: handles the life cycle of a managed job.""" import argparse -import contextlib import multiprocessing import os import pathlib @@ -48,12 +47,10 @@ def _get_dag_and_name(dag_yaml: str) -> Tuple['sky.Dag', str]: class JobsController: """Each jobs controller manages the life cycle of one managed job.""" - def __init__(self, job_id: int, dag_yaml: str, - retry_until_up: bool) -> None: + def __init__(self, job_id: int, dag_yaml: str) -> None: self._job_id = job_id self._dag, self._dag_name = _get_dag_and_name(dag_yaml) logger.info(self._dag) - self._retry_until_up = retry_until_up # TODO(zhwu): this assumes the specific backend. self._backend = cloud_vm_ray_backend.CloudVmRayBackend() @@ -176,32 +173,12 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: cluster_name = managed_job_utils.generate_managed_job_cluster_name( task.name, self._job_id) self._strategy_executor = recovery_strategy.StrategyExecutor.make( - cluster_name, self._backend, task, self._retry_until_up) - - def _schedule_launch(): - if task_id == 0: - # The job is already SUBMITTED in this case, and we do not need to schedule. - return contextlib.nullcontext() - # We need to wait for a scheduling slot. - # Neither the previous task or the current task is SUBMITTED or RUNNING. - return scheduler.schedule_active_job_launch(is_running=False) - - with _schedule_launch(): - # Note: task_id 0 will already be set to submitted by the scheduler. - # However, we only call the callback func here, so keep this. - managed_job_state.set_submitted(self._job_id, - task_id, - callback_func=callback_func) - logger.info( - f'Submitted managed job {self._job_id} (task: {task_id}, name: ' - f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}') - - logger.info('Started monitoring.') - managed_job_state.set_starting( - job_id=self._job_id, - task_id=task_id, - run_timestamp=self._backend.run_timestamp, - submit_time=submitted_at, + cluster_name, self._backend, task, self._job_id) + managed_job_state.set_submitted( + self._job_id, + task_id, + self._backend.run_timestamp, + submitted_at, resources_str=backend_utils.get_task_resources_str( task, is_managed_job=True), specs={ @@ -209,6 +186,16 @@ def _schedule_launch(): self._strategy_executor.max_restarts_on_errors }, callback_func=callback_func) + logger.info( + f'Submitted managed job {self._job_id} (task: {task_id}, name: ' + f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}') + + scheduler.wait_until_launch_okay(self._job_id) + + logger.info('Started monitoring.') + managed_job_state.set_starting(job_id=self._job_id, + task_id=task_id, + callback_func=callback_func) remote_job_submitted_at = self._strategy_executor.launch() assert remote_job_submitted_at is not None, remote_job_submitted_at @@ -216,8 +203,9 @@ def _schedule_launch(): task_id=task_id, start_time=remote_job_submitted_at, callback_func=callback_func) - # Finished sky launch so now maybe something else can launch. - scheduler.schedule_step() + + scheduler.launch_finished(self._job_id) + while True: time.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS) @@ -366,21 +354,18 @@ def _schedule_launch(): # Try to recover the managed jobs, when the cluster is preempted or # failed or the job status is failed to be fetched. - managed_job_state.set_pending_recovery(job_id=self._job_id, - task_id=task_id, - callback_func=callback_func) - # schedule_recovery() will block until we can recover - with scheduler.schedule_active_job_launch(is_running=True): - managed_job_state.set_recovering(job_id=self._job_id, - task_id=task_id, - callback_func=callback_func) + managed_job_state.set_recovering(job_id=self._job_id, + task_id=task_id, + callback_func=callback_func) + # Switch to LAUNCHING schedule_state here, since the entire recovery + # process is somewhat heavy. + scheduler.wait_until_launch_okay(self._job_id) recovered_time = self._strategy_executor.recover() managed_job_state.set_recovered(self._job_id, task_id, recovered_time=recovered_time, callback_func=callback_func) - # Just finished launching, maybe something was waiting to start. - scheduler.schedule_step() + scheduler.launch_finished(self._job_id) def run(self): """Run controller logic and handle exceptions.""" @@ -451,11 +436,11 @@ def _update_failed_task_state( task=self._dag.tasks[task_id])) -def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool): +def _run_controller(job_id: int, dag_yaml: str): """Runs the controller in a remote process for interruption.""" # The controller needs to be instantiated in the remote process, since # the controller is not serializable. - jobs_controller = JobsController(job_id, dag_yaml, retry_until_up) + jobs_controller = JobsController(job_id, dag_yaml) jobs_controller.run() @@ -512,7 +497,7 @@ def _cleanup(job_id: int, dag_yaml: str): backend.teardown_ephemeral_storage(task) -def start(job_id, dag_yaml, retry_until_up): +def start(job_id, dag_yaml): """Start the controller.""" controller_process = None cancelling = False @@ -525,8 +510,7 @@ def start(job_id, dag_yaml, retry_until_up): # So we can only enable daemon after we no longer need to # start daemon processes like Ray. controller_process = multiprocessing.Process(target=_run_controller, - args=(job_id, dag_yaml, - retry_until_up)) + args=(job_id, dag_yaml)) controller_process.start() while controller_process.is_alive(): _handle_signal(job_id) @@ -586,9 +570,7 @@ def start(job_id, dag_yaml, retry_until_up): failure_reason=('Unexpected error occurred. For details, ' f'run: sky jobs logs --controller {job_id}')) - # Run the scheduler to kick off any pending jobs that can now start. - logger.info('Running scheduler') - scheduler.schedule_step() + scheduler.job_done(job_id) if __name__ == '__main__': @@ -597,9 +579,6 @@ def start(job_id, dag_yaml, retry_until_up): required=True, type=int, help='Job id for the controller job.') - parser.add_argument('--retry-until-up', - action='store_true', - help='Retry until the cluster is up.') parser.add_argument('dag_yaml', type=str, help='The path to the user job yaml file.') @@ -607,4 +586,4 @@ def start(job_id, dag_yaml, retry_until_up): # We start process with 'spawn', because 'fork' could result in weird # behaviors; 'spawn' is also cross-platform. multiprocessing.set_start_method('spawn', force=True) - start(args.job_id, args.dag_yaml, args.retry_until_up) + start(args.job_id, args.dag_yaml) diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 9cde3443816..e675e2120d1 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -41,7 +41,6 @@ def launch( name: Optional[str] = None, stream_logs: bool = True, detach_run: bool = False, - retry_until_up: bool = False, fast: bool = False, ) -> None: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. @@ -115,7 +114,6 @@ def launch( 'jobs_controller': controller_name, # Note: actual cluster name will be - 'dag_name': dag.name, - 'retry_until_up': retry_until_up, 'remote_user_config_path': remote_user_config_path, 'modified_catalogs': service_catalog_common.get_modified_catalog_file_mounts(), diff --git a/sky/jobs/recovery_strategy.py b/sky/jobs/recovery_strategy.py index 4fda1a07e08..6c4ad6af7b5 100644 --- a/sky/jobs/recovery_strategy.py +++ b/sky/jobs/recovery_strategy.py @@ -17,6 +17,7 @@ from sky import sky_logging from sky import status_lib from sky.backends import backend_utils +from sky.jobs import scheduler from sky.jobs import utils as managed_job_utils from sky.skylet import job_lib from sky.usage import usage_lib @@ -72,15 +73,14 @@ class StrategyExecutor: RETRY_INIT_GAP_SECONDS = 60 def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool, - max_restarts_on_errors: int) -> None: + task: 'task_lib.Task', max_restarts_on_errors: int, + job_id: int) -> None: """Initialize the strategy executor. Args: cluster_name: The name of the cluster. backend: The backend to use. Only CloudVMRayBackend is supported. task: The task to execute. - retry_until_up: Whether to retry until the cluster is up. """ assert isinstance(backend, backends.CloudVmRayBackend), ( 'Only CloudVMRayBackend is supported.') @@ -88,8 +88,8 @@ def __init__(self, cluster_name: str, backend: 'backends.Backend', self.dag.add(task) self.cluster_name = cluster_name self.backend = backend - self.retry_until_up = retry_until_up self.max_restarts_on_errors = max_restarts_on_errors + self.job_id = job_id self.restart_cnt_on_failure = 0 def __init_subclass__(cls, name: str, default: bool = False): @@ -102,7 +102,7 @@ def __init_subclass__(cls, name: str, default: bool = False): @classmethod def make(cls, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool) -> 'StrategyExecutor': + task: 'task_lib.Task', job_id: int) -> 'StrategyExecutor': """Create a strategy from a task.""" resource_list = list(task.resources) @@ -127,8 +127,9 @@ def make(cls, cluster_name: str, backend: 'backends.Backend', job_recovery_name = job_recovery max_restarts_on_errors = 0 return RECOVERY_STRATEGIES[job_recovery_name](cluster_name, backend, - task, retry_until_up, - max_restarts_on_errors) + task, + max_restarts_on_errors, + job_id) def launch(self) -> float: """Launch the cluster for the first time. @@ -142,10 +143,7 @@ def launch(self) -> float: Raises: Please refer to the docstring of self._launch(). """ - if self.retry_until_up: - job_submit_at = self._launch(max_retry=None) - else: - job_submit_at = self._launch() + job_submit_at = self._launch(max_retry=None) assert job_submit_at is not None return job_submit_at @@ -390,7 +388,11 @@ def _launch(self, gap_seconds = backoff.current_backoff() logger.info('Retrying to launch the cluster in ' f'{gap_seconds:.1f} seconds.') + # Transition to ALIVE during the backoff so that other jobs can + # launch. + scheduler.launch_finished(self.job_id) time.sleep(gap_seconds) + scheduler.wait_until_launch_okay(self.job_id) def should_restart_on_failure(self) -> bool: """Increments counter & checks if job should be restarted on a failure. @@ -411,10 +413,10 @@ class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER', _MAX_RETRY_CNT = 240 # Retry for 4 hours. def __init__(self, cluster_name: str, backend: 'backends.Backend', - task: 'task_lib.Task', retry_until_up: bool, - max_restarts_on_errors: int) -> None: - super().__init__(cluster_name, backend, task, retry_until_up, - max_restarts_on_errors) + task: 'task_lib.Task', max_restarts_on_errors: int, + job_id: int) -> None: + super().__init__(cluster_name, backend, task, max_restarts_on_errors, + job_id) # Note down the cloud/region of the launched cluster, so that we can # first retry in the same cloud/region. (Inside recover() we may not # rely on cluster handle, as it can be None if the cluster is @@ -478,16 +480,11 @@ def recover(self) -> float: raise_on_failure=False) if job_submitted_at is None: # Failed to launch the cluster. - if self.retry_until_up: - gap_seconds = self.RETRY_INIT_GAP_SECONDS - logger.info('Retrying to recover the cluster in ' - f'{gap_seconds:.1f} seconds.') - time.sleep(gap_seconds) - continue - with ux_utils.print_exception_no_traceback(): - raise exceptions.ResourcesUnavailableError( - f'Failed to recover the cluster after retrying ' - f'{self._MAX_RETRY_CNT} times.') + gap_seconds = self.RETRY_INIT_GAP_SECONDS + logger.info('Retrying to recover the cluster in ' + f'{gap_seconds:.1f} seconds.') + time.sleep(gap_seconds) + continue return job_submitted_at @@ -566,15 +563,10 @@ def recover(self) -> float: raise_on_failure=False) if job_submitted_at is None: # Failed to launch the cluster. - if self.retry_until_up: - gap_seconds = self.RETRY_INIT_GAP_SECONDS - logger.info('Retrying to recover the cluster in ' - f'{gap_seconds:.1f} seconds.') - time.sleep(gap_seconds) - continue - with ux_utils.print_exception_no_traceback(): - raise exceptions.ResourcesUnavailableError( - f'Failed to recover the cluster after retrying ' - f'{self._MAX_RETRY_CNT} times.') + gap_seconds = self.RETRY_INIT_GAP_SECONDS + logger.info('Retrying to recover the cluster in ' + f'{gap_seconds:.1f} seconds.') + time.sleep(gap_seconds) + continue return job_submitted_at diff --git a/sky/jobs/scheduler.py b/sky/jobs/scheduler.py index 02fe798e53d..959d1516806 100644 --- a/sky/jobs/scheduler.py +++ b/sky/jobs/scheduler.py @@ -1,13 +1,14 @@ """Scheduler for managed jobs. -Once managed jobs are added as PENDING to the `spot` table, the scheduler is -responsible for the business logic of deciding when they are allowed to start, -and choosing the right one to start. +Once managed jobs are submitted via submit_job, the scheduler is responsible for +the business logic of deciding when they are allowed to start, and choosing the +right one to start. -The scheduler is not its own process - instead, schedule_step() can be called -from any code running on the managed jobs controller to trigger scheduling of -new jobs if possible. This function should be called immediately after any state -change that could result in new jobs being able to start. +The scheduler is not its own process - instead, maybe_start_waiting_jobs() can +be called from any code running on the managed jobs controller instance to +trigger scheduling of new jobs if possible. This function should be called +immediately after any state change that could result in new jobs being able to +start. The scheduling logic limits the number of running jobs according to two limits: 1. The number of jobs that can be launching (that is, STARTING or RECOVERING) at @@ -19,46 +20,52 @@ little once a job starts (just checking its status periodically), the most significant resource it consumes is memory. -There are two ways to interact with the scheduler: -- Any code that could result in new jobs being able to start (that is, it - reduces the number of jobs counting towards one of the above limits) should - call schedule_step(), which will best-effort attempt to schedule new jobs. -- If a running job need to relaunch (recover), it should use schedule_recovery() - to obtain a "slot" in the number of allowed starting jobs. - -Since the scheduling state is determined by the state of jobs in the `spot` -table, we must sychronize all scheduling logic with a global lock. A per-job -lock would be insufficient, since schedule_step() could race with a job -controller trying to start recovery, "double-spending" the open slot. +The state of the scheduler is entirely determined by the schedule_state column +of all the jobs in the job_info table. This column should only be modified via +the functions defined in this file. We will always hold the lock while modifying +this state. See state.ManagedJobScheduleState. + +Nomenclature: +- job: same as managed job (may include multiple tasks) +- launch/launching: launching a cluster (sky.launch) as part of a job +- start/run/schedule: create the job controller process for a job +- alive: a job controller exists + """ -import contextlib +from argparse import ArgumentParser import os -import shlex -import subprocess import time -from typing import Optional, Tuple import filelock import psutil -from sky.jobs import state, constants as managed_job_constants +from sky import sky_logging +from sky.jobs import constants as managed_job_constants +from sky.jobs import state from sky.skylet import constants from sky.utils import subprocess_utils -# The _MANAGED_JOB_SUBMISSION_LOCK should be held whenever a job transitions to -# STARTING or RECOVERING, so that we can ensure correct parallelism control. -_MANAGED_JOB_SUBMISSION_LOCK = '~/.sky/locks/managed_job_submission.lock' -_ACTIVE_JOB_LAUNCH_WAIT_INTERVAL = 0.5 +logger = sky_logging.init_logger('sky.jobs.controller') + +# The _MANAGED_JOB_SCHEDULER_LOCK should be held whenever we are checking the +# parallelism control or updating the schedule_state of any job. +_MANAGED_JOB_SCHEDULER_LOCK = '~/.sky/locks/managed_job_scheduler.lock' +_ALIVE_JOB_LAUNCH_WAIT_INTERVAL = 0.5 -def schedule_step() -> None: - """Determine if any jobs can be launched, and if so, launch them. +def maybe_start_waiting_jobs() -> None: + """Determine if any managed jobs can be launched, and if so, launch them. - This function starts new job controllers for PENDING jobs on a best-effort - basis. That is, if we can start any jobs, we will, but if not, we will exit - (almost) immediately. It's expected that if some PENDING jobs cannot be - started now (either because the lock is held, or because there are not + For newly submitted jobs, this includes starting the job controller + process. For jobs that are already alive but are waiting to launch a new + task or recover, just update the state of the job to indicate that the + launch can proceed. + + This function transitions jobs into LAUNCHING on a best-effort basis. That + is, if we can start any jobs, we will, but if not, we will exit (almost) + immediately. It's expected that if some WAITING or ALIVE_WAITING jobs cannot + be started now (either because the lock is held, or because there are not enough resources), another call to schedule_step() will be made whenever that situation is resolved. (If the lock is held, the lock holder should start the jobs. If there aren't enough resources, the next controller to @@ -69,48 +76,67 @@ def schedule_step() -> None: the jobs controller. New job controller processes will be detached from the current process and there will not be a parent/child relationship - see launch_new_process_tree for more. + """ try: # We must use a global lock rather than a per-job lock to ensure correct - # parallelism control. - # The lock is not held while submitting jobs, so we use timeout=1 as a - # best-effort protection against the race between a previous - # schedule_step() releasing the lock and a job submission. Since we call - # schedule_step() after submitting the job this should capture - # essentially all cases. - # (In the worst case, the skylet event should schedule the job.) - with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SUBMISSION_LOCK), - timeout=1): - while _can_schedule(): - maybe_next_job = _get_next_job_to_start() + # parallelism control. If we cannot obtain the lock, exit immediately. + # The current lock holder is expected to launch any jobs it can before + # releasing the lock. + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK), + blocking=False): + while True: + maybe_next_job = state.get_waiting_job() if maybe_next_job is None: # Nothing left to schedule, break from scheduling loop break - managed_job_id, dag_yaml_path = maybe_next_job - run_cmd = ( - f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV};' - f'python -u -m sky.jobs.controller {dag_yaml_path} --job-id {managed_job_id}' - ) - - state.set_submitted( - job_id=managed_job_id, - # schedule_step() only looks at the first task of each job. - task_id=0, - # We must call set_submitted now so that this job is counted - # as launching by future scheduler runs, but we don't have - # the callback_func here. We will call set_submitted again - # in the jobs controller, which will call the callback_func. - callback_func=lambda _: None) - - logs_dir = os.path.expanduser( - managed_job_constants.JOBS_CONTROLLER_LOGS_DIR) - os.makedirs(logs_dir, exist_ok=True) - log_path = os.path.join(logs_dir, f'{managed_job_id}.log') - - pid = subprocess_utils.launch_new_process_tree( - run_cmd, log_output=log_path) - state.set_job_controller_pid(managed_job_id, pid) + current_state = maybe_next_job['schedule_state'] + + assert current_state in ( + state.ManagedJobScheduleState.ALIVE_WAITING, + state.ManagedJobScheduleState.WAITING), maybe_next_job + + # Note: we expect to get ALIVE_WAITING jobs before WAITING jobs, + # since they will have been submitted and therefore started + # first. The requirements to launch in an alive job are more + # lenient, so there is no way that we wouldn't be able to launch + # an ALIVE_WAITING job, but we would be able to launch a WAITING + # job. + if current_state == state.ManagedJobScheduleState.ALIVE_WAITING: + if not _can_lauch_in_alive_job(): + # Can't schedule anything, break from scheduling loop. + break + elif current_state == state.ManagedJobScheduleState.WAITING: + if not _can_start_new_job(): + # Can't schedule anything, break from scheduling loop. + break + + logger.info(f'Scheduling job {maybe_next_job["job_id"]}') + state.scheduler_set_launching(maybe_next_job['job_id'], + current_state) + + if current_state == state.ManagedJobScheduleState.WAITING: + # The job controller has not been started yet. We must start + # it. + + job_id = maybe_next_job['job_id'] + dag_yaml_path = maybe_next_job['dag_yaml_path'] + + run_cmd = (f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV};' + 'python -u -m sky.jobs.controller ' + f'{dag_yaml_path} --job-id {job_id}') + + logs_dir = os.path.expanduser( + managed_job_constants.JOBS_CONTROLLER_LOGS_DIR) + os.makedirs(logs_dir, exist_ok=True) + log_path = os.path.join(logs_dir, f'{job_id}.log') + + pid = subprocess_utils.launch_new_process_tree( + run_cmd, log_output=log_path) + state.set_job_controller_pid(job_id, pid) + + logger.info(f'Job {job_id} started with pid {pid}') except filelock.Timeout: # If we can't get the lock, just exit. The process holding the lock @@ -118,65 +144,74 @@ def schedule_step() -> None: pass -@contextlib.contextmanager -def schedule_active_job_launch(is_running: bool): - """Block until we can trigger a launch as part of an ongoing job. +def submit_job(job_id: int, dag_yaml_path: str) -> None: + """Submit an existing job to the scheduler. - schedule_step() will only schedule the first launch of a job. There are two - scenarios where we may need to call sky.launch again during the course of a - job controller: + This should be called after a job is created in the `spot` table as + PENDING. It will tell the scheduler to try and start the job controller, if + there are resources available. It may block to acquire the lock, so it + should not be on the critical path for `sky jobs launch -d`. + """ + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_waiting(job_id, dag_yaml_path) + maybe_start_waiting_jobs() + + +def launch_finished(job_id: int) -> None: + """Transition a job from LAUNCHING to ALIVE. + + This should be called after sky.launch finishes, whether or not it was + successful. This may cause other jobs to begin launching. + + To transition back to LAUNCHING, use wait_until_launch_okay. + """ + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_alive(job_id) + maybe_start_waiting_jobs() + + +def wait_until_launch_okay(job_id: int) -> None: + """Block until we can start a launch as part of an ongoing job. + + If a job is ongoing (ALIVE schedule_state), there are two scenarios where we + may need to call sky.launch again during the course of a job controller: - for tasks after the first task - for recovery - We must hold the lock before transitioning to STARTING or RECOVERING, for - these cases, and we have to make sure there are actually available - resources. So, this context manager will block until we have the launch and - there are available resources to schedule. + This function will mark the job as ALIVE_WAITING, which indicates to the + scheduler that it wants to transition back to LAUNCHING. Then, it will wait + until the scheduler transitions the job state. + """ + if (state.get_job_schedule_state(job_id) == + state.ManagedJobScheduleState.LAUNCHING): + # If we're already in LAUNCHING schedule_state, we don't need to wait. + # This may be the case for the first launch of a job. + return + + _set_alive_waiting(job_id) - The context manager should NOT be held for the actual sky.launch - we just - need to hold it while we transition the job state (to STARTING or - RECOVERING). + while (state.get_job_schedule_state(job_id) != + state.ManagedJobScheduleState.LAUNCHING): + time.sleep(_ALIVE_JOB_LAUNCH_WAIT_INTERVAL) + + +def job_done(job_id: int, idempotent: bool = False) -> None: + """Transition a job to DONE. + + If idempotent is True, this will not raise an error if the job is already + DONE. + + The job could be in any terminal ManagedJobStatus. However, once DONE, it + should never transition back to another state. - This function does NOT guarantee any kind of ordering if multiple processes - call it in parallel. This is why we do not use it for the first task on each - job. """ + if idempotent and (state.get_job_schedule_state(job_id) + == state.ManagedJobScheduleState.DONE): + return - def _ready_to_start(): - # If this is being run as part of a job that is already RUNNING, ignore - # the job parallelism. Comparing to state.get_num_alive_jobs() - 1 is - # deadlock-prone if we somehow have more than the max number of jobs - # running (e.g. if 2 jobs are running and _get_job_parallelism() == 1). - if not is_running and state.get_num_alive_jobs( - ) >= _get_job_parallelism(): - return False - if state.get_num_launching_jobs() >= _get_launch_parallelism(): - return False - return True - - # Ideally, we should prioritize launches that are part of ongoing jobs over - # scheduling new jobs. Therefore we grab the lock and wait until a slot - # opens. There is only one lock, so there is no deadlock potential from that - # perspective. We could deadlock if this is called as part of a job that is - # currently STARTING, so don't do that. This could spin forever if jobs get - # stuck as STARTING or RECOVERING, but the same risk exists for the normal - # scheduler. - with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SUBMISSION_LOCK)): - # Only check launch parallelism, since this should be called as part of - # a job that is already RUNNING. - while not _ready_to_start(): - time.sleep(_ACTIVE_JOB_LAUNCH_WAIT_INTERVAL) - # We can launch now. yield to user code, which should update the state - # of the job. DON'T ACTUALLY LAUNCH HERE, WE'RE STILL HOLDING THE LOCK! - yield - - # Release the lock. Wait for more than the lock poll_interval (0.05) in case - # other jobs are waiting to recover - they should get the lock first. - time.sleep(0.1) - - # Since we were holding the lock, other schedule_step() calls may have early - # exited. It's up to us to spawn those controllers. - schedule_step() + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_done(job_id, idempotent) + maybe_start_waiting_jobs() def _get_job_parallelism() -> int: @@ -191,28 +226,34 @@ def _get_launch_parallelism() -> int: return cpus * 4 if cpus is not None else 1 -def _can_schedule() -> bool: +def _can_start_new_job() -> bool: launching_jobs = state.get_num_launching_jobs() alive_jobs = state.get_num_alive_jobs() - print(launching_jobs, alive_jobs) - print(_get_launch_parallelism(), _get_job_parallelism()) return launching_jobs < _get_launch_parallelism( ) and alive_jobs < _get_job_parallelism() -def _get_next_job_to_start() -> Optional[Tuple[int, str]]: - """Returns tuple of job_id, yaml path""" - return state.get_first_pending_job_id_and_yaml() +def _can_lauch_in_alive_job() -> bool: + launching_jobs = state.get_num_launching_jobs() + return launching_jobs < _get_launch_parallelism() -# def _get_pending_job_ids(self) -> List[int]: -# """Returns the job ids in the pending jobs table +def _set_alive_waiting(job_id: int) -> None: + """Should use wait_until_launch_okay() to transition to this state.""" + with filelock.FileLock(os.path.expanduser(_MANAGED_JOB_SCHEDULER_LOCK)): + state.scheduler_set_alive_waiting(job_id) + maybe_start_waiting_jobs() -# The information contains job_id, run command, submit time, -# creation time. -# """ -# raise NotImplementedError if __name__ == '__main__': - print("main") - schedule_step() + parser = ArgumentParser() + parser.add_argument('--job-id', + required=True, + type=int, + help='Job id for the controller job.') + parser.add_argument('dag_yaml', + type=str, + help='The path to the user job yaml file.') + args = parser.parse_args() + submit_job(args.job_id, args.dag_yaml) + maybe_start_waiting_jobs() diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 2ab773a4bc4..de650106f26 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -67,9 +67,7 @@ def create_table(cursor, conn): task_id INTEGER DEFAULT 0, task_name TEXT, specs TEXT, - local_log_file TEXT DEFAULT NULL, - pid INTEGER DEFAULT NULL, - dag_yaml_path TEXT)""") + local_log_file TEXT DEFAULT NULL)""") conn.commit() db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT') @@ -109,17 +107,25 @@ def create_table(cursor, conn): db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file', 'TEXT DEFAULT NULL') - db_utils.add_column_to_table(cursor, conn, 'spot', 'pid', - 'INTEGER DEFAULT NULL') - - db_utils.add_column_to_table(cursor, conn, 'spot', 'dag_yaml_path', 'TEXT') - - # `job_info` contains the mapping from job_id to the job_name. - # In the future, it may contain more information about each job. + # `job_info` contains the mapping from job_id to the job_name, as well as + # information used by the scheduler. cursor.execute("""\ CREATE TABLE IF NOT EXISTS job_info ( spot_job_id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT)""") + name TEXT, + schedule_state TEXT, + pid INTEGER DEFAULT NULL, + dag_yaml_path TEXT)""") + + db_utils.add_column_to_table(cursor, conn, 'job_info', 'schedule_state', + 'TEXT') + + db_utils.add_column_to_table(cursor, conn, 'job_info', 'pid', + 'INTEGER DEFAULT NULL') + + db_utils.add_column_to_table(cursor, conn, 'job_info', 'dag_yaml_path', + 'TEXT') + conn.commit() @@ -168,11 +174,12 @@ def _get_db_path() -> str: 'task_name', 'specs', 'local_log_file', - 'pid', - 'dag_yaml_path', # columns from the job_info table '_job_info_job_id', # This should be the same as job_id 'job_name', + 'schedule_state', + 'pid', + 'dag_yaml_path', ] @@ -220,9 +227,6 @@ class ManagedJobStatus(enum.Enum): # The start_at timestamp of the managed job in the 'spot' table will be set # to the time when the job is firstly transitioned to RUNNING. RUNNING = 'RUNNING' - # PENDING_RECOVERY: The cluster is preempted, and the controller process is - # waiting (e.g. for available resources) to recover the cluster. - PENDING_RECOVERY = 'PENDING_RECOVERY' # RECOVERING: The cluster is preempted, and the controller process is # recovering the cluster (relaunching/failover). RECOVERING = 'RECOVERING' @@ -294,7 +298,6 @@ def failure_statuses(cls) -> List['ManagedJobStatus']: ManagedJobStatus.SUBMITTED: colorama.Fore.BLUE, ManagedJobStatus.STARTING: colorama.Fore.BLUE, ManagedJobStatus.RUNNING: colorama.Fore.GREEN, - ManagedJobStatus.PENDING_RECOVERY: colorama.Fore.CYAN, ManagedJobStatus.RECOVERING: colorama.Fore.CYAN, ManagedJobStatus.SUCCEEDED: colorama.Fore.GREEN, ManagedJobStatus.FAILED: colorama.Fore.RED, @@ -307,53 +310,77 @@ def failure_statuses(cls) -> List['ManagedJobStatus']: } +class ManagedJobScheduleState(enum.Enum): + """Captures the state of the job from the scheduler's perspective. + + A newly created job will be INACTIVE. The following transitions are valid: + - INACTIVE -> WAITING: The job is "submitted" to the scheduler, and its job + controller can be started. + - WAITING -> LAUNCHING: The job controller is starting by the scheduler and + may proceed to sky.launch. + - LAUNCHING -> ALIVE: The launch attempt was completed. It may have + succeeded or failed. The job controller is not allowed to sky.launch again + without transitioning to ALIVE_WAITING and then LAUNCHING. + - ALIVE -> ALIVE_WAITING: The job controller wants to sky.launch again, + either for recovery or to launch a subsequent task. + - ALIVE_WAITING -> LAUNCHING: The scheduler has determined that the job + controller may launch again. + - LAUNCHING, ALIVE, or ALIVE_WAITING -> DONE: The job controller is exiting + and the job is in some terminal status. In the future it may be possible + to transition directly from WAITING or even INACTIVE to DONE if the job is + cancelled. + + There is no well-defined mapping from the managed job status to schedule + state or vice versa. (In fact, schedule state is defined on the job and + status on the task.) + """ + # The job should be ignored by the scheduler. + INACTIVE = 'INACTIVE' + # The job is waiting to transition to LAUNCHING. The scheduler should try to + # transition it. + WAITING = 'WAITING' + # The job is already alive, but wants to transition back to LAUNCHING, + # e.g. for recovery, or launching later tasks in the DAG. The scheduler + # should try to transition it to LAUNCHING. + ALIVE_WAITING = 'ALIVE_WAITING' + # The job is running sky.launch, or soon will, using a limited number of + # allowed launch slots. + LAUNCHING = 'LAUNCHING' + # The controller for the job is running, but it's not currently launching. + ALIVE = 'ALIVE' + # The job is in a terminal state. (Not necessarily SUCCEEDED.) + DONE = 'DONE' + + # === Status transition functions === -def set_job_name(job_id: int, name: str): +def set_job_info(job_id: int, name: str): with db_utils.safe_cursor(_DB_PATH) as cursor: cursor.execute( """\ INSERT INTO job_info - (spot_job_id, name) - VALUES (?, ?)""", (job_id, name)) + (spot_job_id, name, schedule_state) + VALUES (?, ?, ?)""", + (job_id, name, ManagedJobScheduleState.INACTIVE.value)) -def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str, - dag_yaml_path: str): +def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str): """Set the task to pending state.""" with db_utils.safe_cursor(_DB_PATH) as cursor: cursor.execute( """\ INSERT INTO spot - (spot_job_id, task_id, task_name, resources, status, dag_yaml_path) - VALUES (?, ?, ?, ?, ?, ?)""", + (spot_job_id, task_id, task_name, resources, status) + VALUES (?, ?, ?, ?, ?)""", (job_id, task_id, task_name, resources_str, - ManagedJobStatus.PENDING.value, dag_yaml_path)) + ManagedJobStatus.PENDING.value)) -def set_submitted(job_id: int, task_id: int, callback_func: CallbackType): +def set_submitted(job_id: int, task_id: int, run_timestamp: str, + submit_time: float, resources_str: str, + specs: Dict[str, Union[str, + int]], callback_func: CallbackType): """Set the task to submitted. - Args: - job_id: The managed job ID. - task_id: The task ID. - callback_func: The callback function. - """ - with db_utils.safe_cursor(_DB_PATH) as cursor: - cursor.execute( - """\ - UPDATE spot SET - status=(?) - WHERE spot_job_id=(?) AND - task_id=(?)""", (ManagedJobStatus.SUBMITTED.value, job_id, task_id)) - callback_func('SUBMITTED') - - -def set_starting(job_id: int, task_id: int, run_timestamp: str, - submit_time: float, resources_str: str, - specs: Dict[str, Union[str, - int]], callback_func: CallbackType): - """Set the task to starting state. - Args: job_id: The managed job ID. task_id: The task ID. @@ -369,20 +396,31 @@ def set_starting(job_id: int, task_id: int, run_timestamp: str, # make it easier to find them based on one of the values. # Also, using the earlier timestamp should be closer to the term # `submit_at`, which represents the time the managed task is submitted. - logger.info('Launching the spot cluster...') with db_utils.safe_cursor(_DB_PATH) as cursor: cursor.execute( """\ UPDATE spot SET - status=(?), resources=(?), submitted_at=(?), + status=(?), run_timestamp=(?), specs=(?) WHERE spot_job_id=(?) AND task_id=(?)""", - (ManagedJobStatus.STARTING.value, resources_str, submit_time, + (resources_str, submit_time, ManagedJobStatus.SUBMITTED.value, run_timestamp, json.dumps(specs), job_id, task_id)) + callback_func('SUBMITTED') + + +def set_starting(job_id: int, task_id: int, callback_func: CallbackType): + """Set the task to starting state.""" + logger.info('Launching the spot cluster...') + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + """\ + UPDATE spot SET status=(?) + WHERE spot_job_id=(?) AND + task_id=(?)""", (ManagedJobStatus.STARTING.value, job_id, task_id)) callback_func('STARTING') @@ -407,21 +445,6 @@ def set_started(job_id: int, task_id: int, start_time: float, callback_func('STARTED') -def set_pending_recovery(job_id: int, task_id: int, - callback_func: CallbackType): - """Set the task to pending recovery state, and update the job duration.""" - logger.info('=== Recovering... ===') - with db_utils.safe_cursor(_DB_PATH) as cursor: - cursor.execute( - """\ - UPDATE spot SET - status=(?), job_duration=job_duration+(?)-last_recovered_at - WHERE spot_job_id=(?) AND - task_id=(?)""", (ManagedJobStatus.PENDING_RECOVERY.value, - time.time(), job_id, task_id)) - callback_func('PENDING_RECOVERY') - - def set_recovering(job_id: int, task_id: int, callback_func: CallbackType): """Set the task to recovering state.""" logger.info('=== Recovering... ===') @@ -498,13 +521,10 @@ def set_failed( 'SELECT status FROM spot WHERE spot_job_id=(?)', (job_id,)).fetchone()[0] previous_status = ManagedJobStatus(previous_status) - if previous_status in [ - ManagedJobStatus.PENDING_RECOVERY, ManagedJobStatus.RECOVERING - ]: - # If the job is recovering, we should set the - # last_recovered_at to the end_time, so that the - # end_at - last_recovered_at will not be affect the job duration - # calculation. + if previous_status == ManagedJobStatus.RECOVERING: + # If the job is recovering, we should set the last_recovered_at to + # the end_time, so that the end_at - last_recovered_at will not be + # affect the job duration calculation. fields_to_set['last_recovered_at'] = end_time set_str = ', '.join(f'{k}=(?)' for k in fields_to_set) task_str = '' if task_id is None else f' AND task_id={task_id}' @@ -570,12 +590,6 @@ def set_local_log_file(job_id: int, task_id: Optional[int], f'WHERE {filter_str}', filter_args) -def set_job_controller_pid(job_id: int, pid: int): - with db_utils.safe_cursor(_DB_PATH) as cursor: - cursor.execute( - f'UPDATE spot SET pid={pid} WHERE spot_job_id={job_id!r}') - - # ======== utility functions ======== def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]: """Get non-terminal job ids by name.""" @@ -689,6 +703,8 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]: for row in rows: job_dict = dict(zip(columns, row)) job_dict['status'] = ManagedJobStatus(job_dict['status']) + job_dict['schedule_state'] = ManagedJobScheduleState( + job_dict['schedule_state']) if job_dict['job_name'] is None: job_dict['job_name'] = job_dict['task_name'] jobs.append(job_dict) @@ -742,40 +758,124 @@ def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]: return local_log_file[-1] if local_log_file else None +# === Scheduler state functions === +# Only the scheduler should call these functions. They may require holding the +# scheduler lock to work correctly. + + +def scheduler_set_waiting(job_id: int, dag_yaml_path: str) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?), dag_yaml_path = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.WAITING.value, dag_yaml_path, job_id, + ManagedJobScheduleState.INACTIVE.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_launching(job_id: int, + current_state: ManagedJobScheduleState) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.LAUNCHING.value, job_id, + current_state.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_alive(job_id: int) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.ALIVE.value, job_id, + ManagedJobScheduleState.LAUNCHING.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_alive_waiting(job_id: int) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state = (?)', + (ManagedJobScheduleState.ALIVE_WAITING.value, job_id, + ManagedJobScheduleState.ALIVE.value)).rowcount + assert updated_count == 1, (job_id, updated_count) + + +def scheduler_set_done(job_id: int, idempotent: bool = False) -> None: + """Do not call without holding the scheduler lock.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + updated_count = cursor.execute( + 'UPDATE job_info SET ' + 'schedule_state = (?) ' + 'WHERE spot_job_id = (?) AND schedule_state != (?)', + (ManagedJobScheduleState.DONE.value, job_id, + ManagedJobScheduleState.DONE.value)).rowcount + if not idempotent: + assert updated_count == 1, (job_id, updated_count) + + +def set_job_controller_pid(job_id: int, pid: int): + with db_utils.safe_cursor(_DB_PATH) as cursor: + # XXX cooperc + cursor.execute( + f'UPDATE job_info SET pid={pid} WHERE spot_job_id={job_id!r}') + + +def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState: + with db_utils.safe_cursor(_DB_PATH) as cursor: + state = cursor.execute( + 'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)', + (job_id,)).fetchone()[0] + return ManagedJobScheduleState(state) + + def get_num_launching_jobs() -> int: with db_utils.safe_cursor(_DB_PATH) as cursor: return cursor.execute( 'SELECT COUNT(*) ' - 'FROM spot ' - 'WHERE status IN (?, ?, ?)', ( - ManagedJobStatus.SUBMITTED.value, - ManagedJobStatus.STARTING.value, - ManagedJobStatus.RECOVERING.value, - )).fetchone()[0] + 'FROM job_info ' + 'WHERE schedule_state = (?)', + (ManagedJobScheduleState.LAUNCHING.value,)).fetchone()[0] def get_num_alive_jobs() -> int: - terminal_status_fields = ', '.join( - ['?'] * len(ManagedJobStatus.terminal_statuses())) - terminal_status_field_values = [ - status.value for status in ManagedJobStatus.terminal_statuses() - ] with db_utils.safe_cursor(_DB_PATH) as cursor: return cursor.execute( 'SELECT COUNT(*) ' - 'FROM spot ' - f'WHERE status NOT IN (?, {terminal_status_fields})', - (ManagedJobStatus.PENDING.value, - *terminal_status_field_values)).fetchone()[0] + 'FROM job_info ' + 'WHERE schedule_state IN (?, ?, ?)', + (ManagedJobScheduleState.ALIVE_WAITING.value, + ManagedJobScheduleState.LAUNCHING.value, + ManagedJobScheduleState.ALIVE.value)).fetchone()[0] + +def get_waiting_job() -> Optional[Dict[str, Any]]: + """Get the next job that should transition to LAUNCHING. -def get_first_pending_job_id_and_yaml() -> Optional[Tuple[int, str]]: + Backwards compatibility note: jobs submitted before #4485 will have no + schedule_state and will be ignored by this SQL query. + """ with db_utils.safe_cursor(_DB_PATH) as cursor: - # Only consider the first task in the job dag. If it is not pending, the controller for the whole job has already been started. - return cursor.execute( - 'SELECT spot_job_id, dag_yaml_path ' - 'FROM spot ' - 'WHERE task_id = 0 ' - 'AND status = (?) ' + row = cursor.execute( + 'SELECT spot_job_id, schedule_state, dag_yaml_path ' + 'FROM job_info ' + 'WHERE schedule_state in (?, ?) ' 'ORDER BY spot_job_id LIMIT 1', - (ManagedJobStatus.PENDING.value,)).fetchone() + (ManagedJobScheduleState.WAITING.value, + ManagedJobScheduleState.ALIVE_WAITING.value)).fetchone() + return { + 'job_id': row[0], + 'schedule_state': ManagedJobScheduleState(row[1]), + 'dag_yaml_path': row[2], + } if row is not None else None diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index b1c63b5fc6d..ec4735e1887 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -27,6 +27,7 @@ from sky import sky_logging from sky.backends import backend_utils from sky.jobs import constants as managed_job_constants +from sky.jobs import scheduler from sky.jobs import state as managed_job_state from sky.skylet import constants from sky.skylet import job_lib @@ -128,7 +129,8 @@ def update_managed_job_status(job_id: Optional[int] = None): for job_id_ in job_ids: tasks = managed_job_state.get_managed_jobs(job_id_) - if tasks[0]['dag_yaml_path'] is None: + schedule_state = tasks[0]['schedule_state'] + if schedule_state is None: # Backwards compatibility: this job was submitted when ray was still # used for managing the parallelism of job controllers. This code # path can be removed before 0.11.0. @@ -143,16 +145,17 @@ def update_managed_job_status(job_id: Optional[int] = None): else: pid = tasks[0]['pid'] if pid is None: - first_task_status: managed_job_state.ManagedJobStatus = tasks[ - 0]['status'] - if first_task_status == managed_job_state.ManagedJobStatus.PENDING: + if schedule_state in ( + managed_job_state.ManagedJobScheduleState.INACTIVE, + managed_job_state.ManagedJobScheduleState.WAITING): # Job has not been scheduled yet. continue - elif first_task_status == managed_job_state.ManagedJobStatus.SUBMITTED: - # This should only be the case for a very short period of time - # between marking the job as submitted and writing the launched - # controller process pid back to the database (see - # scheduler.schedule_step). + elif (schedule_state == + managed_job_state.ManagedJobScheduleState.LAUNCHING): + # This should only be the case for a very short period of + # time between marking the job as submitted and writing the + # launched controller process pid back to the database (see + # scheduler.maybe_start_waiting_jobs). # TODO(cooperc): Find a way to detect if we get stuck in # this state. continue @@ -200,6 +203,7 @@ def update_managed_job_status(job_id: Optional[int] = None): failure_reason= 'Controller process has exited abnormally. For more details, run: ' f'sky jobs logs --controller {job_id_}') + scheduler.job_done(job_id_, idempotent=True) def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str, @@ -587,7 +591,10 @@ def stream_logs(job_id: Optional[int], # We know that the job is present in the state table because of # earlier checks, so it should not be None. assert job_status is not None, (job_id, job_name) - if job_status.is_terminal(): + # We shouldn't count CANCELLING as terminal here, the controller is + # still cleaning up. + if (job_status.is_terminal() and not job_status + == managed_job_state.ManagedJobStatus.CANCELLING): # Don't keep waiting. If the log file is not created by this # point, it never will be. This job may have been submitted # using an old version that did not create the log file, so this @@ -661,6 +668,7 @@ def dump_managed_job_queue() -> str: job_duration = 0 job['job_duration'] = job_duration job['status'] = job['status'].value + job['schedule_state'] = job['schedule_state'].value cluster_name = generate_managed_job_cluster_name( job['task_name'], job['job_id']) @@ -762,11 +770,18 @@ def get_hash(task): status_counts[managed_job_status.value] += 1 columns = [ - 'ID', 'TASK', 'NAME', 'RESOURCES', 'SUBMITTED', 'TOT. DURATION', - 'JOB DURATION', '#RECOVERIES', 'STATUS' + 'ID', + 'TASK', + 'NAME', + 'RESOURCES', + 'SUBMITTED', + 'TOT. DURATION', + 'JOB DURATION', + '#RECOVERIES', + 'STATUS', ] if show_all: - columns += ['STARTED', 'CLUSTER', 'REGION', 'FAILURE'] + columns += ['STARTED', 'CLUSTER', 'REGION', 'FAILURE', 'SCHED. STATE'] if tasks_have_user: columns.insert(0, 'USER') job_table = log_utils.create_table(columns) @@ -834,11 +849,13 @@ def get_hash(task): status_str, ] if show_all: + schedule_state = job_tasks[0]['schedule_state'] job_values.extend([ '-', '-', '-', failure_reason if failure_reason is not None else '-', + schedule_state, ]) if tasks_have_user: job_values.insert(0, job_tasks[0].get('user', '-')) @@ -866,6 +883,10 @@ def get_hash(task): task['status'].colored_str(), ] if show_all: + # schedule_state is only set at the job level, so if we have + # more than one task, only display on the aggregated row. + schedule_state = task['schedule_state'] if (len(job_tasks) + == 1) else '-' values.extend([ # STARTED log_utils.readable_time_duration(task['start_at']), @@ -873,6 +894,7 @@ def get_hash(task): task['region'], task['failure_reason'] if task['failure_reason'] is not None else '-', + schedule_state, ]) if tasks_have_user: values.insert(0, task.get('user', '-')) @@ -979,20 +1001,18 @@ def stream_logs(cls, return cls._build(code) @classmethod - def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag', - dag_yaml_path: str) -> str: + def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag') -> str: dag_name = managed_job_dag.name # Add the managed job to queue table. code = textwrap.dedent(f"""\ - managed_job_state.set_job_name({job_id}, {dag_name!r}) + managed_job_state.set_job_info({job_id}, {dag_name!r}) """) for task_id, task in enumerate(managed_job_dag.tasks): resources_str = backend_utils.get_task_resources_str( task, is_managed_job=True) code += textwrap.dedent(f"""\ managed_job_state.set_pending({job_id}, {task_id}, - {task.name!r}, {resources_str!r}, - {dag_yaml_path!r}) + {task.name!r}, {resources_str!r}) """) return cls._build(code) diff --git a/sky/skylet/events.py b/sky/skylet/events.py index 13f2818d943..b0c141baa8a 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -75,7 +75,7 @@ class ManagedJobEvent(SkyletEvent): def _run(self): managed_job_utils.update_managed_job_status() - managed_job_scheduler.schedule_step() + managed_job_scheduler.maybe_start_waiting_jobs() class ServiceUpdateEvent(SkyletEvent): diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index 27929a05a70..f222b7f42a7 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -10,7 +10,6 @@ import shlex import signal import sqlite3 -import subprocess import time from typing import Any, Dict, List, Optional diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index ca15252f413..b61f2afe6f9 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -33,13 +33,15 @@ setup: | run: | {{ sky_activate_python_env }} - # Call schedule_step to kick off the job if there are sufficient available resources. - # Note: job is submitted by CloudVmRayBackend._exec_code_on_head() calling managed_job_codegen.set_pending(). - python -u -m sky.jobs.scheduler + # Submit the job to the scheduler. + # Note: The job is already in the `spot` table, marked as PENDING. + # CloudVmRayBackend._exec_code_on_head() calls + # managed_job_codegen.set_pending() before we get here. + python -u -m sky.jobs.scheduler {{remote_user_yaml_path}} \ + --job-id $SKYPILOT_INTERNAL_JOB_ID envs: - DAG_YAML_PATH: {{remote_user_yaml_path}} {%- for env_name, env_value in controller_envs.items() %} {{env_name}}: {{env_value}} {%- endfor %}