Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
landscapepainter committed Aug 17, 2024
1 parent 916143b commit 9d3a414
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
17 changes: 10 additions & 7 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,7 +1985,7 @@ def __init__(self,
name: str,
source: str,
storage_account_name: str = '',
region: Optional[str] = None,
region: Optional[str] = 'eastus',
is_sky_managed: Optional[bool] = None,
sync_on_reconstruction: bool = True):
self.storage_client: 'storage.Client'
Expand Down Expand Up @@ -2165,9 +2165,9 @@ def initialize(self):
self.is_sky_managed = is_new_bucket

@staticmethod
def get_default_storage_account_name(region: str) -> str:
def get_default_storage_account_name(region: Optional[str]) -> str:
"""Generates a default storage account name.
The subscription ID is included to avoid conflicts when user switches
subscriptions. The region value is hashed to ensure the storage account
name adheres to the 24-character limit, as some region names can be
Expand All @@ -2181,18 +2181,21 @@ def get_default_storage_account_name(region: str) -> str:
Returns:
Name of the default storage account.
"""
assert region is not None
subscription_id = azure.get_subscription_id()
subscription_hash_obj = hashlib.md5(subscription_id.encode('utf-8'))
subscription_hash = subscription_hash_obj.hexdigest()[:AzureBlobStore._SUBSCRIPTION_HASH_LENGTH]
subscription_hash = subscription_hash_obj.hexdigest(
)[:AzureBlobStore._SUBSCRIPTION_HASH_LENGTH]
region_hash_obj = hashlib.md5(region.encode('utf-8'))
region_hash = region_hash_obj.hexdigest()[:AzureBlobStore._SUBSCRIPTION_HASH_LENGTH]

region_hash = region_hash_obj.hexdigest()[:AzureBlobStore.
_SUBSCRIPTION_HASH_LENGTH]

storage_account_name = (
AzureBlobStore.DEFAULT_STORAGE_ACCOUNT_NAME.format(
region_hash=region_hash,
user_hash=common_utils.get_user_hash(),
subscription_hash=subscription_hash))

return storage_account_name

def _get_storage_account_and_resource_group(
Expand Down
14 changes: 6 additions & 8 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,8 @@ def test_azure_storage_mounts_with_stop():
cloud = 'azure'
storage_name = f'sky-test-{int(time.time())}'
default_region = 'eastus'
storage_account_name = (
storage_lib.AzureBlobStore.get_default_storage_account_name(default_region))
storage_account_name = (storage_lib.AzureBlobStore.
get_default_storage_account_name(default_region))
storage_account_key = data_utils.get_az_storage_account_key(
storage_account_name)
template_str = pathlib.Path(
Expand Down Expand Up @@ -2973,8 +2973,7 @@ def test_managed_jobs_storage(generic_cloud: str):
region = 'westus2'
region_flag = f' --region {region}'
storage_account_name = (
storage_lib.AzureBlobStore.get_default_storage_account_name(
region))
storage_lib.AzureBlobStore.get_default_storage_account_name(region))
region_cmd = TestStorageWithCredentials.cli_region_cmd(
storage_lib.StoreType.AZURE,
storage_account_name=storage_account_name)
Expand Down Expand Up @@ -4306,10 +4305,9 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''):
config_storage_account = skypilot_config.get_nested(
('azure', 'storage_account'), None)
storage_account_name = config_storage_account if (
config_storage_account is not None
) else (
storage_lib.AzureBlobStore.get_default_storage_account_name(
default_region))
config_storage_account is not None) else (
storage_lib.AzureBlobStore.get_default_storage_account_name(
default_region))
storage_account_key = data_utils.get_az_storage_account_key(
storage_account_name)
list_cmd = ('az storage blob list '
Expand Down

0 comments on commit 9d3a414

Please sign in to comment.