Skip to content

Commit

Permalink
UX: Allow inferring cloud from region or zone. (#2632)
Browse files Browse the repository at this point in the history
* UX: Allow infering cloud from region or zone.

* format

* minor fix

* Remove Local cloud from registry.

* UX for 1-cloud cases

* Format

* Fix test fixtures.

* isort
  • Loading branch information
concretevitamin authored Oct 4, 2023
1 parent 734cf68 commit 4f5b107
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 116 deletions.
13 changes: 11 additions & 2 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,17 @@ def instance_type_exists(self, instance_type):
"""Returns whether the instance type exists for this cloud."""
raise NotImplementedError

def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
"""Validates the region and zone."""
def validate_region_zone(
self, region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
"""Validates whether region and zone exist in the catalog.
Returns:
A tuple of region and zone, if validated.
Raises:
ValueError: If region or zone is invalid or not supported.
"""
return service_catalog.validate_region_zone(region,
zone,
clouds=self._REPR.lower())
Expand Down
9 changes: 9 additions & 0 deletions sky/clouds/cloud_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ class _CloudRegistry(dict):
def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']:
if name is None:
return None
if name.lower() == 'local':
# Backward compatibility. global_user_state's DB may have recorded
# Local cloud, and we've just removed it from the registry, and
# global_user_state.get_enabled_clouds() would call into this func
# and fail.
#
# TODO(skypilot): have a better way to handle clouds removed from
# registry if needed.
return None
if name.lower() not in self:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Cloud {name!r} is not a valid cloud among '
Expand Down
11 changes: 9 additions & 2 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class Kubernetes(clouds.Cloud):
_DEFAULT_MEMORY_CPU_RATIO = 1
_DEFAULT_MEMORY_CPU_RATIO_WITH_GPU = 4 # Allocate more memory for GPU tasks
_REPR = 'Kubernetes'
_regions: List[clouds.Region] = [clouds.Region('kubernetes')]
_SINGLETON_REGION = 'kubernetes'
_regions: List[clouds.Region] = [clouds.Region(_SINGLETON_REGION)]
_CLOUD_UNSUPPORTED_FEATURES = {
# TODO(romilb): Stopping might be possible to implement with
# container checkpointing introduced in Kubernetes v1.25. See:
Expand Down Expand Up @@ -325,7 +326,13 @@ def instance_type_exists(self, instance_type: str) -> bool:
instance_type)

def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
# Kubernetes doesn't have regions or zones, so we don't need to validate
if region != self._SINGLETON_REGION:
raise ValueError(
'Kubernetes support does not support setting region.'
' Cluster used is determined by the kubeconfig.')
if zone is not None:
raise ValueError('Kubernetes support does not support setting zone.'
' Cluster used is determined by the kubeconfig.')
return region, zone

def accelerator_in_region_or_zone(self,
Expand Down
7 changes: 4 additions & 3 deletions sky/clouds/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sky import resources as resources_lib


@clouds.CLOUD_REGISTRY.register
# TODO(skypilot): remove Local now that we're using Kubernetes.
class Local(clouds.Cloud):
"""Local/on-premise cloud.
Expand Down Expand Up @@ -191,10 +191,11 @@ def instance_type_exists(self, instance_type: str) -> bool:
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
# Returns true if the region name is same as Local cloud's
# one and only region: 'Local'.
assert zone is None
if zone is not None:
raise ValueError('Local cloud does not support zones.')
if region is None or region != Local.LOCAL_REGION.name:
raise ValueError(f'Region {region!r} does not match the Local'
' cloud region {Local.LOCAL_REGION.name!r}.')
f' cloud region {Local.LOCAL_REGION.name!r}.')
return region, zone

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion sky/clouds/service_catalog/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def instance_type_exists_impl(df: pd.DataFrame, instance_type: str) -> bool:
def validate_region_zone_impl(
cloud_name: str, df: pd.DataFrame, region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
"""Validates whether region and zone exist in the catalog."""
"""Validates whether region and zone exist in the catalog.
Returns:
A tuple of region and zone, if validated.
Raises:
ValueError: If region or zone is invalid or not supported.
"""

def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9)
Expand Down
60 changes: 51 additions & 9 deletions sky/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Resources: compute requirements of Tasks."""
import functools
import textwrap
from typing import Dict, List, Optional, Set, Tuple, Union

import colorama
Expand All @@ -15,6 +16,7 @@
from sky.provision import docker_utils
from sky.skylet import constants
from sky.utils import accelerator_registry
from sky.utils import log_utils
from sky.utils import resources_utils
from sky.utils import schemas
from sky.utils import tpu_utils
Expand Down Expand Up @@ -134,7 +136,7 @@ def __init__(
self._cloud = cloud
self._region: Optional[str] = None
self._zone: Optional[str] = None
self._set_region_zone(region, zone)
self._validate_and_set_region_zone(region, zone)

self._instance_type = instance_type

Expand Down Expand Up @@ -537,22 +539,62 @@ def is_launchable(self) -> bool:
return self.cloud is not None and self._instance_type is not None

def need_cleanup_after_preemption(self) -> bool:
"""Returns whether a spot resource needs cleanup after preeemption."""
"""Returns whether a spot resource needs cleanup after preemption."""
assert self.is_launchable(), self
return self.cloud.need_cleanup_after_preemption(self)

def _set_region_zone(self, region: Optional[str],
zone: Optional[str]) -> None:
def _validate_and_set_region_zone(self, region: Optional[str],
zone: Optional[str]) -> None:
if region is None and zone is None:
return

if self._cloud is None:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Cloud must be specified when region/zone are specified.')
# Try to infer the cloud from region/zone, if unique. If 0 or >1
# cloud corresponds to region/zone, errors out.
valid_clouds = []
enabled_clouds = global_user_state.get_enabled_clouds()
cloud_to_errors = {}
for cloud in enabled_clouds:
try:
cloud.validate_region_zone(region, zone)
except ValueError as e:
cloud_to_errors[repr(cloud)] = e
continue
valid_clouds.append(cloud)

if len(valid_clouds) == 0:
if len(enabled_clouds) == 1:
cloud_str = f'for cloud {enabled_clouds[0]}'
else:
cloud_str = f'for any cloud among {enabled_clouds}'
with ux_utils.print_exception_no_traceback():
if len(cloud_to_errors) == 1:
# UX: if 1 cloud, don't print a table.
hint = list(cloud_to_errors.items())[0][-1]
else:
table = log_utils.create_table(['Cloud', 'Hint'])
table.add_row(['-----', '----'])
for cloud, error in cloud_to_errors.items():
reason_str = '\n'.join(textwrap.wrap(
str(error), 80))
table.add_row([str(cloud), reason_str])
hint = table.get_string()
raise ValueError(
f'Invalid (region {region!r}, zone {zone!r}) '
f'{cloud_str}. Details:\n{hint}')
elif len(valid_clouds) > 1:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Cannot infer cloud from (region {region!r}, zone '
f'{zone!r}). Multiple enabled clouds have region/zone '
f'of the same names: {valid_clouds}. '
f'To fix: explicitly specify `cloud`.')
logger.debug(f'Cloud is not specified, using {valid_clouds[0]} '
f'inferred from region {region!r} and zone {zone!r}')
self._cloud = valid_clouds[0]

# Validate whether region and zone exist in the catalog, and set the
# region if zone is specified.
# Validate if region and zone exist in the catalog, and set the region
# if zone is specified.
self._region, self._zone = self._cloud.validate_region_zone(
region, zone)

Expand Down
62 changes: 62 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import tempfile
from typing import List, Optional

import pandas as pd
import pytest

from sky import clouds
from sky.utils import kubernetes_utils


def enable_all_clouds_in_monkeypatch(
monkeypatch: pytest.MonkeyPatch,
enabled_clouds: Optional[List[str]] = None,
) -> None:
# Monkey-patching is required because in the test environment, no cloud is
# enabled. The optimizer checks the environment to find enabled clouds, and
# only generates plans within these clouds. The tests assume that all three
# clouds are enabled, so we monkeypatch the `sky.global_user_state` module
# to return all three clouds. We also monkeypatch `sky.check.check` so that
# when the optimizer tries calling it to update enabled_clouds, it does not
# raise exceptions.
if enabled_clouds is None:
enabled_clouds = list(clouds.CLOUD_REGISTRY.values())
monkeypatch.setattr(
'sky.global_user_state.get_enabled_clouds',
lambda: enabled_clouds,
)
monkeypatch.setattr('sky.check.check', lambda *_args, **_kwargs: None)
config_file_backup = tempfile.NamedTemporaryFile(
prefix='tmp_backup_config_default', delete=False)
monkeypatch.setattr('sky.clouds.gcp.GCP_CONFIG_SKY_BACKUP_PATH',
config_file_backup.name)
monkeypatch.setattr(
'sky.clouds.gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH',
config_file_backup.name)
monkeypatch.setenv('OCI_CONFIG', config_file_backup.name)

az_mappings = pd.read_csv('tests/default_aws_az_mappings.csv')

def _get_az_mappings(_):
return az_mappings

monkeypatch.setattr(
'sky.clouds.service_catalog.aws_catalog._get_az_mappings',
_get_az_mappings)

monkeypatch.setattr('sky.backends.backend_utils.check_owner_identity',
lambda _: None)

monkeypatch.setattr(
'sky.clouds.gcp.GCP._list_reservations_for_instance_type',
lambda *_args, **_kwargs: [])

# Monkey patch Kubernetes resource detection since it queries
# the cluster to detect available cluster resources.
monkeypatch.setattr(
'sky.utils.kubernetes_utils.detect_gpu_label_formatter',
lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.detect_gpu_resource',
lambda *_args, **_kwargs: [True, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.check_instance_fits',
lambda *_args, **_kwargs: [True, ''])
62 changes: 6 additions & 56 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import tempfile
from typing import List
from unittest.mock import patch

import pandas as pd
import common # TODO(zongheng): for some reason isort places it here.
import pytest

# Usage: use
Expand All @@ -20,7 +18,8 @@
# To only run tests for a specific cloud (as well as generic tests), use
# --aws, --gcp, --azure, or --lambda.
#
# To only run tests for managed spot (without generic tests), use --managed-spot.
# To only run tests for managed spot (without generic tests), use
# --managed-spot.
all_clouds_in_smoke_tests = [
'aws', 'gcp', 'azure', 'lambda', 'cloudflare', 'ibm', 'scp', 'oci',
'kubernetes'
Expand Down Expand Up @@ -180,61 +179,12 @@ def generic_cloud(request) -> str:


@pytest.fixture
def enable_all_clouds(monkeypatch):
from sky import clouds
from sky.utils import kubernetes_utils

# Monkey-patching is required because in the test environment, no cloud is
# enabled. The optimizer checks the environment to find enabled clouds, and
# only generates plans within these clouds. The tests assume that all three
# clouds are enabled, so we monkeypatch the `sky.global_user_state` module
# to return all three clouds. We also monkeypatch `sky.check.check` so that
# when the optimizer tries calling it to update enabled_clouds, it does not
# raise exceptions.
enabled_clouds = list(clouds.CLOUD_REGISTRY.values())
monkeypatch.setattr(
'sky.global_user_state.get_enabled_clouds',
lambda: enabled_clouds,
)
monkeypatch.setattr('sky.check.check', lambda *_args, **_kwargs: None)
config_file_backup = tempfile.NamedTemporaryFile(
prefix='tmp_backup_config_default', delete=False)
monkeypatch.setattr('sky.clouds.gcp.GCP_CONFIG_SKY_BACKUP_PATH',
config_file_backup.name)
monkeypatch.setattr(
'sky.clouds.gcp.DEFAULT_GCP_APPLICATION_CREDENTIAL_PATH',
config_file_backup.name)
monkeypatch.setenv('OCI_CONFIG', config_file_backup.name)

az_mappings = pd.read_csv('tests/default_aws_az_mappings.csv')

def _get_az_mappings(_):
return az_mappings

monkeypatch.setattr(
'sky.clouds.service_catalog.aws_catalog._get_az_mappings',
_get_az_mappings)

monkeypatch.setattr('sky.backends.backend_utils.check_owner_identity',
lambda _: None)

monkeypatch.setattr(
'sky.clouds.gcp.GCP._list_reservations_for_instance_type',
lambda *_args, **_kwargs: [])

# Monkey patch Kubernetes resource detection since it queries
# the cluster to detect available cluster resources.
monkeypatch.setattr(
'sky.utils.kubernetes_utils.detect_gpu_label_formatter',
lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.detect_gpu_resource',
lambda *_args, **_kwargs: [True, []])
monkeypatch.setattr('sky.utils.kubernetes_utils.check_instance_fits',
lambda *_args, **_kwargs: [True, ''])
def enable_all_clouds(monkeypatch: pytest.MonkeyPatch):
common.enable_all_clouds_in_monkeypatch(monkeypatch)


@pytest.fixture
def aws_config_region(monkeypatch) -> str:
def aws_config_region(monkeypatch: pytest.MonkeyPatch) -> str:
from sky import skypilot_config
region = 'us-west-2'
if skypilot_config.loaded():
Expand Down
Loading

0 comments on commit 4f5b107

Please sign in to comment.