Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

detach the managed job controller from job submission #4458

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sky/jobs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

JOBS_CONTROLLER_TEMPLATE = 'jobs-controller.yaml.j2'
JOBS_CONTROLLER_YAML_PREFIX = '~/.sky/jobs_controller'
JOBS_CONTROLLER_PID_FILE_DIR = '~/.sky/jobs_controller_pids'
JOBS_CONTROLLER_LOGS_DIR = '~/sky_controller_logs'

JOBS_TASK_YAML_PREFIX = '~/.sky/managed_jobs'

Expand Down
46 changes: 32 additions & 14 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Optional, Tuple

import filelock
import psutil

from sky import exceptions
from sky import sky_logging
Expand All @@ -18,6 +19,7 @@
from sky.jobs import recovery_strategy
from sky.jobs import state as managed_job_state
from sky.jobs import utils as managed_job_utils
from sky.jobs import semaphore
from sky.skylet import constants
from sky.skylet import job_lib
from sky.usage import usage_lib
Expand All @@ -30,6 +32,9 @@
if typing.TYPE_CHECKING:
import sky

_JOB_SEMAPHORE_LOCK_DIR = os.path.expanduser('~/.sky/job_semaphore')
_JOB_LAUNCH_SEMAPHORE_LOCK_DIR = os.path.expanduser('~/.sky/job_launch_semaphore')

# Use the explicit logger name so that the logger is under the
# `sky.jobs.controller` namespace when executed directly, so as
# to inherit the setup from the `sky` logger.
Expand Down Expand Up @@ -191,17 +196,19 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
f'Submitted managed job {self._job_id} (task: {task_id}, name: '
f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}')

logger.info('Started monitoring.')
managed_job_state.set_starting(job_id=self._job_id,
task_id=task_id,
callback_func=callback_func)
remote_job_submitted_at = self._strategy_executor.launch()
assert remote_job_submitted_at is not None, remote_job_submitted_at

managed_job_state.set_started(job_id=self._job_id,
task_id=task_id,
start_time=remote_job_submitted_at,
callback_func=callback_func)
with semaphore.FileLockSemaphore(lock_dir_path=_JOB_LAUNCH_SEMAPHORE_LOCK_DIR, lock_count=_get_launch_parallelism()):
logger.info('Started monitoring.')
managed_job_state.set_starting(job_id=self._job_id,
task_id=task_id,
callback_func=callback_func)
remote_job_submitted_at = self._strategy_executor.launch()
assert remote_job_submitted_at is not None, remote_job_submitted_at

managed_job_state.set_started(job_id=self._job_id,
task_id=task_id,
start_time=remote_job_submitted_at,
callback_func=callback_func)

while True:
time.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS)

Expand Down Expand Up @@ -426,7 +433,14 @@ def _update_failed_task_state(
job_id=self._job_id,
task_id=task_id,
task=self._dag.tasks[task_id]))

def _get_job_parallelism() -> int:
# Assume a running job uses 400MB memory.
job_memory = 400 * 1024 * 1024
return psutil.virtual_memory().total // job_memory

def _get_launch_parallelism() -> int:
return os.cpu_count() * 4

def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool):
"""Runs the controller in a remote process for interruption."""
Expand Down Expand Up @@ -563,8 +577,7 @@ def start(job_id, dag_yaml, retry_until_up):
failure_reason=('Unexpected error occurred. For details, '
f'run: sky jobs logs --controller {job_id}'))


if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--job-id',
required=True,
Expand All @@ -580,4 +593,9 @@ def start(job_id, dag_yaml, retry_until_up):
# We start process with 'spawn', because 'fork' could result in weird
# behaviors; 'spawn' is also cross-platform.
multiprocessing.set_start_method('spawn', force=True)
start(args.job_id, args.dag_yaml, args.retry_until_up)

with semaphore.FileLockSemaphore(lock_dir_path=_JOB_SEMAPHORE_LOCK_DIR, lock_count=_get_job_parallelism()):
start(args.job_id, args.dag_yaml, args.retry_until_up)

if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions sky/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def launch(
'remote_user_config_path': remote_user_config_path,
'modified_catalogs':
service_catalog_common.get_modified_catalog_file_mounts(),
'controller_pid_file_dir':
managed_job_constants.JOBS_CONTROLLER_PID_FILE_DIR,
'controller_logs_dir':
managed_job_constants.JOBS_CONTROLLER_LOGS_DIR,
**controller_utils.shared_controller_vars_to_fill(
controller_utils.Controllers.JOBS_CONTROLLER,
remote_user_config_path=remote_user_config_path,
Expand Down
46 changes: 46 additions & 0 deletions sky/jobs/semaphore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""A file-lock based semaphore to limit parallelism of the jobs controller."""

import time
import filelock
import os
from typing import List

class FileLockSemaphore:
"""A cross-process semaphore-like mechanism using file locks.

Some semaphore uses are unsupported:
- Each release() call must have a corresponding acquire(), that is, the
semaphore value cannot go above the initial value (lock_count).
- All processes must use the same lock_count. This is not enforced by the
FileLockSemaphore class.
"""
def __init__(self, lock_dir_path: str, lock_count: int):
self.lock_dir_path = lock_dir_path
self.locks = [filelock.FileLock(os.path.join(lock_dir_path, f"{i}.lock")) for i in range(lock_count)]
self.acquired_locks: List[filelock.FileLock] = []

def acquire(self):
while True:
for lock in self.locks:
try:
lock.acquire(blocking=False)
self.acquired_locks.append(lock)
return
except filelock.Timeout:
pass
time.sleep(0.05)

def release(self):
if self.acquired_locks:
self.acquired_locks.pop().release()

def __enter__(self):
self.acquire()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.release()

def __del__(self):
for lock in self.acquired_locks:
lock.release()
90 changes: 88 additions & 2 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import colorama
import filelock
import psutil
from typing_extensions import Literal

from sky import backends
Expand Down Expand Up @@ -119,8 +120,38 @@ def update_managed_job_status(job_id: Optional[int] = None):
else:
job_ids = [job_id]
for job_id_ in job_ids:
controller_status = job_lib.get_status(job_id_)
if controller_status is None or controller_status.is_terminal():
submission_job_status = job_lib.get_status(job_id_)
if submission_job_status is None or submission_job_status.is_terminal():
if submission_job_status == job_lib.JobStatus.SUCCEEDED:
logger.debug(
f'Job {job_id_} is already {submission_job_status}.')
# This is expected, since the submitted job will detach the
# controller and succeed, even if the controller is still
# running. Check the controller status directly.
pid_file = os.path.join(
os.path.expanduser(
managed_job_constants.JOBS_CONTROLLER_PID_FILE_DIR),
str(job_id_))
try:
with open(pid_file, 'r', encoding='utf-8') as f:
pid = int(f.read())
logger.debug(f'Checking controller pid {pid}')
if psutil.Process(pid).is_running():
# The controller is still running.
continue
# Otherwise, proceed to mark the job as failed.
except FileNotFoundError:
logger.debug('Submission succeeded but controller pid '
f'file {pid_file} not found.')
# Proceed to mark the job as failed.
except ValueError:
logger.debug(f'Failed to parse the controller pid from '
f'{pid_file}.')
# Proceed to mark the job as failed.
except psutil.NoSuchProcess:
logger.debug('Controller process not found.')
# Proceed to mark the job as failed.

logger.error(f'Controller for job {job_id_} has exited abnormally. '
'Setting the job status to FAILED_CONTROLLER.')
tasks = managed_job_state.get_managed_jobs(job_id_)
Expand Down Expand Up @@ -527,6 +558,7 @@ def stream_logs(job_id: Optional[int],
'instead.')
job_id = managed_job_ids.pop()
assert job_id is not None, (job_id, job_name)

# TODO: keep the following code sync with
# job_lib.JobLibCodeGen.tail_logs, we do not directly call that function
# as the following code need to be run in the current machine, instead
Expand All @@ -536,6 +568,59 @@ def stream_logs(job_id: Optional[int],
return f'No managed job contrller log found with job_id {job_id}.'
log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp)
log_lib.tail_logs(job_id=job_id, log_dir=log_dir, follow=follow)

controller_log_path = os.path.join(
os.path.expanduser(managed_job_constants.JOBS_CONTROLLER_LOGS_DIR),
f'{job_id}.log')

# Wait for the log file to be written
while not os.path.exists(controller_log_path):
if not follow:
# Assume that the log file hasn't been written yet. Since we
# aren't following, just return.
return ''

job_status = managed_job_state.get_status(job_id)
# We know that the job is present in the state table because of
# earlier checks, so it should not be None.
assert job_status is not None, (job_id, job_name)
if job_status.is_terminal():
# Don't keep waiting. If the log file is not created by this
# point, it never will be. This job may have been submitted
# using an old version that did not create the log file, so this
# is not considered an exceptional case.
return ''

time.sleep(log_lib.SKY_LOG_WAITING_GAP_SECONDS)

# See also log_lib.tail_logs.
with open(controller_log_path, 'r', newline='', encoding='utf-8') as f:
# Note: we do not need to care about start_stream_at here, since
# that should be in the job log printed above.
for line in f:
print(line, end='')
# Flush.
print(end='', flush=True)

if follow:
while True:
line = f.readline()
if line is not None and line != '':
print(line, end='', flush=True)
else:
job_status = managed_job_state.get_status(job_id)
assert job_status is not None, (job_id, job_name)
if job_status.is_terminal():
break

time.sleep(log_lib.SKY_LOG_TAILING_GAP_SECONDS)

# Wait for final logs to be written.
time.sleep(1 + log_lib.SKY_LOG_TAILING_GAP_SECONDS)

# Print any remaining logs including incomplete line.
print(f.read(), end='', flush=True)

return ''

if job_id is None:
Expand Down Expand Up @@ -868,6 +953,7 @@ def stream_logs(cls,
# should be removed in v0.8.0.
code = textwrap.dedent("""\
import os
import time

from sky.skylet import job_lib, log_lib
from sky.skylet import constants
Expand Down
16 changes: 8 additions & 8 deletions sky/skylet/log_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from sky.utils import subprocess_utils
from sky.utils import ux_utils

_SKY_LOG_WAITING_GAP_SECONDS = 1
_SKY_LOG_WAITING_MAX_RETRY = 5
_SKY_LOG_TAILING_GAP_SECONDS = 0.2
SKY_LOG_WAITING_GAP_SECONDS = 1
SKY_LOG_WAITING_MAX_RETRY = 5
SKY_LOG_TAILING_GAP_SECONDS = 0.2
# Peek the head of the lines to check if we need to start
# streaming when tail > 0.
PEEK_HEAD_LINES_FOR_START_STREAM = 20
Expand Down Expand Up @@ -336,7 +336,7 @@ def _follow_job_logs(file,
]:
if wait_last_logs:
# Wait all the logs are printed before exit.
time.sleep(1 + _SKY_LOG_TAILING_GAP_SECONDS)
time.sleep(1 + SKY_LOG_TAILING_GAP_SECONDS)
wait_last_logs = False
continue
status_str = status.value if status is not None else 'None'
Expand All @@ -345,7 +345,7 @@ def _follow_job_logs(file,
f'Job finished (status: {status_str}).'))
return

time.sleep(_SKY_LOG_TAILING_GAP_SECONDS)
time.sleep(SKY_LOG_TAILING_GAP_SECONDS)
status = job_lib.get_status_no_lock(job_id)


Expand Down Expand Up @@ -426,15 +426,15 @@ def tail_logs(job_id: Optional[int],
retry_cnt += 1
if os.path.exists(log_path) and status != job_lib.JobStatus.INIT:
break
if retry_cnt >= _SKY_LOG_WAITING_MAX_RETRY:
if retry_cnt >= SKY_LOG_WAITING_MAX_RETRY:
print(
f'{colorama.Fore.RED}ERROR: Logs for '
f'{job_str} (status: {status.value}) does not exist '
f'after retrying {retry_cnt} times.{colorama.Style.RESET_ALL}')
return
print(f'INFO: Waiting {_SKY_LOG_WAITING_GAP_SECONDS}s for the logs '
print(f'INFO: Waiting {SKY_LOG_WAITING_GAP_SECONDS}s for the logs '
'to be written...')
time.sleep(_SKY_LOG_WAITING_GAP_SECONDS)
time.sleep(SKY_LOG_WAITING_GAP_SECONDS)
status = job_lib.update_job_status([job_id], silent=True)[0]

start_stream_at = LOG_FILE_START_STREAMING_AT
Expand Down
3 changes: 3 additions & 0 deletions sky/skylet/log_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ from sky.skylet import constants as constants
from sky.skylet import job_lib as job_lib
from sky.utils import log_utils as log_utils

SKY_LOG_WAITING_GAP_SECONDS: int = ...
SKY_LOG_WAITING_MAX_RETRY: int = ...
SKY_LOG_TAILING_GAP_SECONDS: float = ...
LOG_FILE_START_STREAMING_AT: str = ...


Expand Down
9 changes: 7 additions & 2 deletions sky/templates/jobs-controller.yaml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ setup: |

run: |
{{ sky_activate_python_env }}
mkdir -p {{controller_logs_dir}}
mkdir -p {{controller_pid_file_dir}}
# Start the controller for the current managed job.
python -u -m sky.jobs.controller {{remote_user_yaml_path}} \
--job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %}
nohup python -u -m sky.jobs.controller {{remote_user_yaml_path}} \
--job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %} \
> {{controller_logs_dir}}/$SKYPILOT_INTERNAL_JOB_ID.log 2>&1 </dev/null &
# Note: nohup calls exec, so the pid will stay the same.
echo $! > {{controller_pid_file_dir}}/$SKYPILOT_INTERNAL_JOB_ID

envs:
{%- for env_name, env_value in controller_envs.items() %}
Expand Down
Loading