Skip to content

Commit

Permalink
Mockup destination factory
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Nov 9, 2023
1 parent 03e3995 commit c5385fb
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 5 deletions.
11 changes: 7 additions & 4 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def with_config(
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
include_defaults: bool = True
include_defaults: bool = True,
accept_partial: bool = False,
) -> TFun:
...

Expand All @@ -45,7 +46,8 @@ def with_config(
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
include_defaults: bool = True
include_defaults: bool = True,
accept_partial: bool = False,
) -> Callable[[TFun], TFun]:
...

Expand All @@ -57,7 +59,8 @@ def with_config(
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
include_defaults: bool = True
include_defaults: bool = True,
accept_partial: bool = False,
) -> Callable[[TFun], TFun]:
"""Injects values into decorated function arguments following the specification in `spec` or by deriving one from function's signature.
Expand Down Expand Up @@ -139,7 +142,7 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
with _RESOLVE_LOCK:
with inject_section(section_context):
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}")
config = resolve_configuration(config or SPEC(), explicit_value=bound_args.arguments)
config = resolve_configuration(config or SPEC(), explicit_value=bound_args.arguments, accept_partial=accept_partial)
resolved_params = dict(config)
# overwrite or add resolved params
for p in sig.parameters.values():
Expand Down
37 changes: 36 additions & 1 deletion dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable
# the default is to truncate the tables on the staging destination...
return True

TDestinationReferenceArg = Union["DestinationReference", ModuleType, None, str]
TDestinationReferenceArg = Union["DestinationReference", ModuleType, None, str, "DestinationFactory"]


class DestinationReference(Protocol):
Expand Down Expand Up @@ -397,6 +397,41 @@ def from_name(destination: TDestinationReferenceArg) -> "DestinationReference":

@staticmethod
def to_name(destination: TDestinationReferenceArg) -> str:
if isinstance(destination, DestinationFactory):
return destination.__name__
if isinstance(destination, ModuleType):
return get_module_name(destination)
return destination.split(".")[-1] # type: ignore


class DestinationFactory(ABC):
"""A destination factory that can be partially pre-configured
with credentials and other config params.
"""
credentials: Optional[CredentialsConfiguration] = None
config_params: Optional[Dict[str, Any]] = None

@property
@abstractmethod
def destination(self) -> DestinationReference:
"""Returns the destination module"""
...

@property
def __name__(self) -> str:
return self.destination.__name__

def client(self, schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> "JobClientBase":
# TODO: Raise error somewhere if both DestinationFactory and credentials argument are used together in pipeline
cfg = initial_config.copy()
for key, value in self.config_params.items():
setattr(cfg, key, value)
if self.credentials:
cfg.credentials = self.credentials
return self.destination.client(schema, cfg)

def capabilities(self) -> DestinationCapabilitiesContext:
return self.destination.capabilities()

def spec(self) -> Type[DestinationClientConfiguration]:
return self.destination.spec()
10 changes: 10 additions & 0 deletions dlt/destinations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dlt.destinations.impl.postgres.factory import postgres
from dlt.destinations.impl.snowflake.factory import snowflake
from dlt.destinations.impl.filesystem.factory import filesystem


__all__ = [
"postgres",
"snowflake",
"filesystem",
]
26 changes: 26 additions & 0 deletions dlt/destinations/impl/filesystem/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing as t

from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration
from dlt.destinations.impl import filesystem as _filesystem
from dlt.common.configuration import with_config, known_sections
from dlt.common.destination.reference import DestinationClientConfiguration, DestinationFactory
from dlt.common.storages.configuration import FileSystemCredentials


class filesystem(DestinationFactory):

destination = _filesystem

@with_config(spec=FilesystemDestinationClientConfiguration, sections=(known_sections.DESTINATION, 'filesystem'), accept_partial=True)
def __init__(
self,
bucket_url: str = None,
credentials: FileSystemCredentials = None,
**kwargs: t.Any,
) -> None:
cfg: FilesystemDestinationClientConfiguration = kwargs['_dlt_config']
self.credentials = cfg.credentials
self.config_params = {
"credentials": cfg.credentials,
"bucket_url": cfg.bucket_url,
}
25 changes: 25 additions & 0 deletions dlt/destinations/impl/postgres/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import typing as t

from dlt.common.configuration import with_config, known_sections
from dlt.common.destination.reference import DestinationClientConfiguration, DestinationFactory

from dlt.destinations.impl.postgres.configuration import PostgresCredentials, PostgresClientConfiguration
from dlt.destinations.impl import postgres as _postgres


class postgres(DestinationFactory):

destination = _postgres

@with_config(spec=PostgresClientConfiguration, sections=(known_sections.DESTINATION, 'postgres'), accept_partial=True)
def __init__(
self,
credentials: PostgresCredentials = None,
create_indexes: bool = True,
**kwargs: t.Any,
) -> None:
cfg: PostgresClientConfiguration = kwargs['_dlt_config']
self.credentials = cfg.credentials
self.config_params = {
"created_indexes": cfg.create_indexes,
}
26 changes: 26 additions & 0 deletions dlt/destinations/impl/snowflake/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing as t

from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials, SnowflakeClientConfiguration
from dlt.destinations.impl import snowflake as _snowflake
from dlt.common.configuration import with_config, known_sections
from dlt.common.destination.reference import DestinationClientConfiguration, DestinationFactory


class snowflake(DestinationFactory):

destination = _snowflake

@with_config(spec=SnowflakeClientConfiguration, sections=(known_sections.DESTINATION, 'snowflake'), accept_partial=True)
def __init__(
self,
credentials: SnowflakeCredentials = None,
stage_name: t.Optional[str] = None,
keep_staged_files: bool = True,
**kwargs: t.Any,
) -> None:
cfg: SnowflakeClientConfiguration = kwargs['_dlt_config']
self.credentials = cfg.credentials
self.config_params = {
"stage_name": cfg.stage_name,
"keep_staged_files": cfg.keep_staged_files,
}

0 comments on commit c5385fb

Please sign in to comment.