Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve/Spot] Allow spot queue/cancel/logs during controller INIT state #3288

Merged
merged 23 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import shlex
import subprocess
import sys
import tempfile
import textwrap
import time
Expand Down Expand Up @@ -2004,7 +2005,7 @@ def refresh_cluster_record(
*,
force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]] = None,
acquire_per_cluster_status_lock: bool = True,
acquire_lock_timeout: int = CLUSTER_FILE_MOUNTS_LOCK_TIMEOUT_SECONDS
acquire_lock_timeout: int = CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[Dict[str, Any]]:
"""Refresh the cluster, and return the possibly updated record.

Expand Down Expand Up @@ -2226,31 +2227,31 @@ def is_controller_up(
controller_type: controller_utils.Controllers,
stopped_message: str,
non_existent_message: Optional[str] = None,
) -> Tuple[Optional[status_lib.ClusterStatus],
Optional['backends.CloudVmRayResourceHandle']]:
exit_on_error: bool = False,
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
) -> 'backends.CloudVmRayResourceHandle':
"""Check if the spot/serve controller is up.

It can be used to check the actual controller status (since the autostop is
set for the controller) before the spot/serve commands interact with the
controller.

ClusterNotUpError will be raised whenever the controller cannot be accessed.

Args:
type: Type of the controller.
stopped_message: Message to print if the controller is STOPPED.
non_existent_message: Message to show if the controller does not exist.

Returns:
controller_status: The status of the controller. If it fails during
refreshing the status, it will be the cached status. None if the
controller does not exist.
handle: The ResourceHandle of the controller. None if the
controller is not UP or does not exist.
handle: The ResourceHandle of the controller.

Raises:
exceptions.ClusterOwnerIdentityMismatchError: if the current user is not
the same as the user who created the cluster.
exceptions.CloudUserIdentityError: if we fail to get the current user
identity.
exceptions.ClusterNotUpError: if the controller is not UP, or failed to
be connected.
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
"""
if non_existent_message is None:
non_existent_message = (
Expand All @@ -2264,9 +2265,10 @@ def is_controller_up(
# unnecessary costly refresh when the controller is already stopped.
# This optimization is based on the assumption that the user will not
# start the controller manually from the cloud console.
# The acquire_lock_timeout is set to 1 second to avoid hanging the
# command when multiple spot_launch commands are running at the same
# time. It should be safe to set it to 0 (try once to get the lock).
#
# The acquire_lock_timeout is set to 0 to avoid hanging the command when
# multiple spot_launch commands are running at the same time. It should
# be safe to set it to 0 (try once to get the lock).
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
controller_status, handle = refresh_cluster_status_handle(
cluster_name, force_refresh_statuses=None, acquire_lock_timeout=0)
except exceptions.ClusterStatusFetchingError as e:
Expand All @@ -2283,11 +2285,32 @@ def is_controller_up(
if record is not None:
controller_status, handle = record['status'], record['handle']

if controller_status is None:
sky_logging.print(non_existent_message)
error_msg = None
if controller_status is None or handle.head_ip is None:
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
error_msg = non_existent_message
elif controller_status == status_lib.ClusterStatus.STOPPED:
sky_logging.print(stopped_message)
return controller_status, handle
error_msg = stopped_message
elif controller_status == status_lib.ClusterStatus.INIT:
ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml,
handle.docker_user,
handle.ssh_user)

runner = command_runner.SSHCommandRunner(handle.head_ip,
**ssh_credentials,
port=handle.head_ssh_port)
if not runner.check_connection():
error_msg = controller_type.value.hint_for_connection_error
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved

if error_msg is not None:
if exit_on_error:
sky_logging.print(error_msg)
sys.exit(1)
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterNotUpError(error_msg,
cluster_status=controller_status,
handle=handle)

return handle


class CloudFilter(enum.Enum):
Expand Down
24 changes: 10 additions & 14 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4081,12 +4081,10 @@ def spot_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
# Cancel managed spot jobs with IDs 1, 2, 3
$ sky spot cancel 1 2 3
"""
controller_status, _ = backend_utils.is_controller_up(
backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message='All managed spot jobs should have finished.')
if controller_status in [status_lib.ClusterStatus.STOPPED, None]:
# Hint messages already printed by the call above.
sys.exit(1)
stopped_message='All managed spot jobs should have finished.',
exit_on_error=True)

job_id_str = ','.join(map(str, job_ids))
if sum([len(job_ids) > 0, name is not None, all]) != 1:
Expand Down Expand Up @@ -4166,12 +4164,12 @@ def spot_dashboard(port: Optional[int]):
hint = (
'Dashboard is not available if spot controller is not up. Run a spot '
'job first.')
controller_status, _ = backend_utils.is_controller_up(
backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message=hint,
non_existent_message=hint)
if controller_status in [status_lib.ClusterStatus.STOPPED, None]:
sys.exit(1)
non_existent_message=hint,
exit_on_error=True)

# SSH forward a free local port to remote's dashboard port.
remote_port = constants.SPOT_DASHBOARD_REMOTE_PORT
if port is None:
Expand Down Expand Up @@ -4674,12 +4672,10 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
'Can only specify one of SERVICE_NAMES or --all. '
f'Provided {argument_str!r}.')

controller_status, _ = backend_utils.is_controller_up(
backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message='All services should have been terminated.')
if controller_status in [status_lib.ClusterStatus.STOPPED, None]:
# Hint messages already printed by the call above.
sys.exit(1)
stopped_message='All services should have been terminated.',
exit_on_error=True)

if not yes:
quoted_service_names = [f'{name!r}' for name in service_names]
Expand Down
36 changes: 13 additions & 23 deletions sky/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,13 +800,17 @@ def spot_queue(refresh: bool,
stop_msg = ''
if not refresh:
stop_msg = 'To view the latest job table: sky spot queue --refresh'
controller_status, handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message=stop_msg)
try:
handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message=stop_msg)
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
except exceptions.ClusterNotUpError as e:
if not refresh:
raise
handle = None
controller_status = e.cluster_status

if (refresh and controller_status in [
status_lib.ClusterStatus.STOPPED, status_lib.ClusterStatus.INIT
]):
if refresh and handle is not None:
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
sky_logging.print(f'{colorama.Fore.YELLOW}'
'Restarting controller for latest status...'
f'{colorama.Style.RESET_ALL}')
Expand All @@ -817,7 +821,7 @@ def spot_queue(refresh: bool,
controller_status = status_lib.ClusterStatus.UP
rich_utils.force_update_status('[cyan] Checking spot jobs[/]')

if handle is None or handle.head_ip is None:
if handle is None:
concretevitamin marked this conversation as resolved.
Show resolved Hide resolved
# When the controller is STOPPED, the head_ip will be None, as
# it will be set in global_user_state.remove_cluster().
# We do not directly check for UP because the controller may be
Expand Down Expand Up @@ -873,16 +877,9 @@ def spot_cancel(name: Optional[str] = None,
RuntimeError: failed to cancel the job.
"""
job_ids = [] if job_ids is None else job_ids
cluster_status, handle = backend_utils.is_controller_up(
handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message='All managed spot jobs should have finished.')
if handle is None or handle.head_ip is None:
# The error message is already printed in
# backend_utils.is_controller_up
# TODO(zhwu): Move the error message into the exception.
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterNotUpError(message='',
cluster_status=cluster_status)

job_id_str = ','.join(map(str, job_ids))
if sum([len(job_ids) > 0, name is not None, all]) != 1:
Expand Down Expand Up @@ -933,17 +930,10 @@ def spot_tail_logs(name: Optional[str], job_id: Optional[int],
sky.exceptions.ClusterNotUpError: the spot controller is not up.
"""
# TODO(zhwu): Automatically restart the spot controller
controller_status, handle = backend_utils.is_controller_up(
handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SPOT_CONTROLLER,
stopped_message=('Please restart the spot controller with '
f'`sky start {spot.SPOT_CONTROLLER_NAME}`.'))
if handle is None or handle.head_ip is None:
msg = 'All jobs finished.'
if controller_status == status_lib.ClusterStatus.INIT:
msg = ''
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterNotUpError(msg,
cluster_status=controller_status)

if name is not None and job_id is not None:
raise ValueError('Cannot specify both name and job_id.')
Expand Down
47 changes: 5 additions & 42 deletions sky/serve/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""SkyServe core APIs."""
import re
import tempfile
import typing
from typing import Any, Dict, List, Optional, Union

import colorama
Expand All @@ -11,7 +10,6 @@
from sky import exceptions
from sky import global_user_state
from sky import sky_logging
from sky import status_lib
from sky import task as task_lib
from sky.backends import backend_utils
from sky.serve import constants as serve_constants
Expand All @@ -26,9 +24,6 @@
from sky.utils import subprocess_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky import clouds

logger = sky_logging.init_logger(__name__)


Expand Down Expand Up @@ -310,7 +305,7 @@ def update(task: 'sky.Task', service_name: str) -> None:
service_name: Name of the service.
"""
_validate_service_task(task)
cluster_status, handle = backend_utils.is_controller_up(
handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message=
'Service controller is stopped. There is no service to update. '
Expand All @@ -321,14 +316,6 @@ def update(task: 'sky.Task', service_name: str) -> None:
f'use {backend_utils.BOLD}sky serve up{backend_utils.RESET_BOLD}',
)

if handle is None or handle.head_ip is None:
# The error message is already printed in
# backend_utils.is_controller_up
# TODO(zhwu): Move the error message into the exception.
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterNotUpError(message='',
cluster_status=cluster_status)

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

Expand Down Expand Up @@ -461,16 +448,9 @@ def down(
service_names = []
if isinstance(service_names, str):
service_names = [service_names]
cluster_status, handle = backend_utils.is_controller_up(
handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message='All services should have terminated.')
if handle is None or handle.head_ip is None:
# The error message is already printed in
# backend_utils.is_controller_up
# TODO(zhwu): Move the error message into the exception.
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterNotUpError(message='',
cluster_status=cluster_status)

service_names_str = ','.join(service_names)
if sum([len(service_names) > 0, all]) != 1:
Expand Down Expand Up @@ -573,22 +553,10 @@ def status(
raise RuntimeError(
'Failed to refresh service status due to network error.') from e

# TODO(tian): This is so slow... It will take ~10s to refresh the status
# of controller. Can we optimize this?
controller_status, handle = backend_utils.is_controller_up(
handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message='No service is found.')

if handle is None or handle.head_ip is None:
# When the controller is STOPPED, the head_ip will be None, as
# it will be set in global_user_state.remove_cluster().
# We do not directly check for UP because the controller may be
# in INIT state during another `sky serve up`, but still have
# head_ip available. In this case, we can still try to ssh
# into the controller and fetch the job table.
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
raise exceptions.ClusterNotUpError('Sky serve controller is not up.',
cluster_status=controller_status)

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

Expand Down Expand Up @@ -669,15 +637,10 @@ def tail_logs(
with ux_utils.print_exception_no_traceback():
raise ValueError('`replica_id` must be None when using '
'target=CONTROLLER/LOAD_BALANCER.')
controller_status, handle = backend_utils.is_controller_up(
handle = backend_utils.is_controller_up(
controller_type=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message='No service is found.')
if handle is None or handle.head_ip is None:
msg = 'No service is found.'
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
if controller_status == status_lib.ClusterStatus.INIT:
msg = ''
raise exceptions.ClusterNotUpError(msg,
cluster_status=controller_status)

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend), backend
backend.tail_serve_logs(handle,
Expand Down
18 changes: 15 additions & 3 deletions sky/utils/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def make_runner_list(
]

def _ssh_base_command(self, *, ssh_mode: SshMode,
port_forward: Optional[List[int]]) -> List[str]:
port_forward: Optional[List[int]],
connection_timeout: int) -> List[str]:
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
ssh = ['ssh']
if ssh_mode == SshMode.NON_INTERACTIVE:
# Disable pseudo-terminal allocation. Otherwise, the output of
Expand All @@ -243,6 +244,7 @@ def _ssh_base_command(self, *, ssh_mode: SshMode,
ssh_proxy_command=self._ssh_proxy_command,
docker_ssh_proxy_command=docker_ssh_proxy_command,
port=self.port,
timeout=connection_timeout,
disable_control_master=self.disable_control_master) + [
f'{self.ssh_user}@{self.ip}'
]
Expand All @@ -260,6 +262,7 @@ def run(
stream_logs: bool = True,
ssh_mode: SshMode = SshMode.NON_INTERACTIVE,
separate_stderr: bool = False,
connection_timeout: int = 30,
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
**kwargs) -> Union[int, Tuple[int, str, str]]:
"""Uses 'ssh' to run 'cmd' on a node with ip.

Expand All @@ -285,8 +288,10 @@ def run(
or
A tuple of (returncode, stdout, stderr).
"""
base_ssh_command = self._ssh_base_command(ssh_mode=ssh_mode,
port_forward=port_forward)
base_ssh_command = self._ssh_base_command(
ssh_mode=ssh_mode,
port_forward=port_forward,
connection_timeout=connection_timeout)
if ssh_mode == SshMode.LOGIN:
assert isinstance(cmd, list), 'cmd must be a list for login mode.'
command = base_ssh_command + cmd
Expand Down Expand Up @@ -449,3 +454,10 @@ def rsync(
error_msg,
stderr=stderr,
stream_logs=stream_logs)

def check_connection(self) -> bool:
"""Check if the connection to the remote machine is successful."""
returncode = self.run('true', connection_timeout=5)
if returncode:
return False
return True
3 changes: 3 additions & 0 deletions sky/utils/command_runner.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,6 @@ class SSHCommandRunner:
log_path: str = ...,
stream_logs: bool = ...) -> None:
...

def check_connection(self) -> bool:
...
Loading
Loading