From 9d3a41426b331baad237f5127c4ec038f914eb60 Mon Sep 17 00:00:00 2001 From: Doyoung Kim Date: Sat, 17 Aug 2024 01:55:03 +0000 Subject: [PATCH] format --- sky/data/storage.py | 17 ++++++++++------- tests/test_smoke.py | 14 ++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/sky/data/storage.py b/sky/data/storage.py index 53aaf67b40d..beca625f874 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -1985,7 +1985,7 @@ def __init__(self, name: str, source: str, storage_account_name: str = '', - region: Optional[str] = None, + region: Optional[str] = 'eastus', is_sky_managed: Optional[bool] = None, sync_on_reconstruction: bool = True): self.storage_client: 'storage.Client' @@ -2165,9 +2165,9 @@ def initialize(self): self.is_sky_managed = is_new_bucket @staticmethod - def get_default_storage_account_name(region: str) -> str: + def get_default_storage_account_name(region: Optional[str]) -> str: """Generates a default storage account name. - + The subscription ID is included to avoid conflicts when user switches subscriptions. The region value is hashed to ensure the storage account name adheres to the 24-character limit, as some region names can be @@ -2181,18 +2181,21 @@ def get_default_storage_account_name(region: str) -> str: Returns: Name of the default storage account. """ + assert region is not None subscription_id = azure.get_subscription_id() subscription_hash_obj = hashlib.md5(subscription_id.encode('utf-8')) - subscription_hash = subscription_hash_obj.hexdigest()[:AzureBlobStore._SUBSCRIPTION_HASH_LENGTH] + subscription_hash = subscription_hash_obj.hexdigest( + )[:AzureBlobStore._SUBSCRIPTION_HASH_LENGTH] region_hash_obj = hashlib.md5(region.encode('utf-8')) - region_hash = region_hash_obj.hexdigest()[:AzureBlobStore._SUBSCRIPTION_HASH_LENGTH] - + region_hash = region_hash_obj.hexdigest()[:AzureBlobStore. + _SUBSCRIPTION_HASH_LENGTH] + storage_account_name = ( AzureBlobStore.DEFAULT_STORAGE_ACCOUNT_NAME.format( region_hash=region_hash, user_hash=common_utils.get_user_hash(), subscription_hash=subscription_hash)) - + return storage_account_name def _get_storage_account_and_resource_group( diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 33e6db52912..d5fcf159111 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1104,8 +1104,8 @@ def test_azure_storage_mounts_with_stop(): cloud = 'azure' storage_name = f'sky-test-{int(time.time())}' default_region = 'eastus' - storage_account_name = ( - storage_lib.AzureBlobStore.get_default_storage_account_name(default_region)) + storage_account_name = (storage_lib.AzureBlobStore. + get_default_storage_account_name(default_region)) storage_account_key = data_utils.get_az_storage_account_key( storage_account_name) template_str = pathlib.Path( @@ -2973,8 +2973,7 @@ def test_managed_jobs_storage(generic_cloud: str): region = 'westus2' region_flag = f' --region {region}' storage_account_name = ( - storage_lib.AzureBlobStore.get_default_storage_account_name( - region)) + storage_lib.AzureBlobStore.get_default_storage_account_name(region)) region_cmd = TestStorageWithCredentials.cli_region_cmd( storage_lib.StoreType.AZURE, storage_account_name=storage_account_name) @@ -4306,10 +4305,9 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''): config_storage_account = skypilot_config.get_nested( ('azure', 'storage_account'), None) storage_account_name = config_storage_account if ( - config_storage_account is not None - ) else ( - storage_lib.AzureBlobStore.get_default_storage_account_name( - default_region)) + config_storage_account is not None) else ( + storage_lib.AzureBlobStore.get_default_storage_account_name( + default_region)) storage_account_key = data_utils.get_az_storage_account_key( storage_account_name) list_cmd = ('az storage blob list '