Skip to content

Commit

Permalink
[Azure] Avoid azure reconfig everytime, speed up launch by up to 5.8x (
Browse files Browse the repository at this point in the history
…#3697)

* Avoid azure reconfig everytime

* Add debug message

* format

* Fix error handling

* format

* skip deployment recreation when deployment exist

* Add retry for subscription ID

* fix logging

* format

* comment
  • Loading branch information
Michaelvll authored Jun 29, 2024
1 parent a51b507 commit 4821f70
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 34 deletions.
23 changes: 21 additions & 2 deletions sky/adaptors/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# pylint: disable=import-outside-toplevel
import functools
import threading
import time

from sky.adaptors import common
from sky.utils import common_utils

azure = common.LazyImport(
'azure',
Expand All @@ -13,13 +15,30 @@
_LAZY_MODULES = (azure,)

_session_creation_lock = threading.RLock()
_MAX_RETRY_FOR_GET_SUBSCRIPTION_ID = 5


@common.load_lazy_modules(modules=_LAZY_MODULES)
@functools.lru_cache()
def get_subscription_id() -> str:
"""Get the default subscription id."""
from azure.common import credentials
return credentials.get_cli_profile().get_subscription_id()
retry = 0
backoff = common_utils.Backoff(initial_backoff=0.5, max_backoff_factor=4)
while True:
try:
return credentials.get_cli_profile().get_subscription_id()
except Exception as e:
if ('Please run \'az login\' to setup account.' in str(e) and
retry < _MAX_RETRY_FOR_GET_SUBSCRIPTION_ID):
# When there are multiple processes trying to get the
# subscription id, it may fail with the above error message.
# Retry will fix the issue.
retry += 1

time.sleep(backoff.current_backoff())
continue
raise


@common.load_lazy_modules(modules=_LAZY_MODULES)
Expand All @@ -36,8 +55,8 @@ def exceptions():
return azure_exceptions


@functools.lru_cache()
@common.load_lazy_modules(modules=_LAZY_MODULES)
@functools.lru_cache()
def get_client(name: str, subscription_id: str):
# Sky only supports Azure CLI credential for now.
# Increase the timeout to fix the Azure get-access-token timeout issue.
Expand Down
40 changes: 30 additions & 10 deletions sky/skylet/providers/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.resource.resources.models import DeploymentMode

from sky.adaptors import azure
from sky.utils import common_utils

UNIQUE_ID_LEN = 4
Expand Down Expand Up @@ -120,17 +121,36 @@ def _configure_resource_group(config):
create_or_update = get_azure_sdk_function(
client=resource_client.deployments, function_name="create_or_update"
)
# TODO (skypilot): this takes a long time (> 40 seconds) for stopping an
# azure VM, and this can be called twice during ray down.
outputs = (
create_or_update(
resource_group_name=resource_group,
deployment_name="ray-config",
parameters=parameters,
)
.result()
.properties.outputs
# Skip creating or updating the deployment if the deployment already exists
# and the cluster name is the same.
get_deployment = get_azure_sdk_function(
client=resource_client.deployments, function_name="get"
)
deployment_exists = False
try:
deployment = get_deployment(
resource_group_name=resource_group, deployment_name="ray-config"
)
logger.info("Deployment already exists. Skipping deployment creation.")

outputs = deployment.properties.outputs
if outputs is not None:
deployment_exists = True
except azure.exceptions().ResourceNotFoundError:
deployment_exists = False

if not deployment_exists:
# This takes a long time (> 40 seconds), we should be careful calling
# this function.
outputs = (
create_or_update(
resource_group_name=resource_group,
deployment_name="ray-config",
parameters=parameters,
)
.result()
.properties.outputs
)

# We should wait for the NSG to be created before opening any ports
# to avoid overriding the newly-added NSG rules.
Expand Down
37 changes: 16 additions & 21 deletions sky/skylet/providers/azure/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.resource.resources.models import DeploymentMode

from sky.adaptors import azure
from sky.skylet.providers.azure.config import (
bootstrap_azure,
get_azure_sdk_function,
)
from sky.skylet import autostop_lib
from sky.skylet.providers.command_runner import SkyDockerCommandRunner
from sky.provision import docker_utils

Expand Down Expand Up @@ -62,23 +62,7 @@ class AzureNodeProvider(NodeProvider):

def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
if not autostop_lib.get_is_autostopping():
# TODO(suquark): This is a temporary patch for resource group.
# By default, Ray autoscaler assumes the resource group is still
# here even after the whole cluster is destroyed. However, now we
# deletes the resource group after tearing down the cluster. To
# comfort the autoscaler, we need to create/update it here, so the
# resource group always exists.
#
# We should not re-configure the resource group again, when it is
# running on the remote VM and the autostopping is in progress,
# because the VM is running which guarantees the resource group
# exists.
from sky.skylet.providers.azure.config import _configure_resource_group

_configure_resource_group(
{"cluster_name": cluster_name, "provider": provider_config}
)

subscription_id = provider_config["subscription_id"]
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True)
# Sky only supports Azure CLI credential for now.
Expand Down Expand Up @@ -106,9 +90,20 @@ def match_tags(vm):
return False
return True

vms = self.compute_client.virtual_machines.list(
resource_group_name=self.provider_config["resource_group"]
)
try:
vms = list(
self.compute_client.virtual_machines.list(
resource_group_name=self.provider_config["resource_group"]
)
)
except azure.exceptions().ResourceNotFoundError as e:
if "Code: ResourceGroupNotFound" in e.exc_msg:
logger.debug(
"Resource group not found. VMs should have been terminated."
)
vms = []
else:
raise

nodes = [self._extract_metadata(vm) for vm in filter(match_tags, vms)]
self.cached_nodes = {node["name"]: node for node in nodes}
Expand Down
2 changes: 1 addition & 1 deletion sky/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class Backoff:
MULTIPLIER = 1.6
JITTER = 0.4

def __init__(self, initial_backoff: int = 5, max_backoff_factor: int = 5):
def __init__(self, initial_backoff: float = 5, max_backoff_factor: int = 5):
self._initial = True
self._backoff = 0.0
self._initial_backoff = initial_backoff
Expand Down

0 comments on commit 4821f70

Please sign in to comment.