Skip to content

Commit

Permalink
Handle SSH config
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Aug 7, 2024
1 parent 3d06494 commit e454cd1
Show file tree
Hide file tree
Showing 7 changed files with 431 additions and 363 deletions.
52 changes: 35 additions & 17 deletions sky/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from sky.skylet import job_lib
from sky.skylet import log_lib
from sky.usage import usage_lib
from sky.utils import cluster_utils
from sky.utils import common
from sky.utils import common_utils
from sky.utils import controller_utils
Expand Down Expand Up @@ -114,14 +115,28 @@
sdk = sdk_lib


def _get_cluster_records(
clusters: List[str],
refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE
def _get_cluster_records_and_set_ssh_config(
clusters: Optional[List[str]],
refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE,
) -> List[dict]:
"""Returns a list of clusters that match the glob pattern."""
# TODO(zhwu): this additional RTT makes CLIs slow. We should optimize this.
request_id = sdk.status(clusters, refresh=refresh)
cluster_records = sdk.get(request_id)
cluster_records = sdk.stream_and_get(request_id)
# Update the SSH config for all clusters
for record in cluster_records:
handle = record['handle']
if handle is not None and handle.cached_external_ips is not None:
crednetials = record['credentials']
cluster_utils.SSHConfigHelper.add_cluster(
handle.cluster_name,
handle.cached_external_ips,
crednetials,
handle.cached_external_ssh_ports,
handle.docker_user,
handle.ssh_user,
)

return cluster_records


Expand Down Expand Up @@ -1077,6 +1092,9 @@ def launch(
need_confirmation=not yes,
)
_async_call_or_wait(request_id, async_call, 'Launch')
if async_call:
# Add ssh config for the cluster
_get_cluster_records_and_set_ssh_config(clusters=[cluster])


@cli.command(cls=_DocumentedCodeCommand)
Expand Down Expand Up @@ -1570,8 +1588,9 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool,
refresh_mode = common.StatusRefreshMode.NONE
if refresh:
refresh_mode = common.StatusRefreshMode.FORCE
request = sdk.status(cluster_names=query_clusters, refresh=refresh_mode)
cluster_records = sdk.stream_and_get(request)
cluster_records = _get_cluster_records_and_set_ssh_config(
query_clusters, refresh_mode)

# TOOD(zhwu): setup the ssh config for status
if ip or show_endpoints:
if len(cluster_records) != 1:
Expand Down Expand Up @@ -1834,9 +1853,8 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Show the job queue for cluster(s)."""
click.secho('Fetching and parsing job queue...', fg='yellow')
if not clusters:
clusters = ['*']
cluster_records = _get_cluster_records(clusters)
query_clusters = None if not clusters else clusters
cluster_records = _get_cluster_records_and_set_ssh_config(query_clusters)
clusters = [cluster['name'] for cluster in cluster_records]

unsupported_clusters = []
Expand Down Expand Up @@ -2337,8 +2355,8 @@ def start(
if not clusters and not all:
# UX: frequently users may have only 1 cluster. In this case, be smart
# and default to that unique choice.
all_clusters = _get_cluster_records(
['*'], refresh=common.StatusRefreshMode.AUTO)
all_clusters = _get_cluster_records_and_set_ssh_config(
clusters=None, refresh=common.StatusRefreshMode.AUTO)
if len(all_clusters) <= 1:
cluster_records = all_clusters
else:
Expand All @@ -2351,8 +2369,8 @@ def start(
click.echo('Both --all and cluster(s) specified for sky start. '
'Letting --all take effect.')

all_clusters = _get_cluster_records(
['*'], refresh=common.StatusRefreshMode.AUTO)
all_clusters = _get_cluster_records_and_set_ssh_config(
clusters=None, refresh=common.StatusRefreshMode.AUTO)

# Get all clusters that are not controllers.
cluster_records = [
Expand All @@ -2361,7 +2379,7 @@ def start(
]
if cluster_records is None:
# Get GLOB cluster names
cluster_records = _get_cluster_records(
cluster_records = _get_cluster_records_and_set_ssh_config(
clusters, refresh=common.StatusRefreshMode.AUTO)

if not cluster_records:
Expand Down Expand Up @@ -2671,7 +2689,7 @@ def _down_or_stop_clusters(
# UX: frequently users may have only 1 cluster. In this case, 'sky
# stop/down' without args should be smart and default to that unique
# choice.
all_clusters = _get_cluster_records(['*'])
all_clusters = _get_cluster_records_and_set_ssh_config(['*'])
if len(all_clusters) <= 1:
names = [cluster['name'] for cluster in all_clusters]
else:
Expand All @@ -2696,7 +2714,7 @@ def _down_or_stop_clusters(
controllers_str = ', '.join(map(repr, controllers))
names = [
cluster['name']
for cluster in _get_cluster_records(names)
for cluster in _get_cluster_records_and_set_ssh_config(names)
if controller_utils.Controllers.from_name(cluster['name']) is None
]

Expand Down Expand Up @@ -2755,7 +2773,7 @@ def _down_or_stop_clusters(
names += controllers

if apply_to_all:
all_clusters = _get_cluster_records(['*'])
all_clusters = _get_cluster_records_and_set_ssh_config(clusters=None)
if len(names) > 0:
click.echo(
f'Both --all and cluster(s) specified for `sky {command}`. '
Expand Down
Loading

0 comments on commit e454cd1

Please sign in to comment.