diff --git a/sky/api/cli.py b/sky/api/cli.py index 7ab0fcd697e..d646aff650d 100644 --- a/sky/api/cli.py +++ b/sky/api/cli.py @@ -3361,7 +3361,8 @@ def storage(): # pylint: disable=redefined-builtin def storage_ls(all: bool): """List storage objects managed by SkyPilot.""" - storages = sdk.storage_ls() + request_id = sdk.storage_ls() + storages = sdk.stream_and_get(request_id) storage_table = storage_utils.format_storage_table(storages, show_all=all) click.echo(storage_table) @@ -3384,8 +3385,9 @@ def storage_ls(all: bool): is_flag=True, required=False, help='Skip confirmation prompt.') +@_add_click_options(_COMMON_OPTIONS) @usage_lib.entrypoint -def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=redefined-builtin +def storage_delete(names: List[str], all: bool, yes: bool, async_call: bool): # pylint: disable=redefined-builtin """Delete storage objects. Examples: @@ -3422,7 +3424,10 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r abort=True, show_default=True) - subprocess_utils.run_in_parallel(sdk.storage_delete, names) + request_ids = subprocess_utils.run_in_parallel(sdk.storage_delete, names) + for request_id in request_ids: + _async_call_or_wait(request_id, async_call, 'storage') + @cli.group(cls=_NaturalOrderGroup) diff --git a/sky/api/sdk.py b/sky/api/sdk.py index 1d15a27a2c6..246dd1f43c7 100644 --- a/sky/api/sdk.py +++ b/sky/api/sdk.py @@ -27,6 +27,7 @@ from sky import optimizer from sky import sky_logging from sky import skypilot_config +from sky.api import common as api_common from sky.api.requests import payloads from sky.api.requests import tasks from sky.backends import backend_utils @@ -674,7 +675,7 @@ def api_stop(): # Remove the database for requests including any files starting with # common.API_SERVER_REQUEST_DB_PATH - db_path = os.path.expanduser(common.API_SERVER_REQUEST_DB_PATH) + db_path = os.path.expanduser(api_common.API_SERVER_REQUEST_DB_PATH) for extension in ['', '-shm', '-wal']: try: os.remove(f'{db_path}{extension}') diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index d5989d586e6..e369c4c9d0b 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -37,7 +37,6 @@ from sky import provision as provision_lib from sky import sky_logging from sky import skypilot_config -from sky.clouds import cloud_registry from sky.provision import instance_setup from sky.provision.kubernetes import utils as kubernetes_utils from sky.skylet import constants @@ -48,6 +47,7 @@ from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import env_options +from sky.utils import registry from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import schemas @@ -1435,7 +1435,7 @@ def get_node_ips(cluster_yaml: str, ray_config = common_utils.read_yaml(cluster_yaml) # Use the new provisioner for AWS. provider_name = cluster_yaml_utils.get_provider_name(ray_config) - cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name) + cloud = registry.CLOUD_REGISTRY.from_str(provider_name) assert cloud is not None, provider_name if cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.SKYPILOT: @@ -2493,6 +2493,7 @@ def get_clusters( force_refresh_statuses = set(status_lib.ClusterStatus) else: force_refresh_statuses = None + logger.info(f'zhwu debug: force refresh statuses {force_refresh_statuses}') def _refresh_cluster(cluster_name): try: diff --git a/sky/clouds/cloud_registry.py b/sky/clouds/cloud_registry.py deleted file mode 100644 index 5c4b10b9fd4..00000000000 --- a/sky/clouds/cloud_registry.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Clouds need to be registered in CLOUD_REGISTRY to be discovered""" - -import typing -from typing import Optional, Type - -from sky.utils import ux_utils - -if typing.TYPE_CHECKING: - from sky.clouds import cloud - - -class _CloudRegistry(dict): - """Registry of clouds.""" - - def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']: - if name is None: - return None - if name.lower() not in self: - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Cloud {name!r} is not a valid cloud among ' - f'{list(self.keys())}') - return self.get(name.lower()) - - def register(self, cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']: - name = cloud_cls.__name__.lower() - assert name not in self, f'{name} already registered' - self[name] = cloud_cls() - return cloud_cls - - -CLOUD_REGISTRY: _CloudRegistry = _CloudRegistry() diff --git a/sky/skylet/events.py b/sky/skylet/events.py index b6e99707dab..810dd0fb213 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -12,7 +12,6 @@ from sky import clouds from sky import sky_logging from sky.backends import cloud_vm_ray_backend -from sky.clouds import cloud_registry from sky.jobs import utils as managed_job_utils from sky.serve import serve_utils from sky.skylet import autostop_lib @@ -20,6 +19,7 @@ from sky.skylet import job_lib from sky.utils import cluster_yaml_utils from sky.utils import common_utils +from sky.utils import registry from sky.utils import ux_utils # Seconds of sleep between the processing of skylet events. @@ -144,7 +144,7 @@ def _stop_cluster(self, autostop_config): cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH)) config = common_utils.read_yaml(config_path) provider_name = cluster_yaml_utils.get_provider_name(config) - cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name) + cloud = registry.CLOUD_REGISTRY.from_str(provider_name) assert cloud is not None, f'Unknown cloud: {provider_name}' if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion. diff --git a/tests/common.py b/tests/common.py index b6cefda22b8..9ee5d9fc7af 100644 --- a/tests/common.py +++ b/tests/common.py @@ -4,8 +4,8 @@ import pandas as pd import pytest -from sky import clouds from sky.provision.kubernetes import utils as kubernetes_utils +from sky.utils import registry def enable_all_clouds_in_monkeypatch( @@ -20,7 +20,7 @@ def enable_all_clouds_in_monkeypatch( # when the optimizer tries calling it to update enabled_clouds, it does not # raise exceptions. if enabled_clouds is None: - enabled_clouds = list(clouds.CLOUD_REGISTRY.values()) + enabled_clouds = list(registry.CLOUD_REGISTRY.values()) monkeypatch.setattr( 'sky.check.get_cached_enabled_clouds_or_refresh', lambda *_args, **_kwargs: enabled_clouds, diff --git a/tests/test_optimizer_random_dag.py b/tests/test_optimizer_random_dag.py index e19199023be..3b88f9a912e 100644 --- a/tests/test_optimizer_random_dag.py +++ b/tests/test_optimizer_random_dag.py @@ -10,6 +10,7 @@ from sky import clouds from sky import exceptions from sky.clouds import service_catalog +from sky.utils import registry ALL_INSTANCE_TYPE_INFOS = sum( sky.list_accelerators(gpus_only=True).values(), []) @@ -82,7 +83,7 @@ def generate_random_dag( if 'tpu' in candidate.accelerator_name: instance_type = 'TPU-VM' resources = sky.Resources( - cloud=clouds.CLOUD_REGISTRY.from_str(candidate.cloud), + cloud=registry.CLOUD_REGISTRY.from_str(candidate.cloud), instance_type=instance_type, accelerators={ candidate.accelerator_name: candidate.accelerator_count