diff --git a/sky/api/cli.py b/sky/api/cli.py index c670476e1f4..b9a9c981ee5 100644 --- a/sky/api/cli.py +++ b/sky/api/cli.py @@ -65,6 +65,7 @@ from sky.skylet import job_lib from sky.skylet import log_lib from sky.usage import usage_lib +from sky.utils import cluster_utils from sky.utils import common from sky.utils import common_utils from sky.utils import controller_utils @@ -114,14 +115,28 @@ sdk = sdk_lib -def _get_cluster_records( - clusters: List[str], - refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE +def _get_cluster_records_and_set_ssh_config( + clusters: Optional[List[str]], + refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE, ) -> List[dict]: """Returns a list of clusters that match the glob pattern.""" # TODO(zhwu): this additional RTT makes CLIs slow. We should optimize this. request_id = sdk.status(clusters, refresh=refresh) - cluster_records = sdk.get(request_id) + cluster_records = sdk.stream_and_get(request_id) + # Update the SSH config for all clusters + for record in cluster_records: + handle = record['handle'] + if handle is not None and handle.cached_external_ips is not None: + crednetials = record['credentials'] + cluster_utils.SSHConfigHelper.add_cluster( + handle.cluster_name, + handle.cached_external_ips, + crednetials, + handle.cached_external_ssh_ports, + handle.docker_user, + handle.ssh_user, + ) + return cluster_records @@ -1077,6 +1092,9 @@ def launch( need_confirmation=not yes, ) _async_call_or_wait(request_id, async_call, 'Launch') + if async_call: + # Add ssh config for the cluster + _get_cluster_records_and_set_ssh_config(clusters=[cluster]) @cli.command(cls=_DocumentedCodeCommand) @@ -1570,8 +1588,9 @@ def status(all: bool, refresh: bool, ip: bool, endpoints: bool, refresh_mode = common.StatusRefreshMode.NONE if refresh: refresh_mode = common.StatusRefreshMode.FORCE - request = sdk.status(cluster_names=query_clusters, refresh=refresh_mode) - cluster_records = sdk.stream_and_get(request) + cluster_records = _get_cluster_records_and_set_ssh_config( + query_clusters, refresh_mode) + # TOOD(zhwu): setup the ssh config for status if ip or show_endpoints: if len(cluster_records) != 1: @@ -1834,9 +1853,8 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool): # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Show the job queue for cluster(s).""" click.secho('Fetching and parsing job queue...', fg='yellow') - if not clusters: - clusters = ['*'] - cluster_records = _get_cluster_records(clusters) + query_clusters = None if not clusters else clusters + cluster_records = _get_cluster_records_and_set_ssh_config(query_clusters) clusters = [cluster['name'] for cluster in cluster_records] unsupported_clusters = [] @@ -2337,8 +2355,8 @@ def start( if not clusters and not all: # UX: frequently users may have only 1 cluster. In this case, be smart # and default to that unique choice. - all_clusters = _get_cluster_records( - ['*'], refresh=common.StatusRefreshMode.AUTO) + all_clusters = _get_cluster_records_and_set_ssh_config( + clusters=None, refresh=common.StatusRefreshMode.AUTO) if len(all_clusters) <= 1: cluster_records = all_clusters else: @@ -2351,8 +2369,8 @@ def start( click.echo('Both --all and cluster(s) specified for sky start. ' 'Letting --all take effect.') - all_clusters = _get_cluster_records( - ['*'], refresh=common.StatusRefreshMode.AUTO) + all_clusters = _get_cluster_records_and_set_ssh_config( + clusters=None, refresh=common.StatusRefreshMode.AUTO) # Get all clusters that are not controllers. cluster_records = [ @@ -2361,7 +2379,7 @@ def start( ] if cluster_records is None: # Get GLOB cluster names - cluster_records = _get_cluster_records( + cluster_records = _get_cluster_records_and_set_ssh_config( clusters, refresh=common.StatusRefreshMode.AUTO) if not cluster_records: @@ -2671,7 +2689,7 @@ def _down_or_stop_clusters( # UX: frequently users may have only 1 cluster. In this case, 'sky # stop/down' without args should be smart and default to that unique # choice. - all_clusters = _get_cluster_records(['*']) + all_clusters = _get_cluster_records_and_set_ssh_config(['*']) if len(all_clusters) <= 1: names = [cluster['name'] for cluster in all_clusters] else: @@ -2696,7 +2714,7 @@ def _down_or_stop_clusters( controllers_str = ', '.join(map(repr, controllers)) names = [ cluster['name'] - for cluster in _get_cluster_records(names) + for cluster in _get_cluster_records_and_set_ssh_config(names) if controller_utils.Controllers.from_name(cluster['name']) is None ] @@ -2755,7 +2773,7 @@ def _down_or_stop_clusters( names += controllers if apply_to_all: - all_clusters = _get_cluster_records(['*']) + all_clusters = _get_cluster_records_and_set_ssh_config(clusters=None) if len(names) > 0: click.echo( f'Both --all and cluster(s) specified for `sky {command}`. ' diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index ae5a1c9c1ae..4f90eb3cacd 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -2,7 +2,6 @@ from datetime import datetime import enum import fnmatch -import functools import os import pathlib import pprint @@ -11,7 +10,6 @@ import subprocess import sys import tempfile -import textwrap import time import typing from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -41,7 +39,7 @@ from sky.provision.kubernetes import utils as kubernetes_utils from sky.skylet import constants from sky.usage import usage_lib -from sky.utils import cluster_yaml_utils +from sky.utils import cluster_utils from sky.utils import command_runner from sky.utils import common from sky.utils import common_utils @@ -69,7 +67,6 @@ # Exclude subnet mask from IP address regex. IP_ADDR_REGEX = r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}(?!/\d{1,2})\b' SKY_REMOTE_PATH = '~/.sky/wheels' -SKY_USER_FILE_PATH = '~/.sky/generated' BOLD = '\033[1m' RESET_BOLD = '\033[0m' @@ -170,7 +167,8 @@ def is_ip(s: str) -> bool: def _get_yaml_path_from_cluster_name(cluster_name: str, - prefix: str = SKY_USER_FILE_PATH) -> str: + prefix: str = constants.SKY_USER_FILE_PATH + ) -> str: output_path = pathlib.Path( prefix).expanduser().resolve() / f'{cluster_name}.yml' os.makedirs(output_path.parents[0], exist_ok=True) @@ -396,304 +394,6 @@ def make_safe_symlink_command(cls, *, source: str, target: str) -> str: return ' && '.join(commands) -class SSHConfigHelper(object): - """Helper for handling local SSH configuration.""" - - ssh_conf_path = '~/.ssh/config' - ssh_conf_lock_path = os.path.expanduser('~/.sky/ssh_config.lock') - ssh_cluster_path = SKY_USER_FILE_PATH + '/ssh/{}' - - @classmethod - def _get_generated_config(cls, autogen_comment: str, host_name: str, - ip: str, username: str, ssh_key_path: str, - proxy_command: Optional[str], port: int, - docker_proxy_command: Optional[str]): - if proxy_command is not None: - # Already checked in resources - assert docker_proxy_command is None, ( - 'Cannot specify both proxy_command and docker_proxy_command.') - proxy = f'ProxyCommand {proxy_command}' - elif docker_proxy_command is not None: - proxy = f'ProxyCommand {docker_proxy_command}' - else: - proxy = '' - # StrictHostKeyChecking=no skips the host key check for the first - # time. UserKnownHostsFile=/dev/null and GlobalKnownHostsFile/dev/null - # prevent the host key from being added to the known_hosts file and - # always return an empty file for known hosts, making the ssh think - # this is a first-time connection, and thus skipping the host key - # check. - codegen = textwrap.dedent(f"""\ - {autogen_comment} - Host {host_name} - HostName {ip} - User {username} - IdentityFile {ssh_key_path} - IdentitiesOnly yes - ForwardAgent yes - StrictHostKeyChecking no - UserKnownHostsFile=/dev/null - GlobalKnownHostsFile=/dev/null - Port {port} - {proxy} - """.rstrip()) - codegen = codegen + '\n' - return codegen - - @classmethod - @timeline.FileLockEvent(ssh_conf_lock_path) - def add_cluster( - cls, - cluster_name: str, - ips: List[str], - auth_config: Dict[str, str], - ports: List[int], - docker_user: Optional[str] = None, - ssh_user: Optional[str] = None, - ): - """Add authentication information for cluster to local SSH config file. - - If a host with `cluster_name` already exists and the configuration was - not added by sky, then `ip` is used to identify the host instead in the - file. - - If a host with `cluster_name` already exists and the configuration was - added by sky (e.g. a spot instance), then the configuration is - overwritten. - - Args: - cluster_name: Cluster name (see `sky status`) - ips: List of public IP addresses in the cluster. First IP is head - node. - auth_config: read_yaml(handle.cluster_yaml)['auth'] - ports: List of port numbers for SSH corresponding to ips - docker_user: If not None, use this user to ssh into the docker - ssh_user: Override the ssh_user in auth_config - """ - if ssh_user is None: - username = auth_config['ssh_user'] - else: - username = ssh_user - if docker_user is not None: - username = docker_user - key_path = os.path.expanduser(auth_config['ssh_private_key']) - sky_autogen_comment = ('# Added by sky (use `sky stop/down ' - f'{cluster_name}` to remove)') - ip = ips[0] - if docker_user is not None: - ip = 'localhost' - - config_path = os.path.expanduser(cls.ssh_conf_path) - - # For backward compatibility: before #2706, we wrote the config of SkyPilot clusters - # directly in ~/.ssh/config. For these clusters, we remove the config in ~/.ssh/config - # and write/overwrite the config in ~/.sky/ssh/ instead. - cls._remove_stale_cluster_config_for_backward_compatibility( - cluster_name, ip, auth_config, docker_user) - - if not os.path.exists(config_path): - config = ['\n'] - with open(config_path, - 'w', - encoding='utf-8', - opener=functools.partial(os.open, mode=0o644)) as f: - f.writelines(config) - - with open(config_path, 'r', encoding='utf-8') as f: - config = f.readlines() - - ssh_dir = cls.ssh_cluster_path.format('') - os.makedirs(os.path.expanduser(ssh_dir), exist_ok=True, mode=0o700) - - # Handle Include on top of Config file - include_str = f'Include {cls.ssh_cluster_path.format("*")}' - found = False - for i, line in enumerate(config): - config_str = line.strip() - if config_str == include_str: - found = True - break - if 'Host' in config_str: - break - if not found: - # Did not find Include string. Insert `Include` lines. - with open(config_path, 'w', encoding='utf-8') as f: - config.insert( - 0, - f'# Added by SkyPilot for ssh config of all clusters\n{include_str}\n' - ) - f.write(''.join(config).strip()) - f.write('\n' * 2) - - proxy_command = auth_config.get('ssh_proxy_command', None) - - docker_proxy_command_generator = None - if docker_user is not None: - docker_proxy_command_generator = lambda ip, port: ' '.join( - ['ssh'] + command_runner.ssh_options_list( - key_path, ssh_control_name=None, port=port) + - ['-W', '%h:%p', f'{auth_config["ssh_user"]}@{ip}']) - - codegen = '' - # Add the nodes to the codegen - for i, ip in enumerate(ips): - docker_proxy_command = None - port = ports[i] - if docker_proxy_command_generator is not None: - docker_proxy_command = docker_proxy_command_generator(ip, port) - ip = 'localhost' - port = constants.DEFAULT_DOCKER_PORT - node_name = cluster_name if i == 0 else cluster_name + f'-worker{i}' - # TODO(romilb): Update port number when k8s supports multinode - codegen += cls._get_generated_config( - sky_autogen_comment, node_name, ip, username, key_path, - proxy_command, port, docker_proxy_command) + '\n' - - cluster_config_path = os.path.expanduser( - cls.ssh_cluster_path.format(cluster_name)) - - with open(cluster_config_path, - 'w', - encoding='utf-8', - opener=functools.partial(os.open, mode=0o644)) as f: - f.write(codegen) - - @classmethod - def _remove_stale_cluster_config_for_backward_compatibility( - cls, - cluster_name: str, - ip: str, - auth_config: Dict[str, str], - docker_user: Optional[str] = None, - ): - """Remove authentication information for cluster from local SSH config. - - If no existing host matching the provided specification is found, then - nothing is removed. - - Args: - ip: Head node's IP address. - auth_config: read_yaml(handle.cluster_yaml)['auth'] - docker_user: If not None, use this user to ssh into the docker - """ - username = auth_config['ssh_user'] - config_path = os.path.expanduser(cls.ssh_conf_path) - cluster_config_path = os.path.expanduser( - cls.ssh_cluster_path.format(cluster_name)) - if not os.path.exists(config_path): - return - - with open(config_path, 'r', encoding='utf-8') as f: - config = f.readlines() - - start_line_idx = None - - # Scan the config for the cluster name. - for i, line in enumerate(config): - next_line = config[i + 1] if i + 1 < len(config) else '' - if docker_user is None: - found = (line.strip() == f'HostName {ip}' and - next_line.strip() == f'User {username}') - else: - found = (line.strip() == 'HostName localhost' and - next_line.strip() == f'User {docker_user}') - if found: - # Find the line starting with ProxyCommand and contains the ip - found = False - for idx in range(i, len(config)): - # Stop if we reach an empty line, which means a new host - if not config[idx].strip(): - break - if config[idx].strip().startswith('ProxyCommand'): - proxy_command_line = config[idx].strip() - if proxy_command_line.endswith(f'@{ip}'): - found = True - break - if found: - start_line_idx = i - 1 - break - - if start_line_idx is not None: - # Scan for end of previous config. - cursor = start_line_idx - while cursor > 0 and len(config[cursor].strip()) > 0: - cursor -= 1 - prev_end_line_idx = cursor - - # Scan for end of the cluster config. - end_line_idx = None - cursor = start_line_idx + 1 - start_line_idx -= 1 # remove auto-generated comment - while cursor < len(config): - if config[cursor].strip().startswith( - '# ') or config[cursor].strip().startswith('Host '): - end_line_idx = cursor - break - cursor += 1 - - # Remove sky-generated config and update the file. - config[prev_end_line_idx:end_line_idx] = [ - '\n' - ] if end_line_idx is not None else [] - with open(config_path, 'w', encoding='utf-8') as f: - f.write(''.join(config).strip()) - f.write('\n' * 2) - - # Delete include statement if it exists in the config. - sky_autogen_comment = ('# Added by sky (use `sky stop/down ' - f'{cluster_name}` to remove)') - with open(config_path, 'r', encoding='utf-8') as f: - config = f.readlines() - - for i, line in enumerate(config): - config_str = line.strip() - if f'Include {cluster_config_path}' in config_str: - with open(config_path, 'w', encoding='utf-8') as f: - if i < len(config) - 1 and config[i + 1] == '\n': - del config[i + 1] - # Delete Include string - del config[i] - # Delete Sky Autogen Comment - if i > 0 and sky_autogen_comment in config[i - 1].strip(): - del config[i - 1] - f.write(''.join(config)) - break - if 'Host' in config_str: - break - - @classmethod - # TODO: We can remove this after 0.6.0 and have a lock only per cluster. - @timeline.FileLockEvent(ssh_conf_lock_path) - def remove_cluster( - cls, - cluster_name: str, - ip: str, - auth_config: Dict[str, str], - docker_user: Optional[str] = None, - ): - """Remove authentication information for cluster from ~/.sky/ssh/. - - For backward compatibility also remove the config from ~/.ssh/config if it exists. - - If no existing host matching the provided specification is found, then - nothing is removed. - - Args: - ip: Head node's IP address. - auth_config: read_yaml(handle.cluster_yaml)['auth'] - docker_user: If not None, use this user to ssh into the docker - """ - cluster_config_path = os.path.expanduser( - cls.ssh_cluster_path.format(cluster_name)) - common_utils.remove_file_if_exists(cluster_config_path) - - # Ensures backward compatibility: before #2706, we wrote the config of SkyPilot clusters - # directly in ~/.ssh/config. For these clusters, we should clean up the config. - # TODO: Remove this after 0.6.0 - cls._remove_stale_cluster_config_for_backward_compatibility( - cluster_name, ip, auth_config, docker_user) - - def _replace_yaml_dicts( new_yaml: str, old_yaml: str, restore_key_names: Set[str], restore_key_names_exceptions: Sequence[Tuple[str, ...]]) -> str: @@ -955,7 +655,7 @@ def write_cluster_config( 'sky_local_path': str(local_wheel_path), # Add yaml file path to the template variables. 'sky_ray_yaml_remote_path': - cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH, + cluster_utils.SKY_CLUSTER_YAML_REMOTE_PATH, 'sky_ray_yaml_local_path': tmp_yaml_path, 'sky_version': str(version.parse(sky.__version__)), 'sky_wheel_hash': wheel_hash, @@ -1434,7 +1134,7 @@ def get_node_ips(cluster_yaml: str, """ ray_config = common_utils.read_yaml(cluster_yaml) # Use the new provisioner for AWS. - provider_name = cluster_yaml_utils.get_provider_name(ray_config) + provider_name = cluster_utils.get_provider_name(ray_config) cloud = registry.CLOUD_REGISTRY.from_str(provider_name) assert cloud is not None, provider_name @@ -2475,6 +2175,26 @@ def get_clusters( clusters_str = ', '.join(not_exist_cluster_names) logger.info(f'Cluster(s) not found: {bright}{clusters_str}{reset}.') records = new_records + # Add auth_config to the records + for record in records: + handle = record['handle'] + if handle is None: + continue + credentials = ssh_credential_from_yaml(handle.cluster_yaml, + handle.docker_user, + handle.ssh_user) + ssh_private_key_path = credentials.pop('ssh_private_key', None) + if ssh_private_key_path is not None: + with open(os.path.expanduser(ssh_private_key_path), + 'r', + encoding='utf-8') as f: + credentials['ssh_private_key_content'] = f.read() + else: + with open(os.path.expanduser(auth.PRIVATE_SSH_KEY_PATH), + 'r', + encoding='utf-8') as f: + credentials['ssh_private_key_content'] = f.read() + record['credentials'] = credentials if refresh == common.StatusRefreshMode.NONE: return records @@ -2548,6 +2268,7 @@ def _refresh_cluster(cluster_name): f'{len(failed_clusters)} cluster{plural}:{reset}') for cluster_name, e in failed_clusters: logger.warning(f' {bright}{cluster_name}{reset}: {e}') + return kept_records diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 1cebb8f314d..8e40ca2d974 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -53,6 +53,7 @@ from sky.skylet import log_lib from sky.usage import usage_lib from sky.utils import accelerator_registry +from sky.utils import cluster_utils from sky.utils import command_runner from sky.utils import common from sky.utils import common_utils @@ -2994,15 +2995,16 @@ def _update_after_cluster_provisioned( ) usage_lib.messages.usage.update_final_cluster_status( status_lib.ClusterStatus.UP) - auth_config = backend_utils.ssh_credential_from_yaml( - handle.cluster_yaml, - ssh_user=handle.ssh_user, - docker_user=handle.docker_user) - backend_utils.SSHConfigHelper.add_cluster(handle.cluster_name, - ip_list, auth_config, - ssh_port_list, - handle.docker_user, - handle.ssh_user) + # Do not need to add the cluster to ssh config file on API server. + # auth_config = backend_utils.ssh_credential_from_yaml( + # handle.cluster_yaml, + # ssh_user=handle.ssh_user, + # docker_user=handle.docker_user) + # cluster_utils.SSHConfigHelper.add_cluster(handle.cluster_name, + # ip_list, auth_config, + # ssh_port_list, + # handle.docker_user, + # handle.ssh_user) common_utils.remove_file_if_exists(lock_path) @@ -4056,10 +4058,9 @@ def post_teardown_cleanup(self, # be removed after the cluster entry in the database is removed. config = common_utils.read_yaml(handle.cluster_yaml) auth_config = config['auth'] - backend_utils.SSHConfigHelper.remove_cluster(handle.cluster_name, - handle.head_ip, - auth_config, - handle.docker_user) + sky.utils.cluster_utils.SSHConfigHelper.remove_cluster( + handle.cluster_name, handle.head_ip, auth_config, + handle.docker_user) global_user_state.remove_cluster(handle.cluster_name, terminate=terminate) diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 30ef416806d..dc0cca80653 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -294,3 +294,6 @@ # The placeholder for the local skypilot config path in file mounts for # controllers. LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER = 'skypilot:local_skypilot_config_path' + +# Path to the generated cluster config yamls and ssh configs. +SKY_USER_FILE_PATH = '~/.sky/generated' diff --git a/sky/skylet/events.py b/sky/skylet/events.py index 810dd0fb213..4fe02c14656 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -17,7 +17,7 @@ from sky.skylet import autostop_lib from sky.skylet import constants from sky.skylet import job_lib -from sky.utils import cluster_yaml_utils +from sky.utils import cluster_utils from sky.utils import common_utils from sky.utils import registry from sky.utils import ux_utils @@ -140,10 +140,9 @@ def _stop_cluster(self, autostop_config): autostop_lib.set_autostopping_started() config_path = os.path.abspath( - os.path.expanduser( - cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH)) + os.path.expanduser(cluster_utils.SKY_CLUSTER_YAML_REMOTE_PATH)) config = common_utils.read_yaml(config_path) - provider_name = cluster_yaml_utils.get_provider_name(config) + provider_name = cluster_utils.get_provider_name(config) cloud = registry.CLOUD_REGISTRY.from_str(provider_name) assert cloud is not None, f'Unknown cloud: {provider_name}' diff --git a/sky/utils/cluster_utils.py b/sky/utils/cluster_utils.py new file mode 100644 index 00000000000..bc216ef5692 --- /dev/null +++ b/sky/utils/cluster_utils.py @@ -0,0 +1,350 @@ +"""Utility functions for cluster yaml file.""" + +import functools +import os +import re +import textwrap +from typing import Dict, List, Optional + +from sky.skylet import constants +from sky.utils import command_runner +from sky.utils import common_utils +from sky.utils import timeline + +# The cluster yaml used to create the current cluster where the module is +# called. +SKY_CLUSTER_YAML_REMOTE_PATH = '~/.sky/sky_ray.yml' + + +def get_provider_name(config: dict) -> str: + """Return the name of the provider.""" + + provider_module = config['provider']['module'] + # Examples: + # 'sky.skylet.providers.aws.AWSNodeProviderV2' -> 'aws' + # 'sky.provision.aws' -> 'aws' + provider_search = re.search(r'(?:providers|provision)\.(\w+)\.?', + provider_module) + assert provider_search is not None, config + provider_name = provider_search.group(1).lower() + # Special handling for lambda_cloud as Lambda cloud is registered as lambda. + if provider_name == 'lambda_cloud': + provider_name = 'lambda' + return provider_name + + +class SSHConfigHelper(object): + """Helper for handling local SSH configuration.""" + + ssh_conf_path = '~/.ssh/config' + ssh_conf_lock_path = os.path.expanduser('~/.sky/ssh_config.lock') + ssh_cluster_path = constants.SKY_USER_FILE_PATH + '/ssh/{}' + ssh_cluster_key_path = constants.SKY_USER_FILE_PATH + '/ssh-keys/{}.key' + + @classmethod + def _get_generated_config(cls, autogen_comment: str, host_name: str, + ip: str, username: str, ssh_key_path: str, + proxy_command: Optional[str], port: int, + docker_proxy_command: Optional[str]): + if proxy_command is not None: + # Already checked in resources + assert docker_proxy_command is None, ( + 'Cannot specify both proxy_command and docker_proxy_command.') + proxy = f'ProxyCommand {proxy_command}' + elif docker_proxy_command is not None: + proxy = f'ProxyCommand {docker_proxy_command}' + else: + proxy = '' + # StrictHostKeyChecking=no skips the host key check for the first + # time. UserKnownHostsFile=/dev/null and GlobalKnownHostsFile/dev/null + # prevent the host key from being added to the known_hosts file and + # always return an empty file for known hosts, making the ssh think + # this is a first-time connection, and thus skipping the host key + # check. + codegen = textwrap.dedent(f"""\ + {autogen_comment} + Host {host_name} + HostName {ip} + User {username} + IdentityFile {ssh_key_path} + IdentitiesOnly yes + ForwardAgent yes + StrictHostKeyChecking no + UserKnownHostsFile=/dev/null + GlobalKnownHostsFile=/dev/null + Port {port} + {proxy} + """.rstrip()) + codegen = codegen + '\n' + return codegen + + @classmethod + @timeline.FileLockEvent(ssh_conf_lock_path) + def add_cluster( + cls, + cluster_name: str, + ips: List[str], + auth_config: Dict[str, str], + ports: List[int], + docker_user: Optional[str] = None, + ssh_user: Optional[str] = None, + ): + """Add authentication information for cluster to local SSH config file. + + If a host with `cluster_name` already exists and the configuration was + not added by sky, then `ip` is used to identify the host instead in the + file. + + If a host with `cluster_name` already exists and the configuration was + added by sky (e.g. a spot instance), then the configuration is + overwritten. + + Args: + cluster_name: Cluster name (see `sky status`) + ips: List of public IP addresses in the cluster. First IP is head + node. + auth_config: read_yaml(handle.cluster_yaml)['auth'] + ports: List of port numbers for SSH corresponding to ips + docker_user: If not None, use this user to ssh into the docker + ssh_user: Override the ssh_user in auth_config + """ + if ssh_user is None: + username = auth_config['ssh_user'] + else: + username = ssh_user + if docker_user is not None: + username = docker_user + + key_content = auth_config.pop('ssh_private_key_content', None) + if key_content is not None: + cluster_private_key_path = cls.ssh_cluster_key_path.format( + cluster_name) + expanded_cluster_private_key_path = os.path.expanduser(cluster_private_key_path) + os.makedirs(os.path.dirname(expanded_cluster_private_key_path), + exist_ok=True) + os.chmod(expanded_cluster_private_key_path, 0o600) + with open(expanded_cluster_private_key_path, + 'w', + encoding='utf-8') as f: + f.write(key_content) + auth_config['ssh_private_key'] = cluster_private_key_path + key_path = os.path.expanduser(auth_config['ssh_private_key']) + sky_autogen_comment = ('# Added by sky (use `sky stop/down ' + f'{cluster_name}` to remove)') + ip = ips[0] + if docker_user is not None: + ip = 'localhost' + + config_path = os.path.expanduser(cls.ssh_conf_path) + + # For backward compatibility: before #2706, we wrote the config of + # SkyPilot clusters directly in ~/.ssh/config. For these clusters, we + # remove the config in ~/.ssh/config and write/overwrite the config in + # ~/.sky/ssh/ instead. + cls._remove_stale_cluster_config_for_backward_compatibility( + cluster_name, ip, auth_config, docker_user) + + if not os.path.exists(config_path): + config = ['\n'] + with open(config_path, + 'w', + encoding='utf-8', + opener=functools.partial(os.open, mode=0o644)) as f: + f.writelines(config) + + with open(config_path, 'r', encoding='utf-8') as f: + config = f.readlines() + + ssh_dir = cls.ssh_cluster_path.format('') + os.makedirs(os.path.expanduser(ssh_dir), exist_ok=True, mode=0o700) + + # Handle Include on top of Config file + include_str = f'Include {cls.ssh_cluster_path.format("*")}' + found = False + for i, line in enumerate(config): + config_str = line.strip() + if config_str == include_str: + found = True + break + if 'Host' in config_str: + break + if not found: + # Did not find Include string. Insert `Include` lines. + with open(config_path, 'w', encoding='utf-8') as f: + config.insert( + 0, '# Added by SkyPilot for ssh config of all clusters\n' + f'{include_str}\n') + f.write(''.join(config).strip()) + f.write('\n' * 2) + + proxy_command = auth_config.get('ssh_proxy_command', None) + + docker_proxy_command_generator = None + if docker_user is not None: + docker_proxy_command_generator = lambda ip, port: ' '.join( + ['ssh'] + command_runner.ssh_options_list( + key_path, ssh_control_name=None, port=port) + + ['-W', '%h:%p', f'{auth_config["ssh_user"]}@{ip}']) + + codegen = '' + # Add the nodes to the codegen + for i, ip in enumerate(ips): + docker_proxy_command = None + port = ports[i] + if docker_proxy_command_generator is not None: + docker_proxy_command = docker_proxy_command_generator(ip, port) + ip = 'localhost' + port = constants.DEFAULT_DOCKER_PORT + node_name = cluster_name if i == 0 else cluster_name + f'-worker{i}' + # TODO(romilb): Update port number when k8s supports multinode + codegen += cls._get_generated_config( + sky_autogen_comment, node_name, ip, username, key_path, + proxy_command, port, docker_proxy_command) + '\n' + + cluster_config_path = os.path.expanduser( + cls.ssh_cluster_path.format(cluster_name)) + + with open(cluster_config_path, + 'w', + encoding='utf-8', + opener=functools.partial(os.open, mode=0o644)) as f: + f.write(codegen) + + @classmethod + def _remove_stale_cluster_config_for_backward_compatibility( + cls, + cluster_name: str, + ip: str, + auth_config: Dict[str, str], + docker_user: Optional[str] = None, + ): + """Remove authentication information for cluster from local SSH config. + + If no existing host matching the provided specification is found, then + nothing is removed. + + Args: + ip: Head node's IP address. + auth_config: read_yaml(handle.cluster_yaml)['auth'] + docker_user: If not None, use this user to ssh into the docker + """ + username = auth_config['ssh_user'] + config_path = os.path.expanduser(cls.ssh_conf_path) + cluster_config_path = os.path.expanduser( + cls.ssh_cluster_path.format(cluster_name)) + if not os.path.exists(config_path): + return + + with open(config_path, 'r', encoding='utf-8') as f: + config = f.readlines() + + start_line_idx = None + + # Scan the config for the cluster name. + for i, line in enumerate(config): + next_line = config[i + 1] if i + 1 < len(config) else '' + if docker_user is None: + found = (line.strip() == f'HostName {ip}' and + next_line.strip() == f'User {username}') + else: + found = (line.strip() == 'HostName localhost' and + next_line.strip() == f'User {docker_user}') + if found: + # Find the line starting with ProxyCommand and contains ip + found = False + for idx in range(i, len(config)): + # Stop if we reach an empty line, which means a new host + if not config[idx].strip(): + break + if config[idx].strip().startswith('ProxyCommand'): + proxy_command_line = config[idx].strip() + if proxy_command_line.endswith(f'@{ip}'): + found = True + break + if found: + start_line_idx = i - 1 + break + + if start_line_idx is not None: + # Scan for end of previous config. + cursor = start_line_idx + while cursor > 0 and len(config[cursor].strip()) > 0: + cursor -= 1 + prev_end_line_idx = cursor + + # Scan for end of the cluster config. + end_line_idx = None + cursor = start_line_idx + 1 + start_line_idx -= 1 # remove auto-generated comment + while cursor < len(config): + if config[cursor].strip().startswith( + '# ') or config[cursor].strip().startswith('Host '): + end_line_idx = cursor + break + cursor += 1 + + # Remove sky-generated config and update the file. + config[prev_end_line_idx:end_line_idx] = [ + '\n' + ] if end_line_idx is not None else [] + with open(config_path, 'w', encoding='utf-8') as f: + f.write(''.join(config).strip()) + f.write('\n' * 2) + + # Delete include statement if it exists in the config. + sky_autogen_comment = ('# Added by sky (use `sky stop/down ' + f'{cluster_name}` to remove)') + with open(config_path, 'r', encoding='utf-8') as f: + config = f.readlines() + + for i, line in enumerate(config): + config_str = line.strip() + if f'Include {cluster_config_path}' in config_str: + with open(config_path, 'w', encoding='utf-8') as f: + if i < len(config) - 1 and config[i + 1] == '\n': + del config[i + 1] + # Delete Include string + del config[i] + # Delete Sky Autogen Comment + if i > 0 and sky_autogen_comment in config[i - 1].strip(): + del config[i - 1] + f.write(''.join(config)) + break + if 'Host' in config_str: + break + + @classmethod + # TODO: We can remove this after 0.6.0 and have a lock only per cluster. + @timeline.FileLockEvent(ssh_conf_lock_path) + def remove_cluster( + cls, + cluster_name: str, + ip: str, + auth_config: Dict[str, str], + docker_user: Optional[str] = None, + ): + """Remove auth info for cluster from ~/.sky/generated/ssh/ + + For backward compatibility also remove the config from ~/.ssh/config if + it exists. + + If no existing host matching the provided specification is found, then + nothing is removed. + + Args: + ip: Head node's IP address. + auth_config: read_yaml(handle.cluster_yaml)['auth'] + docker_user: If not None, use this user to ssh into the docker + """ + cluster_config_path = os.path.expanduser( + cls.ssh_cluster_path.format(cluster_name)) + common_utils.remove_file_if_exists(cluster_config_path) + cluster_private_key_path = cls.ssh_cluster_key_path.format(cluster_name) + common_utils.remove_file_if_exists(cluster_private_key_path) + + # Ensures backward compatibility: before #2706, we wrote the config of + # SkyPilot clusters directly in ~/.ssh/config. For these clusters, we + # should clean up the config. + # TODO: Remove this after 0.6.0 + cls._remove_stale_cluster_config_for_backward_compatibility( + cluster_name, ip, auth_config, docker_user) diff --git a/sky/utils/cluster_yaml_utils.py b/sky/utils/cluster_yaml_utils.py deleted file mode 100644 index 50410c274fb..00000000000 --- a/sky/utils/cluster_yaml_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Utility functions for cluster yaml file.""" - -import re - -# The cluster yaml used to create the current cluster where the module is -# called. -SKY_CLUSTER_YAML_REMOTE_PATH = '~/.sky/sky_ray.yml' - - -def get_provider_name(config: dict) -> str: - """Return the name of the provider.""" - - provider_module = config['provider']['module'] - # Examples: - # 'sky.skylet.providers.aws.AWSNodeProviderV2' -> 'aws' - # 'sky.provision.aws' -> 'aws' - provider_search = re.search(r'(?:providers|provision)\.(\w+)\.?', - provider_module) - assert provider_search is not None, config - provider_name = provider_search.group(1).lower() - # Special handling for lambda_cloud as Lambda cloud is registered as lambda. - if provider_name == 'lambda_cloud': - provider_name = 'lambda' - return provider_name