From e11b76c2bbe7eaa7dbcbf012e5b3337ac1b2f4e9 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 23 Oct 2023 13:47:08 -0700 Subject: [PATCH 1/4] Add HF ingestion script --- scripts/ingestion/ingest_from_hf.py | 185 ++++++++++++++++++++++++++++ setup.py | 5 + 2 files changed, 190 insertions(+) create mode 100644 scripts/ingestion/ingest_from_hf.py diff --git a/scripts/ingestion/ingest_from_hf.py b/scripts/ingestion/ingest_from_hf.py new file mode 100644 index 000000000..7888cb79e --- /dev/null +++ b/scripts/ingestion/ingest_from_hf.py @@ -0,0 +1,185 @@ +# Improvements on snapshot_download: +# 1. Enable resume = True. retry when bad network happens +# 2. Disable progress_bar to prevent browser/terminal crash +# 3. Add a monitor to print out file stats every 15 seconds + +import os +import time +from huggingface_hub import snapshot_download +from pyspark.sql.functions import concat_ws +from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars +import asyncio +import threading + +from watchdog.observers import Observer +from watchdog.events import PatternMatchingEventHandler +from streaming.base.util import retry + + +class FolderObserver: + def __init__(self, directory): + patterns = ["*"] + ignore_patterns = None + ignore_directories = False + case_sensitive = True + self.average_file_size = 0 + + if not os.path.exists(directory): + os.makedirs(directory) + + self.directory = directory + self.get_directory_info() + + self.my_event_handler = PatternMatchingEventHandler(patterns, ignore_patterns, ignore_directories, case_sensitive) + self.my_event_handler.on_created = self.on_created + self.my_event_handler.on_deleted = self.on_deleted + self.my_event_handler.on_modified = self.on_modified + self.my_event_handler.on_moved = self.on_moved + + go_recursively = True + self.observer = Observer() + self.observer.schedule(self.my_event_handler, directory, recursive=go_recursively) + self.tik = time.time() + + def start(self): + return self.observer.start() + + def stop(self): + return self.observer.stop() + + def join(self): + return self.observer.join() + + def get_directory_info(self): + file_count = 0 + total_file_size = 0 + for root, _, files in os.walk(self.directory): + for file in files: + file_count += 1 + file_path = os.path.join(root, file) + total_file_size += os.path.getsize(file_path) + self.file_count, self.file_size = file_count, total_file_size + + def on_created(self, event): + self.file_count += 1 + + def on_deleted(self, event): + self.file_count -= 1 + + def on_modified(self, event): + pass + + def on_moved(self, event): + pass + + +def monitor_directory_changes(interval=5): + def decorator(func): + def wrapper(repo_id, local_dir, max_workers, token, allow_patterns, *args, **kwargs): + event = threading.Event() + start_time = time.time() # Capture the start time + observer = FolderObserver(local_dir) + + def beautify(kb): + mb = kb //(1024) + gb = mb //(1024) + return str(mb)+'MB' if mb >= 1 else str(kb) + 'KB' + + def monitor_directory(): + observer.start() + while not event.is_set(): + try: + elapsed_time = int(time.time() - observer.tik) + if observer.file_size > 1e9: # too large to keep an accurate count of the file size + if observer.average_file_size == 0: + observer.average_file_size = observer.file_size // observer.file_count + print(f"approximately: average file size = {beautify(observer.average_file_size//1024)}") + kb = observer.average_file_size * observer.file_count // 1024 + else: + observer.get_directory_info() + b = observer.file_size + kb = b // 1024 + + sz = beautify(kb) + cnt = observer.file_count + + if elapsed_time % 10 == 0 : + print(f"Downloaded {cnt} files, Total approx file size = {sz}, Time Elapsed: {elapsed_time} seconds.") + + if elapsed_time > 0 and elapsed_time % 120 == 0: + observer.get_directory_info() # Get the actual stats by walking through the directory + observer.average_file_size = observer.file_size // observer.file_count + print(f"update average file size to {beautify(observer.average_file_size//1024)}") + + time.sleep(1) + except Exception as exc: + # raise RuntimeError("Something bad happened") from exc + print(str(exc)) + time.sleep(1) + continue + + monitor_thread = threading.Thread(target=monitor_directory) + monitor_thread.start() + + try: + result = func(repo_id, local_dir, max_workers, token, allow_patterns, *args, **kwargs) + return result + finally: + observer.get_directory_info() # Get the actual stats by walking through the directory + print(f"Done! Downloaded {observer.file_count} files, Total file size = {beautify(observer.file_size//1024)}, Time Elapsed: {int(time.time() - observer.tik)} seconds.") + observer.stop() + observer.join() + + event.set() + monitor_thread.join() + + return wrapper + + return decorator + +def retry(max_retries=3, retry_delay=5): + def decorator(func): + def wrapper(*args, **kwargs): + for _ in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + print(f"An exception occurred: {str(e)}") + print(f"Retrying in {retry_delay} seconds...") + time.sleep(retry_delay) + raise Exception(f"Function {func.__name__} failed after {max_retries} retries.") + return wrapper + return decorator + +@monitor_directory_changes() +@retry(max_retries=10, retry_delay=10) +def hf_snapshot(repo_id, local_dir, max_workers, token, allow_patterns): + print(f"Now start to download {repo_id} to {local_dir}, with allow_patterns = {allow_patterns}") + output = snapshot_download(repo_id, repo_type="dataset", local_dir=local_dir, local_dir_use_symlinks=False, max_workers=max_workers, resume_download=True, token=token, allow_patterns=allow_patterns) + return output + +def download_hf_dataset(local_cache_directory, prefix='', submixes =[], max_workers=32, token="", allow_patterns=None): + disable_progress_bars() + for submix in submixes: + repo_id = os.path.join(prefix, submix) + local_dir = os.path.join(local_cache_directory, submix) + + output = hf_snapshot( + repo_id, + local_dir, + max_workers, + token, + allow_patterns=allow_patterns, + ) + +if __name__ == "__main__": + #download_hf_dataset(local_cache_directory="/tmp/xiaohan/cifar10_1233", prefix="", submixes=["cifar10"], max_workers=32) + download_hf_dataset( + local_cache_directory="/tmp/xiaohan/c4_1316", + prefix = "allenai/", + submixes = [ + "c4", + ], + allow_patterns=["en/*"], + max_workers=1, + token = "MY_HUGGINGFACE_ACCESS_TOKEN") # 32 seems to be a sweet point, beyond 32 downloading is not smooth diff --git a/setup.py b/setup.py index 9f0010a3e..e83252217 100644 --- a/setup.py +++ b/setup.py @@ -101,6 +101,11 @@ 'databricks-sdk==0.8.0', ] +extra_deps['ingestion'] = [ + 'huggingface_hub[cli,torch]>=0.16.0,<0.17.0', + 'watchdog>=3,<4' +] + extra_deps['all'] = sorted({dep for deps in extra_deps.values() for dep in deps}) package_name = os.environ.get('MOSAIC_PACKAGE_NAME', 'mosaicml-streaming') From 8362549c775ed0a932c8cc4aaacdd4a561c4cc68 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 23 Oct 2023 14:30:40 -0700 Subject: [PATCH 2/4] Fix lints --- scripts/ingestion/ingest_from_hf.py | 181 ++++++++++++++++++---------- setup.py | 5 +- 2 files changed, 120 insertions(+), 66 deletions(-) diff --git a/scripts/ingestion/ingest_from_hf.py b/scripts/ingestion/ingest_from_hf.py index 7888cb79e..abf2f3220 100644 --- a/scripts/ingestion/ingest_from_hf.py +++ b/scripts/ingestion/ingest_from_hf.py @@ -1,36 +1,51 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility for HF datasets Ingestion.""" + # Improvements on snapshot_download: # 1. Enable resume = True. retry when bad network happens # 2. Disable progress_bar to prevent browser/terminal crash # 3. Add a monitor to print out file stats every 15 seconds +import logging import os -import time -from huggingface_hub import snapshot_download -from pyspark.sql.functions import concat_ws -from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars -import asyncio import threading +import time +from typing import Any, List, Optional -from watchdog.observers import Observer +from huggingface_hub import snapshot_download +from huggingface_hub.utils import disable_progress_bars from watchdog.events import PatternMatchingEventHandler +from watchdog.observers import Observer + from streaming.base.util import retry +logger = logging.getLogger(__name__) + class FolderObserver: - def __init__(self, directory): - patterns = ["*"] + """A wrapper class of WatchDog.""" + + def __init__(self, directory: str): + """Specify the download directory to monitor.""" + patterns = ['*'] ignore_patterns = None ignore_directories = False case_sensitive = True self.average_file_size = 0 + self.file_count = 0 + self.file_size = 0 + if not os.path.exists(directory): os.makedirs(directory) self.directory = directory self.get_directory_info() - self.my_event_handler = PatternMatchingEventHandler(patterns, ignore_patterns, ignore_directories, case_sensitive) + self.my_event_handler = PatternMatchingEventHandler(patterns, ignore_patterns, + ignore_directories, case_sensitive) self.my_event_handler.on_created = self.on_created self.my_event_handler.on_deleted = self.on_deleted self.my_event_handler.on_modified = self.on_modified @@ -51,8 +66,8 @@ def join(self): return self.observer.join() def get_directory_info(self): - file_count = 0 - total_file_size = 0 + self.file_count = file_count = 0 + self.file_size = total_file_size = 0 for root, _, files in os.walk(self.directory): for file in files: file_count += 1 @@ -60,40 +75,56 @@ def get_directory_info(self): total_file_size += os.path.getsize(file_path) self.file_count, self.file_size = file_count, total_file_size - def on_created(self, event): + def on_created(self, event: Any): self.file_count += 1 - def on_deleted(self, event): + def on_deleted(self, event: Any): + print(type(event)) self.file_count -= 1 - def on_modified(self, event): + def on_modified(self, event: Any): + print(type(event)) pass - def on_moved(self, event): + def on_moved(self, event: Any): + print(type(event)) pass -def monitor_directory_changes(interval=5): - def decorator(func): - def wrapper(repo_id, local_dir, max_workers, token, allow_patterns, *args, **kwargs): +def monitor_directory_changes(interval: int = 5): + """Dataset downloading monitor. Keep file counts N and file size accumulation. + + Approximate dataset size by N * avg file size. + """ + + def decorator(func: Any): + + def wrapper(repo_id: str, local_dir: str, max_workers: int, token: str, + allow_patterns: Optional[List[str]], *args: Any, **kwargs: Any): event = threading.Event() - start_time = time.time() # Capture the start time observer = FolderObserver(local_dir) - def beautify(kb): - mb = kb //(1024) - gb = mb //(1024) - return str(mb)+'MB' if mb >= 1 else str(kb) + 'KB' + def beautify(kb: int): + mb = kb // (1024) + gb = mb // (1024) + if gb >= 1: + return str(gb) + 'GB' + elif mb >= 1: + return str(mb) + 'MB' + else: + return str(kb) + 'KB' def monitor_directory(): observer.start() while not event.is_set(): try: elapsed_time = int(time.time() - observer.tik) - if observer.file_size > 1e9: # too large to keep an accurate count of the file size + if observer.file_size > 1e9: # too large to keep an accurate count of the file size if observer.average_file_size == 0: observer.average_file_size = observer.file_size // observer.file_count - print(f"approximately: average file size = {beautify(observer.average_file_size//1024)}") + logger.warning( + f'approximately: average file size = {beautify(observer.average_file_size//1024)}' + ) kb = observer.average_file_size * observer.file_count // 1024 else: observer.get_directory_info() @@ -103,18 +134,23 @@ def monitor_directory(): sz = beautify(kb) cnt = observer.file_count - if elapsed_time % 10 == 0 : - print(f"Downloaded {cnt} files, Total approx file size = {sz}, Time Elapsed: {elapsed_time} seconds.") + if elapsed_time % 10 == 0: + logger.warning( + f'Downloaded {cnt} files, Total approx file size = {sz}, Time Elapsed: {elapsed_time} seconds.' + ) if elapsed_time > 0 and elapsed_time % 120 == 0: - observer.get_directory_info() # Get the actual stats by walking through the directory + observer.get_directory_info( + ) # Get the actual stats by walking through the directory observer.average_file_size = observer.file_size // observer.file_count - print(f"update average file size to {beautify(observer.average_file_size//1024)}") + logger.warning( + f'update average file size to {beautify(observer.average_file_size//1024)}' + ) time.sleep(1) except Exception as exc: # raise RuntimeError("Something bad happened") from exc - print(str(exc)) + logger.warning(str(exc)) time.sleep(1) continue @@ -122,11 +158,15 @@ def monitor_directory(): monitor_thread.start() try: - result = func(repo_id, local_dir, max_workers, token, allow_patterns, *args, **kwargs) + result = func(repo_id, local_dir, max_workers, token, allow_patterns, *args, + **kwargs) return result finally: - observer.get_directory_info() # Get the actual stats by walking through the directory - print(f"Done! Downloaded {observer.file_count} files, Total file size = {beautify(observer.file_size//1024)}, Time Elapsed: {int(time.time() - observer.tik)} seconds.") + observer.get_directory_info( + ) # Get the actual stats by walking through the directory + logger.warning( + f'Done! Downloaded {observer.file_count} files, Total file size = {beautify(observer.file_size//1024)}, Time Elapsed: {int(time.time() - observer.tik)} seconds.' + ) observer.stop() observer.join() @@ -137,34 +177,50 @@ def monitor_directory(): return decorator -def retry(max_retries=3, retry_delay=5): - def decorator(func): - def wrapper(*args, **kwargs): - for _ in range(max_retries): - try: - return func(*args, **kwargs) - except Exception as e: - print(f"An exception occurred: {str(e)}") - print(f"Retrying in {retry_delay} seconds...") - time.sleep(retry_delay) - raise Exception(f"Function {func.__name__} failed after {max_retries} retries.") - return wrapper - return decorator @monitor_directory_changes() -@retry(max_retries=10, retry_delay=10) -def hf_snapshot(repo_id, local_dir, max_workers, token, allow_patterns): - print(f"Now start to download {repo_id} to {local_dir}, with allow_patterns = {allow_patterns}") - output = snapshot_download(repo_id, repo_type="dataset", local_dir=local_dir, local_dir_use_symlinks=False, max_workers=max_workers, resume_download=True, token=token, allow_patterns=allow_patterns) +@retry([Exception, RuntimeError], num_attempts=10, initial_backoff=10) +def hf_snapshot(repo_id: str, local_dir: str, max_workers: int, token: str, + allow_patterns: Optional[List[str]]): + """API call to HF snapshot_download. + + which internally use hf_hub_download + """ + print( + f'Now start to download {repo_id} to {local_dir}, with allow_patterns = {allow_patterns}') + output = snapshot_download(repo_id, + repo_type='dataset', + local_dir=local_dir, + local_dir_use_symlinks=False, + max_workers=max_workers, + resume_download=True, + token=token, + allow_patterns=allow_patterns) return output -def download_hf_dataset(local_cache_directory, prefix='', submixes =[], max_workers=32, token="", allow_patterns=None): + +def download_hf_dataset(local_cache_directory: str, + prefix: str, + submixes: List[str], + token: str, + max_workers: int = 32, + allow_patterns: Optional[List[str]] = None) -> None: + """Disable progress bar and call hf_snapshot. + + Args: + local_cache_directory (str): local output directory the dataset will be written to. + prefix (str): HF namespace, allenai for example. + submixes (List): a list of repos within HF namespace, c4 for example. + token (str): HF access toekn. + max_workers (int): number of processors to parallelize downloading. + allow_patterns (List): only files matching the pattern will be download. E.g., "en/*" along with allenai/c4 means to download allenai/c4/en folder only. + """ disable_progress_bars() for submix in submixes: repo_id = os.path.join(prefix, submix) local_dir = os.path.join(local_cache_directory, submix) - output = hf_snapshot( + _ = hf_snapshot( repo_id, local_dir, max_workers, @@ -172,14 +228,15 @@ def download_hf_dataset(local_cache_directory, prefix='', submixes =[], max_work allow_patterns=allow_patterns, ) -if __name__ == "__main__": + +if __name__ == '__main__': #download_hf_dataset(local_cache_directory="/tmp/xiaohan/cifar10_1233", prefix="", submixes=["cifar10"], max_workers=32) - download_hf_dataset( - local_cache_directory="/tmp/xiaohan/c4_1316", - prefix = "allenai/", - submixes = [ - "c4", - ], - allow_patterns=["en/*"], - max_workers=1, - token = "MY_HUGGINGFACE_ACCESS_TOKEN") # 32 seems to be a sweet point, beyond 32 downloading is not smooth + download_hf_dataset(local_cache_directory='/tmp/xiaohan/c4_1316', + prefix='allenai/', + submixes=[ + 'c4', + ], + allow_patterns=['en/*'], + max_workers=1, + token='MY_HUGGINGFACE_ACCESS_TOKEN' + ) # 32 seems to be a sweet point, beyond 32 downloading is not smooth diff --git a/setup.py b/setup.py index e83252217..381bb2717 100644 --- a/setup.py +++ b/setup.py @@ -101,10 +101,7 @@ 'databricks-sdk==0.8.0', ] -extra_deps['ingestion'] = [ - 'huggingface_hub[cli,torch]>=0.16.0,<0.17.0', - 'watchdog>=3,<4' -] +extra_deps['ingestion'] = ['huggingface_hub[cli,torch]>=0.16.0,<0.17.0', 'watchdog>=3,<4'] extra_deps['all'] = sorted({dep for deps in extra_deps.values() for dep in deps}) From 65bdfe7ed469f10394decdb23e7ad2f62b6c3a52 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 23 Oct 2023 14:51:56 -0700 Subject: [PATCH 3/4] move scripts into streaming --- .../ingestion/ingest_from_hf.py => streaming/base/ingest_util.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/ingestion/ingest_from_hf.py => streaming/base/ingest_util.py (100%) diff --git a/scripts/ingestion/ingest_from_hf.py b/streaming/base/ingest_util.py similarity index 100% rename from scripts/ingestion/ingest_from_hf.py rename to streaming/base/ingest_util.py From 7c3d22285b0712e0f3438a5cdefbcd1789d5655a Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 23 Oct 2023 14:59:00 -0700 Subject: [PATCH 4/4] Remove prints --- streaming/base/ingest_util.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/streaming/base/ingest_util.py b/streaming/base/ingest_util.py index abf2f3220..a4baf49ec 100644 --- a/streaming/base/ingest_util.py +++ b/streaming/base/ingest_util.py @@ -79,15 +79,12 @@ def on_created(self, event: Any): self.file_count += 1 def on_deleted(self, event: Any): - print(type(event)) self.file_count -= 1 def on_modified(self, event: Any): - print(type(event)) pass def on_moved(self, event: Any): - print(type(event)) pass