Skip to content

Commit

Permalink
simplify locking mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
cg505 committed Dec 20, 2024
1 parent 78eef52 commit 4c54642
Show file tree
Hide file tree
Showing 11 changed files with 486 additions and 385 deletions.
10 changes: 5 additions & 5 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3329,8 +3329,8 @@ def _exec_code_on_head(
handle: CloudVmRayResourceHandle,
codegen: str,
job_id: int,
task: task_lib.Task,
detach_run: bool = False,
managed_job_dag: Optional['dag.Dag'] = None,
) -> None:
"""Executes generated code on the head node."""
style = colorama.Style
Expand Down Expand Up @@ -3378,11 +3378,11 @@ def _dump_code_to_file(codegen: str) -> None:
_dump_code_to_file(codegen)
job_submit_cmd = f'{mkdir_code} && {code}'

if task.managed_job_dag is not None:
if managed_job_dag is not None:
# Add the managed job to job queue database.
managed_job_codegen = managed_jobs.ManagedJobCodeGen()
managed_job_code = managed_job_codegen.set_pending(
job_id, task.managed_job_dag, task.envs['DAG_YAML_PATH'])
job_id, managed_job_dag)
# Set the managed job to PENDING state to make sure that this
# managed job appears in the `sky jobs queue`, when there are
# already 2x vCPU controller processes running on the controller VM,
Expand Down Expand Up @@ -4896,7 +4896,7 @@ def _execute_task_one_node(self, handle: CloudVmRayResourceHandle,
codegen.build(),
job_id,
detach_run=detach_run,
task=task)
managed_job_dag=task.managed_job_dag)

def _execute_task_n_nodes(self, handle: CloudVmRayResourceHandle,
task: task_lib.Task, job_id: int,
Expand Down Expand Up @@ -4952,4 +4952,4 @@ def _execute_task_n_nodes(self, handle: CloudVmRayResourceHandle,
codegen.build(),
job_id,
detach_run=detach_run,
task=task)
managed_job_dag=task.managed_job_dag)
32 changes: 1 addition & 31 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3583,18 +3583,6 @@ def jobs():
is_flag=True,
help=('If True, as soon as a job is submitted, return from this call '
'and do not stream execution logs.'))
@click.option(
'--retry-until-up/--no-retry-until-up',
'-r/-no-r',
default=None,
is_flag=True,
required=False,
help=(
'(Default: True; this flag is deprecated and will be removed in a '
'future release.) Whether to retry provisioning infinitely until the '
'cluster is up, if unavailability errors are encountered. This ' # pylint: disable=bad-docstring-quotes
'applies to launching all managed jobs (both the initial and '
'any recovery attempts), not the jobs controller.'))
@click.option('--yes',
'-y',
is_flag=True,
Expand Down Expand Up @@ -3634,7 +3622,6 @@ def jobs_launch(
disk_tier: Optional[str],
ports: Tuple[str],
detach_run: bool,
retry_until_up: bool,
yes: bool,
fast: bool,
):
Expand Down Expand Up @@ -3678,19 +3665,6 @@ def jobs_launch(
ports=ports,
job_recovery=job_recovery,
)
# Deprecation. We set the default behavior to be retry until up, and the
# flag `--retry-until-up` is deprecated. We can remove the flag in 0.8.0.
if retry_until_up is not None:
flag_str = '--retry-until-up'
if not retry_until_up:
flag_str = '--no-retry-until-up'
click.secho(
f'Flag {flag_str} is deprecated and will be removed in a '
'future release (managed jobs will always be retried). '
'Please file an issue if this does not work for you.',
fg='yellow')
else:
retry_until_up = True

if not isinstance(task_or_dag, sky.Dag):
assert isinstance(task_or_dag, sky.Task), task_or_dag
Expand Down Expand Up @@ -3730,11 +3704,7 @@ def jobs_launch(

common_utils.check_cluster_name_is_valid(name)

managed_jobs.launch(dag,
name,
detach_run=detach_run,
retry_until_up=retry_until_up,
fast=fast)
managed_jobs.launch(dag, name, detach_run=detach_run, fast=fast)


@jobs.command('queue', cls=_DocumentedCodeCommand)
Expand Down
87 changes: 33 additions & 54 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Controller: handles the life cycle of a managed job."""
import argparse
import contextlib
import multiprocessing
import os
import pathlib
Expand Down Expand Up @@ -48,12 +47,10 @@ def _get_dag_and_name(dag_yaml: str) -> Tuple['sky.Dag', str]:
class JobsController:
"""Each jobs controller manages the life cycle of one managed job."""

def __init__(self, job_id: int, dag_yaml: str,
retry_until_up: bool) -> None:
def __init__(self, job_id: int, dag_yaml: str) -> None:
self._job_id = job_id
self._dag, self._dag_name = _get_dag_and_name(dag_yaml)
logger.info(self._dag)
self._retry_until_up = retry_until_up
# TODO(zhwu): this assumes the specific backend.
self._backend = cloud_vm_ray_backend.CloudVmRayBackend()

Expand Down Expand Up @@ -176,48 +173,39 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
cluster_name = managed_job_utils.generate_managed_job_cluster_name(
task.name, self._job_id)
self._strategy_executor = recovery_strategy.StrategyExecutor.make(
cluster_name, self._backend, task, self._retry_until_up)

def _schedule_launch():
if task_id == 0:
# The job is already SUBMITTED in this case, and we do not need to schedule.
return contextlib.nullcontext()
# We need to wait for a scheduling slot.
# Neither the previous task or the current task is SUBMITTED or RUNNING.
return scheduler.schedule_active_job_launch(is_running=False)

with _schedule_launch():
# Note: task_id 0 will already be set to submitted by the scheduler.
# However, we only call the callback func here, so keep this.
managed_job_state.set_submitted(self._job_id,
task_id,
callback_func=callback_func)
logger.info(
f'Submitted managed job {self._job_id} (task: {task_id}, name: '
f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}')

logger.info('Started monitoring.')
managed_job_state.set_starting(
job_id=self._job_id,
task_id=task_id,
run_timestamp=self._backend.run_timestamp,
submit_time=submitted_at,
cluster_name, self._backend, task, self._job_id)
managed_job_state.set_submitted(
self._job_id,
task_id,
self._backend.run_timestamp,
submitted_at,
resources_str=backend_utils.get_task_resources_str(
task, is_managed_job=True),
specs={
'max_restarts_on_errors':
self._strategy_executor.max_restarts_on_errors
},
callback_func=callback_func)
logger.info(
f'Submitted managed job {self._job_id} (task: {task_id}, name: '
f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}')

scheduler.wait_until_launch_okay(self._job_id)

logger.info('Started monitoring.')
managed_job_state.set_starting(job_id=self._job_id,
task_id=task_id,
callback_func=callback_func)
remote_job_submitted_at = self._strategy_executor.launch()
assert remote_job_submitted_at is not None, remote_job_submitted_at

managed_job_state.set_started(job_id=self._job_id,
task_id=task_id,
start_time=remote_job_submitted_at,
callback_func=callback_func)
# Finished sky launch so now maybe something else can launch.
scheduler.schedule_step()

scheduler.launch_finished(self._job_id)

while True:
time.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS)

Expand Down Expand Up @@ -366,21 +354,18 @@ def _schedule_launch():

# Try to recover the managed jobs, when the cluster is preempted or
# failed or the job status is failed to be fetched.
managed_job_state.set_pending_recovery(job_id=self._job_id,
task_id=task_id,
callback_func=callback_func)
# schedule_recovery() will block until we can recover
with scheduler.schedule_active_job_launch(is_running=True):
managed_job_state.set_recovering(job_id=self._job_id,
task_id=task_id,
callback_func=callback_func)
managed_job_state.set_recovering(job_id=self._job_id,
task_id=task_id,
callback_func=callback_func)
# Switch to LAUNCHING schedule_state here, since the entire recovery
# process is somewhat heavy.
scheduler.wait_until_launch_okay(self._job_id)
recovered_time = self._strategy_executor.recover()
managed_job_state.set_recovered(self._job_id,
task_id,
recovered_time=recovered_time,
callback_func=callback_func)
# Just finished launching, maybe something was waiting to start.
scheduler.schedule_step()
scheduler.launch_finished(self._job_id)

def run(self):
"""Run controller logic and handle exceptions."""
Expand Down Expand Up @@ -451,11 +436,11 @@ def _update_failed_task_state(
task=self._dag.tasks[task_id]))


def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool):
def _run_controller(job_id: int, dag_yaml: str):
"""Runs the controller in a remote process for interruption."""
# The controller needs to be instantiated in the remote process, since
# the controller is not serializable.
jobs_controller = JobsController(job_id, dag_yaml, retry_until_up)
jobs_controller = JobsController(job_id, dag_yaml)
jobs_controller.run()


Expand Down Expand Up @@ -512,7 +497,7 @@ def _cleanup(job_id: int, dag_yaml: str):
backend.teardown_ephemeral_storage(task)


def start(job_id, dag_yaml, retry_until_up):
def start(job_id, dag_yaml):
"""Start the controller."""
controller_process = None
cancelling = False
Expand All @@ -525,8 +510,7 @@ def start(job_id, dag_yaml, retry_until_up):
# So we can only enable daemon after we no longer need to
# start daemon processes like Ray.
controller_process = multiprocessing.Process(target=_run_controller,
args=(job_id, dag_yaml,
retry_until_up))
args=(job_id, dag_yaml))
controller_process.start()
while controller_process.is_alive():
_handle_signal(job_id)
Expand Down Expand Up @@ -586,9 +570,7 @@ def start(job_id, dag_yaml, retry_until_up):
failure_reason=('Unexpected error occurred. For details, '
f'run: sky jobs logs --controller {job_id}'))

# Run the scheduler to kick off any pending jobs that can now start.
logger.info('Running scheduler')
scheduler.schedule_step()
scheduler.job_done(job_id)


if __name__ == '__main__':
Expand All @@ -597,14 +579,11 @@ def start(job_id, dag_yaml, retry_until_up):
required=True,
type=int,
help='Job id for the controller job.')
parser.add_argument('--retry-until-up',
action='store_true',
help='Retry until the cluster is up.')
parser.add_argument('dag_yaml',
type=str,
help='The path to the user job yaml file.')
args = parser.parse_args()
# We start process with 'spawn', because 'fork' could result in weird
# behaviors; 'spawn' is also cross-platform.
multiprocessing.set_start_method('spawn', force=True)
start(args.job_id, args.dag_yaml, args.retry_until_up)
start(args.job_id, args.dag_yaml)
2 changes: 0 additions & 2 deletions sky/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def launch(
name: Optional[str] = None,
stream_logs: bool = True,
detach_run: bool = False,
retry_until_up: bool = False,
fast: bool = False,
) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
Expand Down Expand Up @@ -115,7 +114,6 @@ def launch(
'jobs_controller': controller_name,
# Note: actual cluster name will be <task.name>-<managed job ID>
'dag_name': dag.name,
'retry_until_up': retry_until_up,
'remote_user_config_path': remote_user_config_path,
'modified_catalogs':
service_catalog_common.get_modified_catalog_file_mounts(),
Expand Down
Loading

0 comments on commit 4c54642

Please sign in to comment.