Skip to content

Commit

Permalink
Fixes for CLIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Aug 1, 2024
1 parent 86eeb0e commit ea06a6f
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 71 deletions.
8 changes: 8 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ Spec: ``~/.sky/config.yaml``
Available fields and semantics:

.. code-block:: yaml
# Endpoint of the SkyPilot API server (optional).
#
# This is used to connect to the SkyPilot API server.
#
# Default: null (use the local endpoint, which will be started by SkyPilot
# automatically).
api_server:
endpoint: http://xx.xx.xx.xx:8000
# Custom managed jobs controller resources (optional).
#
Expand Down
143 changes: 85 additions & 58 deletions sky/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from sky import sky_logging
from sky.adaptors import common as adaptors_common
from sky.api import sdk as sdk_lib
from sky.api import common as api_common
from sky.backends import backend_utils
from sky.benchmark import benchmark_state
from sky.benchmark import benchmark_utils
Expand Down Expand Up @@ -113,12 +114,13 @@
sdk = sdk_lib


def _get_glob_clusters(clusters: List[str], silent: bool = False) -> List[str]:

def _get_cluster_records(clusters: List[str], refresh: api_common.StatusRefreshMode = api_common.StatusRefreshMode.NONE) -> List[dict]:
"""Returns a list of clusters that match the glob pattern."""
request_id = sdk.status(clusters)
# TODO(zhwu): this additional RTT makes CLIs slow. We should optimize this.
request_id = sdk.status(clusters, refresh=refresh)
cluster_records = sdk.get(request_id)
clusters = [record['name'] for record in cluster_records]
return clusters
return cluster_records


def _get_glob_storages(storages: List[str]) -> List[str]:
Expand Down Expand Up @@ -147,6 +149,12 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]:
'or KEY.')
return ret[0], ret[1]

def _async_call_or_wait(request_id: str, async_call: bool, request_name: str) -> None:
if not async_call:
sdk.stream_and_get(request_id)
else:
click.secho(f'Submitted {request_name} request: {request_id}', fg='green')


def _merge_env_vars(env_dict: Optional[Dict[str, str]],
env_list: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
Expand Down Expand Up @@ -1087,18 +1095,15 @@ def launch(
backend,
cluster,
dryrun=dryrun,
detach_setup=detach_setup,
detach_run=detach_run,
detach_setup=detach_setup or async_call,
detach_run=detach_run or async_call,
no_confirm=yes,
idle_minutes_to_autostop=idle_minutes_to_autostop,
down=down,
retry_until_up=retry_until_up,
no_setup=no_setup,
clone_disk_from=clone_disk_from)
if not async_call:
sdk.stream_and_get(request_id)
else:
click.secho(f'Submitted Launch request: {request_id}', fg='green')
_async_call_or_wait(request_id, async_call, 'Launch')


@cli.command(cls=_DocumentedCodeCommand)
Expand Down Expand Up @@ -1241,11 +1246,8 @@ def exec(cluster: Optional[str], cluster_option: Optional[str],
click.secho(f'Executing task on cluster {cluster}...', fg='yellow')
request_id = sdk.exec(task,
cluster_name=cluster,
detach_run=detach_run)
if not async_call:
sdk.stream_and_get(request_id)
else:
click.secho(f'Submitted Exec request: {request_id}', fg='green')
detach_run=detach_run or async_call)
_async_call_or_wait(request_id, async_call, 'Exec')


def _get_managed_jobs(
Expand Down Expand Up @@ -1593,7 +1595,10 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
click.echo(f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}Clusters'
f'{colorama.Style.RESET_ALL}')
query_clusters: Optional[List[str]] = None if not clusters else clusters
request = sdk.status(cluster_names=query_clusters, refresh=refresh)
refresh_mode = api_common.StatusRefreshMode.NONE
if refresh:
refresh_mode = api_common.StatusRefreshMode.FORCE
request = sdk.status(cluster_names=query_clusters, refresh=refresh_mode)
cluster_records = sdk.stream_and_get(request)
# TOOD(zhwu): setup the ssh config for status
if ip or show_endpoints:
Expand Down Expand Up @@ -1859,7 +1864,8 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool):
click.secho('Fetching and parsing job queue...', fg='yellow')
if not clusters:
clusters = ['*']
clusters = _get_glob_clusters(clusters)
cluster_records = _get_cluster_records(clusters)
clusters = [cluster['name'] for cluster in cluster_records]

unsupported_clusters = []
logger.info(f'Fetching job queue for {clusters}')
Expand Down Expand Up @@ -2013,9 +2019,10 @@ def logs(
default=False,
required=False,
help='Skip confirmation prompt.')
@_add_click_options(_COMMON_OPTIONS)
@click.argument('jobs', required=False, type=int, nargs=-1)
@usage_lib.entrypoint
def cancel(cluster: str, all: bool, jobs: List[int], yes: bool): # pylint: disable=redefined-builtin, redefined-outer-name
def cancel(cluster: str, all: bool, jobs: List[int], yes: bool, async_call: bool): # pylint: disable=redefined-builtin, redefined-outer-name
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Cancel job(s).
Expand Down Expand Up @@ -2061,7 +2068,8 @@ def cancel(cluster: str, all: bool, jobs: List[int], yes: bool): # pylint: disa
show_default=True)

try:
sdk.cancel(cluster, all=all, job_ids=job_ids_to_cancel)
request_id = sdk.cancel(cluster, all=all, job_ids=job_ids_to_cancel)
_async_call_or_wait(request_id, async_call, 'Cancel')
except exceptions.NotSupportedError as e:
controller = controller_utils.Controllers.from_name(cluster)
assert controller is not None, cluster
Expand Down Expand Up @@ -2090,11 +2098,13 @@ def cancel(cluster: str, all: bool, jobs: List[int], yes: bool): # pylint: disa
default=False,
required=False,
help='Skip confirmation prompt.')
@_add_click_options(_COMMON_OPTIONS)
@usage_lib.entrypoint
def stop(
clusters: List[str],
all: Optional[bool], # pylint: disable=redefined-builtin
yes: bool,
async_call: bool,
):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Stop cluster(s).
Expand Down Expand Up @@ -2128,7 +2138,8 @@ def stop(
_down_or_stop_clusters(clusters,
apply_to_all=all,
down=False,
no_confirm=yes)
no_confirm=yes,
async_call=async_call)


@cli.command(cls=_DocumentedCodeCommand)
Expand Down Expand Up @@ -2168,6 +2179,7 @@ def stop(
default=False,
required=False,
help='Skip confirmation prompt.')
@_add_click_options(_COMMON_OPTIONS)
@usage_lib.entrypoint
def autostop(
clusters: List[str],
Expand All @@ -2176,6 +2188,7 @@ def autostop(
cancel: bool, # pylint: disable=redefined-outer-name
down: bool, # pylint: disable=redefined-outer-name
yes: bool,
async_call: bool,
):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Schedule an autostop or autodown for cluster(s).
Expand Down Expand Up @@ -2230,7 +2243,8 @@ def autostop(
apply_to_all=all,
down=down,
no_confirm=yes,
idle_minutes_to_autostop=idle_minutes)
idle_minutes_to_autostop=idle_minutes,
async_call=async_call)


@cli.command(cls=_DocumentedCodeCommand)
Expand Down Expand Up @@ -2293,6 +2307,7 @@ def autostop(
required=False,
help=('Force start the cluster even if it is already UP. Useful for '
'upgrading the SkyPilot runtime on the cluster.'))
@_add_click_options(_COMMON_OPTIONS)
@usage_lib.entrypoint
# pylint: disable=redefined-builtin
def start(
Expand All @@ -2302,7 +2317,8 @@ def start(
idle_minutes_to_autostop: Optional[int],
down: bool, # pylint: disable=redefined-outer-name
retry_until_up: bool,
force: bool):
force: bool,
async_call: bool,):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Restart cluster(s).
Expand Down Expand Up @@ -2336,12 +2352,13 @@ def start(
'--idle-minutes-to-autostop must be set if --down is set.')
to_start = []

cluster_records = None
if not clusters and not all:
# UX: frequently users may have only 1 cluster. In this case, be smart
# and default to that unique choice.
all_cluster_names = global_user_state.get_cluster_names_start_with('')
if len(all_cluster_names) <= 1:
clusters = all_cluster_names
all_clusters = _get_cluster_records(['*'], refresh=api_common.StatusRefreshMode.AUTO)
if len(all_clusters) <= 1:
cluster_records = all_clusters
else:
raise click.UsageError(
'`sky start` requires either a cluster name or glob '
Expand All @@ -2352,10 +2369,12 @@ def start(
click.echo('Both --all and cluster(s) specified for sky start. '
'Letting --all take effect.')

all_clusters = _get_cluster_records(['*'], refresh=api_common.StatusRefreshMode.AUTO)

# Get all clusters that are not controllers.
clusters = [
cluster['name']
for cluster in global_user_state.get_clusters()
cluster_records = [
cluster
for cluster in all_clusters
if controller_utils.Controllers.from_name(cluster['name']) is None
]

Expand All @@ -2364,12 +2383,12 @@ def start(
'mean to use `sky launch` to provision a new cluster?')
return
else:
# Get GLOB cluster names
clusters = _get_glob_clusters(clusters)

for name in clusters:
cluster_status, _ = backend_utils.refresh_cluster_status_handle(
name)
if cluster_records is None:
# Get GLOB cluster names
cluster_records = _get_cluster_records(clusters, refresh=api_common.StatusRefreshMode.AUTO)
for cluster in cluster_records:
name = cluster['name']
cluster_status = cluster['status']
# A cluster may have one of the following states:
#
# STOPPED - ok to restart
Expand Down Expand Up @@ -2449,18 +2468,24 @@ def start(
abort=True,
show_default=True)

for name in to_start:
try:
sdk.start(name,
request_ids = subprocess_utils.run_in_parallel(
lambda name: sdk.start(name,
idle_minutes_to_autostop,
retry_until_up,
down=down,
force=force)
force=force),
to_start
)

for name, request_id in zip(to_start, request_ids):
try:
_async_call_or_wait(request_id, async_call, 'Start')
except (exceptions.NotSupportedError,
exceptions.ClusterOwnerIdentityMismatchError) as e:
click.echo(str(e))
else:
click.secho(f'Cluster {name} started.', fg='green')
if not async_call:
click.secho(f'Cluster {name} started.', fg='green')


@cli.command(cls=_DocumentedCodeCommand)
Expand Down Expand Up @@ -2643,7 +2668,8 @@ def _down_or_stop_clusters(
down: bool, # pylint: disable=redefined-outer-name
no_confirm: bool,
purge: bool = False,
idle_minutes_to_autostop: Optional[int] = None) -> None:
idle_minutes_to_autostop: Optional[int] = None,
async_call: bool=False) -> None:
"""Tears down or (auto-)stops a cluster (or all clusters).
Controllers (jobs controller and sky serve controller) can only be
Expand All @@ -2660,9 +2686,9 @@ def _down_or_stop_clusters(
# UX: frequently users may have only 1 cluster. In this case, 'sky
# stop/down' without args should be smart and default to that unique
# choice.
all_cluster_names = global_user_state.get_cluster_names_start_with('')
if len(all_cluster_names) <= 1:
names = all_cluster_names
all_clusters = _get_cluster_records(['*'])
if len(all_clusters) <= 1:
names = [cluster['name'] for cluster in all_clusters]
else:
raise click.UsageError(
f'`sky {command}` requires either a cluster name or glob '
Expand All @@ -2684,8 +2710,8 @@ def _down_or_stop_clusters(
]
controllers_str = ', '.join(map(repr, controllers))
names = [
name for name in _get_glob_clusters(names)
if controller_utils.Controllers.from_name(name) is None
cluster['name'] for cluster in _get_cluster_records(names)
if controller_utils.Controllers.from_name(cluster['name']) is None
]

# Make sure the controllers are explicitly specified without other
Expand Down Expand Up @@ -2743,7 +2769,7 @@ def _down_or_stop_clusters(
names += controllers

if apply_to_all:
all_clusters = global_user_state.get_clusters()
all_clusters = _get_cluster_records(['*'])
if len(names) > 0:
click.echo(
f'Both --all and cluster(s) specified for `sky {command}`. '
Expand All @@ -2756,15 +2782,7 @@ def _down_or_stop_clusters(
if controller_utils.Controllers.from_name(record['name']) is None
]

clusters = []
for name in names:
handle = global_user_state.get_handle_from_cluster_name(name)
if handle is None:
# This codepath is used for 'sky down -p <controller>' when the
# controller is not in 'sky status'. Cluster-not-found message
# should've been printed by _get_glob_clusters() above.
continue
clusters.append(name)
clusters = names
usage_lib.record_cluster_name_for_current_operation(clusters)

if not clusters:
Expand All @@ -2789,11 +2807,14 @@ def _down_or_stop_clusters(
f'[bold cyan]{operation} {len(clusters)} cluster{plural}[/]',
total=len(clusters))

request_ids = []
def _down_or_stop(name: str):
success_progress = False
if idle_minutes_to_autostop is not None:
try:
sdk.autostop(name, idle_minutes_to_autostop, down)
request_id = sdk.autostop(name, idle_minutes_to_autostop, down)
request_ids.append(request_id)
_async_call_or_wait(request_id, async_call, operation.capitalize())
except (exceptions.NotSupportedError,
exceptions.ClusterNotUpError) as e:
message = str(e)
Expand All @@ -2816,9 +2837,11 @@ def _down_or_stop(name: str):
else:
try:
if down:
sdk.get(sdk.down(name, purge=purge))
request_id = sdk.down(name, purge=purge)
else:
sdk.get(sdk.stop(name, purge=purge))
request_id = sdk.stop(name, purge=purge)
request_ids.append(request_id)
_async_call_or_wait(request_id, async_call, operation.capitalize())
except RuntimeError as e:
message = (
f'{colorama.Fore.RED}{operation} cluster {name}...failed. '
Expand Down Expand Up @@ -2849,6 +2872,10 @@ def _down_or_stop(name: str):
# Make sure the progress bar not mess up the terminal.
progress.refresh()

if async_call:
click.secho(f'--async is passed, and {operation} requests are sent, '
'but some may fail at the background. Check the requests '
'with their IP')

@cli.command(cls=_DocumentedCodeCommand)
@click.argument('clouds', required=False, type=str, nargs=-1)
Expand Down
13 changes: 13 additions & 0 deletions sky/api/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Common data structures and constants used in the API."""
import enum

class StatusRefreshMode(enum.Enum):
"""The mode of refreshing the status of a cluster."""

NONE = 'NONE'
# Automatically refresh when needed, e.g., autostop is set or the cluster
# is a spot instance.
AUTO = 'AUTO'
FORCE = 'FORCE'


Loading

0 comments on commit ea06a6f

Please sign in to comment.