Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify dataframe_to_mds to accept streaming DF #478

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 58 additions & 16 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import shutil
import time
from collections.abc import Iterable
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

Expand Down Expand Up @@ -224,8 +225,8 @@ def write_mds(iterator: Iterable):
],
axis=1)

if dataframe is None or dataframe.isEmpty():
raise ValueError(f'Input dataframe is None or Empty!')
if dataframe is None:
raise ValueError(f'Input dataframe must be provided')

if not mds_kwargs:
mds_kwargs = {}
Expand Down Expand Up @@ -261,6 +262,9 @@ def write_mds(iterator: Iterable):
if cu.remote is None:
mds_path = (cu.local, '')
else:
if dataframe.isStreaming:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why only outputting to local is possible here? What is the challenge to support original out (meaning local, remote or (local, remote))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to disallow concurrent writes to the top-level index.json. We need a lock file that is shared among all workers. This is implemented here using open with flags to ensure that opening the lock file fails if it already exists.

raise ValueError(
'dataframe_to_mds currently only supports outputting to a local directory')
mds_path = (cu.local, cu.remote)

# Prepare partition schema
Expand All @@ -269,21 +273,59 @@ def write_mds(iterator: Iterable):
StructField('mds_path_remote', StringType(), False),
StructField('fail_count', IntegerType(), False)
])
partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect()
mapped_df = dataframe.mapInPandas(func=write_mds, schema=result_schema)

if mapped_df.isStreaming:

def merge_and_log(df: DataFrame, batch_id: int):
partitions = df.collect()
if len(partitions) == 0:
return

if merge_index:
index_files = [
(row['mds_path_local'], row['mds_path_remote']) for row in partitions
]
lock_file_path = os.path.join(out, '.merge.lock')
# Acquire the lock.
while True:
try:
fd = os.open(lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
except OSError:
time.sleep(1) # File already exists, wait and try again
else:
break
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)
# Release the lock.
os.close(fd)

sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']

if sum_fail_count > 0:
logger.warning(
f'[Batch #{batch_id}] Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}'
)

mapped_df.writeStream.foreachBatch(merge_and_log).start()
return None, 0
else:
partitions = mapped_df.collect()

if merge_index:
index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions]
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)
if merge_index:
index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions]
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)

if cu.remote is not None:
if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False:
shutil.rmtree(cu.local, ignore_errors=True)
if cu.remote is not None:
if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False:
shutil.rmtree(cu.local, ignore_errors=True)

sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']
sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']

if sum_fail_count > 0:
logger.warning(
f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}')
return mds_path, sum_fail_count
if sum_fail_count > 0:
logger.warning(
f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}')
return mds_path, sum_fail_count