Skip to content

Commit

Permalink
fix autostop
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Aug 2, 2024
1 parent 123e5af commit 8cd02e8
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 42 deletions.
11 changes: 8 additions & 3 deletions sky/api/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3361,7 +3361,8 @@ def storage():
# pylint: disable=redefined-builtin
def storage_ls(all: bool):
"""List storage objects managed by SkyPilot."""
storages = sdk.storage_ls()
request_id = sdk.storage_ls()
storages = sdk.stream_and_get(request_id)
storage_table = storage_utils.format_storage_table(storages, show_all=all)
click.echo(storage_table)

Expand All @@ -3384,8 +3385,9 @@ def storage_ls(all: bool):
is_flag=True,
required=False,
help='Skip confirmation prompt.')
@_add_click_options(_COMMON_OPTIONS)
@usage_lib.entrypoint
def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=redefined-builtin
def storage_delete(names: List[str], all: bool, yes: bool, async_call: bool): # pylint: disable=redefined-builtin
"""Delete storage objects.
Examples:
Expand Down Expand Up @@ -3422,7 +3424,10 @@ def storage_delete(names: List[str], all: bool, yes: bool): # pylint: disable=r
abort=True,
show_default=True)

subprocess_utils.run_in_parallel(sdk.storage_delete, names)
request_ids = subprocess_utils.run_in_parallel(sdk.storage_delete, names)
for request_id in request_ids:
_async_call_or_wait(request_id, async_call, 'storage')



@cli.group(cls=_NaturalOrderGroup)
Expand Down
3 changes: 2 additions & 1 deletion sky/api/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sky import optimizer
from sky import sky_logging
from sky import skypilot_config
from sky.api import common as api_common
from sky.api.requests import payloads
from sky.api.requests import tasks
from sky.backends import backend_utils
Expand Down Expand Up @@ -674,7 +675,7 @@ def api_stop():

# Remove the database for requests including any files starting with
# common.API_SERVER_REQUEST_DB_PATH
db_path = os.path.expanduser(common.API_SERVER_REQUEST_DB_PATH)
db_path = os.path.expanduser(api_common.API_SERVER_REQUEST_DB_PATH)
for extension in ['', '-shm', '-wal']:
try:
os.remove(f'{db_path}{extension}')
Expand Down
5 changes: 3 additions & 2 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from sky import provision as provision_lib
from sky import sky_logging
from sky import skypilot_config
from sky.clouds import cloud_registry
from sky.provision import instance_setup
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.skylet import constants
Expand All @@ -48,6 +47,7 @@
from sky.utils import common_utils
from sky.utils import controller_utils
from sky.utils import env_options
from sky.utils import registry
from sky.utils import resources_utils
from sky.utils import rich_utils
from sky.utils import schemas
Expand Down Expand Up @@ -1435,7 +1435,7 @@ def get_node_ips(cluster_yaml: str,
ray_config = common_utils.read_yaml(cluster_yaml)
# Use the new provisioner for AWS.
provider_name = cluster_yaml_utils.get_provider_name(ray_config)
cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name)
cloud = registry.CLOUD_REGISTRY.from_str(provider_name)
assert cloud is not None, provider_name

if cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.SKYPILOT:
Expand Down Expand Up @@ -2493,6 +2493,7 @@ def get_clusters(
force_refresh_statuses = set(status_lib.ClusterStatus)
else:
force_refresh_statuses = None
logger.info(f'zhwu debug: force refresh statuses {force_refresh_statuses}')

def _refresh_cluster(cluster_name):
try:
Expand Down
31 changes: 0 additions & 31 deletions sky/clouds/cloud_registry.py

This file was deleted.

4 changes: 2 additions & 2 deletions sky/skylet/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from sky import clouds
from sky import sky_logging
from sky.backends import cloud_vm_ray_backend
from sky.clouds import cloud_registry
from sky.jobs import utils as managed_job_utils
from sky.serve import serve_utils
from sky.skylet import autostop_lib
from sky.skylet import constants
from sky.skylet import job_lib
from sky.utils import cluster_yaml_utils
from sky.utils import common_utils
from sky.utils import registry
from sky.utils import ux_utils

# Seconds of sleep between the processing of skylet events.
Expand Down Expand Up @@ -144,7 +144,7 @@ def _stop_cluster(self, autostop_config):
cluster_yaml_utils.SKY_CLUSTER_YAML_REMOTE_PATH))
config = common_utils.read_yaml(config_path)
provider_name = cluster_yaml_utils.get_provider_name(config)
cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name)
cloud = registry.CLOUD_REGISTRY.from_str(provider_name)
assert cloud is not None, f'Unknown cloud: {provider_name}'

if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.
Expand Down
4 changes: 2 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
import pytest

from sky import clouds
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.utils import registry


def enable_all_clouds_in_monkeypatch(
Expand All @@ -20,7 +20,7 @@ def enable_all_clouds_in_monkeypatch(
# when the optimizer tries calling it to update enabled_clouds, it does not
# raise exceptions.
if enabled_clouds is None:
enabled_clouds = list(clouds.CLOUD_REGISTRY.values())
enabled_clouds = list(registry.CLOUD_REGISTRY.values())
monkeypatch.setattr(
'sky.check.get_cached_enabled_clouds_or_refresh',
lambda *_args, **_kwargs: enabled_clouds,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_optimizer_random_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sky import clouds
from sky import exceptions
from sky.clouds import service_catalog
from sky.utils import registry

ALL_INSTANCE_TYPE_INFOS = sum(
sky.list_accelerators(gpus_only=True).values(), [])
Expand Down Expand Up @@ -82,7 +83,7 @@ def generate_random_dag(
if 'tpu' in candidate.accelerator_name:
instance_type = 'TPU-VM'
resources = sky.Resources(
cloud=clouds.CLOUD_REGISTRY.from_str(candidate.cloud),
cloud=registry.CLOUD_REGISTRY.from_str(candidate.cloud),
instance_type=instance_type,
accelerators={
candidate.accelerator_name: candidate.accelerator_count
Expand Down

0 comments on commit 8cd02e8

Please sign in to comment.