diff --git a/docs/source/cloud-setup/cloud-permissions/gcp.rst b/docs/source/cloud-setup/cloud-permissions/gcp.rst index 379ba7a672a..064aeb7e8d5 100644 --- a/docs/source/cloud-setup/cloud-permissions/gcp.rst +++ b/docs/source/cloud-setup/cloud-permissions/gcp.rst @@ -267,3 +267,60 @@ See details in :ref:`config-yaml`. Example use cases include using a private VP VPC with fine-grained constraints, typically created via Terraform or manually. The custom VPC should contain the :ref:`required firewall rules `. + + +.. _gcp-use-internal-ips: + + +Using Internal IPs +----------------------- +For security reason, users may only want to use internal IPs for SkyPilot instances. +To do so, you can use SkyPilot's global config file ``~/.sky/config.yaml`` to specify the ``gcp.use_internal_ips`` and ``gcp.ssh_proxy_command`` fields (to see the detailed syntax, see :ref:`config-yaml`): + +.. code-block:: yaml + + gcp: + use_internal_ips: true + # VPC with NAT setup, see below + vpc_name: my-vpc-name + ssh_proxy_command: ssh -W %h:%p -o StrictHostKeyChecking=no myself@my.proxy + +The ``gcp.ssh_proxy_command`` field is optional. If SkyPilot is run on a machine that can directly access the internal IPs of the instances, it can be omitted. Otherwise, it should be set to a command that can be used to proxy SSH connections to the internal IPs of the instances. + + +Cloud NAT Setup +~~~~~~~~~~~~~~~~ + +Instances created with internal IPs only on GCP cannot access public internet by default. To make sure SkyPilot can install the dependencies correctly on the instances, +cloud NAT needs to be setup for the VPC (see `GCP's documentation `__ for details). + + +Cloud NAT is a regional resource, so it will need to be created in each region that SkyPilot will be used in. + + +.. image:: ../../images/screenshots/gcp/cloud-nat.png + :width: 80% + :align: center + :alt: GCP Cloud NAT + +To limit SkyPilot to use some specific regions only, you can specify the ``gcp.ssh_proxy_command`` to be a dict mapping from region to the SSH proxy command for that region (see :ref:`config-yaml` for details): + +.. code-block:: yaml + + gcp: + use_internal_ips: true + vpc_name: my-vpc-name + ssh_proxy_command: + us-west1: ssh -W %h:%p -o StrictHostKeyChecking=no myself@my.us-west1.proxy + us-east1: ssh -W %h:%p -o StrictHostKeyChecking=no myself@my.us-west2.proxy + +If proxy is not needed, but the regions need to be limited, you can set the ``gcp.ssh_proxy_command`` to be a dict mapping from region to ``null``: + +.. code-block:: yaml + + gcp: + use_internal_ips: true + vpc_name: my-vpc-name + ssh_proxy_command: + us-west1: null + us-east1: null diff --git a/docs/source/images/screenshots/gcp/cloud-nat.png b/docs/source/images/screenshots/gcp/cloud-nat.png new file mode 100644 index 00000000000..951056d5659 Binary files /dev/null and b/docs/source/images/screenshots/gcp/cloud-nat.png differ diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 51f8ef92c10..16a22350209 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -121,6 +121,31 @@ Available fields and semantics: # will be added. vpc_name: skypilot-vpc + # Should instances be assigned private IPs only? (optional) + # + # Set to true to use private IPs to communicate between the local client and + # any SkyPilot nodes. This requires the networking stack be properly set up. + # + # This flag is typically set together with 'vpc_name' above and + # 'ssh_proxy_command' below. + # + # Default: false. + use_internal_ips: true + # SSH proxy command (optional). + # + # Please refer to the aws.ssh_proxy_command section above for more details. + ### Format 1 ### + # A string; the same proxy command is used for all regions. + ssh_proxy_command: ssh -W %h:%p -i ~/.ssh/sky-key -o StrictHostKeyChecking=no gcpuser@ + ### Format 2 ### + # A dict mapping region names to region-specific proxy commands. + # NOTE: This restricts SkyPilot's search space for this cloud to only use + # the specified regions and not any other regions in this cloud. + ssh_proxy_command: + us-central1: ssh -W %h:%p -p 1234 -o StrictHostKeyChecking=no myself@my.us-central1.proxy + us-west1: ssh -W %h:%p -i ~/.ssh/sky-key -o StrictHostKeyChecking=no gcpuser@ + + # Reserved capacity (optional). # # The specific reservation to be considered when provisioning clusters on GCP. diff --git a/examples/job_queue/job.yaml b/examples/job_queue/job.yaml index c54e3d9a173..aa9c3502247 100644 --- a/examples/job_queue/job.yaml +++ b/examples/job_queue/job.yaml @@ -17,7 +17,7 @@ setup: | run: | timestamp=$(date +%s) conda env list - for i in {1..120}; do + for i in {1..140}; do echo "$timestamp $i" sleep 1 done diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index d226828bba8..59de2aba731 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -48,6 +48,7 @@ from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import env_options +from sky.utils import remote_cluster_yaml_utils from sky.utils import rich_utils from sky.utils import subprocess_utils from sky.utils import timeline @@ -64,7 +65,6 @@ # NOTE: keep in sync with the cluster template 'file_mounts'. SKY_REMOTE_APP_DIR = '~/.sky/sky_app' -SKY_RAY_YAML_REMOTE_PATH = '~/.sky/sky_ray.yml' # Exclude subnet mask from IP address regex. IP_ADDR_REGEX = r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}(?!/\d{1,2})\b' SKY_REMOTE_PATH = '~/.sky/wheels' @@ -1013,18 +1013,17 @@ def write_cluster_config( # If the current code is run by controller, propagate the real # calling user which should've been passed in as the # SKYPILOT_USER env var (see - # execution.py::_shared_controller_env_vars). + # controller_utils.shared_controller_vars_to_fill(). 'user': get_cleaned_username( os.environ.get(constants.USER_ENV_VAR, '')), - # AWS only: - 'aws_vpc_name': skypilot_config.get_nested(('aws', 'vpc_name'), - None), + # Networking configs 'use_internal_ips': skypilot_config.get_nested( - ('aws', 'use_internal_ips'), False), - # Not exactly AWS only, but we only test it's supported on AWS - # for now: + (str(cloud).lower(), 'use_internal_ips'), False), 'ssh_proxy_command': ssh_proxy_command, + 'vpc_name': skypilot_config.get_nested( + (str(cloud).lower(), 'vpc_name'), None), + # User-supplied instance tags. 'instance_tags': instance_tags, @@ -1033,8 +1032,6 @@ def write_cluster_config( 'resource_group': f'{cluster_name}-{region_name}', # GCP only: - 'gcp_vpc_name': skypilot_config.get_nested(('gcp', 'vpc_name'), - None), 'gcp_project_id': gcp_project_id, 'specific_reservations': filtered_specific_reservations, 'num_specific_reserved_workers': num_specific_reserved_workers, @@ -1057,7 +1054,8 @@ def write_cluster_config( 'sky_remote_path': SKY_REMOTE_PATH, 'sky_local_path': str(local_wheel_path), # Add yaml file path to the template variables. - 'sky_ray_yaml_remote_path': SKY_RAY_YAML_REMOTE_PATH, + 'sky_ray_yaml_remote_path': + remote_cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH, 'sky_ray_yaml_local_path': tmp_yaml_path if not isinstance(cloud, clouds.Local) else yaml_path, diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 7fc596d8457..6007bf37726 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -17,7 +17,7 @@ import threading import time import typing -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import colorama import filelock @@ -33,7 +33,6 @@ from sky import resources as resources_lib from sky import serve as serve_lib from sky import sky_logging -from sky import skypilot_config from sky import spot as spot_lib from sky import status_lib from sky import task as task_lib @@ -1460,7 +1459,7 @@ def _retry_zones( cloud_user_identity: Optional[List[str]], prev_cluster_status: Optional[status_lib.ClusterStatus], prev_handle: Optional['CloudVmRayResourceHandle'], - ): + ) -> Dict[str, Any]: """The provision retry loop.""" style = colorama.Style fore = colorama.Fore @@ -2180,7 +2179,7 @@ def provision_with_retries( to_provision_config: ToProvisionConfig, dryrun: bool, stream_logs: bool, - ): + ) -> Dict[str, Any]: """Provision with retries for all launchable resources.""" cluster_name = to_provision_config.cluster_name to_provision = to_provision_config.resources @@ -2218,7 +2217,7 @@ def provision_with_retries( prev_cluster_status=prev_cluster_status, prev_handle=prev_handle) if dryrun: - return + return config_dict except (exceptions.InvalidClusterNameError, exceptions.NotSupportedError, exceptions.CloudUserIdentityError) as e: @@ -2343,8 +2342,9 @@ def __init__(self, self.cluster_name_on_cloud = cluster_name_on_cloud self._cluster_yaml = cluster_yaml.replace(os.path.expanduser('~'), '~', 1) - # List of (internal_ip, external_ip) tuples for all the nodes - # in the cluster, sorted by the external ips. + # List of (internal_ip, feasible_ip) tuples for all the nodes in the + # cluster, sorted by the feasible ips. The feasible ips can be either + # internal or external ips, depending on the use_internal_ips flag. self.stable_internal_external_ips = stable_internal_external_ips self.stable_ssh_ports = stable_ssh_ports self.launched_nodes = launched_nodes @@ -2374,6 +2374,14 @@ def __repr__(self): def get_cluster_name(self): return self.cluster_name + def _use_internal_ips(self): + """Returns whether to use internal IPs for SSH connections.""" + # Directly load the `use_internal_ips` flag from the cluster yaml + # instead of `skypilot_config` as the latter can be changed after the + # cluster is UP. + return common_utils.read_yaml(self.cluster_yaml).get( + 'provider', {}).get('use_internal_ips', False) + def _maybe_make_local_handle(self): """Adds local handle for the local cloud case. @@ -2481,40 +2489,43 @@ def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool: return (ips is not None and len(ips) == self.num_node_ips and all(ip is not None for ip in ips)) + use_internal_ips = self._use_internal_ips() + + # cluster_feasible_ips is the list of IPs of the nodes in the cluster + # which can be used to connect to the cluster. It is a list of external + # IPs if the cluster is assigned public IPs, otherwise it is a list of + # internal IPs. + cluster_feasible_ips: List[str] if is_provided_ips_valid(external_ips): logger.debug(f'Using provided external IPs: {external_ips}') - cluster_external_ips = typing.cast(List[str], external_ips) + cluster_feasible_ips = typing.cast(List[str], external_ips) else: - cluster_external_ips = backend_utils.get_node_ips( + cluster_feasible_ips = backend_utils.get_node_ips( self.cluster_yaml, self.launched_nodes, handle=self, head_ip_max_attempts=max_attempts, worker_ip_max_attempts=max_attempts, - get_internal_ips=False) + get_internal_ips=use_internal_ips) - if self.cached_external_ips == cluster_external_ips: + if self.cached_external_ips == cluster_feasible_ips: logger.debug('Skipping the fetching of internal IPs as the cached ' 'external IPs matches the newly fetched ones.') # Optimization: If the cached external IPs are the same as the - # retrieved external IPs, then we can skip retrieving internal + # retrieved feasible IPs, then we can skip retrieving internal # IPs since the cached IPs are up-to-date. return logger.debug( 'Cached external IPs do not match with the newly fetched ones: ' - f'cached ({self.cached_external_ips}), new ({cluster_external_ips})' + f'cached ({self.cached_external_ips}), new ({cluster_feasible_ips})' ) - is_cluster_aws = (self.launched_resources is not None and - isinstance(self.launched_resources.cloud, clouds.AWS)) - if is_cluster_aws and skypilot_config.get_nested( - keys=('aws', 'use_internal_ips'), default_value=False): + if use_internal_ips: # Optimization: if we know use_internal_ips is True (currently - # only exposed for AWS), then our AWS NodeProvider is - # guaranteed to pick subnets that will not assign public IPs, - # thus the first list of IPs returned above are already private - # IPs. So skip the second query. - cluster_internal_ips = list(cluster_external_ips) + # only exposed for AWS and GCP), then our provisioner is guaranteed + # to not assign public IPs, thus the first list of IPs returned + # above are already private IPs. So skip the second query. + cluster_internal_ips = list(cluster_feasible_ips) elif is_provided_ips_valid(internal_ips): logger.debug(f'Using provided internal IPs: {internal_ips}') cluster_internal_ips = typing.cast(List[str], internal_ips) @@ -2527,13 +2538,16 @@ def is_provided_ips_valid(ips: Optional[List[Optional[str]]]) -> bool: worker_ip_max_attempts=max_attempts, get_internal_ips=True) - assert len(cluster_external_ips) == len(cluster_internal_ips), ( + assert len(cluster_feasible_ips) == len(cluster_internal_ips), ( f'Cluster {self.cluster_name!r}:' f'Expected same number of internal IPs {cluster_internal_ips}' - f' and external IPs {cluster_external_ips}.') + f' and external IPs {cluster_feasible_ips}.') + # List of (internal_ip, feasible_ip) tuples for all the nodes in the + # cluster, sorted by the feasible ips. The feasible ips can be either + # internal or external ips, depending on the use_internal_ips flag. internal_external_ips: List[Tuple[str, str]] = list( - zip(cluster_internal_ips, cluster_external_ips)) + zip(cluster_internal_ips, cluster_feasible_ips)) # Ensure head node is the first element, then sort based on the # external IPs for stableness @@ -3204,6 +3218,8 @@ def _sync_file_mounts( storage_mounts: Optional[Dict[Path, storage_lib.Storage]], ) -> None: """Mounts all user files to the remote nodes.""" + controller_utils.replace_skypilot_config_path_in_file_mounts( + handle.launched_resources.cloud, all_file_mounts) self._execute_file_mounts(handle, all_file_mounts) self._execute_storage_mounts(handle, storage_mounts) self._set_storage_mounts_metadata(handle.cluster_name, storage_mounts) diff --git a/sky/execution.py b/sky/execution.py index b473df4a95b..9a6f5e98cd3 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -22,7 +22,6 @@ from sky.backends import backend_utils from sky.skylet import constants from sky.usage import usage_lib -from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import dag_utils from sky.utils import env_options @@ -666,21 +665,10 @@ def spot_launch( prefix = spot.SPOT_TASK_YAML_PREFIX remote_user_yaml_path = f'{prefix}/{dag.name}-{dag_uuid}.yaml' remote_user_config_path = f'{prefix}/{dag.name}-{dag_uuid}.config_yaml' - extra_vars, controller_resources_config = ( - controller_utils.skypilot_config_setup( - controller_type='spot', - controller_resources_config=spot.constants.CONTROLLER_RESOURCES, - remote_user_config_path=remote_user_config_path)) - try: - controller_resources = sky.Resources.from_yaml_config( - controller_resources_config) - except ValueError as e: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - controller_utils.CONTROLLER_RESOURCES_NOT_VALID_MESSAGE. - format(controller_type='spot', - err=common_utils.format_exception( - e, use_bracket=True))) from e + controller_resources = (controller_utils.get_controller_resources( + controller_type='spot', + controller_resources_config=spot.constants.CONTROLLER_RESOURCES)) + vars_to_fill = { 'remote_user_yaml_path': remote_user_yaml_path, 'user_yaml_path': f.name, @@ -688,7 +676,8 @@ def spot_launch( # Note: actual spot cluster name will be - 'dag_name': dag.name, 'retry_until_up': retry_until_up, - **extra_vars, + 'remote_user_config_path': remote_user_config_path, + **controller_utils.shared_controller_vars_to_fill('spot'), } yaml_path = os.path.join(spot.SPOT_CONTROLLER_YAML_PREFIX, diff --git a/sky/serve/core.py b/sky/serve/core.py index d1c9e1914fc..e285492987f 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -83,28 +83,17 @@ def up( serve_utils.generate_remote_config_yaml_file_name(service_name)) controller_log_file = ( serve_utils.generate_remote_controller_log_file_name(service_name)) - extra_vars, controller_resources_config = ( - controller_utils.skypilot_config_setup( - controller_type='serve', - controller_resources_config=serve_constants. - CONTROLLER_RESOURCES, - remote_user_config_path=remote_config_yaml_path)) - try: - controller_resources = sky.Resources.from_yaml_config( - controller_resources_config) - except ValueError as e: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - controller_utils.CONTROLLER_RESOURCES_NOT_VALID_MESSAGE. - format(controller_type='serve', - err=common_utils.format_exception( - e, use_bracket=True))) from e + controller_resources = (controller_utils.get_controller_resources( + controller_type='serve', + controller_resources_config=serve_constants.CONTROLLER_RESOURCES)) + vars_to_fill = { 'remote_task_yaml_path': remote_tmp_task_yaml_path, 'local_task_yaml_path': service_file.name, 'service_name': service_name, 'controller_log_file': controller_log_file, - **extra_vars, + 'remote_user_config_path': remote_config_yaml_path, + **controller_utils.shared_controller_vars_to_fill('serve'), } backend_utils.fill_template(serve_constants.CONTROLLER_TEMPLATE, vars_to_fill, diff --git a/sky/skylet/events.py b/sky/skylet/events.py index ade43104048..4287acd394b 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -11,13 +11,13 @@ import yaml from sky import sky_logging -from sky.backends import backend_utils from sky.backends import cloud_vm_ray_backend from sky.serve import serve_utils from sky.skylet import autostop_lib from sky.skylet import job_lib from sky.spot import spot_utils from sky.utils import common_utils +from sky.utils import remote_cluster_yaml_utils from sky.utils import ux_utils # Seconds of sleep between the processing of skylet events. @@ -104,8 +104,8 @@ class AutostopEvent(SkyletEvent): def __init__(self): super().__init__() autostop_lib.set_last_active_time_to_now() - self._ray_yaml_path = os.path.abspath( - os.path.expanduser(backend_utils.SKY_RAY_YAML_REMOTE_PATH)) + self._ray_yaml_path = ( + remote_cluster_yaml_utils.get_cluster_yaml_absolute_path()) def _run(self): autostop_config = autostop_lib.get_autostop_config() @@ -139,16 +139,9 @@ def _stop_cluster(self, autostop_config): cloud_vm_ray_backend.CloudVmRayBackend.NAME): autostop_lib.set_autostopping_started() - config = common_utils.read_yaml(self._ray_yaml_path) + config = remote_cluster_yaml_utils.load_cluster_yaml() + provider_name = remote_cluster_yaml_utils.get_provider_name(config) - provider_module = config['provider']['module'] - # Examples: - # 'sky.skylet.providers.aws.AWSNodeProviderV2' -> 'aws' - # 'sky.provision.aws' -> 'aws' - provider_search = re.search(r'(?:providers|provision)\.(\w+)\.?', - provider_module) - assert provider_search is not None, config - provider_name = provider_search.group(1).lower() if provider_name in ('aws', 'gcp'): logger.info('Using new provisioner to stop the cluster.') self._stop_cluster_with_new_provisioner(autostop_config, config, diff --git a/sky/skylet/providers/gcp/config.py b/sky/skylet/providers/gcp/config.py index 68c535d73fd..781737c6cfc 100644 --- a/sky/skylet/providers/gcp/config.py +++ b/sky/skylet/providers/gcp/config.py @@ -910,6 +910,9 @@ def _configure_subnet(config, compute): ], } ] + if config["provider"].get("use_internal_ips", False): + # Removing this key means the VM will not be assigned an external IP. + default_interfaces[0].pop("accessConfigs") for node_config in node_configs: # The not applicable key will be removed during node creation @@ -920,7 +923,10 @@ def _configure_subnet(config, compute): # TPU if "networkConfig" not in node_config: node_config["networkConfig"] = copy.deepcopy(default_interfaces)[0] - node_config["networkConfig"].pop("accessConfigs") + # TPU doesn't have accessConfigs + node_config["networkConfig"].pop("accessConfigs", None) + if config["provider"].get("use_internal_ips", False): + node_config["networkConfig"]["enableExternalIps"] = False return config diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 8c9c6fd5e5a..78fe4cf6ca5 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -49,7 +49,6 @@ import yaml from sky import sky_logging -from sky.clouds import cloud_registry from sky.utils import common_utils from sky.utils import schemas @@ -74,6 +73,7 @@ # The loaded config. _dict = None +_loaded_config_path = None def get_nested(keys: Iterable[str], default_value: Any) -> Any: @@ -127,27 +127,8 @@ def to_dict() -> Dict[str, Any]: return {} -def _syntax_check_for_ssh_proxy_command(cloud: str) -> None: - ssh_proxy_command_config = get_nested((cloud.lower(), 'ssh_proxy_command'), - None) - if ssh_proxy_command_config is None or isinstance(ssh_proxy_command_config, - str): - return - - if isinstance(ssh_proxy_command_config, dict): - for region, cmd in ssh_proxy_command_config.items(): - if cmd and not isinstance(cmd, str): - raise ValueError( - f'Invalid ssh_proxy_command config for region {region!r} ' - f'(expected a str): {cmd!r}') - return - raise ValueError( - 'Invalid ssh_proxy_command config (expected a str or a dict with ' - f'region names as keys): {ssh_proxy_command_config!r}') - - def _try_load_config() -> None: - global _dict + global _dict, _loaded_config_path config_path_via_env_var = os.environ.get(ENV_VAR_SKYPILOT_CONFIG) if config_path_via_env_var is not None: config_path = config_path_via_env_var @@ -156,6 +137,7 @@ def _try_load_config() -> None: config_path = os.path.expanduser(config_path) if os.path.exists(config_path): logger.debug(f'Using config path: {config_path}') + _loaded_config_path = config_path try: _dict = common_utils.read_yaml(config_path) logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}') @@ -168,8 +150,6 @@ def _try_load_config() -> None: f'Invalid config YAML ({config_path}): ', skip_none=False) - for cloud in cloud_registry.CLOUD_REGISTRY: - _syntax_check_for_ssh_proxy_command(cloud) logger.debug('Config syntax check passed.') diff --git a/sky/task.py b/sky/task.py index b2af3b1ae24..d7e22323aa1 100644 --- a/sky/task.py +++ b/sky/task.py @@ -759,8 +759,9 @@ def set_file_mounts(self, file_mounts: Optional[Dict[str, str]]) -> 'Task': raise ValueError( 'File mount destination paths cannot be cloud storage') if not data_utils.is_cloud_store_url(source): - if not os.path.exists( - os.path.abspath(os.path.expanduser(source))): + if (not os.path.exists( + os.path.abspath(os.path.expanduser(source))) and + not source.startswith('skypilot:')): with ux_utils.print_exception_no_traceback(): raise ValueError( f'File mount source {source!r} does not exist ' diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 1494a2c7060..efb95799366 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -36,9 +36,9 @@ provider: security_group: # AWS config file must include security group name GroupName: {{security_group}} -{% if aws_vpc_name is not none %} +{% if vpc_name is not none %} # NOTE: This is a new field added by SkyPilot to force use a specific VPC. - vpc_name: {{aws_vpc_name}} + vpc_name: {{vpc_name}} {% endif %} {%- if docker_login_config is not none %} # We put docker login config in provider section because ray's schema disabled diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index ca73199e317..ab62e4f5413 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -28,9 +28,9 @@ provider: cache_stopped_nodes: True # The GCP project ID. project_id: {{gcp_project_id}} -{% if gcp_vpc_name is not none %} +{% if vpc_name is not none %} # NOTE: This is a new field added by SkyPilot to force use a specific VPC. - vpc_name: {{gcp_vpc_name}} + vpc_name: {{vpc_name}} {% endif %} # The firewall rule name for customized firewall rules. Only enabled # if we have ports requirement. @@ -46,6 +46,7 @@ provider: password: {{docker_login_config.password}} server: {{docker_login_config.server}} {%- endif %} + use_internal_ips: {{use_internal_ips}} {%- if tpu_vm %} _has_tpus: True {%- endif %} @@ -59,6 +60,9 @@ provider: auth: ssh_user: gcpuser ssh_private_key: {{ssh_private_key}} +{% if ssh_proxy_command is not none %} + ssh_proxy_command: {{ssh_proxy_command}} +{% endif %} available_node_types: ray_head_default: diff --git a/sky/templates/sky-serve-controller.yaml.j2 b/sky/templates/sky-serve-controller.yaml.j2 index 362bd4a93da..d49412fb9cd 100644 --- a/sky/templates/sky-serve-controller.yaml.j2 +++ b/sky/templates/sky-serve-controller.yaml.j2 @@ -16,9 +16,7 @@ setup: | file_mounts: {{remote_task_yaml_path}}: {{local_task_yaml_path}} -{% if user_config_path is not none %} - {{remote_user_config_path}}: {{user_config_path}} -{% endif %} + {{remote_user_config_path}}: skypilot:local_skypilot_config_path run: | # Start sky serve service. diff --git a/sky/templates/spot-controller.yaml.j2 b/sky/templates/spot-controller.yaml.j2 index ee9d863e83f..5181f9d4544 100644 --- a/sky/templates/spot-controller.yaml.j2 +++ b/sky/templates/spot-controller.yaml.j2 @@ -4,9 +4,7 @@ name: {{dag_name}} file_mounts: {{remote_user_yaml_path}}: {{user_yaml_path}} -{% if user_config_path is not none %} - {{remote_user_config_path}}: {{user_config_path}} -{% endif %} + {{remote_user_config_path}}: skypilot:local_skypilot_config_path setup: | {%- for cmd in cloud_dependencies_installation_commands %} @@ -23,6 +21,7 @@ setup: | ((ps aux | grep -v nohup | grep -v grep | grep -q -- "python3 -m sky.spot.dashboard.dashboard") || (nohup python3 -m sky.spot.dashboard.dashboard >> ~/.sky/spot-dashboard.log 2>&1 &)); run: | + # Start the controller for the current spot job. python -u -m sky.spot.controller {{remote_user_yaml_path}} \ --job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %} diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 00f39e3756a..bdf001965ff 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -6,11 +6,12 @@ import os import tempfile import typing -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import colorama from sky import exceptions +from sky import resources from sky import sky_logging from sky import skypilot_config from sky.clouds import gcp @@ -24,6 +25,7 @@ from sky.utils import ux_utils if typing.TYPE_CHECKING: + from sky import clouds from sky import task as task_lib from sky.backends import cloud_vm_ray_backend @@ -37,6 +39,9 @@ '{controller_type}.controller.resources is a valid resources spec. ' 'Details:\n {err}') +# The placeholder for the local skypilot config path in file mounts. +LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER = 'skypilot:local_skypilot_config_path' + @dataclasses.dataclass class _ControllerSpec: @@ -220,7 +225,11 @@ def download_and_stream_latest_job_log( return log_file -def _shared_controller_env_vars() -> Dict[str, str]: +def shared_controller_vars_to_fill(controller_type: str) -> Dict[str, str]: + vars_to_fill: Dict[str, Any] = { + 'cloud_dependencies_installation_commands': + _get_cloud_dependencies_installation_commands(controller_type) + } env_vars: Dict[str, str] = { env.value: '1' for env in env_options.Options if env.get() } @@ -232,14 +241,14 @@ def _shared_controller_env_vars() -> Dict[str, str]: # Skip cloud identity check to avoid the overhead. env_options.Options.SKIP_CLOUD_IDENTITY_CHECK.value: '1', }) - return env_vars + vars_to_fill['controller_envs'] = env_vars + return vars_to_fill -def skypilot_config_setup( +def get_controller_resources( controller_type: str, controller_resources_config: Dict[str, Any], - remote_user_config_path: str, -) -> Tuple[Dict[str, Any], Dict[str, Any]]: +) -> resources.Resources: """Read the skypilot config and setup the controller resources. Returns: @@ -248,67 +257,9 @@ def skypilot_config_setup( The controller_resources_config is the resources config that will be used to launch the controller. """ - vars_to_fill: Dict[str, Any] = { - 'cloud_dependencies_installation_commands': - _get_cloud_dependencies_installation_commands(controller_type) - } - controller_envs = _shared_controller_env_vars() controller_resources_config_copied: Dict[str, Any] = copy.copy( controller_resources_config) if skypilot_config.loaded(): - # Look up the contents of the already loaded configs via the - # 'skypilot_config' module. Don't simply read the on-disk file as - # it may have changed since this process started. - # - # Set any proxy command to None, because the controller would've - # been launched behind the proxy, and in general any nodes we - # launch may not have or need the proxy setup. (If the controller - # needs to launch mew clusters in another region/VPC, the user - # should properly set up VPC peering, which will allow the - # cross-region/VPC communication. The proxy command is orthogonal - # to this scenario.) - # - # This file will be uploaded to the controller node and will be - # used throughout the spot job's / service's recovery attempts - # (i.e., if it relaunches due to preemption, we make sure the - # same config is used). - # - # NOTE: suppose that we have a controller in old VPC, then user - # changes 'vpc_name' in the config and does a 'spot launch' / - # 'serve up'. In general, the old controller may not successfully - # launch the job in the new VPC. This happens if the two VPCs don’t - # have peering set up. Like other places in the code, we assume - # properly setting up networking is user's responsibilities. - # TODO(zongheng): consider adding a basic check that checks - # controller VPC (or name) == the spot job's / service's VPC - # (or name). It may not be a sufficient check (as it's always - # possible that peering is not set up), but it may catch some - # obvious errors. - # TODO(zhwu): hacky. We should only set the proxy command of the - # cloud where the controller is launched (currently, only aws user - # uses proxy_command). - proxy_command_key = ('aws', 'ssh_proxy_command') - ssh_proxy_command = skypilot_config.get_nested(proxy_command_key, None) - config_dict = skypilot_config.to_dict() - if isinstance(ssh_proxy_command, str): - config_dict = skypilot_config.set_nested(proxy_command_key, None) - elif isinstance(ssh_proxy_command, dict): - # Instead of removing the key, we set the value to empty string - # so that the controller will only try the regions specified by - # the keys. - ssh_proxy_command = {k: None for k in ssh_proxy_command} - config_dict = skypilot_config.set_nested(proxy_command_key, - ssh_proxy_command) - - with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmpfile: - common_utils.dump_yaml(tmpfile.name, config_dict) - controller_envs[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = ( - remote_user_config_path) - vars_to_fill.update({ - 'user_config_path': tmpfile.name, - 'remote_user_config_path': remote_user_config_path, - }) - # Override the controller resources with the ones specified in the # config. custom_controller_resources_config = skypilot_config.get_nested( @@ -316,13 +267,95 @@ def skypilot_config_setup( if custom_controller_resources_config is not None: controller_resources_config_copied.update( custom_controller_resources_config) - else: - # If the user config is not loaded, manually set this to None - # so that the template won't render this. - vars_to_fill['user_config_path'] = None - vars_to_fill['controller_envs'] = controller_envs - return vars_to_fill, controller_resources_config_copied + try: + controller_resources = resources.Resources.from_yaml_config( + controller_resources_config_copied) + except ValueError as e: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + CONTROLLER_RESOURCES_NOT_VALID_MESSAGE.format( + controller_type=controller_type, + err=common_utils.format_exception(e, + use_bracket=True))) from e + + return controller_resources + + +def _setup_proxy_command_on_controller( + controller_launched_cloud: 'clouds.Cloud') -> Dict[str, Any]: + """Sets up proxy command on the controller. + + This function should be called on the controller (remote cluster), which + has the `~/.sky/sky_ray.yaml` file. + """ + # Look up the contents of the already loaded configs via the + # 'skypilot_config' module. Don't simply read the on-disk file as + # it may have changed since this process started. + # + # Set any proxy command to None, because the controller would've + # been launched behind the proxy, and in general any nodes we + # launch may not have or need the proxy setup. (If the controller + # needs to launch mew clusters in another region/VPC, the user + # should properly set up VPC peering, which will allow the + # cross-region/VPC communication. The proxy command is orthogonal + # to this scenario.) + # + # This file will be uploaded to the controller node and will be + # used throughout the spot job's / service's recovery attempts + # (i.e., if it relaunches due to preemption, we make sure the + # same config is used). + # + # NOTE: suppose that we have a controller in old VPC, then user + # changes 'vpc_name' in the config and does a 'spot launch' / + # 'serve up'. In general, the old controller may not successfully + # launch the job in the new VPC. This happens if the two VPCs don’t + # have peering set up. Like other places in the code, we assume + # properly setting up networking is user's responsibilities. + # TODO(zongheng): consider adding a basic check that checks + # controller VPC (or name) == the spot job's / service's VPC + # (or name). It may not be a sufficient check (as it's always + # possible that peering is not set up), but it may catch some + # obvious errors. + proxy_command_key = (str(controller_launched_cloud).lower(), + 'ssh_proxy_command') + ssh_proxy_command = skypilot_config.get_nested(proxy_command_key, None) + config_dict = skypilot_config.to_dict() + if isinstance(ssh_proxy_command, str): + config_dict = skypilot_config.set_nested(proxy_command_key, None) + elif isinstance(ssh_proxy_command, dict): + # Instead of removing the key, we set the value to empty string + # so that the controller will only try the regions specified by + # the keys. + ssh_proxy_command = {k: None for k in ssh_proxy_command} + config_dict = skypilot_config.set_nested(proxy_command_key, + ssh_proxy_command) + + return config_dict + + +def replace_skypilot_config_path_in_file_mounts( + cloud: 'clouds.Cloud', file_mounts: Optional[Dict[str, str]]): + """Replaces the SkyPilot config path in file mounts with the real path.""" + # TODO(zhwu): This function can be moved to `backend_utils` once we have + # more predefined file mounts that needs to be replaced after the cluster + # is provisioned, e.g., we may need to decide which cloud to create a bucket + # to be mounted to the cluster based on the cloud the cluster is actually + # launched on (after failover). + if file_mounts is None or not skypilot_config.loaded(): + return + replaced = False + with tempfile.NamedTemporaryFile('w', delete=False) as f: + new_skypilot_config = _setup_proxy_command_on_controller(cloud) + if new_skypilot_config is not None: + common_utils.dump_yaml(f.name, new_skypilot_config) + for remote_path, local_path in file_mounts.items(): + if local_path == LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER: + file_mounts[remote_path] = f.name + replaced = True + if replaced: + logger.debug(f'Replaced {LOCAL_SKYPILOT_CONFIG_PATH_PLACEHOLDER} with ' + f'the real path in file mounts: {file_mounts}') def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', diff --git a/sky/utils/remote_cluster_yaml_utils.py b/sky/utils/remote_cluster_yaml_utils.py new file mode 100644 index 00000000000..6f1a044953e --- /dev/null +++ b/sky/utils/remote_cluster_yaml_utils.py @@ -0,0 +1,37 @@ +"""Utility functions for cluster yaml file on remote cluster. + +This module should only be used on the remote cluster. +""" + +import os +import re + +from sky.utils import common_utils + +# The cluster yaml used to create the current cluster where the module is +# called. +SKY_CLUSTER_YAML_REMOTE_PATH = '~/.sky/sky_ray.yml' + + +def get_cluster_yaml_absolute_path() -> str: + """Return the absolute path of the cluster yaml file.""" + return os.path.abspath(os.path.expanduser(SKY_CLUSTER_YAML_REMOTE_PATH)) + + +def load_cluster_yaml() -> dict: + """Load the cluster yaml file.""" + return common_utils.read_yaml(get_cluster_yaml_absolute_path()) + + +def get_provider_name(config: dict) -> str: + """Return the name of the provider.""" + + provider_module = config['provider']['module'] + # Examples: + # 'sky.skylet.providers.aws.AWSNodeProviderV2' -> 'aws' + # 'sky.provision.aws' -> 'aws' + provider_search = re.search(r'(?:providers|provision)\.(\w+)\.?', + provider_module) + assert provider_search is not None, config + provider_name = provider_search.group(1).lower() + return provider_name diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 2fa7614e614..c1646b3fa77 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -458,6 +458,40 @@ def get_cluster_schema(): } +_NETWORK_CONFIG_SCHEMA = { + 'vpc_name': { + 'oneOf': [{ + 'type': 'string', + }, { + 'type': 'null', + }], + }, + 'use_internal_ips': { + 'type': 'boolean', + }, + 'ssh_proxy_command': { + 'oneOf': [{ + 'type': 'string', + }, { + 'type': 'null', + }, { + 'type': 'object', + 'required': [], + 'additionalProperties': { + 'anyOf': [ + { + 'type': 'string' + }, + { + 'type': 'null' + }, + ] + } + }] + }, +} + + def get_config_schema(): # pylint: disable=import-outside-toplevel from sky.utils import kubernetes_enums @@ -505,36 +539,7 @@ def get_config_schema(): 'type': 'string', }, }, - 'vpc_name': { - 'oneOf': [{ - 'type': 'string', - }, { - 'type': 'null', - }], - }, - 'use_internal_ips': { - 'type': 'boolean', - }, - 'ssh_proxy_command': { - 'oneOf': [{ - 'type': 'string', - }, { - 'type': 'null', - }, { - 'type': 'object', - 'required': [], - 'additionalProperties': { - 'anyOf': [ - { - 'type': 'string' - }, - { - 'type': 'null' - }, - ] - } - }] - }, + **_NETWORK_CONFIG_SCHEMA, } }, 'gcp': { @@ -550,13 +555,7 @@ def get_config_schema(): 'minItems': 1, 'maxItems': 1, }, - 'vpc_name': { - 'oneOf': [{ - 'type': 'string', - }, { - 'type': 'null', - }], - }, + **_NETWORK_CONFIG_SCHEMA, } }, 'kubernetes': { diff --git a/tests/test_config.py b/tests/test_config.py index dfe64f77f06..afa85cedf29 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -40,6 +40,7 @@ def _create_config_file(config_file_path: pathlib.Path) -> None: gcp: vpc_name: {VPC_NAME} + use_internal_ips: true kubernetes: networking: {NODEPORT_MODE_NAME} @@ -195,6 +196,7 @@ def test_config_get_set_nested(monkeypatch, tmp_path) -> None: assert skypilot_config.get_nested( ('aws', 'ssh_proxy_command'), None) is None assert skypilot_config.get_nested(('gcp', 'vpc_name'), None) == VPC_NAME + assert skypilot_config.get_nested(('gcp', 'use_internal_ips'), None) # Check config with only partial keys still works new_config3 = copy.copy(new_config2) @@ -230,3 +232,4 @@ def test_config_with_env(monkeypatch, tmp_path) -> None: assert skypilot_config.get_nested(('aws', 'ssh_proxy_command'), None) == PROXY_COMMAND assert skypilot_config.get_nested(('gcp', 'vpc_name'), None) == VPC_NAME + assert skypilot_config.get_nested(('gcp', 'use_internal_ips'), None) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 0ca34066f72..3c6df59faf1 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -194,12 +194,19 @@ def get_aws_region_for_quota_failover() -> Optional[str]: use_spot=True, region=None, zone=None) + original_resources = sky.Resources(cloud=sky.AWS(), + instance_type='p3.16xlarge', + use_spot=True) + + # Filter the regions with proxy command in ~/.sky/config.yaml. + filtered_regions = original_resources.get_valid_regions_for_launchable() + candidate_regions = [ + region for region in candidate_regions + if region.name in filtered_regions + ] for region in candidate_regions: - resources = sky.Resources(cloud=sky.AWS(), - instance_type='p3.16xlarge', - region=region.name, - use_spot=True) + resources = original_resources.copy(region=region.name) if not AWS.check_quota_available(resources): return region.name @@ -214,12 +221,21 @@ def get_gcp_region_for_quota_failover() -> Optional[str]: region=None, zone=None) + original_resources = sky.Resources(cloud=sky.GCP(), + instance_type='a2-ultragpu-1g', + accelerators={'A100-80GB': 1}, + use_spot=True) + + # Filter the regions with proxy command in ~/.sky/config.yaml. + filtered_regions = original_resources.get_valid_regions_for_launchable() + candidate_regions = [ + region for region in candidate_regions + if region.name in filtered_regions + ] + for region in candidate_regions: if not GCP.check_quota_available( - sky.Resources(cloud=sky.GCP(), - region=region.name, - accelerators={'A100-80GB': 1}, - use_spot=True)): + original_resources.copy(region=region.name)): return region.name return None @@ -1226,7 +1242,7 @@ def test_large_job_queue(generic_cloud: str): ], ], f'sky down -y {name}', - timeout=20 * 60, + timeout=25 * 60, ) run_one_test(test) @@ -1611,7 +1627,7 @@ def test_autostop(generic_cloud: str): f'sky status | grep {name} | grep "1m"', # Ensure the cluster is not stopped early. - 'sleep 45', + 'sleep 30', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. @@ -1669,7 +1685,7 @@ def test_autodown(generic_cloud: str): # Ensure autostop is set. f'sky status | grep {name} | grep "1m (down)"', # Ensure the cluster is not terminated early. - 'sleep 45', + 'sleep 30', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is terminated. f'sleep {autodown_timeout}', diff --git a/tests/test_yamls/test_multiple_accelerators_unordered.yaml b/tests/test_yamls/test_multiple_accelerators_unordered.yaml index db0fc9c5f7c..3bb26c197ce 100644 --- a/tests/test_yamls/test_multiple_accelerators_unordered.yaml +++ b/tests/test_yamls/test_multiple_accelerators_unordered.yaml @@ -1,7 +1,7 @@ name: multi-accelerators-unordered resources: - accelerators: {'A100-40GB:1', 'T4:1', 'V100:1', 'K80:1'} + accelerators: {'A100-40GB:1', 'T4:1', 'V100:1'} run: | nvidia-smi