diff --git a/sky/jobs/constants.py b/sky/jobs/constants.py index d5f32908317..873822e9291 100644 --- a/sky/jobs/constants.py +++ b/sky/jobs/constants.py @@ -2,6 +2,8 @@ JOBS_CONTROLLER_TEMPLATE = 'jobs-controller.yaml.j2' JOBS_CONTROLLER_YAML_PREFIX = '~/.sky/jobs_controller' +JOBS_CONTROLLER_PID_FILE_DIR = '~/.sky/jobs_controller_pids' +JOBS_CONTROLLER_LOGS_DIR = '~/sky_controller_logs' JOBS_TASK_YAML_PREFIX = '~/.sky/managed_jobs' diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 72dce3e50d7..49d87fb4822 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -9,6 +9,7 @@ from typing import Optional, Tuple import filelock +import psutil from sky import exceptions from sky import sky_logging @@ -18,6 +19,7 @@ from sky.jobs import recovery_strategy from sky.jobs import state as managed_job_state from sky.jobs import utils as managed_job_utils +from sky.jobs import semaphore from sky.skylet import constants from sky.skylet import job_lib from sky.usage import usage_lib @@ -30,6 +32,9 @@ if typing.TYPE_CHECKING: import sky +_JOB_SEMAPHORE_LOCK_DIR = os.path.expanduser('~/.sky/job_semaphore') +_JOB_LAUNCH_SEMAPHORE_LOCK_DIR = os.path.expanduser('~/.sky/job_launch_semaphore') + # Use the explicit logger name so that the logger is under the # `sky.jobs.controller` namespace when executed directly, so as # to inherit the setup from the `sky` logger. @@ -191,17 +196,19 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: f'Submitted managed job {self._job_id} (task: {task_id}, name: ' f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}') - logger.info('Started monitoring.') - managed_job_state.set_starting(job_id=self._job_id, - task_id=task_id, - callback_func=callback_func) - remote_job_submitted_at = self._strategy_executor.launch() - assert remote_job_submitted_at is not None, remote_job_submitted_at - - managed_job_state.set_started(job_id=self._job_id, - task_id=task_id, - start_time=remote_job_submitted_at, - callback_func=callback_func) + with semaphore.FileLockSemaphore(lock_dir_path=_JOB_LAUNCH_SEMAPHORE_LOCK_DIR, lock_count=_get_launch_parallelism()): + logger.info('Started monitoring.') + managed_job_state.set_starting(job_id=self._job_id, + task_id=task_id, + callback_func=callback_func) + remote_job_submitted_at = self._strategy_executor.launch() + assert remote_job_submitted_at is not None, remote_job_submitted_at + + managed_job_state.set_started(job_id=self._job_id, + task_id=task_id, + start_time=remote_job_submitted_at, + callback_func=callback_func) + while True: time.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS) @@ -426,7 +433,14 @@ def _update_failed_task_state( job_id=self._job_id, task_id=task_id, task=self._dag.tasks[task_id])) + +def _get_job_parallelism() -> int: + # Assume a running job uses 400MB memory. + job_memory = 400 * 1024 * 1024 + return psutil.virtual_memory().total // job_memory +def _get_launch_parallelism() -> int: + return os.cpu_count() * 4 def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool): """Runs the controller in a remote process for interruption.""" @@ -563,8 +577,7 @@ def start(job_id, dag_yaml, retry_until_up): failure_reason=('Unexpected error occurred. For details, ' f'run: sky jobs logs --controller {job_id}')) - -if __name__ == '__main__': +def main(): parser = argparse.ArgumentParser() parser.add_argument('--job-id', required=True, @@ -580,4 +593,9 @@ def start(job_id, dag_yaml, retry_until_up): # We start process with 'spawn', because 'fork' could result in weird # behaviors; 'spawn' is also cross-platform. multiprocessing.set_start_method('spawn', force=True) - start(args.job_id, args.dag_yaml, args.retry_until_up) + + with semaphore.FileLockSemaphore(lock_dir_path=_JOB_SEMAPHORE_LOCK_DIR, lock_count=_get_job_parallelism()): + start(args.job_id, args.dag_yaml, args.retry_until_up) + +if __name__ == '__main__': + main() diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 9cde3443816..50b04bac3bb 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -119,6 +119,10 @@ def launch( 'remote_user_config_path': remote_user_config_path, 'modified_catalogs': service_catalog_common.get_modified_catalog_file_mounts(), + 'controller_pid_file_dir': + managed_job_constants.JOBS_CONTROLLER_PID_FILE_DIR, + 'controller_logs_dir': + managed_job_constants.JOBS_CONTROLLER_LOGS_DIR, **controller_utils.shared_controller_vars_to_fill( controller_utils.Controllers.JOBS_CONTROLLER, remote_user_config_path=remote_user_config_path, diff --git a/sky/jobs/semaphore.py b/sky/jobs/semaphore.py new file mode 100644 index 00000000000..961eaad2d56 --- /dev/null +++ b/sky/jobs/semaphore.py @@ -0,0 +1,46 @@ +"""A file-lock based semaphore to limit parallelism of the jobs controller.""" + +import time +import filelock +import os +from typing import List + +class FileLockSemaphore: + """A cross-process semaphore-like mechanism using file locks. + + Some semaphore uses are unsupported: + - Each release() call must have a corresponding acquire(), that is, the + semaphore value cannot go above the initial value (lock_count). + - All processes must use the same lock_count. This is not enforced by the + FileLockSemaphore class. + """ + def __init__(self, lock_dir_path: str, lock_count: int): + self.lock_dir_path = lock_dir_path + self.locks = [filelock.FileLock(os.path.join(lock_dir_path, f"{i}.lock")) for i in range(lock_count)] + self.acquired_locks: List[filelock.FileLock] = [] + + def acquire(self): + while True: + for lock in self.locks: + try: + lock.acquire(blocking=False) + self.acquired_locks.append(lock) + return + except filelock.Timeout: + pass + time.sleep(0.05) + + def release(self): + if self.acquired_locks: + self.acquired_locks.pop().release() + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.release() + + def __del__(self): + for lock in self.acquired_locks: + lock.release() diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 267c205285b..bef25f821c9 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -18,6 +18,7 @@ import colorama import filelock +import psutil from typing_extensions import Literal from sky import backends @@ -119,8 +120,38 @@ def update_managed_job_status(job_id: Optional[int] = None): else: job_ids = [job_id] for job_id_ in job_ids: - controller_status = job_lib.get_status(job_id_) - if controller_status is None or controller_status.is_terminal(): + submission_job_status = job_lib.get_status(job_id_) + if submission_job_status is None or submission_job_status.is_terminal(): + if submission_job_status == job_lib.JobStatus.SUCCEEDED: + logger.debug( + f'Job {job_id_} is already {submission_job_status}.') + # This is expected, since the submitted job will detach the + # controller and succeed, even if the controller is still + # running. Check the controller status directly. + pid_file = os.path.join( + os.path.expanduser( + managed_job_constants.JOBS_CONTROLLER_PID_FILE_DIR), + str(job_id_)) + try: + with open(pid_file, 'r', encoding='utf-8') as f: + pid = int(f.read()) + logger.debug(f'Checking controller pid {pid}') + if psutil.Process(pid).is_running(): + # The controller is still running. + continue + # Otherwise, proceed to mark the job as failed. + except FileNotFoundError: + logger.debug('Submission succeeded but controller pid ' + f'file {pid_file} not found.') + # Proceed to mark the job as failed. + except ValueError: + logger.debug(f'Failed to parse the controller pid from ' + f'{pid_file}.') + # Proceed to mark the job as failed. + except psutil.NoSuchProcess: + logger.debug('Controller process not found.') + # Proceed to mark the job as failed. + logger.error(f'Controller for job {job_id_} has exited abnormally. ' 'Setting the job status to FAILED_CONTROLLER.') tasks = managed_job_state.get_managed_jobs(job_id_) @@ -527,6 +558,7 @@ def stream_logs(job_id: Optional[int], 'instead.') job_id = managed_job_ids.pop() assert job_id is not None, (job_id, job_name) + # TODO: keep the following code sync with # job_lib.JobLibCodeGen.tail_logs, we do not directly call that function # as the following code need to be run in the current machine, instead @@ -536,6 +568,59 @@ def stream_logs(job_id: Optional[int], return f'No managed job contrller log found with job_id {job_id}.' log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp) log_lib.tail_logs(job_id=job_id, log_dir=log_dir, follow=follow) + + controller_log_path = os.path.join( + os.path.expanduser(managed_job_constants.JOBS_CONTROLLER_LOGS_DIR), + f'{job_id}.log') + + # Wait for the log file to be written + while not os.path.exists(controller_log_path): + if not follow: + # Assume that the log file hasn't been written yet. Since we + # aren't following, just return. + return '' + + job_status = managed_job_state.get_status(job_id) + # We know that the job is present in the state table because of + # earlier checks, so it should not be None. + assert job_status is not None, (job_id, job_name) + if job_status.is_terminal(): + # Don't keep waiting. If the log file is not created by this + # point, it never will be. This job may have been submitted + # using an old version that did not create the log file, so this + # is not considered an exceptional case. + return '' + + time.sleep(log_lib.SKY_LOG_WAITING_GAP_SECONDS) + + # See also log_lib.tail_logs. + with open(controller_log_path, 'r', newline='', encoding='utf-8') as f: + # Note: we do not need to care about start_stream_at here, since + # that should be in the job log printed above. + for line in f: + print(line, end='') + # Flush. + print(end='', flush=True) + + if follow: + while True: + line = f.readline() + if line is not None and line != '': + print(line, end='', flush=True) + else: + job_status = managed_job_state.get_status(job_id) + assert job_status is not None, (job_id, job_name) + if job_status.is_terminal(): + break + + time.sleep(log_lib.SKY_LOG_TAILING_GAP_SECONDS) + + # Wait for final logs to be written. + time.sleep(1 + log_lib.SKY_LOG_TAILING_GAP_SECONDS) + + # Print any remaining logs including incomplete line. + print(f.read(), end='', flush=True) + return '' if job_id is None: @@ -868,6 +953,7 @@ def stream_logs(cls, # should be removed in v0.8.0. code = textwrap.dedent("""\ import os + import time from sky.skylet import job_lib, log_lib from sky.skylet import constants diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index 8a40982972a..ac2b488baf0 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -25,9 +25,9 @@ from sky.utils import subprocess_utils from sky.utils import ux_utils -_SKY_LOG_WAITING_GAP_SECONDS = 1 -_SKY_LOG_WAITING_MAX_RETRY = 5 -_SKY_LOG_TAILING_GAP_SECONDS = 0.2 +SKY_LOG_WAITING_GAP_SECONDS = 1 +SKY_LOG_WAITING_MAX_RETRY = 5 +SKY_LOG_TAILING_GAP_SECONDS = 0.2 # Peek the head of the lines to check if we need to start # streaming when tail > 0. PEEK_HEAD_LINES_FOR_START_STREAM = 20 @@ -336,7 +336,7 @@ def _follow_job_logs(file, ]: if wait_last_logs: # Wait all the logs are printed before exit. - time.sleep(1 + _SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(1 + SKY_LOG_TAILING_GAP_SECONDS) wait_last_logs = False continue status_str = status.value if status is not None else 'None' @@ -345,7 +345,7 @@ def _follow_job_logs(file, f'Job finished (status: {status_str}).')) return - time.sleep(_SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(SKY_LOG_TAILING_GAP_SECONDS) status = job_lib.get_status_no_lock(job_id) @@ -426,15 +426,15 @@ def tail_logs(job_id: Optional[int], retry_cnt += 1 if os.path.exists(log_path) and status != job_lib.JobStatus.INIT: break - if retry_cnt >= _SKY_LOG_WAITING_MAX_RETRY: + if retry_cnt >= SKY_LOG_WAITING_MAX_RETRY: print( f'{colorama.Fore.RED}ERROR: Logs for ' f'{job_str} (status: {status.value}) does not exist ' f'after retrying {retry_cnt} times.{colorama.Style.RESET_ALL}') return - print(f'INFO: Waiting {_SKY_LOG_WAITING_GAP_SECONDS}s for the logs ' + print(f'INFO: Waiting {SKY_LOG_WAITING_GAP_SECONDS}s for the logs ' 'to be written...') - time.sleep(_SKY_LOG_WAITING_GAP_SECONDS) + time.sleep(SKY_LOG_WAITING_GAP_SECONDS) status = job_lib.update_job_status([job_id], silent=True)[0] start_stream_at = LOG_FILE_START_STREAMING_AT diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index 89d1628ec11..c7028e121aa 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -13,6 +13,9 @@ from sky.skylet import constants as constants from sky.skylet import job_lib as job_lib from sky.utils import log_utils as log_utils +SKY_LOG_WAITING_GAP_SECONDS: int = ... +SKY_LOG_WAITING_MAX_RETRY: int = ... +SKY_LOG_TAILING_GAP_SECONDS: float = ... LOG_FILE_START_STREAMING_AT: str = ... diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index 45cdb5141d4..7367bc9a55e 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -33,9 +33,14 @@ setup: | run: | {{ sky_activate_python_env }} + mkdir -p {{controller_logs_dir}} + mkdir -p {{controller_pid_file_dir}} # Start the controller for the current managed job. - python -u -m sky.jobs.controller {{remote_user_yaml_path}} \ - --job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %} + nohup python -u -m sky.jobs.controller {{remote_user_yaml_path}} \ + --job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %} \ + > {{controller_logs_dir}}/$SKYPILOT_INTERNAL_JOB_ID.log 2>&1 {{controller_pid_file_dir}}/$SKYPILOT_INTERNAL_JOB_ID envs: {%- for env_name, env_value in controller_envs.items() %}