Skip to content

Commit

Permalink
Support fsspec keyword parameters
Browse files Browse the repository at this point in the history
kwargs or client_kwargs can be passed when constructing fsspec filesystem instances.
See dlt-hub/dlt#869
  • Loading branch information
deanja committed Jan 16, 2024
1 parent 2947665 commit 1c16396
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 63 deletions.
33 changes: 27 additions & 6 deletions sources/filesystem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def readers(
bucket_url: str = dlt.secrets.value,
credentials: Union[FileSystemCredentials, AbstractFileSystem] = dlt.secrets.value,
file_glob: Optional[str] = "*",
kwargs: Optional[DictStrAny] = None,
client_kwargs: Optional[DictStrAny] = None,
) -> Tuple[DltResource, ...]:
"""This source provides a few resources that are chunked file readers. Readers can be further parametrized before use
read_csv(chunksize, **pandas_kwargs)
Expand All @@ -33,11 +35,29 @@ def readers(
file_glob (str, optional): The filter to apply to the files in glob format. by default lists all files in bucket_url non-recursively
"""
return (
filesystem(bucket_url, credentials, file_glob=file_glob)
filesystem(
bucket_url,
credentials,
file_glob=file_glob,
kwargs=kwargs,
client_kwargs=client_kwargs,
)
| dlt.transformer(name="read_csv")(_read_csv),
filesystem(bucket_url, credentials, file_glob=file_glob)
filesystem(
bucket_url,
credentials,
file_glob=file_glob,
kwargs=kwargs,
client_kwargs=client_kwargs,
)
| dlt.transformer(name="read_jsonl")(_read_jsonl),
filesystem(bucket_url, credentials, file_glob=file_glob)
filesystem(
bucket_url,
credentials,
file_glob=file_glob,
kwargs=kwargs,
client_kwargs=client_kwargs,
)
| dlt.transformer(name="read_parquet")(_read_parquet),
)

Expand All @@ -53,7 +73,6 @@ def filesystem(
extract_content: bool = False,
kwargs: Optional[DictStrAny] = None,
client_kwargs: Optional[DictStrAny] = None,

) -> Iterator[List[FileItem]]:
"""This resource lists files in `bucket_url` using `file_glob` pattern. The files are yielded as FileItem which also
provide methods to open and read file data. It should be combined with transformers that further process (ie. load files)
Expand All @@ -72,11 +91,13 @@ def filesystem(
if isinstance(credentials, AbstractFileSystem):
fs_client = credentials
else:
fs_client = fsspec_filesystem(bucket_url, credentials, kwargs=kwargs, client_kwargs=client_kwargs)[0]
fs_client = fsspec_filesystem(
bucket_url, credentials, kwargs=kwargs, client_kwargs=client_kwargs
)[0]

files_chunk: List[FileItem] = []
for file_model in glob_files(fs_client, bucket_url, file_glob):
file_dict = FileItemDict(file_model, credentials)
file_dict = FileItemDict(file_model, fs_client)
if extract_content:
file_dict["file_content"] = file_dict.read_bytes()
files_chunk.append(file_dict) # type: ignore
Expand Down
40 changes: 24 additions & 16 deletions tests/filesystem/settings.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import os

TESTS_BUCKET_URLS = [
os.path.abspath("tests/filesystem/samples"),
# Toginal:
# "s3://dlt-ci-test-bucket/standard_source/samples",
FACTORY_ARGS = [
{"bucket_url": os.path.abspath("tests/filesystem/samples")},
# Ooginal:
# {
# "bucket_url": "s3://dlt-ci-test-bucket/standard_source/samples",
# "kwargs": {"use_ssl": True}
# },
# deanja dev:
"s3://flyingfish-dlt-ci-test-bucket/standard_source/samples",

# "gs://ci-test-bucket/standard_source/samples",
# "az://dlt-ci-test-bucket/standard_source/samples",

{
"bucket_url": "s3://flyingfish-dlt-ci-test-bucket/standard_source/samples",
"kwargs": {"use_ssl": True}
},
# {"bucket_url": "gs://ci-test-bucket/standard_source/samples"},
# {"bucket_url": "az://dlt-ci-test-bucket/standard_source/samples"},
# gitpythonfs variations:
# For dlt.common.storages with no support for params in url netloc. If no
# function args provided it defaults to repo in working directory and ref HEAD
"gitpythonfs://samples",
# with separate bare-ish repo in `cases` and repo_path and ref specified in url netloc:
# "gitpythonfs://tests/filesystem/cases/git:unmodified-samples@samples",


# For dlt.common.storages with no support for params in url netloc. If no
# function args provided it defaults to repo in working directory and ref HEAD
{
"bucket_url": "gitpythonfs://samples",
"kwargs": {
"repo_path": "tests/filesystem/cases/git",
"ref": "unmodified-samples",
},
}
# with dedicated test repo in `cases` and repo_path and ref specified in url netloc:
# ["bucket_url":"gitpythonfs://tests/filesystem/cases/git:unmodified-samples@samples"],
]

GLOB_RESULTS = [
Expand Down
134 changes: 93 additions & 41 deletions tests/filesystem/test_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,38 @@
assert_query_data,
TEST_STORAGE_ROOT,
)
from tests.filesystem.utils import unpack_factory_args

from .settings import GLOB_RESULTS, TESTS_BUCKET_URLS
from .settings import GLOB_RESULTS, FACTORY_ARGS


@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS)
@pytest.mark.parametrize("factory_args", FACTORY_ARGS)
@pytest.mark.parametrize("glob_params", GLOB_RESULTS)
def test_file_list(bucket_url: str, glob_params: Dict[str, Any]) -> None:
def test_file_list(factory_args: Dict[str, Any], glob_params: Dict[str, Any]) -> None:
bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args)

@dlt.transformer
def bypass(items) -> str:
return items

#301 hacked to try out params with gitpythonfs. ToDo: factor into pytest parameters.
if bucket_url.startswith("gitpythonfs"):
# we need to pass repo_path and ref to the resource
repo_args = {
"repo_path": "tests/filesystem/cases/git",
"ref": "unmodified-samples"
}
# we just pass the glob parameter to the resource if it is not None
if file_glob := glob_params["glob"]:
filesystem_res = filesystem(bucket_url=bucket_url, file_glob=file_glob, kwargs=repo_args) | bypass
else:
filesystem_res = filesystem(bucket_url=bucket_url, kwargs=repo_args) | bypass
# we only pass the glob parameter to the resource if it is not None
if file_glob := glob_params["glob"]:
filesystem_res = (
filesystem(
bucket_url=bucket_url,
file_glob=file_glob,
kwargs=kwargs,
client_kwargs=client_kwargs,
)
| bypass
)
else:
# we just pass the glob parameter to the resource if it is not None
if file_glob := glob_params["glob"]:
filesystem_res = filesystem(bucket_url=bucket_url, file_glob=file_glob) | bypass
else:
filesystem_res = filesystem(bucket_url=bucket_url) | bypass
filesystem_res = (
filesystem(
bucket_url=bucket_url, kwargs=kwargs, client_kwargs=client_kwargs
)
| bypass
)

all_files = list(filesystem_res)
file_count = len(all_files)
Expand All @@ -59,8 +62,12 @@ def bypass(items) -> str:


@pytest.mark.parametrize("extract_content", [True, False])
@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS)
def test_load_content_resources(bucket_url: str, extract_content: bool) -> None:
@pytest.mark.parametrize("factory_args", FACTORY_ARGS)
def test_load_content_resources(
factory_args: Dict[str, Any], extract_content: bool
) -> None:
bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args)

@dlt.transformer
def assert_sample_content(items: List[FileItem]):
# expect just one file
Expand All @@ -75,12 +82,13 @@ def assert_sample_content(items: List[FileItem]):

yield items

# use transformer to test files
sample_file = (
filesystem(
bucket_url=bucket_url,
file_glob="sample.txt",
extract_content=extract_content,
kwargs=kwargs,
client_kwargs=client_kwargs,
)
| assert_sample_content
)
Expand All @@ -99,7 +107,12 @@ def assert_csv_file(item: FileItem):
# print(item)
return item

nested_file = filesystem(bucket_url, file_glob="met_csv/A801/A881_20230920.csv")
nested_file = filesystem(
bucket_url,
file_glob="met_csv/A801/A881_20230920.csv",
kwargs=kwargs,
client_kwargs=client_kwargs,
)

assert len(list(nested_file | assert_csv_file)) == 1

Expand All @@ -117,10 +130,12 @@ def test_fsspec_as_credentials():
print(list(gs_resource))


@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS)
def test_csv_transformers(bucket_url: str) -> None:
@pytest.mark.parametrize("factory_args", FACTORY_ARGS)
def test_csv_transformers(factory_args: Dict[str, Any]) -> None:
from sources.filesystem_pipeline import read_csv

bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args)

pipeline = dlt.pipeline(
pipeline_name="file_data",
destination="duckdb",
Expand All @@ -130,7 +145,13 @@ def test_csv_transformers(bucket_url: str) -> None:

# load all csvs merging data on a date column
met_files = (
filesystem(bucket_url=bucket_url, file_glob="met_csv/A801/*.csv") | read_csv()
filesystem(
bucket_url=bucket_url,
file_glob="met_csv/A801/*.csv",
kwargs=kwargs,
client_kwargs=client_kwargs,
)
| read_csv()
)
met_files.apply_hints(write_disposition="merge", merge_key="date")
load_info = pipeline.run(met_files.with_name("met_csv"))
Expand All @@ -142,7 +163,13 @@ def test_csv_transformers(bucket_url: str) -> None:
# load the other folder that contains data for the same day + one other day
# the previous data will be replaced
met_files = (
filesystem(bucket_url=bucket_url, file_glob="met_csv/A803/*.csv") | read_csv()
filesystem(
bucket_url=bucket_url,
file_glob="met_csv/A803/*.csv",
kwargs=kwargs,
client_kwargs=client_kwargs,
)
| read_csv()
)
met_files.apply_hints(write_disposition="merge", merge_key="date")
load_info = pipeline.run(met_files.with_name("met_csv"))
Expand All @@ -155,14 +182,20 @@ def test_csv_transformers(bucket_url: str) -> None:
assert load_table_counts(pipeline, "met_csv") == {"met_csv": 48}


@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS)
def test_standard_readers(bucket_url: str) -> None:
@pytest.mark.parametrize("factory_args", FACTORY_ARGS)
def test_standard_readers(factory_args: Dict[str, Any]) -> None:
bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args)

# extract pipes with standard readers
jsonl_reader = readers(bucket_url, file_glob="**/*.jsonl").read_jsonl()
parquet_reader = readers(bucket_url, file_glob="**/*.parquet").read_parquet()
csv_reader = readers(bucket_url, file_glob="**/*.csv").read_csv(
float_precision="high"
)
jsonl_reader = readers(
bucket_url, file_glob="**/*.jsonl", kwargs=kwargs, client_kwargs=client_kwargs
).read_jsonl()
parquet_reader = readers(
bucket_url, file_glob="**/*.parquet", kwargs=kwargs, client_kwargs=client_kwargs
).read_parquet()
csv_reader = readers(
bucket_url, file_glob="**/*.csv", kwargs=kwargs, client_kwargs=client_kwargs
).read_csv(float_precision="high")

# a step that copies files into test storage
def _copy(item: FileItemDict):
Expand All @@ -175,7 +208,9 @@ def _copy(item: FileItemDict):
# return file item unchanged
return item

downloader = filesystem(bucket_url, file_glob="**").add_map(_copy)
downloader = filesystem(
bucket_url, file_glob="**", kwargs=kwargs, client_kwargs=client_kwargs
).add_map(_copy)

# load in single pipeline
pipeline = dlt.pipeline(
Expand Down Expand Up @@ -205,12 +240,14 @@ def _copy(item: FileItemDict):
# print(pipeline.default_schema.to_pretty_yaml())


@pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS)
def test_incremental_load(bucket_url: str) -> None:
@pytest.mark.parametrize("factory_args", FACTORY_ARGS)
def test_incremental_load(factory_args: Dict[str, Any]) -> None:
@dlt.transformer
def bypass(items) -> str:
return items

bucket_url, kwargs, client_kwargs = unpack_factory_args(factory_args)

pipeline = dlt.pipeline(
pipeline_name="file_data",
destination="duckdb",
Expand All @@ -219,7 +256,12 @@ def bypass(items) -> str:
)

# Load all files
all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*")
all_files = filesystem(
bucket_url=bucket_url,
file_glob="csv/*",
kwargs=kwargs,
client_kwargs=client_kwargs,
)
# add incremental on modification time
all_files.apply_hints(incremental=dlt.sources.incremental("modification_date"))
load_info = pipeline.run((all_files | bypass).with_name("csv_files"))
Expand All @@ -230,7 +272,12 @@ def bypass(items) -> str:
assert table_counts["csv_files"] == 4

# load again
all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*")
all_files = filesystem(
bucket_url=bucket_url,
file_glob="csv/*",
kwargs=kwargs,
client_kwargs=client_kwargs,
)
all_files.apply_hints(incremental=dlt.sources.incremental("modification_date"))
load_info = pipeline.run((all_files | bypass).with_name("csv_files"))
# nothing into csv_files
Expand All @@ -239,7 +286,12 @@ def bypass(items) -> str:
assert table_counts["csv_files"] == 4

# load again into different table
all_files = filesystem(bucket_url=bucket_url, file_glob="csv/*")
all_files = filesystem(
bucket_url=bucket_url,
file_glob="csv/*",
kwargs=kwargs,
client_kwargs=client_kwargs,
)
all_files.apply_hints(incremental=dlt.sources.incremental("modification_date"))
load_info = pipeline.run((all_files | bypass).with_name("csv_files_2"))
assert_load_info(load_info)
Expand Down
6 changes: 6 additions & 0 deletions tests/filesystem/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Any, Dict, List


def unpack_factory_args(factory_args: Dict[str, Any]) -> List[Any]:
"""Unpacks filesystem factory arguments from pytest parameters."""
return [factory_args.get(k) for k in ("bucket_url", "kwargs", "client_kwargs")]

0 comments on commit 1c16396

Please sign in to comment.