Skip to content

Commit

Permalink
Merge pull request #626 from dlt-hub/rfix/moves-filesystem-common
Browse files Browse the repository at this point in the history
moves fsspec support to common
  • Loading branch information
rudolfix authored Sep 13, 2023
2 parents 839e5bb + a225fbc commit 474a420
Show file tree
Hide file tree
Showing 15 changed files with 253 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_local_destinations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:
- name: Install dependencies
run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E weaviate

- run: poetry run pytest tests/load tests/cli
- run: poetry run pytest tests/load && poetry run pytest tests/cli
name: Run tests Linux
env:
DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data
Expand Down
1 change: 1 addition & 0 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:


def last_config(**kwargs: Any) -> Any:
"""Get configuration instance used to inject function arguments """
return kwargs[_LAST_DLT_CONFIG]


Expand Down
1 change: 1 addition & 0 deletions dlt/common/configuration/specs/azure_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def on_partial(self) -> None:
if not self.is_partial():
self.resolve()


@configspec
class AzureCredentials(AzureCredentialsWithoutDefaults, CredentialsWithDefault):
def on_partial(self) -> None:
Expand Down
15 changes: 10 additions & 5 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]:
cls = type(cls.__name__, (cls, _F_BaseConfiguration), fields)
# get all annotations without corresponding attributes and set them to None
for ann in cls.__annotations__:
if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_impl")):
if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_")):
setattr(cls, ann, None)
# get all attributes without corresponding annotations
for att_name, att_value in list(cls.__dict__.items()):
Expand All @@ -129,7 +129,7 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]:
except NameError:
# Dealing with BaseConfiguration itself before it is defined
continue
if not att_name.startswith(("__", "_abc_impl")) and not isinstance(att_value, (staticmethod, classmethod, property)):
if not att_name.startswith(("__", "_abc_")) and not isinstance(att_value, (staticmethod, classmethod, property)):
if att_name not in cls.__annotations__:
raise ConfigFieldMissingTypeHintException(att_name, cls)
hint = cls.__annotations__[att_name]
Expand Down Expand Up @@ -211,7 +211,7 @@ def _get_resolvable_dataclass_fields(cls) -> Iterator[TDtcField]:
"""Yields all resolvable dataclass fields in the order they should be resolved"""
# Sort dynamic type hint fields last because they depend on other values
yield from sorted(
(f for f in cls.__dataclass_fields__.values() if not f.name.startswith("__")),
(f for f in cls.__dataclass_fields__.values() if cls.__is_valid_field(f)),
key=lambda f: f.name in cls.__hint_resolvers__
)

Expand Down Expand Up @@ -264,7 +264,8 @@ def __delitem__(self, __key: str) -> None:
raise KeyError("Configuration fields cannot be deleted")

def __iter__(self) -> Iterator[str]:
return filter(lambda k: not k.startswith("__"), self.__dataclass_fields__.__iter__())
"""Iterator or valid key names"""
return map(lambda field: field.name, filter(lambda val: self.__is_valid_field(val), self.__dataclass_fields__.values()))

def __len__(self) -> int:
return sum(1 for _ in self.__iter__())
Expand All @@ -279,7 +280,11 @@ def update(self, other: Any = (), /, **kwds: Any) -> None:
# helper functions

def __has_attr(self, __key: str) -> bool:
return __key in self.__dataclass_fields__ and not __key.startswith("__")
return __key in self.__dataclass_fields__ and self.__is_valid_field(self.__dataclass_fields__[__key])

@staticmethod
def __is_valid_field(field: TDtcField) -> bool:
return not field.name.startswith("__") and field._field_type is dataclasses._FIELD # type: ignore

def call_method_in_mro(config, method_name: str) -> None:
# python multi-inheritance is cooperative and this would require that all configurations cooperatively
Expand Down
3 changes: 2 additions & 1 deletion dlt/common/storages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .normalize_storage import NormalizeStorage # noqa: F401
from .load_storage import LoadStorage # noqa: F401
from .data_item_storage import DataItemStorage # noqa: F401
from .configuration import LoadStorageConfiguration, NormalizeStorageConfiguration, SchemaStorageConfiguration, TSchemaFileFormat # noqa: F401
from .configuration import LoadStorageConfiguration, NormalizeStorageConfiguration, SchemaStorageConfiguration, TSchemaFileFormat, FilesystemConfiguration # noqa: F401
from .filesystem import filesystem_from_config, filesystem # noqa: F401
82 changes: 80 additions & 2 deletions dlt/common/storages/configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import TYPE_CHECKING, Literal, Optional, get_args
from urllib.parse import urlparse
from typing import TYPE_CHECKING, Any, Literal, Optional, Type, get_args, ClassVar, Dict, Union

from dlt.common.configuration.specs import BaseConfiguration, configspec
from dlt.common.configuration.specs import BaseConfiguration, configspec, CredentialsConfiguration
from dlt.common.configuration import configspec, resolve_type
from dlt.common.configuration.specs import GcpServiceAccountCredentials, AwsCredentials, GcpOAuthCredentials, AzureCredentials, AzureCredentialsWithoutDefaults, BaseConfiguration
from dlt.common.utils import digest128
from dlt.common.configuration.exceptions import ConfigurationValueError

TSchemaFileFormat = Literal["json", "yaml"]
SchemaFileExtensions = get_args(TSchemaFileFormat)
Expand Down Expand Up @@ -36,3 +41,76 @@ class LoadStorageConfiguration(BaseConfiguration):
if TYPE_CHECKING:
def __init__(self, load_volume_path: str = None, delete_completed_jobs: bool = None) -> None:
...


FileSystemCredentials = Union[AwsCredentials, GcpServiceAccountCredentials, AzureCredentials, GcpOAuthCredentials]

@configspec
class FilesystemConfiguration(BaseConfiguration):
"""A configuration defining filesystem location and access credentials.
When configuration is resolved, `bucket_url` is used to extract a protocol and request corresponding credentials class.
* s3
* gs, gcs
* az, abfs, adl
* file, memory
* gdrive
"""
PROTOCOL_CREDENTIALS: ClassVar[Dict[str, Any]] = {
"gs": Union[GcpServiceAccountCredentials, GcpOAuthCredentials],
"gcs": Union[GcpServiceAccountCredentials, GcpOAuthCredentials],
"gdrive": GcpOAuthCredentials,
"s3": AwsCredentials,
"az": Union[AzureCredentialsWithoutDefaults, AzureCredentials],
"abfs": Union[AzureCredentialsWithoutDefaults, AzureCredentials],
"adl": Union[AzureCredentialsWithoutDefaults, AzureCredentials],
}

bucket_url: str = None
# should be an union of all possible credentials as found in PROTOCOL_CREDENTIALS
credentials: FileSystemCredentials

@property
def protocol(self) -> str:
"""`bucket_url` protocol"""
url = urlparse(self.bucket_url)
return url.scheme or "file"

def on_resolved(self) -> None:
url = urlparse(self.bucket_url)
if not url.path and not url.netloc:
raise ConfigurationValueError("File path or netloc missing. Field bucket_url of FilesystemClientConfiguration must contain valid url with a path or host:password component.")
# this is just a path in local file system
if url.path == self.bucket_url:
url = url._replace(scheme="file")
self.bucket_url = url.geturl()

@resolve_type('credentials')
def resolve_credentials_type(self) -> Type[CredentialsConfiguration]:
# use known credentials or empty credentials for unknown protocol
return self.PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration] # type: ignore[return-value]

def fingerprint(self) -> str:
"""Returns a fingerprint of bucket_url"""
if self.bucket_url:
return digest128(self.bucket_url)
return ""

def __str__(self) -> str:
"""Return displayable destination location"""
url = urlparse(self.bucket_url)
# do not show passwords
if url.password:
new_netloc = f"{url.username}:****@{url.hostname}"
if url.port:
new_netloc += f":{url.port}"
return url._replace(netloc=new_netloc).geturl()
return self.bucket_url

if TYPE_CHECKING:
def __init__(
self,
bucket_url: str,
credentials: FileSystemCredentials = None
) -> None:
...
80 changes: 80 additions & 0 deletions dlt/common/storages/filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import cast, Tuple, TypedDict, Optional, Union

from fsspec.core import url_to_fs
from fsspec import AbstractFileSystem

from dlt.common import pendulum
from dlt.common.exceptions import MissingDependencyException
from dlt.common.time import ensure_pendulum_datetime
from dlt.common.typing import DictStrAny
from dlt.common.configuration.specs import CredentialsWithDefault, GcpCredentials, AwsCredentials, AzureCredentials
from dlt.common.storages.configuration import FileSystemCredentials, FilesystemConfiguration

from dlt import version


class FileItem(TypedDict):
"""A DataItem representing a file"""
file_url: str
file_name: str
mime_type: str
modification_date: pendulum.DateTime
file_content: Optional[Union[str, bytes]]


# Map of protocol to mtime resolver
# we only need to support a small finite set of protocols
MTIME_DISPATCH = {
"s3": lambda f: ensure_pendulum_datetime(f["LastModified"]),
"adl": lambda f: ensure_pendulum_datetime(f["LastModified"]),
"az": lambda f: ensure_pendulum_datetime(f["last_modified"]),
"gcs": lambda f: ensure_pendulum_datetime(f["updated"]),
"file": lambda f: ensure_pendulum_datetime(f["mtime"]),
"memory": lambda f: ensure_pendulum_datetime(f["created"]),
}
# Support aliases
MTIME_DISPATCH["gs"] = MTIME_DISPATCH["gcs"]
MTIME_DISPATCH["s3a"] = MTIME_DISPATCH["s3"]
MTIME_DISPATCH["abfs"] = MTIME_DISPATCH["az"]


def filesystem(protocol: str, credentials: FileSystemCredentials = None) -> Tuple[AbstractFileSystem, str]:
"""Instantiates an authenticated fsspec `FileSystem` for a given `protocol` and credentials.
Please supply credentials instance corresponding to the protocol
"""
return filesystem_from_config(FilesystemConfiguration(protocol, credentials))



def filesystem_from_config(config: FilesystemConfiguration) -> Tuple[AbstractFileSystem, str]:
"""Instantiates an authenticated fsspec `FileSystem` from `config` argument.
Authenticates following filesystems:
* s3
* az, abfs
* gcs, gs
All other filesystems are not authenticated
Returns: (fsspec filesystem, normalized url)
"""
proto = config.protocol
fs_kwargs: DictStrAny = {}
if proto == "s3":
fs_kwargs.update(cast(AwsCredentials, config.credentials).to_s3fs_credentials())
elif proto in ["az", "abfs", "adl", "azure"]:
fs_kwargs.update(cast(AzureCredentials, config.credentials).to_adlfs_credentials())
elif proto in ['gcs', 'gs']:
assert isinstance(config.credentials, GcpCredentials)
# Default credentials are handled by gcsfs
if isinstance(config.credentials, CredentialsWithDefault) and config.credentials.has_default_credentials():
fs_kwargs['token'] = None
else:
fs_kwargs['token'] = dict(config.credentials)
fs_kwargs['project'] = config.credentials.project_id
try:
return url_to_fs(config.bucket_url, use_listings_cache=False, **fs_kwargs) # type: ignore[no-any-return]
except ModuleNotFoundError as e:
raise MissingDependencyException("filesystem", [f"{version.DLT_PKG_NAME}[{proto}]"]) from e
20 changes: 4 additions & 16 deletions dlt/common/storages/transactional_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from pathlib import Path
import posixpath
from contextlib import contextmanager
from dlt.common.pendulum import pendulum, timedelta
from threading import Timer

import fsspec

from dlt.common.pendulum import pendulum, timedelta
from dlt.common.storages.filesystem import MTIME_DISPATCH


def lock_id(k: int = 4) -> str:
"""Generate a time based random id.
Expand Down Expand Up @@ -44,19 +45,6 @@ def run(self) -> None:
class TransactionalFile:
"""A transaction handler which wraps a file path."""

# Map of protocol to mtime resolver
# we only need to support a small finite set of protocols
_mtime_dispatch = {
"s3": lambda f: pendulum.parser.parse(f["LastModified"]),
"adl": lambda f: pendulum.parser.parse(f["LastModified"]),
"gcs": lambda f: pendulum.parser.parse(f["updated"]),
"file": lambda f: pendulum.from_timestamp(f["mtime"]),
}
# Support aliases
_mtime_dispatch["gs"] = _mtime_dispatch["gcs"]
_mtime_dispatch["s3a"] = _mtime_dispatch["s3"]
_mtime_dispatch["azure"] = _mtime_dispatch["adl"]

POLLING_INTERVAL = 0.5
LOCK_TTL_SECONDS = 30.0

Expand All @@ -68,7 +56,7 @@ def __init__(self, path: str, fs: fsspec.AbstractFileSystem) -> None:
fs: The fsspec file system.
"""
proto = fs.protocol[0] if isinstance(fs.protocol, (list, tuple)) else fs.protocol
self.extract_mtime = self._mtime_dispatch.get(proto, self._mtime_dispatch["file"])
self.extract_mtime = MTIME_DISPATCH.get(proto, MTIME_DISPATCH["file"])

parsed_path = Path(path)
if not parsed_path.is_absolute():
Expand Down
10 changes: 5 additions & 5 deletions dlt/destinations/filesystem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import JobClientBase, DestinationClientDwhWithStagingConfiguration

from dlt.destinations.filesystem.configuration import FilesystemClientConfiguration
from dlt.destinations.filesystem.configuration import FilesystemDestinationClientConfiguration


@with_config(spec=FilesystemClientConfiguration, sections=(known_sections.DESTINATION, "filesystem",))
def _configure(config: FilesystemClientConfiguration = config.value) -> FilesystemClientConfiguration:
@with_config(spec=FilesystemDestinationClientConfiguration, sections=(known_sections.DESTINATION, "filesystem",))
def _configure(config: FilesystemDestinationClientConfiguration = config.value) -> FilesystemDestinationClientConfiguration:
return config


Expand All @@ -25,5 +25,5 @@ def client(schema: Schema, initial_config: DestinationClientDwhWithStagingConfig
return FilesystemClient(schema, _configure(initial_config)) # type: ignore


def spec() -> Type[FilesystemClientConfiguration]:
return FilesystemClientConfiguration
def spec() -> Type[FilesystemDestinationClientConfiguration]:
return FilesystemDestinationClientConfiguration
55 changes: 5 additions & 50 deletions dlt/destinations/filesystem/configuration.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,26 @@
from urllib.parse import urlparse

from typing import Final, Type, Optional, Union, TYPE_CHECKING
from typing import Final, Type, Optional, Any, TYPE_CHECKING

from dlt.common.configuration import configspec, resolve_type
from dlt.common.destination.reference import CredentialsConfiguration, DestinationClientStagingConfiguration
from dlt.common.configuration.specs import GcpServiceAccountCredentials, AwsCredentials, GcpOAuthCredentials, AzureCredentials, AzureCredentialsWithoutDefaults
from dlt.common.utils import digest128
from dlt.common.configuration.exceptions import ConfigurationValueError


PROTOCOL_CREDENTIALS = {
"gs": Union[GcpServiceAccountCredentials, GcpOAuthCredentials],
"gcs": Union[GcpServiceAccountCredentials, GcpOAuthCredentials],
"gdrive": GcpOAuthCredentials,
"s3": AwsCredentials,
"az": Union[AzureCredentialsWithoutDefaults, AzureCredentials],
"abfs": Union[AzureCredentialsWithoutDefaults, AzureCredentials],
}
from dlt.common.storages import FilesystemConfiguration


@configspec
class FilesystemClientConfiguration(DestinationClientStagingConfiguration):
class FilesystemDestinationClientConfiguration(FilesystemConfiguration, DestinationClientStagingConfiguration): # type: ignore[misc]
destination_name: Final[str] = "filesystem" # type: ignore
# should be an union of all possible credentials as found in PROTOCOL_CREDENTIALS
credentials: Union[AwsCredentials, GcpServiceAccountCredentials, AzureCredentials, GcpOAuthCredentials]

@property
def protocol(self) -> str:
url = urlparse(self.bucket_url)
return url.scheme or "file"

def on_resolved(self) -> None:
url = urlparse(self.bucket_url)
if not url.path and not url.netloc:
raise ConfigurationValueError("File path or netloc missing. Field bucket_url of FilesystemClientConfiguration must contain valid url with a path or host:password component.")
# this is just a path in local file system
if url.path == self.bucket_url:
url = url._replace(scheme="file")
self.bucket_url = url.geturl()

@resolve_type('credentials')
def resolve_credentials_type(self) -> Type[CredentialsConfiguration]:
# use known credentials or empty credentials for unknown protocol
return PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration] # type: ignore[return-value]

def fingerprint(self) -> str:
"""Returns a fingerprint of bucket_url"""
if self.bucket_url:
return digest128(self.bucket_url)
return ""

def __str__(self) -> str:
"""Return displayable destination location"""
url = urlparse(self.bucket_url)
# do not show passwords
if url.password:
new_netloc = f"{url.username}:****@{url.hostname}"
if url.port:
new_netloc += f":{url.port}"
return url._replace(netloc=new_netloc).geturl()
return self.bucket_url
return self.PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration] # type: ignore[return-value]

if TYPE_CHECKING:
def __init__(
self,
destination_name: str = None,
credentials: Optional[GcpServiceAccountCredentials] = None,
credentials: Optional[Any] = None,
dataset_name: str = None,
default_schema_name: Optional[str] = None,
bucket_url: str = None,
Expand Down
Loading

0 comments on commit 474a420

Please sign in to comment.