Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Azure] Avoid azure reconfig everytime, speed up launch by up to 5.8x #3697

Merged
merged 10 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions sky/adaptors/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# pylint: disable=import-outside-toplevel
import functools
import threading
import time

from sky.adaptors import common
from sky.utils import common_utils

azure = common.LazyImport(
'azure',
Expand All @@ -13,13 +15,30 @@
_LAZY_MODULES = (azure,)

_session_creation_lock = threading.RLock()
_MAX_RETRY_FOR_GET_SUBSCRIPTION_ID = 5


@common.load_lazy_modules(modules=_LAZY_MODULES)
@functools.lru_cache()
def get_subscription_id() -> str:
"""Get the default subscription id."""
from azure.common import credentials
return credentials.get_cli_profile().get_subscription_id()
retry = 0
backoff = common_utils.Backoff(initial_backoff=0.5, max_backoff_factor=4)
while True:
try:
return credentials.get_cli_profile().get_subscription_id()
except Exception as e:
if ('Please run \'az login\' to setup account.' in str(e) and
retry < _MAX_RETRY_FOR_GET_SUBSCRIPTION_ID):
# When there are multiple processes trying to get the
# subscription id, it may fail with the above error message.
# Retry will fix the issue.
retry += 1

time.sleep(backoff.current_backoff())
continue
raise


@common.load_lazy_modules(modules=_LAZY_MODULES)
Expand All @@ -36,8 +55,8 @@ def exceptions():
return azure_exceptions


@functools.lru_cache()
@common.load_lazy_modules(modules=_LAZY_MODULES)
@functools.lru_cache()
def get_client(name: str, subscription_id: str):
# Sky only supports Azure CLI credential for now.
# Increase the timeout to fix the Azure get-access-token timeout issue.
Expand Down
40 changes: 30 additions & 10 deletions sky/skylet/providers/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.resource.resources.models import DeploymentMode

from sky.adaptors import azure
from sky.utils import common_utils

UNIQUE_ID_LEN = 4
Expand Down Expand Up @@ -120,17 +121,36 @@ def _configure_resource_group(config):
create_or_update = get_azure_sdk_function(
client=resource_client.deployments, function_name="create_or_update"
)
# TODO (skypilot): this takes a long time (> 40 seconds) for stopping an
# azure VM, and this can be called twice during ray down.
outputs = (
create_or_update(
resource_group_name=resource_group,
deployment_name="ray-config",
parameters=parameters,
)
.result()
.properties.outputs
# Skip creating or updating the deployment if the deployment already exists
# and the cluster name is the same.
get_deployment = get_azure_sdk_function(
client=resource_client.deployments, function_name="get"
)
deployment_exists = False
try:
deployment = get_deployment(
resource_group_name=resource_group, deployment_name="ray-config"
)
logger.info("Deployment already exists. Skipping deployment creation.")

outputs = deployment.properties.outputs
if outputs is not None:
deployment_exists = True
except azure.exceptions().ResourceNotFoundError:
deployment_exists = False

if not deployment_exists:
# TODO (skypilot): this takes a long time (> 40 seconds) for stopping an
# azure VM, and this can be called twice during ray down.
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
outputs = (
create_or_update(
resource_group_name=resource_group,
deployment_name="ray-config",
parameters=parameters,
)
.result()
.properties.outputs
)

# We should wait for the NSG to be created before opening any ports
# to avoid overriding the newly-added NSG rules.
Expand Down
35 changes: 14 additions & 21 deletions sky/skylet/providers/azure/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.resource.resources.models import DeploymentMode

from sky.adaptors import azure
from sky.skylet.providers.azure.config import (
bootstrap_azure,
get_azure_sdk_function,
)
from sky.skylet import autostop_lib
from sky.skylet.providers.command_runner import SkyDockerCommandRunner
from sky.provision import docker_utils

Expand Down Expand Up @@ -62,23 +62,7 @@ class AzureNodeProvider(NodeProvider):

def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
if not autostop_lib.get_is_autostopping():
# TODO(suquark): This is a temporary patch for resource group.
# By default, Ray autoscaler assumes the resource group is still
# here even after the whole cluster is destroyed. However, now we
# deletes the resource group after tearing down the cluster. To
# comfort the autoscaler, we need to create/update it here, so the
# resource group always exists.
#
# We should not re-configure the resource group again, when it is
# running on the remote VM and the autostopping is in progress,
# because the VM is running which guarantees the resource group
# exists.
from sky.skylet.providers.azure.config import _configure_resource_group

_configure_resource_group(
{"cluster_name": cluster_name, "provider": provider_config}
)

subscription_id = provider_config["subscription_id"]
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True)
# Sky only supports Azure CLI credential for now.
Expand Down Expand Up @@ -106,9 +90,18 @@ def match_tags(vm):
return False
return True

vms = self.compute_client.virtual_machines.list(
resource_group_name=self.provider_config["resource_group"]
)
try:
vms = list(
self.compute_client.virtual_machines.list(
resource_group_name=self.provider_config["resource_group"]
)
)
except azure.exceptions().ResourceNotFoundError as e:
if "Code: ResourceGroupNotFound" in e.exc_msg:
logger.debug("Resource group not found. VMs should be terminated.")
vms = []
else:
raise

nodes = [self._extract_metadata(vm) for vm in filter(match_tags, vms)]
self.cached_nodes = {node["name"]: node for node in nodes}
Expand Down
Loading