Skip to content

Commit

Permalink
Add more CLIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Jul 31, 2024
1 parent b2e9654 commit 4f44f15
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
21 changes: 11 additions & 10 deletions sky/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,8 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
if clusters:
query_clusters = _get_glob_clusters(clusters, silent=ip)
request = sdk.status(cluster_names=query_clusters, refresh=refresh)
cluster_records = sdk.get(request)
cluster_records = sdk.stream_and_get(request)
# TOOD(zhwu): setup the ssh config for status
if ip or show_endpoints:
if len(cluster_records) != 1:
with ux_utils.print_exception_no_traceback():
Expand Down Expand Up @@ -1807,7 +1808,7 @@ def cost_report(all: bool): # pylint: disable=redefined-builtin
- Clusters that were terminated/stopped on the cloud console.
"""
cluster_records = sdk.cost_report()
cluster_records = sdk.get(sdk.cost_report())

normal_cluster_records = []
controllers = dict()
Expand Down Expand Up @@ -1971,16 +1972,16 @@ def logs(
return

assert job_ids is None or len(job_ids) <= 1, job_ids
job_id: Optional[int] = None
job_id: int
job_ids_to_query: Optional[List[int]] = None
if job_ids:
# Already check that len(job_ids) <= 1. This variable is used later
# in sdk.tail_logs.
job_id = job_ids[0]
if not job_id.isdigit():
raise click.UsageError(f'Invalid job ID {job_id}. '
cur_job_id = job_ids[0]
if not cur_job_id.isdigit():
raise click.UsageError(f'Invalid job ID {cur_job_id}. '
'Job ID must be integers.')
job_id = int(job_ids[0])
job_id = int(cur_job_id)
job_ids_to_query = [int(job_ids[0])]
else:
# job_ids is either None or empty list, so it is safe to cast it here.
Expand Down Expand Up @@ -3328,7 +3329,7 @@ def storage():
# pylint: disable=redefined-builtin
def storage_ls(all: bool):
"""List storage objects managed by SkyPilot."""
storages = sky.storage_ls()
storages = sdk.storage_ls()
storage_table = storage_utils.format_storage_table(storages, show_all=all)
click.echo(storage_table)

Expand Down Expand Up @@ -3371,7 +3372,7 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r
if sum([len(names) > 0, all]) != 1:
raise click.UsageError('Either --all or a name must be specified.')
if all:
storages = sky.storage_ls()
storages = sdk.get(sdk.storage_ls())
if not storages:
click.echo('No storage(s) to delete.')
return
Expand All @@ -3389,7 +3390,7 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r
abort=True,
show_default=True)

subprocess_utils.run_in_parallel(sky.storage_delete, names)
subprocess_utils.run_in_parallel(sdk.storage_delete, names)


@cli.group(cls=_NaturalOrderGroup)
Expand Down
4 changes: 4 additions & 0 deletions sky/api/requests/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ class RequestIdBody(pydantic.BaseModel):
class EndpointBody(pydantic.BaseModel):
cluster_name: str
port: Optional[Union[int, str]] = None


class costReportBody(pydantic.BaseModel):
all: bool
11 changes: 11 additions & 0 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,17 @@ async def logs(request: fastapi.Request,
)


@app.get('/cost-report')
async def cost_report(request: fastapi.Request,
cost_report_body: payloads.CostReportBody) -> None:
_start_background_request(
request_id=request.state.request_id,
request_name='cost_report',
request_body=cost_report_body.model_dump(),
func=core.cost_report,
all=cost_report_body.all,
)

@app.get('/storage/ls')
async def storage_ls(request: fastapi.Request):
_start_background_request(
Expand Down
28 changes: 28 additions & 0 deletions sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,34 @@ def endpoints(cluster_name: str, port: Optional[Union[int, str]] = None) -> str:
json=body.model_dump())
return _get_request_id(response)

@usage_lib.entrypoint
@_check_health
def cost_report(all: bool) -> str:
body = payloads.CostReportBody(all=all)
response = requests.get(f'{_get_server_url()}/cost_report',
json=body.model_dump())
return _get_request_id(response)



# === Storage APIs ===
@usage_lib.entrypoint
@_check_health
def storage_ls() -> str:
response = requests.get(f'{_get_server_url()}/storage/ls')
return _get_request_id(response)


@usage_lib.entrypoint
@_check_health
def storage_delete(name: str) -> str:
body = payloads.StorageBody(name=name)
response = requests.post(f'{_get_server_url()}/storage/delete',
json=body.model_dump())
return _get_request_id(response)


# === API request API ===

@usage_lib.entrypoint
@_check_health
Expand Down

0 comments on commit 4f44f15

Please sign in to comment.