From ea06a6faeeff2e01810e770a6593024f9a8ae763 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 1 Aug 2024 22:49:06 +0000 Subject: [PATCH] Fixes for CLIs --- docs/source/reference/config.rst | 8 ++ sky/api/cli.py | 143 ++++++++++++++++++------------- sky/api/common.py | 13 +++ sky/api/requests/decoders.py | 4 + sky/api/requests/encoders.py | 5 ++ sky/api/requests/payloads.py | 5 +- sky/api/rest.py | 2 +- sky/api/sdk.py | 20 +++-- sky/skylet/constants.py | 2 +- sky/utils/schemas.py | 14 +++ 10 files changed, 145 insertions(+), 71 deletions(-) create mode 100644 sky/api/common.py diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index fc5eddd6a47..9f4481536d1 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -17,6 +17,14 @@ Spec: ``~/.sky/config.yaml`` Available fields and semantics: .. code-block:: yaml + # Endpoint of the SkyPilot API server (optional). + # + # This is used to connect to the SkyPilot API server. + # + # Default: null (use the local endpoint, which will be started by SkyPilot + # automatically). + api_server: + endpoint: http://xx.xx.xx.xx:8000 # Custom managed jobs controller resources (optional). # diff --git a/sky/api/cli.py b/sky/api/cli.py index e55b71ed781..73530baf29b 100644 --- a/sky/api/cli.py +++ b/sky/api/cli.py @@ -55,6 +55,7 @@ from sky import sky_logging from sky.adaptors import common as adaptors_common from sky.api import sdk as sdk_lib +from sky.api import common as api_common from sky.backends import backend_utils from sky.benchmark import benchmark_state from sky.benchmark import benchmark_utils @@ -113,12 +114,13 @@ sdk = sdk_lib -def _get_glob_clusters(clusters: List[str], silent: bool = False) -> List[str]: + +def _get_cluster_records(clusters: List[str], refresh: api_common.StatusRefreshMode = api_common.StatusRefreshMode.NONE) -> List[dict]: """Returns a list of clusters that match the glob pattern.""" - request_id = sdk.status(clusters) + # TODO(zhwu): this additional RTT makes CLIs slow. We should optimize this. + request_id = sdk.status(clusters, refresh=refresh) cluster_records = sdk.get(request_id) - clusters = [record['name'] for record in cluster_records] - return clusters + return cluster_records def _get_glob_storages(storages: List[str]) -> List[str]: @@ -147,6 +149,12 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]: 'or KEY.') return ret[0], ret[1] +def _async_call_or_wait(request_id: str, async_call: bool, request_name: str) -> None: + if not async_call: + sdk.stream_and_get(request_id) + else: + click.secho(f'Submitted {request_name} request: {request_id}', fg='green') + def _merge_env_vars(env_dict: Optional[Dict[str, str]], env_list: List[Tuple[str, str]]) -> List[Tuple[str, str]]: @@ -1087,18 +1095,15 @@ def launch( backend, cluster, dryrun=dryrun, - detach_setup=detach_setup, - detach_run=detach_run, + detach_setup=detach_setup or async_call, + detach_run=detach_run or async_call, no_confirm=yes, idle_minutes_to_autostop=idle_minutes_to_autostop, down=down, retry_until_up=retry_until_up, no_setup=no_setup, clone_disk_from=clone_disk_from) - if not async_call: - sdk.stream_and_get(request_id) - else: - click.secho(f'Submitted Launch request: {request_id}', fg='green') + _async_call_or_wait(request_id, async_call, 'Launch') @cli.command(cls=_DocumentedCodeCommand) @@ -1241,11 +1246,8 @@ def exec(cluster: Optional[str], cluster_option: Optional[str], click.secho(f'Executing task on cluster {cluster}...', fg='yellow') request_id = sdk.exec(task, cluster_name=cluster, - detach_run=detach_run) - if not async_call: - sdk.stream_and_get(request_id) - else: - click.secho(f'Submitted Exec request: {request_id}', fg='green') + detach_run=detach_run or async_call) + _async_call_or_wait(request_id, async_call, 'Exec') def _get_managed_jobs( @@ -1593,7 +1595,10 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, click.echo(f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}Clusters' f'{colorama.Style.RESET_ALL}') query_clusters: Optional[List[str]] = None if not clusters else clusters - request = sdk.status(cluster_names=query_clusters, refresh=refresh) + refresh_mode = api_common.StatusRefreshMode.NONE + if refresh: + refresh_mode = api_common.StatusRefreshMode.FORCE + request = sdk.status(cluster_names=query_clusters, refresh=refresh_mode) cluster_records = sdk.stream_and_get(request) # TOOD(zhwu): setup the ssh config for status if ip or show_endpoints: @@ -1859,7 +1864,8 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool): click.secho('Fetching and parsing job queue...', fg='yellow') if not clusters: clusters = ['*'] - clusters = _get_glob_clusters(clusters) + cluster_records = _get_cluster_records(clusters) + clusters = [cluster['name'] for cluster in cluster_records] unsupported_clusters = [] logger.info(f'Fetching job queue for {clusters}') @@ -2013,9 +2019,10 @@ def logs( default=False, required=False, help='Skip confirmation prompt.') +@_add_click_options(_COMMON_OPTIONS) @click.argument('jobs', required=False, type=int, nargs=-1) @usage_lib.entrypoint -def cancel(cluster: str, all: bool, jobs: List[int], yes: bool): # pylint: disable=redefined-builtin, redefined-outer-name +def cancel(cluster: str, all: bool, jobs: List[int], yes: bool, async_call: bool): # pylint: disable=redefined-builtin, redefined-outer-name # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Cancel job(s). @@ -2061,7 +2068,8 @@ def cancel(cluster: str, all: bool, jobs: List[int], yes: bool): # pylint: disa show_default=True) try: - sdk.cancel(cluster, all=all, job_ids=job_ids_to_cancel) + request_id = sdk.cancel(cluster, all=all, job_ids=job_ids_to_cancel) + _async_call_or_wait(request_id, async_call, 'Cancel') except exceptions.NotSupportedError as e: controller = controller_utils.Controllers.from_name(cluster) assert controller is not None, cluster @@ -2090,11 +2098,13 @@ def cancel(cluster: str, all: bool, jobs: List[int], yes: bool): # pylint: disa default=False, required=False, help='Skip confirmation prompt.') +@_add_click_options(_COMMON_OPTIONS) @usage_lib.entrypoint def stop( clusters: List[str], all: Optional[bool], # pylint: disable=redefined-builtin yes: bool, + async_call: bool, ): # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Stop cluster(s). @@ -2128,7 +2138,8 @@ def stop( _down_or_stop_clusters(clusters, apply_to_all=all, down=False, - no_confirm=yes) + no_confirm=yes, + async_call=async_call) @cli.command(cls=_DocumentedCodeCommand) @@ -2168,6 +2179,7 @@ def stop( default=False, required=False, help='Skip confirmation prompt.') +@_add_click_options(_COMMON_OPTIONS) @usage_lib.entrypoint def autostop( clusters: List[str], @@ -2176,6 +2188,7 @@ def autostop( cancel: bool, # pylint: disable=redefined-outer-name down: bool, # pylint: disable=redefined-outer-name yes: bool, + async_call: bool, ): # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Schedule an autostop or autodown for cluster(s). @@ -2230,7 +2243,8 @@ def autostop( apply_to_all=all, down=down, no_confirm=yes, - idle_minutes_to_autostop=idle_minutes) + idle_minutes_to_autostop=idle_minutes, + async_call=async_call) @cli.command(cls=_DocumentedCodeCommand) @@ -2293,6 +2307,7 @@ def autostop( required=False, help=('Force start the cluster even if it is already UP. Useful for ' 'upgrading the SkyPilot runtime on the cluster.')) +@_add_click_options(_COMMON_OPTIONS) @usage_lib.entrypoint # pylint: disable=redefined-builtin def start( @@ -2302,7 +2317,8 @@ def start( idle_minutes_to_autostop: Optional[int], down: bool, # pylint: disable=redefined-outer-name retry_until_up: bool, - force: bool): + force: bool, + async_call: bool,): # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Restart cluster(s). @@ -2336,12 +2352,13 @@ def start( '--idle-minutes-to-autostop must be set if --down is set.') to_start = [] + cluster_records = None if not clusters and not all: # UX: frequently users may have only 1 cluster. In this case, be smart # and default to that unique choice. - all_cluster_names = global_user_state.get_cluster_names_start_with('') - if len(all_cluster_names) <= 1: - clusters = all_cluster_names + all_clusters = _get_cluster_records(['*'], refresh=api_common.StatusRefreshMode.AUTO) + if len(all_clusters) <= 1: + cluster_records = all_clusters else: raise click.UsageError( '`sky start` requires either a cluster name or glob ' @@ -2352,10 +2369,12 @@ def start( click.echo('Both --all and cluster(s) specified for sky start. ' 'Letting --all take effect.') + all_clusters = _get_cluster_records(['*'], refresh=api_common.StatusRefreshMode.AUTO) + # Get all clusters that are not controllers. - clusters = [ - cluster['name'] - for cluster in global_user_state.get_clusters() + cluster_records = [ + cluster + for cluster in all_clusters if controller_utils.Controllers.from_name(cluster['name']) is None ] @@ -2364,12 +2383,12 @@ def start( 'mean to use `sky launch` to provision a new cluster?') return else: - # Get GLOB cluster names - clusters = _get_glob_clusters(clusters) - - for name in clusters: - cluster_status, _ = backend_utils.refresh_cluster_status_handle( - name) + if cluster_records is None: + # Get GLOB cluster names + cluster_records = _get_cluster_records(clusters, refresh=api_common.StatusRefreshMode.AUTO) + for cluster in cluster_records: + name = cluster['name'] + cluster_status = cluster['status'] # A cluster may have one of the following states: # # STOPPED - ok to restart @@ -2449,18 +2468,24 @@ def start( abort=True, show_default=True) - for name in to_start: - try: - sdk.start(name, + request_ids = subprocess_utils.run_in_parallel( + lambda name: sdk.start(name, idle_minutes_to_autostop, retry_until_up, down=down, - force=force) + force=force), + to_start + ) + + for name, request_id in zip(to_start, request_ids): + try: + _async_call_or_wait(request_id, async_call, 'Start') except (exceptions.NotSupportedError, exceptions.ClusterOwnerIdentityMismatchError) as e: click.echo(str(e)) else: - click.secho(f'Cluster {name} started.', fg='green') + if not async_call: + click.secho(f'Cluster {name} started.', fg='green') @cli.command(cls=_DocumentedCodeCommand) @@ -2643,7 +2668,8 @@ def _down_or_stop_clusters( down: bool, # pylint: disable=redefined-outer-name no_confirm: bool, purge: bool = False, - idle_minutes_to_autostop: Optional[int] = None) -> None: + idle_minutes_to_autostop: Optional[int] = None, + async_call: bool=False) -> None: """Tears down or (auto-)stops a cluster (or all clusters). Controllers (jobs controller and sky serve controller) can only be @@ -2660,9 +2686,9 @@ def _down_or_stop_clusters( # UX: frequently users may have only 1 cluster. In this case, 'sky # stop/down' without args should be smart and default to that unique # choice. - all_cluster_names = global_user_state.get_cluster_names_start_with('') - if len(all_cluster_names) <= 1: - names = all_cluster_names + all_clusters = _get_cluster_records(['*']) + if len(all_clusters) <= 1: + names = [cluster['name'] for cluster in all_clusters] else: raise click.UsageError( f'`sky {command}` requires either a cluster name or glob ' @@ -2684,8 +2710,8 @@ def _down_or_stop_clusters( ] controllers_str = ', '.join(map(repr, controllers)) names = [ - name for name in _get_glob_clusters(names) - if controller_utils.Controllers.from_name(name) is None + cluster['name'] for cluster in _get_cluster_records(names) + if controller_utils.Controllers.from_name(cluster['name']) is None ] # Make sure the controllers are explicitly specified without other @@ -2743,7 +2769,7 @@ def _down_or_stop_clusters( names += controllers if apply_to_all: - all_clusters = global_user_state.get_clusters() + all_clusters = _get_cluster_records(['*']) if len(names) > 0: click.echo( f'Both --all and cluster(s) specified for `sky {command}`. ' @@ -2756,15 +2782,7 @@ def _down_or_stop_clusters( if controller_utils.Controllers.from_name(record['name']) is None ] - clusters = [] - for name in names: - handle = global_user_state.get_handle_from_cluster_name(name) - if handle is None: - # This codepath is used for 'sky down -p ' when the - # controller is not in 'sky status'. Cluster-not-found message - # should've been printed by _get_glob_clusters() above. - continue - clusters.append(name) + clusters = names usage_lib.record_cluster_name_for_current_operation(clusters) if not clusters: @@ -2789,11 +2807,14 @@ def _down_or_stop_clusters( f'[bold cyan]{operation} {len(clusters)} cluster{plural}[/]', total=len(clusters)) + request_ids = [] def _down_or_stop(name: str): success_progress = False if idle_minutes_to_autostop is not None: try: - sdk.autostop(name, idle_minutes_to_autostop, down) + request_id = sdk.autostop(name, idle_minutes_to_autostop, down) + request_ids.append(request_id) + _async_call_or_wait(request_id, async_call, operation.capitalize()) except (exceptions.NotSupportedError, exceptions.ClusterNotUpError) as e: message = str(e) @@ -2816,9 +2837,11 @@ def _down_or_stop(name: str): else: try: if down: - sdk.get(sdk.down(name, purge=purge)) + request_id = sdk.down(name, purge=purge) else: - sdk.get(sdk.stop(name, purge=purge)) + request_id = sdk.stop(name, purge=purge) + request_ids.append(request_id) + _async_call_or_wait(request_id, async_call, operation.capitalize()) except RuntimeError as e: message = ( f'{colorama.Fore.RED}{operation} cluster {name}...failed. ' @@ -2849,6 +2872,10 @@ def _down_or_stop(name: str): # Make sure the progress bar not mess up the terminal. progress.refresh() + if async_call: + click.secho(f'--async is passed, and {operation} requests are sent, ' + 'but some may fail at the background. Check the requests ' + 'with their IP') @cli.command(cls=_DocumentedCodeCommand) @click.argument('clouds', required=False, type=str, nargs=-1) diff --git a/sky/api/common.py b/sky/api/common.py new file mode 100644 index 00000000000..b3d8a2fd28b --- /dev/null +++ b/sky/api/common.py @@ -0,0 +1,13 @@ +"""Common data structures and constants used in the API.""" +import enum + +class StatusRefreshMode(enum.Enum): + """The mode of refreshing the status of a cluster.""" + + NONE = 'NONE' + # Automatically refresh when needed, e.g., autostop is set or the cluster + # is a spot instance. + AUTO = 'AUTO' + FORCE = 'FORCE' + + diff --git a/sky/api/requests/decoders.py b/sky/api/requests/decoders.py index 9d8d07dc3c1..4803eadfa85 100644 --- a/sky/api/requests/decoders.py +++ b/sky/api/requests/decoders.py @@ -52,6 +52,10 @@ def decode_launch(return_value: Dict[str, Any]) -> Dict[str, Any]: 'handle': decode_and_unpickle(return_value['handle']), } +@register_handler('start') +def decode_start(return_value: bytes) -> 'backends.CloudVmRayResourceHandle': + return decode_and_unpickle(return_value) + @register_handler('queue') def decode_queue( return_value: List[dict], diff --git a/sky/api/requests/encoders.py b/sky/api/requests/encoders.py index 9060f3413fb..13b3d362b89 100644 --- a/sky/api/requests/encoders.py +++ b/sky/api/requests/encoders.py @@ -60,6 +60,11 @@ def encode_launch( } +@register_handler('start') +def encode_start(resource_handle: 'backends.CloudVmRayResourceHandle') -> bytes: + return pickle_and_encode(resource_handle) + + @register_handler('queue') def encode_queue( jobs: List[dict], diff --git a/sky/api/requests/payloads.py b/sky/api/requests/payloads.py index 848c854aca8..af593417ed4 100644 --- a/sky/api/requests/payloads.py +++ b/sky/api/requests/payloads.py @@ -4,6 +4,7 @@ import pydantic from sky import optimizer +from sky.api import common class RequestBody(pydantic.BaseModel): @@ -58,7 +59,7 @@ class StopOrDownBody(pydantic.BaseModel): class StatusBody(pydantic.BaseModel): cluster_names: Optional[List[str]] = None - refresh: bool = False + refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE class StartBody(RequestBody): @@ -83,7 +84,7 @@ class QueueBody(pydantic.BaseModel): class CancelBody(pydantic.BaseModel): cluster_name: str - job_ids: List[int] + job_ids: Optional[List[int]] all: bool = False # Internal only: try_cancel_if_cluster_is_init: bool = False diff --git a/sky/api/rest.py b/sky/api/rest.py index 866820a504a..926f9c927c3 100644 --- a/sky/api/rest.py +++ b/sky/api/rest.py @@ -482,7 +482,7 @@ async def job_status(request: fastapi.Request, @app.post('/cancel') async def cancel(request: fastapi.Request, - cancel_body: payloads.QueueBody) -> None: + cancel_body: payloads.CancelBody) -> None: _start_background_request( request_id=request.state.request_id, request_name='cancel', diff --git a/sky/api/sdk.py b/sky/api/sdk.py index 1e89a3deac7..6906fceba44 100644 --- a/sky/api/sdk.py +++ b/sky/api/sdk.py @@ -28,6 +28,7 @@ from sky import sky_logging from sky.api.requests import payloads from sky.api.requests import tasks +from sky.api import common from sky.api.requests import constants as requests_constants from sky.backends import backend_utils from sky.data import data_utils @@ -38,6 +39,7 @@ from sky.utils import env_options from sky.utils import rich_utils from sky.utils import status_lib +from sky import skypilot_config from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -53,7 +55,8 @@ @functools.lru_cache() def _get_server_url(): return os.environ.get(constants.SKY_API_SERVER_URL_ENV_VAR, - DEFAULT_SERVER_URL) + skypilot_config.get_nested(('api_server', 'endpoint'), + DEFAULT_SERVER_URL)) @functools.lru_cache() @@ -129,14 +132,14 @@ def wrapper(*args, api_server_reload: bool = False, **kwargs): return wrapper - -def _add_env_vars_to_body(body: payloads.RequestBody): +@functools.lru_cache() +def _request_body_env_vars() -> dict: env_vars = {} for env_var in os.environ: if env_var.startswith('SKYPILOT_'): env_vars[env_var] = os.environ[env_var] env_vars[constants.USER_ID_ENV_VAR] = common_utils.get_user_hash() - body.env_vars = env_vars + return env_vars @usage_lib.entrypoint @@ -156,7 +159,6 @@ def optimize(dag: 'sky.Dag') -> str: dag_str = f.read() body = payloads.OptimizeBody(dag=dag_str) - # _add_env_vars_to_body(body) response = requests.get(f'{_get_server_url()}/optimize', json=json.loads(body.model_dump_json())) return _get_request_id(response) @@ -296,9 +298,8 @@ def launch( is_launched_by_jobs_controller=_is_launched_by_jobs_controller, is_launched_by_sky_serve_controller=_is_launched_by_sky_serve_controller, disable_controller_check=_disable_controller_check, + env_vars=_request_body_env_vars(), ) - - _add_env_vars_to_body(body) response = requests.post( f'{_get_server_url()}/launch', json=json.loads(body.model_dump_json()), @@ -329,6 +330,7 @@ def exec( # pylint: disable=redefined-builtin down=down, backend=backend.NAME if backend else None, detach_run=detach_run, + env_vars=_request_body_env_vars(), ) response = requests.post( @@ -380,6 +382,7 @@ def start( retry_until_up=retry_until_up, down=down, force=force, + env_vars=_request_body_env_vars(), ) response = requests.post( f'{_get_server_url()}/start', @@ -516,11 +519,10 @@ def cancel( json=json.loads(body.model_dump_json())) return _get_request_id(response) - @usage_lib.entrypoint @_check_health def status(cluster_names: Optional[List[str]] = None, - refresh: bool = False) -> str: + refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE) -> str: """Get the status of clusters. Args: diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 4d8d6e8c68e..48164e473f9 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -258,7 +258,7 @@ # The name for the environment variable that stores the URL of the SkyPilot # API server. -SKY_API_SERVER_URL_ENV_VAR = 'SKYPILOT_API_SERVER_URL' +SKY_API_SERVER_URL_ENV_VAR = 'SKYPILOT_API_SERVER_ENDPOINT' # SkyPilot environment variables SKYPILOT_NUM_NODES = 'SKYPILOT_NUM_NODES' diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 860e7dad431..fe2b5a9f197 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -880,6 +880,19 @@ def get_config_schema(): } } + api_server = { + 'type': 'object', + 'required': [], + 'additionalProperties': False, + 'properties': { + 'endpoint': { + 'type': 'string', + # Apply validation for URL + 'pattern': r'^https?://.*$', + }, + } + } + for cloud, config in cloud_configs.items(): if cloud == 'aws': config['properties'].update({ @@ -901,6 +914,7 @@ def get_config_schema(): 'allowed_clouds': allowed_clouds, 'docker': docker_configs, 'nvidia_gpus': gpu_configs, + 'api_server': api_server, **cloud_configs, }, # Avoid spot and jobs being present at the same time.