Skip to content

Commit

Permalink
Fix queue
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Aug 1, 2024
1 parent e6867a9 commit d8f5698
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 34 deletions.
39 changes: 19 additions & 20 deletions sky/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,10 @@

def _get_glob_clusters(clusters: List[str], silent: bool = False) -> List[str]:
"""Returns a list of clusters that match the glob pattern."""
glob_clusters = []
for cluster in clusters:
glob_cluster = global_user_state.get_glob_cluster_names(cluster)
if len(glob_cluster) == 0 and not silent:
click.echo(f'Cluster {cluster} not found.')
glob_clusters.extend(glob_cluster)
return list(set(glob_clusters))
request_id = sdk.status(clusters)
cluster_records = sdk.get(request_id)
clusters = [record['name'] for record in cluster_records]
return clusters


def _get_glob_storages(storages: List[str]) -> List[str]:
Expand Down Expand Up @@ -1601,9 +1598,7 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
else:
click.echo(f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}Clusters'
f'{colorama.Style.RESET_ALL}')
query_clusters: Optional[List[str]] = None
if clusters:
query_clusters = _get_glob_clusters(clusters, silent=ip)
query_clusters: Optional[List[str]] = None if not clusters else clusters
request = sdk.status(cluster_names=query_clusters, refresh=refresh)
cluster_records = sdk.stream_and_get(request)
# TOOD(zhwu): setup the ssh config for status
Expand Down Expand Up @@ -1868,16 +1863,16 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Show the job queue for cluster(s)."""
click.secho('Fetching and parsing job queue...', fg='yellow')
if clusters:
clusters = _get_glob_clusters(clusters)
else:
cluster_infos = global_user_state.get_clusters()
clusters = [c['name'] for c in cluster_infos]
if not clusters:
clusters = ['*']
clusters = _get_glob_clusters(clusters)

unsupported_clusters = []
for cluster in clusters:
logger.info(f'Fetching job queue for {clusters}')
job_tables = {}
def _get_job_queue(cluster):
try:
job_table = sdk.queue(cluster, skip_finished, all_users)
job_table = sdk.stream_and_get(sdk.queue(cluster, skip_finished, all_users))
except (RuntimeError, exceptions.CommandError, ValueError,
exceptions.NotSupportedError, exceptions.ClusterNotUpError,
exceptions.CloudUserIdentityError,
Expand All @@ -1887,8 +1882,11 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool):
click.echo(f'{colorama.Fore.YELLOW}Failed to get the job queue for '
f'cluster {cluster!r}.{colorama.Style.RESET_ALL}\n'
f' {common_utils.format_exception(e)}')
continue
job_table = job_lib.format_job_queue(job_table)
return
job_tables[cluster] = job_lib.format_job_queue(job_table)

subprocess_utils.run_in_parallel(_get_job_queue, clusters)
for cluster, job_table in job_tables.items():
click.echo(f'\nJob queue of cluster {cluster}\n{job_table}')

if unsupported_clusters:
Expand Down Expand Up @@ -2890,7 +2888,8 @@ def check(clouds: Tuple[str], verbose: bool):
sky check aws gcp
"""
clouds_arg = clouds if len(clouds) > 0 else None
sky_check.check(verbose=verbose, clouds=clouds_arg)
request_id = sdk.check(clouds=clouds_arg, verbose=verbose)
sdk.stream_and_get(request_id)


@cli.command()
Expand Down
10 changes: 10 additions & 0 deletions sky/api/requests/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List

from sky.utils import status_lib
from sky.skylet import job_lib

handlers: Dict[str, Any] = {}

Expand Down Expand Up @@ -50,3 +51,12 @@ def decode_launch(return_value: Dict[str, Any]) -> Dict[str, Any]:
'job_id': return_value['job_id'],
'handle': decode_and_unpickle(return_value['handle']),
}

@register_handler('queue')
def decode_queue(
return_value: List[dict],
) -> Dict[str, Any]:
jobs = return_value
for job in jobs:
job['status'] = job_lib.JobStatus(job['status'])
return jobs
12 changes: 12 additions & 0 deletions sky/api/requests/encoders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Handlers for the REST API return values."""
# TODO(zhwu): we should evaluate that if we can move our return values to
# pydantic models, so we can take advantage of model_dump_json of pydantic,
# instead of implementing our own handlers.
import base64
import pickle
import typing
Expand Down Expand Up @@ -55,3 +58,12 @@ def encode_launch(
'job_id': job_id,
'handle': pickle_and_encode(handle),
}


@register_handler('queue')
def encode_queue(
jobs: List[dict],
) -> Dict[str, Any]:
for job in jobs:
job['status'] = job['status'].value
return jobs
7 changes: 6 additions & 1 deletion sky/api/requests/payloads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Payloads for the Sky API requests."""
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import pydantic

Expand All @@ -10,6 +10,11 @@ class RequestBody(pydantic.BaseModel):
env_vars: Dict[str, str] = {}


class CheckBody(RequestBody):
clouds: Optional[Tuple[str]]
verbose: bool


class OptimizeBody(pydantic.BaseModel):
dag: str
minimize: optimizer.OptimizeTarget = optimizer.OptimizeTarget.COST
Expand Down
14 changes: 14 additions & 0 deletions sky/api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import fastapi
import starlette.middleware.base

from sky import check as sky_check
from sky import core
from sky import execution
from sky import optimizer
Expand Down Expand Up @@ -178,6 +179,19 @@ async def startup():
func=event)


@app.get('/check')
async def check(request: fastapi.Request, check_body: payloads.CheckBody):
"""Check enabled clouds."""
_start_background_request(
request_id=request.state.request_id,
request_name='check',
request_body=json.loads(check_body.model_dump_json()),
func=sky_check.check,
clouds=check_body.clouds,
verbose=check_body.verbose,
)


@app.get('/optimize')
async def optimize(optimize_body: payloads.OptimizeBody,
request: fastapi.Request):
Expand Down
23 changes: 20 additions & 3 deletions sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tempfile
import time
import typing
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Tuple, Union

import click
import colorama
Expand Down Expand Up @@ -138,6 +138,15 @@ def _add_env_vars_to_body(body: payloads.RequestBody):
body.env_vars = env_vars


@usage_lib.entrypoint
@_check_health
def check(clouds: Optional[Tuple[str]], verbose: bool) -> str:
body = payloads.CheckBody(clouds=clouds, verbose=verbose)
response = requests.get(f'{_get_server_url()}/check',
json=json.loads(body.model_dump_json()))
return _get_request_id(response)


@usage_lib.entrypoint
@_check_health
def optimize(dag: 'sky.Dag') -> str:
Expand Down Expand Up @@ -462,7 +471,7 @@ def autostop(cluster_name: str, idle_minutes: int, down: bool = False) -> str:

@usage_lib.entrypoint
@_check_health
def queue(cluster_name: str,
def queue(cluster_name: List[str],
skip_finished: bool = False,
all_users: bool = False) -> str:
body = payloads.QueueBody(
Expand Down Expand Up @@ -511,8 +520,16 @@ def cancel(
@_check_health
def status(cluster_names: Optional[List[str]] = None,
refresh: bool = False) -> str:
"""Get the status of clusters.
Args:
cluster_names: names of clusters to get status for. If None, get status
for all clusters. The cluster names specified can be in glob pattern
(e.g., 'my-cluster-*').
refresh: whether to refresh the status of the clusters.
"""
# TODO(zhwu): this does not stream the logs output by logger back to the
# user
# user, due to the rich progress implementation.
body = payloads.StatusBody(
cluster_names=cluster_names,
refresh=refresh,
Expand Down
29 changes: 21 additions & 8 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,17 @@ class CloudFilter(enum.Enum):
LOCAL = 'local'


def _get_glob_clusters(clusters: List[str], silent: bool = False) -> List[str]:
"""Returns a list of clusters that match the glob pattern."""
glob_clusters = []
for cluster in clusters:
glob_cluster = global_user_state.get_glob_cluster_names(cluster)
if len(glob_cluster) == 0 and not silent:
logger.info(f'Cluster {cluster} not found.')
glob_clusters.extend(glob_cluster)
return list(set(glob_clusters))


def get_clusters(
include_controller: bool,
refresh: bool,
Expand Down Expand Up @@ -2449,6 +2460,9 @@ def get_clusters(
if cluster_names is not None:
if isinstance(cluster_names, str):
cluster_names = [cluster_names]
print(f'zhwu debug: get glob cluster {cluster_names}')
cluster_names = _get_glob_clusters(cluster_names, silent=True)
print(f'zhwu debug: got glob cluster {cluster_names}')
new_records = []
not_exist_cluster_names = []
for cluster_name in cluster_names:
Expand Down Expand Up @@ -2720,7 +2734,7 @@ def check_stale_runtime_on_remote(returncode: int, stderr: str,
f'\n--- Details ---\n{stderr.strip()}\n')


def get_endpoints(cluster_name: str,
def get_endpoints(cluster: str,
port: Optional[Union[int, str]] = None,
skip_status_check: bool = False) -> Dict[int, str]:
"""Gets the endpoint for a given cluster and port number (endpoint).
Expand Down Expand Up @@ -2756,7 +2770,7 @@ def get_endpoints(cluster_name: str,
raise ValueError(f'Invalid endpoint {port!r}.') from None
cluster_records = get_clusters(include_controller=True,
refresh=False,
cluster_names=[cluster_name])
cluster_names=[cluster])
cluster_record = cluster_records[0]
if (not skip_status_check and
cluster_record['status'] != status_lib.ClusterStatus.UP):
Expand All @@ -2768,7 +2782,7 @@ def get_endpoints(cluster_name: str,
if not isinstance(handle, backends.CloudVmRayResourceHandle):
with ux_utils.print_exception_no_traceback():
raise ValueError('Querying IP address is not supported '
f'for cluster {cluster_name!r} with backend '
f'for cluster {cluster!r} with backend '
f'{get_backend_from_handle(handle).NAME}.')

launched_resources = handle.launched_resources
Expand All @@ -2778,9 +2792,8 @@ def get_endpoints(cluster_name: str,
launched_resources, {clouds.CloudImplementationFeatures.OPEN_PORTS})
except exceptions.NotSupportedError:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Querying endpoints is not supported '
f'for cluster {cluster_name!r} on {cloud}.') from None
raise ValueError('Querying endpoints is not supported '
f'for {cluster!r} on {cloud}.') from None

config = common_utils.read_yaml(handle.cluster_yaml)
port_details = provision_lib.query_ports(repr(cloud),
Expand All @@ -2796,7 +2809,7 @@ def get_endpoints(cluster_name: str,
handle.launched_resources.ports)
if port not in port_set:
logger.warning(f'Port {port} is not exposed on '
f'cluster {cluster_name!r}.')
f'cluster {cluster!r}.')
return {}
# If the user requested a specific port endpoint, check if it is exposed
if port not in port_details:
Expand All @@ -2813,7 +2826,7 @@ def get_endpoints(cluster_name: str,
if not port_details:
# If cluster had no ports to be exposed
if handle.launched_resources.ports is None:
logger.warning(f'Cluster {cluster_name!r} does not have any '
logger.warning(f'Cluster {cluster!r} does not have any '
'ports to be exposed.')
return {}
# Else ports have not been exposed even though they exist.
Expand Down
3 changes: 2 additions & 1 deletion sky/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def check(
quiet: bool = False,
verbose: bool = False,
clouds: Optional[Iterable[str]] = None,
) -> None:
) -> List[str]:
echo = (lambda *_args, **_kwargs: None) if quiet else click.echo
echo('Checking credentials to enable clouds for SkyPilot.')
enabled_clouds = []
Expand Down Expand Up @@ -158,6 +158,7 @@ def get_all_clouds():
[''] + sorted(all_enabled_clouds))
rich.print('\n[green]:tada: Enabled clouds :tada:'
f'{enabled_clouds_str}[/green]')
return enabled_clouds


def get_cached_enabled_clouds_or_refresh(
Expand Down
2 changes: 1 addition & 1 deletion sky/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def autostop(


@usage_lib.entrypoint
def queue(cluster_name: str,
def queue(cluster_name: List[str],
skip_finished: bool = False,
all_users: bool = False) -> List[dict]:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
Expand Down

0 comments on commit d8f5698

Please sign in to comment.