diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 5752a5fa067..72800396283 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -2362,7 +2362,6 @@ def test_cancel_ibm(): # ---------- Testing use-spot option ---------- @pytest.mark.no_fluidstack # FluidStack does not support spot instances -@pytest.mark.no_azure # Azure does not support spot instances @pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.no_paperspace # Paperspace does not support spot instances @pytest.mark.no_ibm # IBM Cloud does not support spot instances @@ -2447,7 +2446,6 @@ def test_managed_jobs(generic_cloud: str): @pytest.mark.no_fluidstack #fluidstack does not support spot instances -@pytest.mark.no_azure # Azure does not support spot instances @pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.no_ibm # IBM Cloud does not support spot instances @pytest.mark.no_scp # SCP does not support spot instances @@ -2489,7 +2487,6 @@ def test_job_pipeline(generic_cloud: str): @pytest.mark.no_fluidstack #fluidstack does not support spot instances -@pytest.mark.no_azure # Azure does not support spot instances @pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.no_ibm # IBM Cloud does not support spot instances @pytest.mark.no_scp # SCP does not support spot instances @@ -2515,7 +2512,6 @@ def test_managed_jobs_failed_setup(generic_cloud: str): @pytest.mark.no_fluidstack #fluidstack does not support spot instances -@pytest.mark.no_azure # Azure does not support spot instances @pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.no_ibm # IBM Cloud does not support spot instances @pytest.mark.no_scp # SCP does not support spot instances @@ -2711,7 +2707,6 @@ def test_managed_jobs_pipeline_recovery_gcp(): @pytest.mark.no_fluidstack # Fluidstack does not support spot instances -@pytest.mark.no_azure # Azure does not support spot instances @pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.no_ibm # IBM Cloud does not support spot instances @pytest.mark.no_scp # SCP does not support spot instances @@ -2935,7 +2930,6 @@ def test_managed_jobs_cancellation_gcp(): # ---------- Testing storage for managed job ---------- @pytest.mark.no_fluidstack # Fluidstack does not support spot instances -@pytest.mark.no_azure # Azure does not support spot instances @pytest.mark.no_lambda_cloud # Lambda Cloud does not support spot instances @pytest.mark.no_ibm # IBM Cloud does not support spot instances @pytest.mark.no_paperspace # Paperspace does not support spot instances @@ -2961,7 +2955,7 @@ def test_managed_jobs_storage(generic_cloud: str): region = 'eu-central-1' region_flag = f' --region {region}' region_cmd = TestStorageWithCredentials.cli_region_cmd( - storage_lib.StoreType.S3, storage_name) + storage_lib.StoreType.S3, bucket_name=storage_name) region_validation_cmd = f'{region_cmd} | grep {region}' s3_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket( storage_lib.StoreType.S3, output_storage_name, 'output.txt') @@ -2970,11 +2964,27 @@ def test_managed_jobs_storage(generic_cloud: str): region = 'us-west2' region_flag = f' --region {region}' region_cmd = TestStorageWithCredentials.cli_region_cmd( - storage_lib.StoreType.GCS, storage_name) + storage_lib.StoreType.GCS, bucket_name=storage_name) region_validation_cmd = f'{region_cmd} | grep {region}' gcs_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket( storage_lib.StoreType.GCS, output_storage_name, 'output.txt') output_check_cmd = f'{gcs_check_file_count} | grep 1' + elif generic_cloud == 'azure': + region = 'westus2' + region_flag = f' --region {region}' + storage_account_name = ( + storage_lib.AzureBlobStore.DEFAULT_STORAGE_ACCOUNT_NAME.format( + region=region, user_hash=common_utils.get_user_hash())) + region_cmd = TestStorageWithCredentials.cli_region_cmd( + storage_lib.StoreType.AZURE, + storage_account_name=storage_account_name) + region_validation_cmd = f'{region_cmd} | grep {region}' + az_check_file_count = TestStorageWithCredentials.cli_count_name_in_bucket( + storage_lib.StoreType.AZURE, + output_storage_name, + 'output.txt', + storage_account_name=storage_account_name) + output_check_cmd = f'{az_check_file_count} | grep 1' elif generic_cloud == 'kubernetes': # With Kubernetes, we don't know which object storage provider is used. # Check both S3 and GCS if bucket exists in either. @@ -4323,20 +4333,32 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''): return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}' @staticmethod - def cli_region_cmd(store_type, bucket_name): + def cli_region_cmd(store_type, bucket_name=None, storage_account_name=None): if store_type == storage_lib.StoreType.S3: + assert bucket_name is not None return ('aws s3api get-bucket-location ' f'--bucket {bucket_name} --output text') elif store_type == storage_lib.StoreType.GCS: + assert bucket_name is not None return (f'gsutil ls -L -b gs://{bucket_name}/ | ' 'grep "Location constraint" | ' 'awk \'{print tolower($NF)}\'') + elif store_type == storage_lib.StoreType.AZURE: + # For Azure Blob Storage, the location of the containers are + # determined by the location of storage accounts. + assert storage_account_name is not None + return (f'az storage account show --name {storage_account_name} ' + '--query "primaryLocation" --output tsv') else: raise NotImplementedError(f'Region command not implemented for ' f'{store_type}') @staticmethod - def cli_count_name_in_bucket(store_type, bucket_name, file_name, suffix=''): + def cli_count_name_in_bucket(store_type, + bucket_name, + file_name, + suffix='', + storage_account_name=None): if store_type == storage_lib.StoreType.S3: if suffix: return f'aws s3api list-objects --bucket "{bucket_name}" --prefix {suffix} --query "length(Contents[?contains(Key,\'{file_name}\')].Key)"' @@ -4348,11 +4370,12 @@ def cli_count_name_in_bucket(store_type, bucket_name, file_name, suffix=''): else: return f'gsutil ls -r gs://{bucket_name} | grep "{file_name}" | wc -l' elif store_type == storage_lib.StoreType.AZURE: - default_region = 'eastus' - storage_account_name = ( - storage_lib.AzureBlobStore.DEFAULT_STORAGE_ACCOUNT_NAME.format( - region=default_region, - user_hash=common_utils.get_user_hash())) + if storage_account_name is None: + default_region = 'eastus' + storage_account_name = ( + storage_lib.AzureBlobStore.DEFAULT_STORAGE_ACCOUNT_NAME. + format(region=default_region, + user_hash=common_utils.get_user_hash())) storage_account_key = data_utils.get_az_storage_account_key( storage_account_name) return ('az storage blob list ' @@ -5101,7 +5124,7 @@ def test_aws_regions(self, tmp_local_storage_obj, region): bucket_name = tmp_local_storage_obj.name # Confirm that the bucket was created in the correct region - region_cmd = self.cli_region_cmd(store_type, bucket_name) + region_cmd = self.cli_region_cmd(store_type, bucket_name=bucket_name) out = subprocess.check_output(region_cmd, shell=True) output = out.decode('utf-8') expected_output_region = region @@ -5139,7 +5162,7 @@ def test_gcs_regions(self, tmp_local_storage_obj, region): bucket_name = tmp_local_storage_obj.name # Confirm that the bucket was created in the correct region - region_cmd = self.cli_region_cmd(store_type, bucket_name) + region_cmd = self.cli_region_cmd(store_type, bucket_name=bucket_name) out = subprocess.check_output(region_cmd, shell=True) output = out.decode('utf-8') assert region in out.decode('utf-8'), (