diff --git a/sky/api/cli.py b/sky/api/cli.py index 7ca0c8eeb8e..1152e1cce76 100644 --- a/sky/api/cli.py +++ b/sky/api/cli.py @@ -1613,7 +1613,8 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, if clusters: query_clusters = _get_glob_clusters(clusters, silent=ip) request = sdk.status(cluster_names=query_clusters, refresh=refresh) - cluster_records = sdk.get(request) + cluster_records = sdk.stream_and_get(request) + # TOOD(zhwu): setup the ssh config for status if ip or show_endpoints: if len(cluster_records) != 1: with ux_utils.print_exception_no_traceback(): @@ -1807,7 +1808,7 @@ def cost_report(all: bool): # pylint: disable=redefined-builtin - Clusters that were terminated/stopped on the cloud console. """ - cluster_records = sdk.cost_report() + cluster_records = sdk.get(sdk.cost_report()) normal_cluster_records = [] controllers = dict() @@ -1971,16 +1972,16 @@ def logs( return assert job_ids is None or len(job_ids) <= 1, job_ids - job_id: Optional[int] = None + job_id: int job_ids_to_query: Optional[List[int]] = None if job_ids: # Already check that len(job_ids) <= 1. This variable is used later # in sdk.tail_logs. - job_id = job_ids[0] - if not job_id.isdigit(): - raise click.UsageError(f'Invalid job ID {job_id}. ' + cur_job_id = job_ids[0] + if not cur_job_id.isdigit(): + raise click.UsageError(f'Invalid job ID {cur_job_id}. ' 'Job ID must be integers.') - job_id = int(job_ids[0]) + job_id = int(cur_job_id) job_ids_to_query = [int(job_ids[0])] else: # job_ids is either None or empty list, so it is safe to cast it here. @@ -3328,7 +3329,7 @@ def storage(): # pylint: disable=redefined-builtin def storage_ls(all: bool): """List storage objects managed by SkyPilot.""" - storages = sky.storage_ls() + storages = sdk.storage_ls() storage_table = storage_utils.format_storage_table(storages, show_all=all) click.echo(storage_table) @@ -3371,7 +3372,7 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r if sum([len(names) > 0, all]) != 1: raise click.UsageError('Either --all or a name must be specified.') if all: - storages = sky.storage_ls() + storages = sdk.get(sdk.storage_ls()) if not storages: click.echo('No storage(s) to delete.') return @@ -3389,7 +3390,7 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r abort=True, show_default=True) - subprocess_utils.run_in_parallel(sky.storage_delete, names) + subprocess_utils.run_in_parallel(sdk.storage_delete, names) @cli.group(cls=_NaturalOrderGroup) diff --git a/sky/api/requests/payloads.py b/sky/api/requests/payloads.py index a8f60f4d1a1..5ecce0fe9b8 100644 --- a/sky/api/requests/payloads.py +++ b/sky/api/requests/payloads.py @@ -99,3 +99,7 @@ class RequestIdBody(pydantic.BaseModel): class EndpointBody(pydantic.BaseModel): cluster_name: str port: Optional[Union[int, str]] = None + + +class costReportBody(pydantic.BaseModel): + all: bool diff --git a/sky/api/rest.py b/sky/api/rest.py index c9f19cef696..f80151d7e8f 100644 --- a/sky/api/rest.py +++ b/sky/api/rest.py @@ -475,6 +475,17 @@ async def logs(request: fastapi.Request, ) +@app.get('/cost-report') +async def cost_report(request: fastapi.Request, + cost_report_body: payloads.CostReportBody) -> None: + _start_background_request( + request_id=request.state.request_id, + request_name='cost_report', + request_body=cost_report_body.model_dump(), + func=core.cost_report, + all=cost_report_body.all, + ) + @app.get('/storage/ls') async def storage_ls(request: fastapi.Request): _start_background_request( diff --git a/sky/api/sdk.py b/sky/api/sdk.py index b800a008282..71d5badecd1 100644 --- a/sky/api/sdk.py +++ b/sky/api/sdk.py @@ -500,6 +500,34 @@ def endpoints(cluster_name: str, port: Optional[Union[int, str]] = None) -> str: json=body.model_dump()) return _get_request_id(response) +@usage_lib.entrypoint +@_check_health +def cost_report(all: bool) -> str: + body = payloads.CostReportBody(all=all) + response = requests.get(f'{_get_server_url()}/cost_report', + json=body.model_dump()) + return _get_request_id(response) + + + +# === Storage APIs === +@usage_lib.entrypoint +@_check_health +def storage_ls() -> str: + response = requests.get(f'{_get_server_url()}/storage/ls') + return _get_request_id(response) + + +@usage_lib.entrypoint +@_check_health +def storage_delete(name: str) -> str: + body = payloads.StorageBody(name=name) + response = requests.post(f'{_get_server_url()}/storage/delete', + json=body.model_dump()) + return _get_request_id(response) + + +# === API request API === @usage_lib.entrypoint @_check_health