From 4f5b1073e1280b68bf649d8103b08deb994a926a Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Wed, 4 Oct 2023 15:34:33 -0700 Subject: [PATCH] UX: Allow inferring cloud from region or zone. (#2632) * UX: Allow infering cloud from region or zone. * format * minor fix * Remove Local cloud from registry. * UX for 1-cloud cases * Format * Fix test fixtures. * isort --- sky/clouds/cloud.py | 13 +++- sky/clouds/cloud_registry.py | 9 +++ sky/clouds/kubernetes.py | 11 ++- sky/clouds/local.py | 7 +- sky/clouds/service_catalog/common.py | 9 ++- sky/resources.py | 60 ++++++++++++--- tests/common.py | 62 +++++++++++++++ tests/conftest.py | 62 ++------------- tests/test_optimizer_dryruns.py | 110 ++++++++++++++++----------- 9 files changed, 227 insertions(+), 116 deletions(-) create mode 100644 tests/common.py diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index d9aeca0cca2..29eb36de54f 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -410,8 +410,17 @@ def instance_type_exists(self, instance_type): """Returns whether the instance type exists for this cloud.""" raise NotImplementedError - def validate_region_zone(self, region: Optional[str], zone: Optional[str]): - """Validates the region and zone.""" + def validate_region_zone( + self, region: Optional[str], + zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + """Validates whether region and zone exist in the catalog. + + Returns: + A tuple of region and zone, if validated. + + Raises: + ValueError: If region or zone is invalid or not supported. + """ return service_catalog.validate_region_zone(region, zone, clouds=self._REPR.lower()) diff --git a/sky/clouds/cloud_registry.py b/sky/clouds/cloud_registry.py index 5c4b10b9fd4..f3b8ad8f6d2 100644 --- a/sky/clouds/cloud_registry.py +++ b/sky/clouds/cloud_registry.py @@ -15,6 +15,15 @@ class _CloudRegistry(dict): def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']: if name is None: return None + if name.lower() == 'local': + # Backward compatibility. global_user_state's DB may have recorded + # Local cloud, and we've just removed it from the registry, and + # global_user_state.get_enabled_clouds() would call into this func + # and fail. + # + # TODO(skypilot): have a better way to handle clouds removed from + # registry if needed. + return None if name.lower() not in self: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Cloud {name!r} is not a valid cloud among ' diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index e4355f4bd30..fede03c9044 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -46,7 +46,8 @@ class Kubernetes(clouds.Cloud): _DEFAULT_MEMORY_CPU_RATIO = 1 _DEFAULT_MEMORY_CPU_RATIO_WITH_GPU = 4 # Allocate more memory for GPU tasks _REPR = 'Kubernetes' - _regions: List[clouds.Region] = [clouds.Region('kubernetes')] + _SINGLETON_REGION = 'kubernetes' + _regions: List[clouds.Region] = [clouds.Region(_SINGLETON_REGION)] _CLOUD_UNSUPPORTED_FEATURES = { # TODO(romilb): Stopping might be possible to implement with # container checkpointing introduced in Kubernetes v1.25. See: @@ -325,7 +326,13 @@ def instance_type_exists(self, instance_type: str) -> bool: instance_type) def validate_region_zone(self, region: Optional[str], zone: Optional[str]): - # Kubernetes doesn't have regions or zones, so we don't need to validate + if region != self._SINGLETON_REGION: + raise ValueError( + 'Kubernetes support does not support setting region.' + ' Cluster used is determined by the kubeconfig.') + if zone is not None: + raise ValueError('Kubernetes support does not support setting zone.' + ' Cluster used is determined by the kubeconfig.') return region, zone def accelerator_in_region_or_zone(self, diff --git a/sky/clouds/local.py b/sky/clouds/local.py index 3773583702b..4cd5ab02639 100644 --- a/sky/clouds/local.py +++ b/sky/clouds/local.py @@ -10,7 +10,7 @@ from sky import resources as resources_lib -@clouds.CLOUD_REGISTRY.register +# TODO(skypilot): remove Local now that we're using Kubernetes. class Local(clouds.Cloud): """Local/on-premise cloud. @@ -191,10 +191,11 @@ def instance_type_exists(self, instance_type: str) -> bool: def validate_region_zone(self, region: Optional[str], zone: Optional[str]): # Returns true if the region name is same as Local cloud's # one and only region: 'Local'. - assert zone is None + if zone is not None: + raise ValueError('Local cloud does not support zones.') if region is None or region != Local.LOCAL_REGION.name: raise ValueError(f'Region {region!r} does not match the Local' - ' cloud region {Local.LOCAL_REGION.name!r}.') + f' cloud region {Local.LOCAL_REGION.name!r}.') return region, zone @classmethod diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index a3a668bbf3b..2fea4e79690 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -166,7 +166,14 @@ def instance_type_exists_impl(df: pd.DataFrame, instance_type: str) -> bool: def validate_region_zone_impl( cloud_name: str, df: pd.DataFrame, region: Optional[str], zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]: - """Validates whether region and zone exist in the catalog.""" + """Validates whether region and zone exist in the catalog. + + Returns: + A tuple of region and zone, if validated. + + Raises: + ValueError: If region or zone is invalid or not supported. + """ def _get_candidate_str(loc: str, all_loc: List[str]) -> str: candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9) diff --git a/sky/resources.py b/sky/resources.py index 764731a35fc..38c677485c9 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -1,5 +1,6 @@ """Resources: compute requirements of Tasks.""" import functools +import textwrap from typing import Dict, List, Optional, Set, Tuple, Union import colorama @@ -15,6 +16,7 @@ from sky.provision import docker_utils from sky.skylet import constants from sky.utils import accelerator_registry +from sky.utils import log_utils from sky.utils import resources_utils from sky.utils import schemas from sky.utils import tpu_utils @@ -134,7 +136,7 @@ def __init__( self._cloud = cloud self._region: Optional[str] = None self._zone: Optional[str] = None - self._set_region_zone(region, zone) + self._validate_and_set_region_zone(region, zone) self._instance_type = instance_type @@ -537,22 +539,62 @@ def is_launchable(self) -> bool: return self.cloud is not None and self._instance_type is not None def need_cleanup_after_preemption(self) -> bool: - """Returns whether a spot resource needs cleanup after preeemption.""" + """Returns whether a spot resource needs cleanup after preemption.""" assert self.is_launchable(), self return self.cloud.need_cleanup_after_preemption(self) - def _set_region_zone(self, region: Optional[str], - zone: Optional[str]) -> None: + def _validate_and_set_region_zone(self, region: Optional[str], + zone: Optional[str]) -> None: if region is None and zone is None: return if self._cloud is None: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - 'Cloud must be specified when region/zone are specified.') + # Try to infer the cloud from region/zone, if unique. If 0 or >1 + # cloud corresponds to region/zone, errors out. + valid_clouds = [] + enabled_clouds = global_user_state.get_enabled_clouds() + cloud_to_errors = {} + for cloud in enabled_clouds: + try: + cloud.validate_region_zone(region, zone) + except ValueError as e: + cloud_to_errors[repr(cloud)] = e + continue + valid_clouds.append(cloud) + + if len(valid_clouds) == 0: + if len(enabled_clouds) == 1: + cloud_str = f'for cloud {enabled_clouds[0]}' + else: + cloud_str = f'for any cloud among {enabled_clouds}' + with ux_utils.print_exception_no_traceback(): + if len(cloud_to_errors) == 1: + # UX: if 1 cloud, don't print a table. + hint = list(cloud_to_errors.items())[0][-1] + else: + table = log_utils.create_table(['Cloud', 'Hint']) + table.add_row(['-----', '----']) + for cloud, error in cloud_to_errors.items(): + reason_str = '\n'.join(textwrap.wrap( + str(error), 80)) + table.add_row([str(cloud), reason_str]) + hint = table.get_string() + raise ValueError( + f'Invalid (region {region!r}, zone {zone!r}) ' + f'{cloud_str}. Details:\n{hint}') + elif len(valid_clouds) > 1: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Cannot infer cloud from (region {region!r}, zone ' + f'{zone!r}). Multiple enabled clouds have region/zone ' + f'of the same names: {valid_clouds}. ' + f'To fix: explicitly specify `cloud`.') + logger.debug(f'Cloud is not specified, using {valid_clouds[0]} ' + f'inferred from region {region!r} and zone {zone!r}') + self._cloud = valid_clouds[0] - # Validate whether region and zone exist in the catalog, and set the - # region if zone is specified. + # Validate if region and zone exist in the catalog, and set the region + # if zone is specified. self._region, self._zone = self._cloud.validate_region_zone( region, zone) diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 00000000000..835f53ef251 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,62 @@ +import tempfile +from typing import List, Optional + +import pandas as pd +import pytest + +from sky import clouds +from sky.utils import kubernetes_utils + + +def enable_all_clouds_in_monkeypatch( + monkeypatch: pytest.MonkeyPatch, + enabled_clouds: Optional[List[str]] = None, +) -> None: + # Monkey-patching is required because in the test environment, no cloud is + # enabled. The optimizer checks the environment to find enabled clouds, and + # only generates plans within these clouds. The tests assume that all three + # clouds are enabled, so we monkeypatch the `sky.global_user_state` module + # to return all three clouds. We also monkeypatch `sky.check.check` so that + # when the optimizer tries calling it to update enabled_clouds, it does not + # raise exceptions. + if enabled_clouds is None: + enabled_clouds = list(clouds.CLOUD_REGISTRY.values()) + monkeypatch.setattr( + 'sky.global_user_state.get_enabled_clouds', + lambda: enabled_clouds, + ) + monkeypatch.setattr('sky.check.check', lambda *_args, **_kwargs: None) + config_file_backup = tempfile.NamedTemporaryFile( + prefix='tmp_backup_config_default', delete=False) + monkeypatch.setattr('sky.clouds.gcp.GCP_CONFIG_SKY_BACKUP_PATH', + config_file_backup.name) + monkeypatch.setattr( + 'sky.clouds.gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH', + config_file_backup.name) + monkeypatch.setenv('OCI_CONFIG', config_file_backup.name) + + az_mappings = pd.read_csv('tests/default_aws_az_mappings.csv') + + def _get_az_mappings(_): + return az_mappings + + monkeypatch.setattr( + 'sky.clouds.service_catalog.aws_catalog._get_az_mappings', + _get_az_mappings) + + monkeypatch.setattr('sky.backends.backend_utils.check_owner_identity', + lambda _: None) + + monkeypatch.setattr( + 'sky.clouds.gcp.GCP._list_reservations_for_instance_type', + lambda *_args, **_kwargs: []) + + # Monkey patch Kubernetes resource detection since it queries + # the cluster to detect available cluster resources. + monkeypatch.setattr( + 'sky.utils.kubernetes_utils.detect_gpu_label_formatter', + lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, []]) + monkeypatch.setattr('sky.utils.kubernetes_utils.detect_gpu_resource', + lambda *_args, **_kwargs: [True, []]) + monkeypatch.setattr('sky.utils.kubernetes_utils.check_instance_fits', + lambda *_args, **_kwargs: [True, '']) diff --git a/tests/conftest.py b/tests/conftest.py index 11b86a33f20..6f3fb89d083 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,6 @@ -import tempfile from typing import List -from unittest.mock import patch -import pandas as pd +import common # TODO(zongheng): for some reason isort places it here. import pytest # Usage: use @@ -20,7 +18,8 @@ # To only run tests for a specific cloud (as well as generic tests), use # --aws, --gcp, --azure, or --lambda. # -# To only run tests for managed spot (without generic tests), use --managed-spot. +# To only run tests for managed spot (without generic tests), use +# --managed-spot. all_clouds_in_smoke_tests = [ 'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci', 'kubernetes' @@ -180,61 +179,12 @@ def generic_cloud(request) -> str: @pytest.fixture -def enable_all_clouds(monkeypatch): - from sky import clouds - from sky.utils import kubernetes_utils - - # Monkey-patching is required because in the test environment, no cloud is - # enabled. The optimizer checks the environment to find enabled clouds, and - # only generates plans within these clouds. The tests assume that all three - # clouds are enabled, so we monkeypatch the `sky.global_user_state` module - # to return all three clouds. We also monkeypatch `sky.check.check` so that - # when the optimizer tries calling it to update enabled_clouds, it does not - # raise exceptions. - enabled_clouds = list(clouds.CLOUD_REGISTRY.values()) - monkeypatch.setattr( - 'sky.global_user_state.get_enabled_clouds', - lambda: enabled_clouds, - ) - monkeypatch.setattr('sky.check.check', lambda *_args, **_kwargs: None) - config_file_backup = tempfile.NamedTemporaryFile( - prefix='tmp_backup_config_default', delete=False) - monkeypatch.setattr('sky.clouds.gcp.GCP_CONFIG_SKY_BACKUP_PATH', - config_file_backup.name) - monkeypatch.setattr( - 'sky.clouds.gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH', - config_file_backup.name) - monkeypatch.setenv('OCI_CONFIG', config_file_backup.name) - - az_mappings = pd.read_csv('tests/default_aws_az_mappings.csv') - - def _get_az_mappings(_): - return az_mappings - - monkeypatch.setattr( - 'sky.clouds.service_catalog.aws_catalog._get_az_mappings', - _get_az_mappings) - - monkeypatch.setattr('sky.backends.backend_utils.check_owner_identity', - lambda _: None) - - monkeypatch.setattr( - 'sky.clouds.gcp.GCP._list_reservations_for_instance_type', - lambda *_args, **_kwargs: []) - - # Monkey patch Kubernetes resource detection since it queries - # the cluster to detect available cluster resources. - monkeypatch.setattr( - 'sky.utils.kubernetes_utils.detect_gpu_label_formatter', - lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, []]) - monkeypatch.setattr('sky.utils.kubernetes_utils.detect_gpu_resource', - lambda *_args, **_kwargs: [True, []]) - monkeypatch.setattr('sky.utils.kubernetes_utils.check_instance_fits', - lambda *_args, **_kwargs: [True, '']) +def enable_all_clouds(monkeypatch: pytest.MonkeyPatch): + common.enable_all_clouds_in_monkeypatch(monkeypatch) @pytest.fixture -def aws_config_region(monkeypatch) -> str: +def aws_config_region(monkeypatch: pytest.MonkeyPatch) -> str: from sky import skypilot_config region = 'us-west-2' if skypilot_config.loaded(): diff --git a/tests/test_optimizer_dryruns.py b/tests/test_optimizer_dryruns.py index 3e7aeacfc06..c6ad64f9f20 100644 --- a/tests/test_optimizer_dryruns.py +++ b/tests/test_optimizer_dryruns.py @@ -3,12 +3,12 @@ import time from typing import Callable, List, Optional +import common # TODO(zongheng): for some reason isort places it here. import pytest import sky from sky import clouds from sky import exceptions -from sky.utils import kubernetes_utils def _test_parse_task_yaml(spec: str, test_fn: Optional[Callable] = None): @@ -46,52 +46,14 @@ def test_fn(task): _test_parse_task_yaml(spec, test_fn) -# Monkey-patching is required because in the test environment, no cloud is -# enabled. The optimizer checks the environment to find enabled clouds, and -# only generates plans within these clouds. The tests assume that all three -# clouds are enabled, so we monkeypatch the `sky.global_user_state` module -# to return all three clouds. We also monkeypatch `sky.check.check` so that -# when the optimizer tries calling it to update enabled_clouds, it does not -# TODO: Keep the cloud enabling in sync with the fixture enable_all_clouds -# in tests/conftest.py -# raise exceptions. def _make_resources( monkeypatch, *resources_args, - enabled_clouds: List[str] = None, + enabled_clouds: Optional[List[str]] = None, **resources_kwargs, ): - if enabled_clouds is None: - enabled_clouds = list(clouds.CLOUD_REGISTRY.values()) - monkeypatch.setattr( - 'sky.global_user_state.get_enabled_clouds', - lambda: enabled_clouds, - ) - monkeypatch.setattr('sky.check.check', lambda *_args, **_kwargs: None) - - config_file_backup = tempfile.NamedTemporaryFile( - prefix='tmp_backup_config_default', delete=False) - monkeypatch.setattr('sky.clouds.gcp.GCP_CONFIG_SKY_BACKUP_PATH', - config_file_backup.name) - monkeypatch.setattr( - 'sky.clouds.gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH', - config_file_backup.name) - monkeypatch.setenv('OCI_CONFIG', config_file_backup.name) - - monkeypatch.setattr( - 'sky.clouds.gcp.GCP._list_reservations_for_instance_type', - lambda *_args, **_kwargs: []) - - # Monkey patch Kubernetes resource detection since it queries - # the cluster to detect available cluster resources. - monkeypatch.setattr( - 'sky.utils.kubernetes_utils.detect_gpu_label_formatter', - lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, []]) - monkeypatch.setattr('sky.utils.kubernetes_utils.detect_gpu_resource', - lambda *_args, **_kwargs: [True, []]) - monkeypatch.setattr('sky.utils.kubernetes_utils.check_instance_fits', - lambda *_args, **_kwargs: [True, '']) - + # See comments inside to see why we monkey patch: + common.enable_all_clouds_in_monkeypatch(monkeypatch, enabled_clouds) # Should create Resources here, since it uses the enabled clouds. return sky.Resources(*resources_args, **resources_kwargs) @@ -674,7 +636,8 @@ def _test_optimize_speed(resources: sky.Resources): start = time.time() sky.optimize(dag) end = time.time() - assert end - start < 5.0, (f'optimize took too long for {resources}, ' + # 5.0 seconds = somewhat flaky. + assert end - start < 6.0, (f'optimize took too long for {resources}, ' f'{end - start} seconds') @@ -691,3 +654,64 @@ def test_optimize_speed(enable_all_clouds, monkeypatch): sky.Resources(cpus='4+', memory='4+', accelerators='A100-80GB:8')) _test_optimize_speed( sky.Resources(cpus='4+', memory='4+', accelerators='tpu-v3-32')) + + +def test_infer_cloud_from_region_or_zone(monkeypatch): + # Maps to GCP. + _test_resources_launch(monkeypatch, region='us-east1') + _test_resources_launch(monkeypatch, zone='us-west2-a') + + # Maps to AWS. + _test_resources_launch(monkeypatch, region='us-east-2') + _test_resources_launch(monkeypatch, zone='us-west-2a') + + # `sky launch` + _test_resources_launch(monkeypatch) + + # Same-named regions need `cloud`. + _test_resources_launch(monkeypatch, region='us-east-1', cloud=sky.AWS()) + _test_resources_launch(monkeypatch, region='us-east-1', cloud=sky.Lambda()) + + # Cases below: cannot infer cloud. + + # Same-named region: AWS and Lambda. + with pytest.raises(ValueError) as e: + _test_resources_launch(monkeypatch, region='us-east-1') + assert ('Multiple enabled clouds have region/zone of the same names' + in str(e)) + + # Typo, fuzzy hint. + with pytest.raises(ValueError) as e: + _test_resources_launch(monkeypatch, zone='us-west-2-a', cloud=sky.AWS()) + assert ('Did you mean one of these: \'us-west-2a\'?' in str(e)) + + # Detailed hints. + # ValueError: Invalid (region None, zone 'us-west-2-a') for any cloud among + # [AWS, Azure, GCP, IBM, Lambda, Local, OCI, SCP]. Details: + # Cloud Hint + # ----- ---- + # AWS Invalid zone 'us-west-2-a' Did you mean one of these: 'us-west-2a'? + # Azure Azure does not support zones. + # GCP Invalid zone 'us-west-2-a' Did you mean one of these: 'us-west2-a'? + # IBM Invalid zone 'us-west-2-a' + # Lambda Lambda Cloud does not support zones. + # Local Local cloud does not support zones. + # OCI Invalid zone 'us-west-2-a' + # SCP SCP Cloud does not support zones. + with pytest.raises(ValueError) as e: + _test_resources_launch(monkeypatch, zone='us-west-2-a') + assert ('Invalid (region None, zone \'us-west-2-a\') for any cloud among' + in str(e)) + + with pytest.raises(ValueError) as e: + _test_resources_launch(monkeypatch, zone='us-west-2z') + assert ('Invalid (region None, zone \'us-west-2z\') for any cloud among' + in str(e)) + + with pytest.raises(ValueError) as e: + _test_resources_launch(monkeypatch, + region='us-east1', + zone='us-west2-a') + assert ( + 'Invalid (region \'us-east1\', zone \'us-west2-a\') for any cloud among' + in str(e))