From c5385fbf3be88df4d991765e3756b0cd4e8bf17a Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 8 Nov 2023 21:59:30 -0500 Subject: [PATCH] Mockup destination factory --- dlt/common/configuration/inject.py | 11 +++--- dlt/common/destination/reference.py | 37 ++++++++++++++++++++- dlt/destinations/__init__.py | 10 ++++++ dlt/destinations/impl/filesystem/factory.py | 26 +++++++++++++++ dlt/destinations/impl/postgres/factory.py | 25 ++++++++++++++ dlt/destinations/impl/snowflake/factory.py | 26 +++++++++++++++ 6 files changed, 130 insertions(+), 5 deletions(-) create mode 100644 dlt/destinations/impl/filesystem/factory.py create mode 100644 dlt/destinations/impl/postgres/factory.py create mode 100644 dlt/destinations/impl/snowflake/factory.py diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index 1880727a0f..4e214695f2 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -32,7 +32,8 @@ def with_config( sections: Tuple[str, ...] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, - include_defaults: bool = True + include_defaults: bool = True, + accept_partial: bool = False, ) -> TFun: ... @@ -45,7 +46,8 @@ def with_config( sections: Tuple[str, ...] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, - include_defaults: bool = True + include_defaults: bool = True, + accept_partial: bool = False, ) -> Callable[[TFun], TFun]: ... @@ -57,7 +59,8 @@ def with_config( sections: Tuple[str, ...] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, - include_defaults: bool = True + include_defaults: bool = True, + accept_partial: bool = False, ) -> Callable[[TFun], TFun]: """Injects values into decorated function arguments following the specification in `spec` or by deriving one from function's signature. @@ -139,7 +142,7 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: with _RESOLVE_LOCK: with inject_section(section_context): # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") - config = resolve_configuration(config or SPEC(), explicit_value=bound_args.arguments) + config = resolve_configuration(config or SPEC(), explicit_value=bound_args.arguments, accept_partial=accept_partial) resolved_params = dict(config) # overwrite or add resolved params for p in sig.parameters.values(): diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index ded654e965..cb6e02c8db 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -344,7 +344,7 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable # the default is to truncate the tables on the staging destination... return True -TDestinationReferenceArg = Union["DestinationReference", ModuleType, None, str] +TDestinationReferenceArg = Union["DestinationReference", ModuleType, None, str, "DestinationFactory"] class DestinationReference(Protocol): @@ -397,6 +397,41 @@ def from_name(destination: TDestinationReferenceArg) -> "DestinationReference": @staticmethod def to_name(destination: TDestinationReferenceArg) -> str: + if isinstance(destination, DestinationFactory): + return destination.__name__ if isinstance(destination, ModuleType): return get_module_name(destination) return destination.split(".")[-1] # type: ignore + + +class DestinationFactory(ABC): + """A destination factory that can be partially pre-configured + with credentials and other config params. + """ + credentials: Optional[CredentialsConfiguration] = None + config_params: Optional[Dict[str, Any]] = None + + @property + @abstractmethod + def destination(self) -> DestinationReference: + """Returns the destination module""" + ... + + @property + def __name__(self) -> str: + return self.destination.__name__ + + def client(self, schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> "JobClientBase": + # TODO: Raise error somewhere if both DestinationFactory and credentials argument are used together in pipeline + cfg = initial_config.copy() + for key, value in self.config_params.items(): + setattr(cfg, key, value) + if self.credentials: + cfg.credentials = self.credentials + return self.destination.client(schema, cfg) + + def capabilities(self) -> DestinationCapabilitiesContext: + return self.destination.capabilities() + + def spec(self) -> Type[DestinationClientConfiguration]: + return self.destination.spec() diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index e69de29bb2..cd8d0dc265 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -0,0 +1,10 @@ +from dlt.destinations.impl.postgres.factory import postgres +from dlt.destinations.impl.snowflake.factory import snowflake +from dlt.destinations.impl.filesystem.factory import filesystem + + +__all__ = [ + "postgres", + "snowflake", + "filesystem", +] diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py new file mode 100644 index 0000000000..2e49c8a6f1 --- /dev/null +++ b/dlt/destinations/impl/filesystem/factory.py @@ -0,0 +1,26 @@ +import typing as t + +from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration +from dlt.destinations.impl import filesystem as _filesystem +from dlt.common.configuration import with_config, known_sections +from dlt.common.destination.reference import DestinationClientConfiguration, DestinationFactory +from dlt.common.storages.configuration import FileSystemCredentials + + +class filesystem(DestinationFactory): + + destination = _filesystem + + @with_config(spec=FilesystemDestinationClientConfiguration, sections=(known_sections.DESTINATION, 'filesystem'), accept_partial=True) + def __init__( + self, + bucket_url: str = None, + credentials: FileSystemCredentials = None, + **kwargs: t.Any, + ) -> None: + cfg: FilesystemDestinationClientConfiguration = kwargs['_dlt_config'] + self.credentials = cfg.credentials + self.config_params = { + "credentials": cfg.credentials, + "bucket_url": cfg.bucket_url, + } diff --git a/dlt/destinations/impl/postgres/factory.py b/dlt/destinations/impl/postgres/factory.py new file mode 100644 index 0000000000..eb686a1216 --- /dev/null +++ b/dlt/destinations/impl/postgres/factory.py @@ -0,0 +1,25 @@ +import typing as t + +from dlt.common.configuration import with_config, known_sections +from dlt.common.destination.reference import DestinationClientConfiguration, DestinationFactory + +from dlt.destinations.impl.postgres.configuration import PostgresCredentials, PostgresClientConfiguration +from dlt.destinations.impl import postgres as _postgres + + +class postgres(DestinationFactory): + + destination = _postgres + + @with_config(spec=PostgresClientConfiguration, sections=(known_sections.DESTINATION, 'postgres'), accept_partial=True) + def __init__( + self, + credentials: PostgresCredentials = None, + create_indexes: bool = True, + **kwargs: t.Any, + ) -> None: + cfg: PostgresClientConfiguration = kwargs['_dlt_config'] + self.credentials = cfg.credentials + self.config_params = { + "created_indexes": cfg.create_indexes, + } diff --git a/dlt/destinations/impl/snowflake/factory.py b/dlt/destinations/impl/snowflake/factory.py new file mode 100644 index 0000000000..c1bc915704 --- /dev/null +++ b/dlt/destinations/impl/snowflake/factory.py @@ -0,0 +1,26 @@ +import typing as t + +from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials, SnowflakeClientConfiguration +from dlt.destinations.impl import snowflake as _snowflake +from dlt.common.configuration import with_config, known_sections +from dlt.common.destination.reference import DestinationClientConfiguration, DestinationFactory + + +class snowflake(DestinationFactory): + + destination = _snowflake + + @with_config(spec=SnowflakeClientConfiguration, sections=(known_sections.DESTINATION, 'snowflake'), accept_partial=True) + def __init__( + self, + credentials: SnowflakeCredentials = None, + stage_name: t.Optional[str] = None, + keep_staged_files: bool = True, + **kwargs: t.Any, + ) -> None: + cfg: SnowflakeClientConfiguration = kwargs['_dlt_config'] + self.credentials = cfg.credentials + self.config_params = { + "stage_name": cfg.stage_name, + "keep_staged_files": cfg.keep_staged_files, + }