Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMP-97070 Move Spark metadata upon transfer completion #8

Open
wants to merge 1 commit into
base: AMP-96980
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion test/unload_databricks_data_to_s3_tests.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import unittest
from unittest.mock import MagicMock

from unload_databricks_data_to_s3 import parse_table_versions_map_arg, build_sql_to_query_table_of_version, \
build_sql_to_query_table_between_versions, get_partition_count, normalize_sql_query, \
determine_first_table_name_in_sql, copy_and_inject_cdf_metadata_column_names, \
determine_id_column_name_for_mutation_row_type, generate_sql_to_unload_mutation_data
determine_id_column_name_for_mutation_row_type, generate_sql_to_unload_mutation_data, \
replace_double_slashes_with_single_slash, move_spark_metadata_to_separate_s3_folder


class TestStringMethods(unittest.TestCase):
Expand Down Expand Up @@ -161,3 +164,41 @@ def test_generate_sql_to_unload_mutation_data(self):
self.assertEqual(normalize_sql_query(expected_output),
normalize_sql_query(generate_sql_to_unload_mutation_data(
sql_query, mutation_row_type, is_initial_sync)))

def test_replace_double_slashes_with_single_slash(self):
input_string = '///path/to//file////with//double//////slashes/end/'
expected_output = '/path/to/file/with/double/slashes/end/'
self.assertEqual(expected_output, replace_double_slashes_with_single_slash(input_string))

def test_move_spark_metadata_to_separate_s3_folder(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding tests!

# given
bucket = 'mybucket'
prefix = '/myprefix/with/subdir'
s3_uri = f's3://{bucket}/{prefix}'

mock_s3 = MagicMock()
mock_s3.list_objects_v2.return_value = {
'Contents': [
{'Key': f'{prefix}/file1'},
{'Key': f'{prefix}/part-file2'},
{'Key': f'{prefix}/file3'},
{'Key': f'{prefix}/part-file4'}
]
}

# when
move_spark_metadata_to_separate_s3_folder(mock_s3, s3_uri)

# then
# Check that the list_objects_v2 method was called with the correct bucket and prefix
mock_s3.list_objects_v2.assert_called_with(Bucket='mybucket', Prefix='/myprefix/with/subdir', Delimiter='/')

# Check that the copy_object and delete_object methods were called for the correct files
mock_s3.copy_object.assert_any_call(Bucket='mybucket', CopySource='mybucket/myprefix/with/subdir/file1', Key='/myprefix/with/subdir/spark_metadata/file1')
mock_s3.delete_object.assert_any_call(Bucket='mybucket', Key='/myprefix/with/subdir/file1')
mock_s3.copy_object.assert_any_call(Bucket='mybucket', CopySource='mybucket/myprefix/with/subdir/file3', Key='/myprefix/with/subdir/spark_metadata/file3')
mock_s3.delete_object.assert_any_call(Bucket='mybucket', Key='/myprefix/with/subdir/file3')

# Check copy_object and delete_object were called the correct number of times
self.assertEqual(2, mock_s3.copy_object.call_count)
self.assertEqual(2, mock_s3.delete_object.call_count)
58 changes: 58 additions & 0 deletions unload_databricks_data_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import time

import boto3
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
Expand Down Expand Up @@ -213,6 +214,55 @@ def get_partition_count(event_count: int, max_event_count_per_output_file: int)
return max(1, math.ceil(event_count / max_event_count_per_output_file))


def replace_double_slashes_with_single_slash(string: str) -> str:
while '//' in string:
string = string.replace('//', '/')
return string


def move_spark_metadata_to_separate_s3_folder(s3_client: boto3.client, s3_uri_with_spark_metadata: str):
"""
Lists all files in the s3_uri_with_spark_metadata directory and moves all spark metadata files (files without
'/part-' in the name) to a separate subdirectory under s3_uri_with_spark_metadata.

NOTE: Expects spark metadata files to be in the same directory as the data files and available at the time of the
function call

:param s3_client: boto3 client for s3
:param s3_uri_with_spark_metadata: s3 URI with spark metadata files
"""
print(f'Moving spark metadata files to a separate subdirectory in {s3_uri_with_spark_metadata}')

if '://' not in s3_uri_with_spark_metadata:
raise ValueError(f'Invalid s3 URI: {s3_uri_with_spark_metadata}. Expected to contain "://".')
bucket, prefix = s3_uri_with_spark_metadata.split('://')[1].split('/', 1)
bucket = replace_double_slashes_with_single_slash(bucket)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious where double slashes are coming from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for bucket, it shouldn't

It's just a sanitization in case some unnormalized input for s3uri is provided (e.g. s3:////bucket////prefix////)

prefix = replace_double_slashes_with_single_slash(prefix)
print(f'Identified bucket: {bucket}, prefix: {prefix}')

# List all files in the s3_path directory
response = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter='/')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this return folders as well? If only files are returned, then we are good.

Copy link
Collaborator Author

@LeontiBrechko LeontiBrechko Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only files. Same for listing objects in our Java repo using Amplitude's S3 wrapper

The example job mentioned in the description has the logs of what was discovered (note that /meta folder exists at the point of execution of this method and is not listed)

print(f'Found {len(response["Contents"])} files in bucket {bucket} with prefix {prefix}:\n{response["Contents"]}')

for file in response['Contents']:
key = file['Key']
if '/part-' in key:
# Skip data files
continue

# Construct the destination file name
new_key = replace_double_slashes_with_single_slash(key.replace(prefix, f'{prefix}/spark_metadata/'))
copy_source = replace_double_slashes_with_single_slash(f'{bucket}/{key}')
print(f'Moving copy_source: {copy_source} to new_key: {new_key} in bucket: {bucket}')

# Copy the file to the new location
s3_client.copy_object(Bucket=bucket, CopySource=copy_source, Key=new_key)
# Delete the original file
s3_client.delete_object(Bucket=bucket, Key=key)

print(f'Successfully moved {key} to {new_key}')


def export_meta_data(event_count: int, partition_count: int):
meta_data: list = [{'event_count': event_count, 'partition_count': partition_count}]
spark.createDataFrame(meta_data).write.mode("overwrite").json(args.s3_path + "/meta")
Expand Down Expand Up @@ -311,6 +361,14 @@ def export_meta_data(event_count: int, partition_count: int):
export_data = export_data.repartition(partition_count)
# export data
export_data.write.mode("overwrite").json(args.s3_path)
# move spark metadata files to a separate subdirectory
# s3 client initialized after the data is exported to avoid missing files in the destination s3 path due to
# eventual consistency of s3 client
s3_client = boto3.client('s3',
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token)
move_spark_metadata_to_separate_s3_folder(s3_client, args.s3_path)
print("Unloaded {event_count} events.".format(event_count=event_count))
else:
print("No events were exported.")
Expand Down