From ed55b0d36d90ec63d4af54cb3c1472b5f23ca8de Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Sun, 12 Nov 2023 16:02:57 -0500 Subject: [PATCH] Defer duckdb credentials resolving in pipeline context --- dlt/common/destination/reference.py | 20 ++++++------- dlt/destinations/impl/duckdb/configuration.py | 30 ++++++++++--------- dlt/destinations/impl/dummy/factory.py | 2 +- dlt/pipeline/pipeline.py | 3 ++ tests/common/test_destination.py | 8 ++--- tests/load/duckdb/test_duckdb_client.py | 23 +++++++------- tests/load/utils.py | 8 ++--- 7 files changed, 49 insertions(+), 45 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 018a2e11c0..f232622c52 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -441,14 +441,14 @@ def client_class(self) -> Type[JobClientBase]: @staticmethod def to_name(ref: TDestinationReferenceArg) -> str: - if not ref: + if ref is None: raise InvalidDestinationReference(ref) if isinstance(ref, str): return ref.rsplit(".", 1)[-1] return ref.name @staticmethod - def from_reference(ref: TDestinationReferenceArg, *args, **kwargs) -> Optional["Destination"]: + def from_reference(ref: TDestinationReferenceArg, credentials: Optional[CredentialsConfiguration] = None, **kwargs: Any) -> Optional["Destination"]: """Instantiate destination from str reference. The ref can be a destination name or import path pointing to a destination class (e.g. `dlt.destinations.postgres`) """ @@ -463,24 +463,22 @@ def from_reference(ref: TDestinationReferenceArg, *args, **kwargs) -> Optional[" module_path, attr_name = ref.rsplit(".", 1) dest_module = import_module(module_path) else: - from dlt import destinations as dest_module + from dlt import destinations as dest_module attr_name = ref - except ImportError as e: + except ModuleNotFoundError as e: raise UnknownDestinationModule(ref) from e try: factory: Type[Destination] = getattr(dest_module, attr_name) except AttributeError as e: - raise UnknownDestinationModule(ref) from e - return factory(*args, **kwargs) - + raise InvalidDestinationReference(ref) from e + if credentials: + kwargs["credentials"] = credentials + return factory(**kwargs) def client(self, schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> "JobClientBase": - # TODO: Raise error somewhere if both DestinationFactory and credentials argument are used together in pipeline - cfg = initial_config.copy() + cfg = initial_config cfg.update(self.config_params) - # for key, value in self.config_params.items(): - # setattr(cfg, key, value) if self.credentials: cfg.credentials = self.credentials return self.client_class(schema, cfg) diff --git a/dlt/destinations/impl/duckdb/configuration.py b/dlt/destinations/impl/duckdb/configuration.py index 82ee325ed3..556a7c9829 100644 --- a/dlt/destinations/impl/duckdb/configuration.py +++ b/dlt/destinations/impl/duckdb/configuration.py @@ -25,6 +25,7 @@ class DuckDbBaseCredentials(ConnectionStringCredentials): read_only: bool = False # open database read/write def borrow_conn(self, read_only: bool) -> Any: + # TODO: Can this be done in sql client instead? import duckdb if not hasattr(self, "_conn_lock"): @@ -95,14 +96,10 @@ class DuckDbCredentials(DuckDbBaseCredentials): __config_gen_annotations__: ClassVar[List[str]] = [] - def on_resolved(self) -> None: - # do not set any paths for external database - if self.database == ":external:": - return - # try the pipeline context + def _database_path(self) -> str: is_default_path = False if self.database == ":pipeline:": - self.database = self._path_in_pipeline(DEFAULT_DUCK_DB_NAME) + return self._path_in_pipeline(DEFAULT_DUCK_DB_NAME) else: # maybe get database maybe_database, maybe_is_default_path = self._path_from_pipeline(DEFAULT_DUCK_DB_NAME) @@ -110,13 +107,14 @@ def on_resolved(self) -> None: if not self.database or not maybe_is_default_path: # create database locally is_default_path = maybe_is_default_path - self.database = maybe_database + path = maybe_database + else: + path = self.database - # always make database an abs path - self.database = os.path.abspath(self.database) - # do not save the default path into pipeline's local state + path = os.path.abspath(path) if not is_default_path: - self._path_to_pipeline(self.database) + self._path_to_pipeline(path) + return path def _path_in_pipeline(self, rel_path: str) -> str: from dlt.common.configuration.container import Container @@ -125,9 +123,10 @@ def _path_in_pipeline(self, rel_path: str) -> str: context = Container()[PipelineContext] if context.is_active(): # pipeline is active, get the working directory - return os.path.join(context.pipeline().working_dir, rel_path) - return None - + abs_path = os.path.abspath(os.path.join(context.pipeline().working_dir, rel_path)) + context.pipeline().set_local_state_val(LOCAL_STATE_KEY, abs_path) + return abs_path + raise RuntimeError("Attempting to use special duckdb database :pipeline: outside of pipeline context.") def _path_to_pipeline(self, abspath: str) -> None: from dlt.common.configuration.container import Container @@ -173,6 +172,9 @@ def _path_from_pipeline(self, default_path: str) -> Tuple[str, bool]: return default_path, True + def _conn_str(self) -> str: + return self._database_path() + @configspec class DuckDbClientConfiguration(DestinationClientDwhWithStagingConfiguration): diff --git a/dlt/destinations/impl/dummy/factory.py b/dlt/destinations/impl/dummy/factory.py index 90136385e1..7a79ddd0a1 100644 --- a/dlt/destinations/impl/dummy/factory.py +++ b/dlt/destinations/impl/dummy/factory.py @@ -15,7 +15,7 @@ class dummy(Destination): spec = DummyClientConfiguration @property - def capabilitites(self) -> DestinationCapabilitiesContext: + def capabilities(self) -> DestinationCapabilitiesContext: return capabilities() @property diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 0a4bc78889..25910235a2 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -900,6 +900,9 @@ def _get_destination_client_initial_config(self, destination: Destination = None client_spec.get_resolvable_fields()["credentials"], credentials ) + + if credentials and not as_staging: + # Explicit pipeline credentials always supersede other credentials destination.credentials = credentials # this client support many schemas and datasets diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index 00a8480ef4..b1c85bb91f 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -11,7 +11,7 @@ def test_import_unknown_destination() -> None: # standard destination - with pytest.raises(UnknownDestinationModule): + with pytest.raises(InvalidDestinationReference): Destination.from_reference("meltdb") # custom module with pytest.raises(UnknownDestinationModule): @@ -25,9 +25,9 @@ def test_invalid_destination_reference() -> None: def test_import_all_destinations() -> None: # this must pass without the client dependencies being imported - for module in ACTIVE_DESTINATIONS: - dest = Destination.from_reference(module) - assert dest.name == "dlt.destinations." + module + for dest_name in ACTIVE_DESTINATIONS: + dest = Destination.from_reference(dest_name) + assert dest.name == dest_name dest.spec() assert isinstance(dest.capabilities, DestinationCapabilitiesContext) diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index 9d3faa3881..ace46ebd5e 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -7,6 +7,7 @@ from dlt.common.configuration.utils import get_resolved_traces from dlt.destinations.impl.duckdb.configuration import DUCK_DB_NAME, DuckDbClientConfiguration, DuckDbCredentials, DEFAULT_DUCK_DB_NAME +from dlt.destinations import duckdb from tests.load.pipeline.utils import drop_pipeline, assert_table from tests.utils import patch_home_dir, autouse_test_storage, preserve_environ, TEST_STORAGE_ROOT @@ -46,13 +47,13 @@ def test_duckdb_open_conn_default() -> None: def test_duckdb_database_path() -> None: # resolve without any path provided c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) - assert c.credentials.database.lower() == os.path.abspath("quack.duckdb").lower() + assert c.credentials._database_path().lower() == os.path.abspath("quack.duckdb").lower() # resolve without any path but with pipeline context p = dlt.pipeline(pipeline_name="quack_pipeline") c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) # still cwd db_path = os.path.abspath(os.path.join(".", "quack_pipeline.duckdb")) - assert c.credentials.database.lower() == db_path.lower() + assert c.credentials._database_path().lower() == db_path.lower() # we do not keep default duckdb path in the local state with pytest.raises(KeyError): p.get_local_state_val("duckdb_database") @@ -69,7 +70,7 @@ def test_duckdb_database_path() -> None: # test special :pipeline: path to create in pipeline folder c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=":pipeline:")) db_path = os.path.abspath(os.path.join(p.working_dir, DEFAULT_DUCK_DB_NAME)) - assert c.credentials.database.lower() == db_path.lower() + assert c.credentials._database_path().lower() == db_path.lower() # connect conn = c.credentials.borrow_conn(read_only=False) c.credentials.return_conn(conn) @@ -80,7 +81,7 @@ def test_duckdb_database_path() -> None: # provide relative path db_path = "_storage/test_quack.duckdb" c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials="duckdb:///_storage/test_quack.duckdb")) - assert c.credentials.database.lower() == os.path.abspath(db_path).lower() + assert c.credentials._database_path().lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) c.credentials.return_conn(conn) assert os.path.isfile(db_path) @@ -90,7 +91,7 @@ def test_duckdb_database_path() -> None: db_path = os.path.abspath("_storage/abs_test_quack.duckdb") c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=f"duckdb:///{db_path}")) assert os.path.isabs(c.credentials.database) - assert c.credentials.database.lower() == db_path.lower() + assert c.credentials._database_path().lower() == db_path.lower() conn = c.credentials.borrow_conn(read_only=False) c.credentials.return_conn(conn) assert os.path.isfile(db_path) @@ -99,7 +100,7 @@ def test_duckdb_database_path() -> None: # set just path as credentials db_path = "_storage/path_test_quack.duckdb" c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path)) - assert c.credentials.database.lower() == os.path.abspath(db_path).lower() + assert c.credentials._database_path().lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) c.credentials.return_conn(conn) assert os.path.isfile(db_path) @@ -108,7 +109,7 @@ def test_duckdb_database_path() -> None: db_path = os.path.abspath("_storage/abs_path_test_quack.duckdb") c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path)) assert os.path.isabs(c.credentials.database) - assert c.credentials.database.lower() == db_path.lower() + assert c.credentials._database_path().lower() == db_path.lower() conn = c.credentials.borrow_conn(read_only=False) c.credentials.return_conn(conn) assert os.path.isfile(db_path) @@ -128,7 +129,7 @@ def test_keeps_initial_db_path() -> None: print(p.pipelines_dir) with p.sql_client() as conn: # still cwd - assert conn.credentials.database.lower() == os.path.abspath(db_path).lower() + assert conn.credentials._database_path().lower() == os.path.abspath(db_path).lower() # but it is kept in the local state assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower() @@ -138,7 +139,7 @@ def test_keeps_initial_db_path() -> None: with p.sql_client() as conn: # still cwd assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower() - assert conn.credentials.database.lower() == os.path.abspath(db_path).lower() + assert conn.credentials._database_path().lower() == os.path.abspath(db_path).lower() # now create a new pipeline dlt.pipeline(pipeline_name="not_quack", destination="dummy") @@ -147,12 +148,12 @@ def test_keeps_initial_db_path() -> None: assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower() # new pipeline context took over # TODO: restore pipeline context on each call - assert conn.credentials.database.lower() != os.path.abspath(db_path).lower() + assert conn.credentials._database_path().lower() != os.path.abspath(db_path).lower() def test_duckdb_database_delete() -> None: db_path = "_storage/path_test_quack.duckdb" - p = dlt.pipeline(pipeline_name="quack_pipeline", credentials=db_path, destination="duckdb") + p = dlt.pipeline(pipeline_name="quack_pipeline", destination=duckdb(credentials=DuckDbCredentials(db_path))) p.run([1, 2, 3], table_name="table", dataset_name="dataset") # attach the pipeline p = dlt.attach(pipeline_name="quack_pipeline") diff --git a/tests/load/utils.py b/tests/load/utils.py index 4b1cfc2f1a..098c5a5509 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -13,7 +13,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.destination.reference import DestinationClientDwhConfiguration, JobClientBase, LoadJob, DestinationClientStagingConfiguration, WithStagingDataset, TDestinationReferenceArg -from dlt.common.destination import TLoaderFileFormat +from dlt.common.destination import TLoaderFileFormat, Destination from dlt.common.data_writers import DataWriter from dlt.common.schema import TColumnSchema, TTableSchemaColumns, Schema from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration @@ -229,10 +229,10 @@ def yield_client( ) -> Iterator[SqlJobClientBase]: os.environ.pop("DATASET_NAME", None) # import destination reference by name - destination = import_module(f"dlt.destinations.impl.{destination_name}") + destination = Destination.from_reference(destination_name) # create initial config dest_config: DestinationClientDwhConfiguration = None - dest_config = destination.spec()() + dest_config = destination.spec() # type: ignore[assignment] dest_config.dataset_name = dataset_name # type: ignore[misc] # TODO: Why is dataset_name final? if default_config_values is not None: @@ -261,7 +261,7 @@ def yield_client( # lookup for credentials in the section that is destination name with Container().injectable_context(ConfigSectionContext(sections=("destination", destination_name,))): - with destination.client(schema, dest_config) as client: + with destination.client(schema, dest_config) as client: # type: ignore[assignment] yield client @contextlib.contextmanager