diff --git a/test/unload_databricks_data_to_s3_tests.py b/test/unload_databricks_data_to_s3_tests.py index c540be3..fa54164 100644 --- a/test/unload_databricks_data_to_s3_tests.py +++ b/test/unload_databricks_data_to_s3_tests.py @@ -1,8 +1,11 @@ import unittest +from unittest.mock import MagicMock + from unload_databricks_data_to_s3 import parse_table_versions_map_arg, build_sql_to_query_table_of_version, \ build_sql_to_query_table_between_versions, get_partition_count, normalize_sql_query, \ determine_first_table_name_in_sql, copy_and_inject_cdf_metadata_column_names, \ - determine_id_column_name_for_mutation_row_type, generate_sql_to_unload_mutation_data + determine_id_column_name_for_mutation_row_type, generate_sql_to_unload_mutation_data, \ + replace_double_slashes_with_single_slash, move_spark_metadata_to_separate_s3_folder class TestStringMethods(unittest.TestCase): @@ -161,3 +164,41 @@ def test_generate_sql_to_unload_mutation_data(self): self.assertEqual(normalize_sql_query(expected_output), normalize_sql_query(generate_sql_to_unload_mutation_data( sql_query, mutation_row_type, is_initial_sync))) + + def test_replace_double_slashes_with_single_slash(self): + input_string = '///path/to//file////with//double//////slashes/end/' + expected_output = '/path/to/file/with/double/slashes/end/' + self.assertEqual(expected_output, replace_double_slashes_with_single_slash(input_string)) + + def test_move_spark_metadata_to_separate_s3_folder(self): + # given + bucket = 'mybucket' + prefix = '/myprefix/with/subdir' + s3_uri = f's3://{bucket}/{prefix}' + + mock_s3 = MagicMock() + mock_s3.list_objects_v2.return_value = { + 'Contents': [ + {'Key': f'{prefix}/file1'}, + {'Key': f'{prefix}/part-file2'}, + {'Key': f'{prefix}/file3'}, + {'Key': f'{prefix}/part-file4'} + ] + } + + # when + move_spark_metadata_to_separate_s3_folder(mock_s3, s3_uri) + + # then + # Check that the list_objects_v2 method was called with the correct bucket and prefix + mock_s3.list_objects_v2.assert_called_with(Bucket='mybucket', Prefix='/myprefix/with/subdir', Delimiter='/') + + # Check that the copy_object and delete_object methods were called for the correct files + mock_s3.copy_object.assert_any_call(Bucket='mybucket', CopySource='mybucket/myprefix/with/subdir/file1', Key='/myprefix/with/subdir/spark_metadata/file1') + mock_s3.delete_object.assert_any_call(Bucket='mybucket', Key='/myprefix/with/subdir/file1') + mock_s3.copy_object.assert_any_call(Bucket='mybucket', CopySource='mybucket/myprefix/with/subdir/file3', Key='/myprefix/with/subdir/spark_metadata/file3') + mock_s3.delete_object.assert_any_call(Bucket='mybucket', Key='/myprefix/with/subdir/file3') + + # Check copy_object and delete_object were called the correct number of times + self.assertEqual(2, mock_s3.copy_object.call_count) + self.assertEqual(2, mock_s3.delete_object.call_count) diff --git a/unload_databricks_data_to_s3.py b/unload_databricks_data_to_s3.py index c704a64..7db83fe 100644 --- a/unload_databricks_data_to_s3.py +++ b/unload_databricks_data_to_s3.py @@ -4,6 +4,7 @@ import re import time +import boto3 from pyspark.sql import SparkSession from pyspark.sql import DataFrame from pyspark.sql.functions import col @@ -213,6 +214,55 @@ def get_partition_count(event_count: int, max_event_count_per_output_file: int) return max(1, math.ceil(event_count / max_event_count_per_output_file)) +def replace_double_slashes_with_single_slash(string: str) -> str: + while '//' in string: + string = string.replace('//', '/') + return string + + +def move_spark_metadata_to_separate_s3_folder(s3_client: boto3.client, s3_uri_with_spark_metadata: str): + """ + Lists all files in the s3_uri_with_spark_metadata directory and moves all spark metadata files (files without + '/part-' in the name) to a separate subdirectory under s3_uri_with_spark_metadata. + + NOTE: Expects spark metadata files to be in the same directory as the data files and available at the time of the + function call + + :param s3_client: boto3 client for s3 + :param s3_uri_with_spark_metadata: s3 URI with spark metadata files + """ + print(f'Moving spark metadata files to a separate subdirectory in {s3_uri_with_spark_metadata}') + + if '://' not in s3_uri_with_spark_metadata: + raise ValueError(f'Invalid s3 URI: {s3_uri_with_spark_metadata}. Expected to contain "://".') + bucket, prefix = s3_uri_with_spark_metadata.split('://')[1].split('/', 1) + bucket = replace_double_slashes_with_single_slash(bucket) + prefix = replace_double_slashes_with_single_slash(prefix) + print(f'Identified bucket: {bucket}, prefix: {prefix}') + + # List all files in the s3_path directory + response = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter='/') + print(f'Found {len(response["Contents"])} files in bucket {bucket} with prefix {prefix}:\n{response["Contents"]}') + + for file in response['Contents']: + key = file['Key'] + if '/part-' in key: + # Skip data files + continue + + # Construct the destination file name + new_key = replace_double_slashes_with_single_slash(key.replace(prefix, f'{prefix}/spark_metadata/')) + copy_source = replace_double_slashes_with_single_slash(f'{bucket}/{key}') + print(f'Moving copy_source: {copy_source} to new_key: {new_key} in bucket: {bucket}') + + # Copy the file to the new location + s3_client.copy_object(Bucket=bucket, CopySource=copy_source, Key=new_key) + # Delete the original file + s3_client.delete_object(Bucket=bucket, Key=key) + + print(f'Successfully moved {key} to {new_key}') + + def export_meta_data(event_count: int, partition_count: int): meta_data: list = [{'event_count': event_count, 'partition_count': partition_count}] spark.createDataFrame(meta_data).write.mode("overwrite").json(args.s3_path + "/meta") @@ -311,6 +361,14 @@ def export_meta_data(event_count: int, partition_count: int): export_data = export_data.repartition(partition_count) # export data export_data.write.mode("overwrite").json(args.s3_path) + # move spark metadata files to a separate subdirectory + # s3 client initialized after the data is exported to avoid missing files in the destination s3 path due to + # eventual consistency of s3 client + s3_client = boto3.client('s3', + aws_access_key_id=aws_access_key, + aws_secret_access_key=aws_secret_key, + aws_session_token=aws_session_token) + move_spark_metadata_to_separate_s3_folder(s3_client, args.s3_path) print("Unloaded {event_count} events.".format(event_count=event_count)) else: print("No events were exported.")