diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index fecbcaad0b8..65447890037 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -861,6 +861,9 @@ def write_cluster_config( f'open(os.path.expanduser("{constants.SKY_REMOTE_RAY_PORT_FILE}"), "w", encoding="utf-8"))\'' ) + # TODO(tian): Hack. Reformat here. + default_use_internal_ips = 'vpn_config' in resources_vars + # Use a tmp file path to avoid incomplete YAML file being re-used in the # future. tmp_yaml_path = yaml_path + '.tmp' @@ -881,7 +884,8 @@ def write_cluster_config( # Networking configs 'use_internal_ips': skypilot_config.get_nested( - (str(cloud).lower(), 'use_internal_ips'), False), + (str(cloud).lower(), 'use_internal_ips'), + default_use_internal_ips), 'ssh_proxy_command': ssh_proxy_command, 'vpc_name': skypilot_config.get_nested( (str(cloud).lower(), 'vpc_name'), None), diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index f916d931b5f..54d20b00694 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -8,6 +8,7 @@ import math import os import pathlib +import pprint import re import signal import subprocess @@ -34,6 +35,7 @@ from sky import resources as resources_lib from sky import serve as serve_lib from sky import sky_logging +from sky import skypilot_config from sky import status_lib from sky import task as task_lib from sky.backends import backend_utils @@ -2110,11 +2112,12 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle): - (optional) Launched num nodes - (optional) Launched resources - (optional) Docker user name + - (optional) The cluster VPN configuration (if used) - (optional) If TPU(s) are managed, a path to a deletion script. """ # Bump if any fields get added/removed/changed, and add backward - # compaitibility logic in __setstate__. - _VERSION = 7 + # compatibility logic in __setstate__. + _VERSION = 8 def __init__( self, @@ -2147,6 +2150,8 @@ def __init__( self.launched_resources = launched_resources self.docker_user: Optional[str] = None self.ssh_user: Optional[str] = None + # TODO(tian): Should we store the APIs in the config YAML? + self.vpn_config: Optional[Dict[str, Any]] = self._get_vpn_config() # Deprecated. SkyPilot new provisioner API handles the TPU node # creation/deletion. # Backward compatibility for TPU nodes created before #2943. @@ -2154,6 +2159,14 @@ def __init__( self.tpu_create_script = tpu_create_script self.tpu_delete_script = tpu_delete_script + def _get_vpn_config(self) -> Optional[Dict[str, Any]]: + """Returns the VPN config used by the cluster.""" + # Directly load the VPN config from the cluster + # yaml instead of `skypilot_config` as the latter + # can be changed after the cluster is UP. + return common_utils.read_yaml(self.cluster_yaml).get( + 'provider', {}).get('vpn_config', None) + def __repr__(self): return (f'ResourceHandle(' f'\n\tcluster_name={self.cluster_name},' @@ -2169,6 +2182,7 @@ def __repr__(self): f'{self.launched_resources}, ' f'\n\tdocker_user={self.docker_user},' f'\n\tssh_user={self.ssh_user},' + f'\n\tvpn_config={self.vpn_config},' # TODO (zhwu): Remove this after 0.6.0. f'\n\ttpu_create_script={self.tpu_create_script}, ' f'\n\ttpu_delete_script={self.tpu_delete_script})') @@ -2440,6 +2454,9 @@ def __setstate__(self, state): if version < 7: self.ssh_user = None + if version < 8: + self.vpn_config = None + self.__dict__.update(state) # Because the update_cluster_ips and update_ssh_ports @@ -2533,23 +2550,68 @@ def check_resources_fit_cluster( # was handled by ResourceHandle._update_cluster_region. assert launched_resources.region is not None, handle + def _check_vpn_unchanged( + resource: resources_lib.Resources) -> Optional[str]: + """Check if the VPN configuration is unchanged. + + This function should only be called after checking cloud is the + same. Current VPN configuration is per-cloud basis, so we could + only check for the same cloud. + + Returns: + None if the VPN configuration is unchanged, otherwise a string + indicating the mismatch. + """ + assert resource.cloud is None or resource.cloud.is_same_cloud( + launched_resources.cloud) + # Use launched_resources.cloud here when resource.cloud is None. + now_vpn_config = skypilot_config.get_nested( + (str(launched_resources.cloud).lower(), 'vpn', 'tailscale'), + None) + use_or_not_mismatch_str = ( + '{} VPN, but current config requires the opposite') + if handle.vpn_config is None: + if now_vpn_config is None: + return None + return use_or_not_mismatch_str.format('without') + if now_vpn_config is None: + return use_or_not_mismatch_str.format('with') + if now_vpn_config == handle.vpn_config: + return None + return (f'with VPN config\n{pprint.pformat(handle.vpn_config)}, ' + f'but current config is\n{pprint.pformat(now_vpn_config)}') + mismatch_str = (f'To fix: specify a new cluster name, or down the ' f'existing cluster first: sky down {cluster_name}') valid_resource = None requested_resource_list = [] + resource_failure_reason: Dict[resources_lib.Resources, str] = {} for resource in task.resources: if (task.num_nodes <= handle.launched_nodes and resource.less_demanding_than( launched_resources, requested_num_nodes=task.num_nodes, check_ports=check_ports)): - valid_resource = resource - break + reason = _check_vpn_unchanged(resource) + if reason is None: + valid_resource = resource + break + else: + # TODO(tian): Maybe refactor the following into this dict + resource_failure_reason[resource] = ( + f'Cloud {launched_resources.cloud} VPN config ' + f'mismatch. Cluster {handle.cluster_name} is ' + f'launched {reason}. Please update the VPN ' + 'configuration in skypilot_config.') else: requested_resource_list.append(f'{task.num_nodes}x {resource}') if valid_resource is None: for example_resource in task.resources: + if example_resource in resource_failure_reason: + with ux_utils.print_exception_no_traceback(): + raise exceptions.ResourcesMismatchError( + resource_failure_reason[example_resource]) if (example_resource.region is not None and example_resource.region != launched_resources.region): with ux_utils.print_exception_no_traceback(): @@ -2849,6 +2911,9 @@ def _get_zone(runner): return handle def _open_ports(self, handle: CloudVmRayResourceHandle) -> None: + if handle.vpn_config is not None: + # Skip opening any ports if VPN is used. + return cloud = handle.launched_resources.cloud logger.debug( f'Opening ports {handle.launched_resources.ports} for {cloud}') @@ -3963,6 +4028,8 @@ def post_teardown_cleanup(self, f'Failed to delete cloned image {image_id}. Please ' 'remove it manually to avoid image leakage. Details: ' f'{common_utils.format_exception(e, use_bracket=True)}') + # We don't need to explicitly skip cleanup ports if VPN is used, as it + # will use the default security group and automatically skip it. if terminate: cloud = handle.launched_resources.cloud config = common_utils.read_yaml(handle.cluster_yaml) diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 1fef481d8d0..930f0b327bf 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -398,9 +398,13 @@ def make_deploy_resources_variables(self, image_id = self._get_image_id(image_id_to_use, region_name, r.instance_type) + tailscale_config = skypilot_config.get_nested( + ('aws', 'vpn', 'tailscale'), None) + user_security_group = skypilot_config.get_nested( ('aws', 'security_group_name'), None) - if resources.ports is not None: + # Only open ports if VPN is not enabled. + if resources.ports is not None and tailscale_config is None: # Already checked in Resources._try_validate_ports assert user_security_group is None security_group = USER_PORTS_SECURITY_GROUP_NAME.format( @@ -411,7 +415,7 @@ def make_deploy_resources_variables(self, else: security_group = DEFAULT_SECURITY_GROUP_NAME - return { + resources_vars = { 'instance_type': r.instance_type, 'custom_resources': custom_resources, 'use_spot': r.use_spot, @@ -423,6 +427,31 @@ def make_deploy_resources_variables(self, str(security_group != user_security_group).lower(), **AWS._get_disk_specs(r.disk_tier) } + resources_vars['vpn_config'] = tailscale_config + if tailscale_config is not None: + unique_id = cluster_name_on_cloud + resources_vars['vpn_unique_id'] = unique_id + resources_vars['vpn_cloud_init_commands'] = [ + [ + 'sh', '-c', + 'curl -fsSL https://tailscale.com/install.sh | sh' + ], + [ + 'sh', '-c', + ('echo \'net.ipv4.ip_forward = 1\' | ' + 'sudo tee -a /etc/sysctl.d/99-tailscale.conf && ' + 'echo \'net.ipv6.conf.all.forwarding = 1\' | ' + 'sudo tee -a /etc/sysctl.d/99-tailscale.conf && ' + 'sudo sysctl -p /etc/sysctl.d/99-tailscale.conf') + ], + [ + 'tailscale', 'up', + f'--authkey={tailscale_config["auth_key"]}', + f'--hostname={unique_id}' + ], + ] + + return resources_vars def _get_feasible_launchable_resources( self, resources: 'resources_lib.Resources' diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index b9fdf80326d..e838640a69d 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -11,6 +11,8 @@ import time from typing import Any, Callable, Dict, List, Optional, Set, TypeVar +import requests + from sky import sky_logging from sky import status_lib from sky.adaptors import aws @@ -19,6 +21,7 @@ from sky.provision.aws import utils from sky.utils import common_utils from sky.utils import resources_utils +from sky.utils import subprocess_utils from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -610,6 +613,40 @@ def terminate_instances( included_instances=None, excluded_instances=None) instances.terminate() + # Cleanup VPN record. + vpn_config = provider_config.get('vpn_config', None) + if vpn_config is not None: + auth_headers = {'Authorization': f'Bearer {vpn_config["api_key"]}'} + + def _get_node_id_from_hostname(network_name: str, + hostname: str) -> Optional[str]: + # TODO(tian): Refactor to a dedicated file for all + # VPN related functions and constants. + url_to_query = ('https://api.tailscale.com/api/v2/' + f'tailnet/{network_name}/devices') + # TODO(tian): Error handling if api key is wrong. + resp = requests.get(url_to_query, headers=auth_headers) + all_devices_in_network = resp.json().get('devices', []) + for device_info in all_devices_in_network: + if device_info.get('hostname') == hostname: + return device_info.get('nodeId') + return None + + node_id_in_vpn = _get_node_id_from_hostname( + vpn_config['tailnet'], provider_config['vpn_unique_id']) + if node_id_in_vpn is None: + logger.warning('Cannot find node id for ' + f'{provider_config["vpn_unique_id"]}. ' + f'Skip deleting vpn record.') + else: + url_to_delete = ('https://api.tailscale.com/api/v2/' + f'device/{node_id_in_vpn}') + resp = requests.delete(url_to_delete, headers=auth_headers) + if resp.status_code != 200: + logger.warning('Failed to delete vpn record for ' + f'{provider_config["vpn_unique_id"]}. ' + f'Status code: {resp.status_code}, ' + f'Response: {resp.text}') if (sg_name == aws_cloud.DEFAULT_SECURITY_GROUP_NAME or not managed_by_skypilot): # Using default AWS SG or user specified security group. We don't need @@ -843,7 +880,7 @@ def get_cluster_info( cluster_name_on_cloud: str, provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo: """See sky/provision/__init__.py""" - del provider_config # unused + assert provider_config is not None ec2 = _default_ec2_resource(region) filters = [ { @@ -863,10 +900,33 @@ def get_cluster_info( tags = [(t['Key'], t['Value']) for t in inst.tags] # sort tags by key to support deterministic unit test stubbing tags.sort(key=lambda x: x[0]) + vpn_unique_id = provider_config.get('vpn_unique_id', None) + if vpn_unique_id is None: + private_ip = inst.private_ip_address + else: + # TODO(tian): Using cluster name as hostname is problematic for + # multi-node cluster. Should use f'{unique_id}-{node_id}' + # TODO(tian): max_retry=1000 ==> infinite retry. + # TODO(tian): Check cloud status and set a timeout after the + # instance is ready on the cloud. + query_cmd = f'tailscale ip -4 {vpn_unique_id}' + rc, stdout, stderr = subprocess_utils.run_with_retries( + query_cmd, + max_retry=1000, + retry_wait_time=5, + retry_stderrs=['no such host', 'server misbehaving']) + subprocess_utils.handle_returncode( + rc, + query_cmd, + error_msg=('Failed to query Private IP in VPN ' + f'for cluster {cluster_name_on_cloud} ' + f'with unique id {vpn_unique_id}'), + stderr=stdout + stderr) + private_ip = stdout.strip() instances[inst.id] = [ common.InstanceInfo( instance_id=inst.id, - internal_ip=inst.private_ip_address, + internal_ip=private_ip, external_ip=inst.public_ip_address, tags=dict(tags), ) diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index 764d197493a..f8ba4dca79a 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -417,10 +417,18 @@ def _post_provision_setup( custom_resource: Optional[str]) -> provision_common.ClusterInfo: config_from_yaml = common_utils.read_yaml(cluster_yaml) provider_config = config_from_yaml.get('provider') - cluster_info = provision.get_cluster_info(cloud_name, - provision_record.region, - cluster_name.name_on_cloud, - provider_config=provider_config) + if (provider_config is not None and + provider_config.get('vpn_config', None) is not None): + get_info_status = rich_utils.safe_status( + '[bold cyan]Launching - Waiting for VPN setup[/]') + else: + get_info_status = rich_utils.empty_status() + with get_info_status: + cluster_info = provision.get_cluster_info( + cloud_name, + provision_record.region, + cluster_name.name_on_cloud, + provider_config=provider_config) if cluster_info.num_instances > 1: # Only worker nodes have logs in the per-instance log directory. Head @@ -456,7 +464,7 @@ def _post_provision_setup( logger.debug( f'\nWaiting for SSH to be available for {cluster_name!r} ...') wait_for_ssh(cluster_info, ssh_credentials) - logger.debug(f'SSH Conection ready for {cluster_name!r}') + logger.debug(f'SSH Connection ready for {cluster_name!r}') plural = '' if len(cluster_info.instances) == 1 else 's' logger.info(f'{colorama.Fore.GREEN}Successfully provisioned ' f'or found existing instance{plural}.' diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 6f1df43cfd5..31194327f60 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -42,6 +42,13 @@ provider: vpc_name: {{vpc_name}} {% endif %} use_internal_ips: {{use_internal_ips}} +{%- if vpn_config is not none %} + vpn_unique_id: {{vpn_unique_id}} + vpn_config: + {%- for key, value in vpn_config.items() %} + {{key}}: {{value}} + {%- endfor %} +{%- endif %} # Disable launch config check for worker nodes as it can cause resource # leakage. # Reference: https://github.com/ray-project/ray/blob/cd1ba65e239360c8a7b130f991ed414eccc063ce/python/ray/autoscaler/_private/autoscaler.py#L1115 @@ -90,6 +97,7 @@ available_node_types: # The bootcmd is to disable automatic APT updates, to avoid the lock # when user call `apt install` on the node. # Reference: https://askubuntu.com/questions/1322292/how-do-i-turn-off-automatic-updates-completely-and-for-real + # TODO(tian): Have a class for different VPN provider and return different setup commands. UserData: | #cloud-config users: @@ -109,6 +117,12 @@ available_node_types: - path: /etc/apt/apt.conf.d/10cloudinit-disable content: | APT::Periodic::Enable "0"; + {%- if vpn_cloud_init_commands is not none %} + runcmd: + {%- for cmd in vpn_cloud_init_commands %} + - {{cmd}} + {%- endfor %} + {%- endif %} TagSpecifications: - ResourceType: instance Tags: diff --git a/sky/utils/rich_utils.py b/sky/utils/rich_utils.py index 4b3dd07257e..4012381687b 100644 --- a/sky/utils/rich_utils.py +++ b/sky/utils/rich_utils.py @@ -43,6 +43,11 @@ def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]: return _NoOpConsoleStatus() +def empty_status() -> _NoOpConsoleStatus: + """An empty status spinner.""" + return _NoOpConsoleStatus() + + def force_update_status(msg: str): """Update the status message even if sky_logging.is_silent() is true.""" if (threading.current_thread() is threading.main_thread() and diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index d02436619c3..5c805a99824 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -470,6 +470,32 @@ def get_cluster_schema(): } +_VPN_CONFIG_SCHEMA = { + 'vpn': { + 'type': 'object', + 'required': [], + 'additionalProperties': False, + 'properties': { + 'tailscale': { + 'type': 'object', + 'required': ['auth_key', 'api_key', 'tailnet'], + 'additionalProperties': False, + 'properties': { + 'auth_key': { + 'type': 'string', + }, + 'api_key': { + 'type': 'string', + }, + 'tailnet': { + 'type': 'string', + }, + }, + }, + } + } +} + _NETWORK_CONFIG_SCHEMA = { 'vpc_name': { 'oneOf': [{ @@ -565,6 +591,7 @@ def get_config_schema(): 'security_group_name': { 'type': 'string', }, + **_VPN_CONFIG_SCHEMA, **_LABELS_SCHEMA, **_NETWORK_CONFIG_SCHEMA, }, diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index bd48a91a796..4cc17ec35f7 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -144,6 +144,7 @@ def _kill_processes(processes: List[psutil.Process]) -> None: def run_with_retries( cmd: str, max_retry: int = 3, + retry_wait_time: float = 2.0, retry_returncode: Optional[List[int]] = None, retry_stderrs: Optional[List[str]] = None) -> Tuple[int, str, str]: """Run a command and retry if it fails due to the specified reasons. @@ -151,6 +152,8 @@ def run_with_retries( Args: cmd: The command to run. max_retry: The maximum number of retries. + retry_wait_time: The time to wait between retries. The actual wait time + will be a random value between 0 and this value. retry_returncode: The returncodes that should be retried. retry_stderr: The cmd needs to be retried if the stderr contains any of the strings in this list. @@ -170,7 +173,7 @@ def run_with_retries( logger.debug( f'Retrying command due to returncode {returncode}: {cmd}') retry_cnt += 1 - time.sleep(random.uniform(0, 1) * 2) + time.sleep(random.uniform(0, 1) * retry_wait_time) continue if retry_stderrs is None: @@ -179,8 +182,11 @@ def run_with_retries( need_retry = False for retry_err in retry_stderrs: if retry_err in stderr: + logger.debug( + f'Retrying command due to retry err {retry_err!r}: ' + f'{cmd}, stderr: {stderr}') retry_cnt += 1 - time.sleep(random.uniform(0, 1) * 2) + time.sleep(random.uniform(0, 1) * retry_wait_time) need_retry = True break if need_retry: diff --git a/tests/test_jobs.py b/tests/test_jobs.py index e4e18d30120..9b8b927afb4 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -4,6 +4,7 @@ from sky import backends from sky import exceptions from sky import global_user_state +from sky import skypilot_config from sky.utils import db_utils from sky.utils import resources_utils @@ -18,6 +19,10 @@ def _mock_db_conn(self, monkeypatch, tmp_path): monkeypatch.setattr( global_user_state, '_DB', db_utils.SQLiteConn(str(db_path), global_user_state.create_table)) + monkeypatch.setattr(backends.CloudVmRayResourceHandle, + '_get_vpn_config', lambda *args, **kwargs: None) + # Disable any configuration in the unittest + monkeypatch.setattr(skypilot_config, '_dict', None) @pytest.fixture def _mock_cluster_state(self, _mock_db_conn, enable_all_clouds): diff --git a/tests/test_jobs_and_serve.py b/tests/test_jobs_and_serve.py index 61d8a9f0a98..93cbc4e185a 100644 --- a/tests/test_jobs_and_serve.py +++ b/tests/test_jobs_and_serve.py @@ -44,6 +44,8 @@ def _mock_db_conn(monkeypatch, tmp_path): monkeypatch.setattr( global_user_state, '_DB', db_utils.SQLiteConn(str(db_path), global_user_state.create_table)) + monkeypatch.setattr(backends.CloudVmRayResourceHandle, '_get_vpn_config', + lambda *args, **kwargs: None) @pytest.fixture