From d9cd06d38b931e99b0f2df539a558530972d5b8b Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Sat, 11 Nov 2023 17:02:59 -0500 Subject: [PATCH] Update factories --- dlt/common/destination/reference.py | 7 ++++- dlt/destinations/__init__.py | 4 +++ dlt/destinations/impl/duckdb/__init__.py | 24 --------------- dlt/destinations/impl/dummy/factory.py | 33 +++++++++++++++++++++ dlt/destinations/impl/mssql/__init__.py | 24 --------------- dlt/destinations/impl/mssql/factory.py | 31 +++++++++++++++++++ dlt/destinations/impl/postgres/__init__.py | 20 ------------- dlt/destinations/impl/snowflake/__init__.py | 23 -------------- tests/load/pipeline/test_pipelines.py | 6 ++-- tests/load/pipeline/utils.py | 2 +- tests/load/test_dummy_client.py | 6 ++-- tests/load/utils.py | 2 +- 12 files changed, 82 insertions(+), 100 deletions(-) create mode 100644 dlt/destinations/impl/dummy/factory.py create mode 100644 dlt/destinations/impl/mssql/factory.py diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index dd7af6e586..018a2e11c0 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -448,7 +448,12 @@ def to_name(ref: TDestinationReferenceArg) -> str: return ref.name @staticmethod - def from_reference(ref: TDestinationReferenceArg, *args, **kwargs) -> "Destination": + def from_reference(ref: TDestinationReferenceArg, *args, **kwargs) -> Optional["Destination"]: + """Instantiate destination from str reference. + The ref can be a destination name or import path pointing to a destination class (e.g. `dlt.destinations.postgres`) + """ + if ref is None: + return None if isinstance(ref, Destination): return ref if not isinstance(ref, str): diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 0abba830ab..04bc43bc1a 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -2,6 +2,8 @@ from dlt.destinations.impl.snowflake.factory import snowflake from dlt.destinations.impl.filesystem.factory import filesystem from dlt.destinations.impl.duckdb.factory import duckdb +from dlt.destinations.impl.dummy.factory import dummy +from dlt.destinations.impl.mssql.factory import mssql __all__ = [ @@ -9,4 +11,6 @@ "snowflake", "filesystem", "duckdb", + "dummy", + "mssql", ] diff --git a/dlt/destinations/impl/duckdb/__init__.py b/dlt/destinations/impl/duckdb/__init__.py index b2a57d0788..5cbc8dea53 100644 --- a/dlt/destinations/impl/duckdb/__init__.py +++ b/dlt/destinations/impl/duckdb/__init__.py @@ -1,20 +1,7 @@ -from typing import Type - -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.accessors import config from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration - - -@with_config(spec=DuckDbClientConfiguration, sections=(known_sections.DESTINATION, "duckdb",)) -def _configure(config: DuckDbClientConfiguration = config.value) -> DuckDbClientConfiguration: - return config - def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() @@ -37,14 +24,3 @@ def capabilities() -> DestinationCapabilitiesContext: caps.supports_truncate_command = False return caps - - -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: - # import client when creating instance so capabilities and config specs can be accessed without dependencies installed - from dlt.destinations.impl.duckdb.duck import DuckDbClient - - return DuckDbClient(schema, _configure(initial_config)) # type: ignore - - -def spec() -> Type[DestinationClientConfiguration]: - return DuckDbClientConfiguration diff --git a/dlt/destinations/impl/dummy/factory.py b/dlt/destinations/impl/dummy/factory.py new file mode 100644 index 0000000000..90136385e1 --- /dev/null +++ b/dlt/destinations/impl/dummy/factory.py @@ -0,0 +1,33 @@ +import typing as t + +from dlt.common.configuration import with_config, known_sections +from dlt.common.destination.reference import DestinationClientConfiguration, Destination, DestinationCapabilitiesContext + +from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration, DummyClientCredentials +from dlt.destinations.impl.dummy import capabilities + +if t.TYPE_CHECKING: + from dlt.destinations.impl.dummy.dummy import DummyClient + + +class dummy(Destination): + + spec = DummyClientConfiguration + + @property + def capabilitites(self) -> DestinationCapabilitiesContext: + return capabilities() + + @property + def client_class(self) -> t.Type["DummyClient"]: + from dlt.destinations.impl.dummy.dummy import DummyClient + + return DummyClient + + @with_config(spec=DummyClientConfiguration, sections=(known_sections.DESTINATION, 'dummy'), accept_partial=True) + def __init__( + self, + credentials: DummyClientCredentials = None, + **kwargs: t.Any, + ) -> None: + super().__init__(kwargs['_dlt_config']) diff --git a/dlt/destinations/impl/mssql/__init__.py b/dlt/destinations/impl/mssql/__init__.py index 8f9f92d4eb..40e971cacf 100644 --- a/dlt/destinations/impl/mssql/__init__.py +++ b/dlt/destinations/impl/mssql/__init__.py @@ -1,21 +1,8 @@ -from typing import Type - -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.accessors import config from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION -from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration - - -@with_config(spec=MsSqlClientConfiguration, sections=(known_sections.DESTINATION, "mssql",)) -def _configure(config: MsSqlClientConfiguration = config.value) -> MsSqlClientConfiguration: - return config - def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() @@ -39,14 +26,3 @@ def capabilities() -> DestinationCapabilitiesContext: caps.timestamp_precision = 7 return caps - - -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: - # import client when creating instance so capabilities and config specs can be accessed without dependencies installed - from dlt.destinations.impl.mssql.mssql import MsSqlClient - - return MsSqlClient(schema, _configure(initial_config)) # type: ignore[arg-type] - - -def spec() -> Type[DestinationClientConfiguration]: - return MsSqlClientConfiguration diff --git a/dlt/destinations/impl/mssql/factory.py b/dlt/destinations/impl/mssql/factory.py new file mode 100644 index 0000000000..ee49978fb6 --- /dev/null +++ b/dlt/destinations/impl/mssql/factory.py @@ -0,0 +1,31 @@ +import typing as t + +from dlt.common.configuration import with_config, known_sections +from dlt.common.destination.reference import DestinationClientConfiguration, Destination + +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration +from dlt.destinations.impl.mssql import capabilities + +if t.TYPE_CHECKING: + from dlt.destinations.impl.mssql.mssql import MsSqlClient + + +class mssql(Destination): + + capabilities = capabilities() + spec = MsSqlClientConfiguration + + @property + def client_class(self) -> t.Type["MsSqlClient"]: + from dlt.destinations.impl.mssql.mssql import MsSqlClient + + return MsSqlClient + + @with_config(spec=MsSqlClientConfiguration, sections=(known_sections.DESTINATION, 'mssql'), accept_partial=True) + def __init__( + self, + credentials: MsSqlCredentials = None, + create_indexes: bool = True, + **kwargs: t.Any, + ) -> None: + super().__init__(kwargs['_dlt_config']) diff --git a/dlt/destinations/impl/postgres/__init__.py b/dlt/destinations/impl/postgres/__init__.py index 54bc3297b1..009174ecc9 100644 --- a/dlt/destinations/impl/postgres/__init__.py +++ b/dlt/destinations/impl/postgres/__init__.py @@ -1,20 +1,9 @@ -from typing import Type - -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.accessors import config from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION -from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration - - -@with_config(spec=PostgresClientConfiguration, sections=(known_sections.DESTINATION, "postgres",)) -def _configure(config: PostgresClientConfiguration = config.value) -> PostgresClientConfiguration: - return config def capabilities() -> DestinationCapabilitiesContext: @@ -39,12 +28,3 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: - # import client when creating instance so capabilities and config specs can be accessed without dependencies installed - from dlt.destinations.impl.postgres.postgres import PostgresClient - - return PostgresClient(schema, _configure(initial_config)) # type: ignore - - -def spec() -> Type[DestinationClientConfiguration]: - return PostgresClientConfiguration diff --git a/dlt/destinations/impl/snowflake/__init__.py b/dlt/destinations/impl/snowflake/__init__.py index 8476ceb318..12e118eeab 100644 --- a/dlt/destinations/impl/snowflake/__init__.py +++ b/dlt/destinations/impl/snowflake/__init__.py @@ -1,20 +1,8 @@ -from typing import Type from dlt.common.data_writers.escape import escape_bigquery_identifier - -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.accessors import config from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration from dlt.common.data_writers.escape import escape_snowflake_identifier from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration - - -# @with_config(spec=SnowflakeClientConfiguration, sections=(known_sections.DESTINATION, "snowflake",)) -# def _configure(config: SnowflakeClientConfiguration = config.value) -> SnowflakeClientConfiguration: -# return config def capabilities() -> DestinationCapabilitiesContext: @@ -35,14 +23,3 @@ def capabilities() -> DestinationCapabilitiesContext: caps.supports_ddl_transactions = True caps.alter_add_multi_column = True return caps - - -# def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: -# # import client when creating instance so capabilities and config specs can be accessed without dependencies installed -# from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient - -# return SnowflakeClient(schema, _configure(initial_config)) # type: ignore - - -# def spec() -> Type[DestinationClientConfiguration]: -# return SnowflakeClientConfiguration diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 99071a7ac6..2fc4aad1a8 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -8,7 +8,7 @@ from dlt.common.pipeline import SupportsPipeline from dlt.common import json, sleep -from dlt.common.destination.reference import DestinationReference +from dlt.common.destination import Destination from dlt.common.schema.schema import Schema from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.typing import TDataItem @@ -66,8 +66,8 @@ def data_fun() -> Iterator[Any]: # mock the correct destinations (never do that in normal code) with p.managed_state(): p._set_destinations( - DestinationReference.from_name(destination_config.destination), - DestinationReference.from_name(destination_config.staging) if destination_config.staging else None + Destination.from_reference(destination_config.destination), + Destination.from_reference(destination_config.staging) if destination_config.staging else None ) # does not reset the dataset name assert p.dataset_name in possible_dataset_names diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index 7ed71fe27a..113585f669 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -67,7 +67,7 @@ def _drop_dataset(schema_name: str) -> None: def _is_filesystem(p: dlt.Pipeline) -> bool: if not p.destination: return False - return p.destination.__name__.rsplit('.', 1)[-1] == 'filesystem' + return p.destination.name == 'filesystem' def assert_table(p: dlt.Pipeline, table_name: str, table_data: List[Any], schema_name: str = None, info: LoadInfo = None) -> None: diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index bb4b76c0b7..a959f6d960 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -11,12 +11,12 @@ from dlt.common.storages import FileStorage, LoadStorage from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.utils import uniq_id -from dlt.common.destination.reference import DestinationReference, LoadJob +from dlt.common.destination.reference import Destination, LoadJob from dlt.load import Load from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.impl import dummy +from dlt.destinations import dummy from dlt.destinations.impl.dummy import dummy as dummy_impl from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry @@ -445,7 +445,7 @@ def run_all(load: Load) -> None: def setup_loader(delete_completed_jobs: bool = False, client_config: DummyClientConfiguration = None) -> Load: # reset jobs for a test dummy_impl.JOBS = {} - destination: DestinationReference = dummy # type: ignore[assignment] + destination: Destination = dummy() client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") # patch destination to provide client_config # destination.client = lambda schema: dummy_impl.DummyClient(schema, client_config) diff --git a/tests/load/utils.py b/tests/load/utils.py index f8680b3885..4b1cfc2f1a 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -12,7 +12,7 @@ from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.destination.reference import DestinationClientDwhConfiguration, DestinationReference, JobClientBase, LoadJob, DestinationClientStagingConfiguration, WithStagingDataset, TDestinationReferenceArg +from dlt.common.destination.reference import DestinationClientDwhConfiguration, JobClientBase, LoadJob, DestinationClientStagingConfiguration, WithStagingDataset, TDestinationReferenceArg from dlt.common.destination import TLoaderFileFormat from dlt.common.data_writers import DataWriter from dlt.common.schema import TColumnSchema, TTableSchemaColumns, Schema