From 1c16396eac87e17b5a2d8039d83563c9ddad7172 Mon Sep 17 00:00:00 2001 From: Dean Jackson <57651082+deanja@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:55:51 +1100 Subject: [PATCH] Support fsspec keyword parameters kwargs or client_kwargs can be passed when constructing fsspec filesystem instances. See https://github.com/dlt-hub/dlt/pull/869 --- sources/filesystem/__init__.py | 33 +++++-- tests/filesystem/settings.py | 40 +++++---- tests/filesystem/test_filesystem.py | 134 +++++++++++++++++++--------- tests/filesystem/utils.py | 6 ++ 4 files changed, 150 insertions(+), 63 deletions(-) create mode 100644 tests/filesystem/utils.py diff --git a/sources/filesystem/__init__.py b/sources/filesystem/__init__.py index 4683a5d45..cf6ec8bcb 100644 --- a/sources/filesystem/__init__.py +++ b/sources/filesystem/__init__.py @@ -21,6 +21,8 @@ def readers( bucket_url: str = dlt.secrets.value, credentials: Union[FileSystemCredentials, AbstractFileSystem] = dlt.secrets.value, file_glob: Optional[str] = "*", + kwargs: Optional[DictStrAny] = None, + client_kwargs: Optional[DictStrAny] = None, ) -> Tuple[DltResource, ...]: """This source provides a few resources that are chunked file readers. Readers can be further parametrized before use read_csv(chunksize, **pandas_kwargs) @@ -33,11 +35,29 @@ def readers( file_glob (str, optional): The filter to apply to the files in glob format. by default lists all files in bucket_url non-recursively """ return ( - filesystem(bucket_url, credentials, file_glob=file_glob) + filesystem( + bucket_url, + credentials, + file_glob=file_glob, + kwargs=kwargs, + client_kwargs=client_kwargs, + ) | dlt.transformer(name="read_csv")(_read_csv), - filesystem(bucket_url, credentials, file_glob=file_glob) + filesystem( + bucket_url, + credentials, + file_glob=file_glob, + kwargs=kwargs, + client_kwargs=client_kwargs, + ) | dlt.transformer(name="read_jsonl")(_read_jsonl), - filesystem(bucket_url, credentials, file_glob=file_glob) + filesystem( + bucket_url, + credentials, + file_glob=file_glob, + kwargs=kwargs, + client_kwargs=client_kwargs, + ) | dlt.transformer(name="read_parquet")(_read_parquet), ) @@ -53,7 +73,6 @@ def filesystem( extract_content: bool = False, kwargs: Optional[DictStrAny] = None, client_kwargs: Optional[DictStrAny] = None, - ) -> Iterator[List[FileItem]]: """This resource lists files in `bucket_url` using `file_glob` pattern. The files are yielded as FileItem which also provide methods to open and read file data. It should be combined with transformers that further process (ie. load files) @@ -72,11 +91,13 @@ def filesystem( if isinstance(credentials, AbstractFileSystem): fs_client = credentials else: - fs_client = fsspec_filesystem(bucket_url, credentials, kwargs=kwargs, client_kwargs=client_kwargs)[0] + fs_client = fsspec_filesystem( + bucket_url, credentials, kwargs=kwargs, client_kwargs=client_kwargs + )[0] files_chunk: List[FileItem] = [] for file_model in glob_files(fs_client, bucket_url, file_glob): - file_dict = FileItemDict(file_model, credentials) + file_dict = FileItemDict(file_model, fs_client) if extract_content: file_dict["file_content"] = file_dict.read_bytes() files_chunk.append(file_dict) # type: ignore diff --git a/tests/filesystem/settings.py b/tests/filesystem/settings.py index d15a3eadd..da0f677c5 100644 --- a/tests/filesystem/settings.py +++ b/tests/filesystem/settings.py @@ -1,23 +1,31 @@ import os -TESTS_BUCKET_URLS = [ - os.path.abspath("tests/filesystem/samples"), - # Toginal: - # "s3://dlt-ci-test-bucket/standard_source/samples", +FACTORY_ARGS = [ + {"bucket_url": os.path.abspath("tests/filesystem/samples")}, + # Ooginal: + # { + # "bucket_url": "s3://dlt-ci-test-bucket/standard_source/samples", + # "kwargs": {"use_ssl": True} + # }, # deanja dev: - "s3://flyingfish-dlt-ci-test-bucket/standard_source/samples", - - # "gs://ci-test-bucket/standard_source/samples", - # "az://dlt-ci-test-bucket/standard_source/samples", - + { + "bucket_url": "s3://flyingfish-dlt-ci-test-bucket/standard_source/samples", + "kwargs": {"use_ssl": True} + }, + # {"bucket_url": "gs://ci-test-bucket/standard_source/samples"}, + # {"bucket_url": "az://dlt-ci-test-bucket/standard_source/samples"}, # gitpythonfs variations: - # For dlt.common.storages with no support for params in url netloc. If no - # function args provided it defaults to repo in working directory and ref HEAD - "gitpythonfs://samples", - # with separate bare-ish repo in `cases` and repo_path and ref specified in url netloc: - # "gitpythonfs://tests/filesystem/cases/git:unmodified-samples@samples", - - + # For dlt.common.storages with no support for params in url netloc. If no + # function args provided it defaults to repo in working directory and ref HEAD + { + "bucket_url": "gitpythonfs://samples", + "kwargs": { + "repo_path": "tests/filesystem/cases/git", + "ref": "unmodified-samples", + }, + } + # with dedicated test repo in `cases` and repo_path and ref specified in url netloc: + # ["bucket_url":"gitpythonfs://tests/filesystem/cases/git:unmodified-samples@samples"], ] GLOB_RESULTS = [ diff --git a/tests/filesystem/test_filesystem.py b/tests/filesystem/test_filesystem.py index 7e6c2d11f..15142e2fc 100644 --- a/tests/filesystem/test_filesystem.py +++ b/tests/filesystem/test_filesystem.py @@ -21,35 +21,38 @@ assert_query_data, TEST_STORAGE_ROOT, ) +from tests.filesystem.utils import unpack_factory_args -from .settings import GLOB_RESULTS, TESTS_BUCKET_URLS +from .settings import GLOB_RESULTS, FACTORY_ARGS -@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) +@pytest.mark.parametrize("factory_args", FACTORY_ARGS) @pytest.mark.parametrize("glob_params", GLOB_RESULTS) -def test_file_list(bucket_url: str, glob_params: Dict[str, Any]) -> None: +def test_file_list(factory_args: Dict[str, Any], glob_params: Dict[str, Any]) -> None: + bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args) + @dlt.transformer def bypass(items) -> str: return items - #301 hacked to try out params with gitpythonfs. ToDo: factor into pytest parameters. - if bucket_url.startswith("gitpythonfs"): - # we need to pass repo_path and ref to the resource - repo_args = { - "repo_path": "tests/filesystem/cases/git", - "ref": "unmodified-samples" - } - # we just pass the glob parameter to the resource if it is not None - if file_glob := glob_params["glob"]: - filesystem_res = filesystem(bucket_url=bucket_url, file_glob=file_glob, kwargs=repo_args) | bypass - else: - filesystem_res = filesystem(bucket_url=bucket_url, kwargs=repo_args) | bypass + # we only pass the glob parameter to the resource if it is not None + if file_glob := glob_params["glob"]: + filesystem_res = ( + filesystem( + bucket_url=bucket_url, + file_glob=file_glob, + kwargs=kwargs, + client_kwargs=client_kwargs, + ) + | bypass + ) else: - # we just pass the glob parameter to the resource if it is not None - if file_glob := glob_params["glob"]: - filesystem_res = filesystem(bucket_url=bucket_url, file_glob=file_glob) | bypass - else: - filesystem_res = filesystem(bucket_url=bucket_url) | bypass + filesystem_res = ( + filesystem( + bucket_url=bucket_url, kwargs=kwargs, client_kwargs=client_kwargs + ) + | bypass + ) all_files = list(filesystem_res) file_count = len(all_files) @@ -59,8 +62,12 @@ def bypass(items) -> str: @pytest.mark.parametrize("extract_content", [True, False]) -@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) -def test_load_content_resources(bucket_url: str, extract_content: bool) -> None: +@pytest.mark.parametrize("factory_args", FACTORY_ARGS) +def test_load_content_resources( + factory_args: Dict[str, Any], extract_content: bool +) -> None: + bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args) + @dlt.transformer def assert_sample_content(items: List[FileItem]): # expect just one file @@ -75,12 +82,13 @@ def assert_sample_content(items: List[FileItem]): yield items - # use transformer to test files sample_file = ( filesystem( bucket_url=bucket_url, file_glob="sample.txt", extract_content=extract_content, + kwargs=kwargs, + client_kwargs=client_kwargs, ) | assert_sample_content ) @@ -99,7 +107,12 @@ def assert_csv_file(item: FileItem): # print(item) return item - nested_file = filesystem(bucket_url, file_glob="met_csv/A801/A881_20230920.csv") + nested_file = filesystem( + bucket_url, + file_glob="met_csv/A801/A881_20230920.csv", + kwargs=kwargs, + client_kwargs=client_kwargs, + ) assert len(list(nested_file | assert_csv_file)) == 1 @@ -117,10 +130,12 @@ def test_fsspec_as_credentials(): print(list(gs_resource)) -@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) -def test_csv_transformers(bucket_url: str) -> None: +@pytest.mark.parametrize("factory_args", FACTORY_ARGS) +def test_csv_transformers(factory_args: Dict[str, Any]) -> None: from sources.filesystem_pipeline import read_csv + bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args) + pipeline = dlt.pipeline( pipeline_name="file_data", destination="duckdb", @@ -130,7 +145,13 @@ def test_csv_transformers(bucket_url: str) -> None: # load all csvs merging data on a date column met_files = ( - filesystem(bucket_url=bucket_url, file_glob="met_csv/A801/*.csv") | read_csv() + filesystem( + bucket_url=bucket_url, + file_glob="met_csv/A801/*.csv", + kwargs=kwargs, + client_kwargs=client_kwargs, + ) + | read_csv() ) met_files.apply_hints(write_disposition="merge", merge_key="date") load_info = pipeline.run(met_files.with_name("met_csv")) @@ -142,7 +163,13 @@ def test_csv_transformers(bucket_url: str) -> None: # load the other folder that contains data for the same day + one other day # the previous data will be replaced met_files = ( - filesystem(bucket_url=bucket_url, file_glob="met_csv/A803/*.csv") | read_csv() + filesystem( + bucket_url=bucket_url, + file_glob="met_csv/A803/*.csv", + kwargs=kwargs, + client_kwargs=client_kwargs, + ) + | read_csv() ) met_files.apply_hints(write_disposition="merge", merge_key="date") load_info = pipeline.run(met_files.with_name("met_csv")) @@ -155,14 +182,20 @@ def test_csv_transformers(bucket_url: str) -> None: assert load_table_counts(pipeline, "met_csv") == {"met_csv": 48} -@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) -def test_standard_readers(bucket_url: str) -> None: +@pytest.mark.parametrize("factory_args", FACTORY_ARGS) +def test_standard_readers(factory_args: Dict[str, Any]) -> None: + bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args) + # extract pipes with standard readers - jsonl_reader = readers(bucket_url, file_glob="**/*.jsonl").read_jsonl() - parquet_reader = readers(bucket_url, file_glob="**/*.parquet").read_parquet() - csv_reader = readers(bucket_url, file_glob="**/*.csv").read_csv( - float_precision="high" - ) + jsonl_reader = readers( + bucket_url, file_glob="**/*.jsonl", kwargs=kwargs, client_kwargs=client_kwargs + ).read_jsonl() + parquet_reader = readers( + bucket_url, file_glob="**/*.parquet", kwargs=kwargs, client_kwargs=client_kwargs + ).read_parquet() + csv_reader = readers( + bucket_url, file_glob="**/*.csv", kwargs=kwargs, client_kwargs=client_kwargs + ).read_csv(float_precision="high") # a step that copies files into test storage def _copy(item: FileItemDict): @@ -175,7 +208,9 @@ def _copy(item: FileItemDict): # return file item unchanged return item - downloader = filesystem(bucket_url, file_glob="**").add_map(_copy) + downloader = filesystem( + bucket_url, file_glob="**", kwargs=kwargs, client_kwargs=client_kwargs + ).add_map(_copy) # load in single pipeline pipeline = dlt.pipeline( @@ -205,12 +240,14 @@ def _copy(item: FileItemDict): # print(pipeline.default_schema.to_pretty_yaml()) -@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) -def test_incremental_load(bucket_url: str) -> None: +@pytest.mark.parametrize("factory_args", FACTORY_ARGS) +def test_incremental_load(factory_args: Dict[str, Any]) -> None: @dlt.transformer def bypass(items) -> str: return items + bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args) + pipeline = dlt.pipeline( pipeline_name="file_data", destination="duckdb", @@ -219,7 +256,12 @@ def bypass(items) -> str: ) # Load all files - all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*") + all_files = filesystem( + bucket_url=bucket_url, + file_glob="csv/*", + kwargs=kwargs, + client_kwargs=client_kwargs, + ) # add incremental on modification time all_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) load_info = pipeline.run((all_files | bypass).with_name("csv_files")) @@ -230,7 +272,12 @@ def bypass(items) -> str: assert table_counts["csv_files"] == 4 # load again - all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*") + all_files = filesystem( + bucket_url=bucket_url, + file_glob="csv/*", + kwargs=kwargs, + client_kwargs=client_kwargs, + ) all_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) load_info = pipeline.run((all_files | bypass).with_name("csv_files")) # nothing into csv_files @@ -239,7 +286,12 @@ def bypass(items) -> str: assert table_counts["csv_files"] == 4 # load again into different table - all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*") + all_files = filesystem( + bucket_url=bucket_url, + file_glob="csv/*", + kwargs=kwargs, + client_kwargs=client_kwargs, + ) all_files.apply_hints(incremental=dlt.sources.incremental("modification_date")) load_info = pipeline.run((all_files | bypass).with_name("csv_files_2")) assert_load_info(load_info) diff --git a/tests/filesystem/utils.py b/tests/filesystem/utils.py new file mode 100644 index 000000000..74cce4f79 --- /dev/null +++ b/tests/filesystem/utils.py @@ -0,0 +1,6 @@ +from typing import Any, Dict, List + + +def unpack_factory_args(factory_args: Dict[str, Any]) -> List[Any]: + """Unpacks filesystem factory arguments from pytest parameters.""" + return [factory_args.get(k) for k in ("bucket_url", "kwargs", "client_kwargs")]