Skip to content

Commit

Permalink
Defer duckdb credentials resolving in pipeline context
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Nov 13, 2023
1 parent d9cd06d commit ed55b0d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 45 deletions.
20 changes: 9 additions & 11 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,14 +441,14 @@ def client_class(self) -> Type[JobClientBase]:

@staticmethod
def to_name(ref: TDestinationReferenceArg) -> str:
if not ref:
if ref is None:
raise InvalidDestinationReference(ref)
if isinstance(ref, str):
return ref.rsplit(".", 1)[-1]
return ref.name

@staticmethod
def from_reference(ref: TDestinationReferenceArg, *args, **kwargs) -> Optional["Destination"]:
def from_reference(ref: TDestinationReferenceArg, credentials: Optional[CredentialsConfiguration] = None, **kwargs: Any) -> Optional["Destination"]:
"""Instantiate destination from str reference.
The ref can be a destination name or import path pointing to a destination class (e.g. `dlt.destinations.postgres`)
"""
Expand All @@ -463,24 +463,22 @@ def from_reference(ref: TDestinationReferenceArg, *args, **kwargs) -> Optional["
module_path, attr_name = ref.rsplit(".", 1)
dest_module = import_module(module_path)
else:
from dlt import destinations as dest_module
from dlt import destinations as dest_module
attr_name = ref
except ImportError as e:
except ModuleNotFoundError as e:
raise UnknownDestinationModule(ref) from e

try:
factory: Type[Destination] = getattr(dest_module, attr_name)
except AttributeError as e:
raise UnknownDestinationModule(ref) from e
return factory(*args, **kwargs)

raise InvalidDestinationReference(ref) from e
if credentials:
kwargs["credentials"] = credentials
return factory(**kwargs)

def client(self, schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> "JobClientBase":
# TODO: Raise error somewhere if both DestinationFactory and credentials argument are used together in pipeline
cfg = initial_config.copy()
cfg = initial_config
cfg.update(self.config_params)
# for key, value in self.config_params.items():
# setattr(cfg, key, value)
if self.credentials:
cfg.credentials = self.credentials
return self.client_class(schema, cfg)
30 changes: 16 additions & 14 deletions dlt/destinations/impl/duckdb/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class DuckDbBaseCredentials(ConnectionStringCredentials):
read_only: bool = False # open database read/write

def borrow_conn(self, read_only: bool) -> Any:
# TODO: Can this be done in sql client instead?
import duckdb

if not hasattr(self, "_conn_lock"):
Expand Down Expand Up @@ -95,28 +96,25 @@ class DuckDbCredentials(DuckDbBaseCredentials):

__config_gen_annotations__: ClassVar[List[str]] = []

def on_resolved(self) -> None:
# do not set any paths for external database
if self.database == ":external:":
return
# try the pipeline context
def _database_path(self) -> str:
is_default_path = False
if self.database == ":pipeline:":
self.database = self._path_in_pipeline(DEFAULT_DUCK_DB_NAME)
return self._path_in_pipeline(DEFAULT_DUCK_DB_NAME)
else:
# maybe get database
maybe_database, maybe_is_default_path = self._path_from_pipeline(DEFAULT_DUCK_DB_NAME)
# if pipeline context was not present or database was not set
if not self.database or not maybe_is_default_path:
# create database locally
is_default_path = maybe_is_default_path
self.database = maybe_database
path = maybe_database
else:
path = self.database

# always make database an abs path
self.database = os.path.abspath(self.database)
# do not save the default path into pipeline's local state
path = os.path.abspath(path)
if not is_default_path:
self._path_to_pipeline(self.database)
self._path_to_pipeline(path)
return path

def _path_in_pipeline(self, rel_path: str) -> str:
from dlt.common.configuration.container import Container
Expand All @@ -125,9 +123,10 @@ def _path_in_pipeline(self, rel_path: str) -> str:
context = Container()[PipelineContext]
if context.is_active():
# pipeline is active, get the working directory
return os.path.join(context.pipeline().working_dir, rel_path)
return None

abs_path = os.path.abspath(os.path.join(context.pipeline().working_dir, rel_path))
context.pipeline().set_local_state_val(LOCAL_STATE_KEY, abs_path)
return abs_path
raise RuntimeError("Attempting to use special duckdb database :pipeline: outside of pipeline context.")

def _path_to_pipeline(self, abspath: str) -> None:
from dlt.common.configuration.container import Container
Expand Down Expand Up @@ -173,6 +172,9 @@ def _path_from_pipeline(self, default_path: str) -> Tuple[str, bool]:

return default_path, True

def _conn_str(self) -> str:
return self._database_path()


@configspec
class DuckDbClientConfiguration(DestinationClientDwhWithStagingConfiguration):
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/dummy/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class dummy(Destination):
spec = DummyClientConfiguration

@property
def capabilitites(self) -> DestinationCapabilitiesContext:
def capabilities(self) -> DestinationCapabilitiesContext:
return capabilities()

@property
Expand Down
3 changes: 3 additions & 0 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,9 @@ def _get_destination_client_initial_config(self, destination: Destination = None
client_spec.get_resolvable_fields()["credentials"],
credentials
)

if credentials and not as_staging:
# Explicit pipeline credentials always supersede other credentials
destination.credentials = credentials

# this client support many schemas and datasets
Expand Down
8 changes: 4 additions & 4 deletions tests/common/test_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def test_import_unknown_destination() -> None:
# standard destination
with pytest.raises(UnknownDestinationModule):
with pytest.raises(InvalidDestinationReference):
Destination.from_reference("meltdb")
# custom module
with pytest.raises(UnknownDestinationModule):
Expand All @@ -25,9 +25,9 @@ def test_invalid_destination_reference() -> None:

def test_import_all_destinations() -> None:
# this must pass without the client dependencies being imported
for module in ACTIVE_DESTINATIONS:
dest = Destination.from_reference(module)
assert dest.name == "dlt.destinations." + module
for dest_name in ACTIVE_DESTINATIONS:
dest = Destination.from_reference(dest_name)
assert dest.name == dest_name
dest.spec()
assert isinstance(dest.capabilities, DestinationCapabilitiesContext)

Expand Down
23 changes: 12 additions & 11 deletions tests/load/duckdb/test_duckdb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dlt.common.configuration.utils import get_resolved_traces

from dlt.destinations.impl.duckdb.configuration import DUCK_DB_NAME, DuckDbClientConfiguration, DuckDbCredentials, DEFAULT_DUCK_DB_NAME
from dlt.destinations import duckdb

from tests.load.pipeline.utils import drop_pipeline, assert_table
from tests.utils import patch_home_dir, autouse_test_storage, preserve_environ, TEST_STORAGE_ROOT
Expand Down Expand Up @@ -46,13 +47,13 @@ def test_duckdb_open_conn_default() -> None:
def test_duckdb_database_path() -> None:
# resolve without any path provided
c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset"))
assert c.credentials.database.lower() == os.path.abspath("quack.duckdb").lower()
assert c.credentials._database_path().lower() == os.path.abspath("quack.duckdb").lower()
# resolve without any path but with pipeline context
p = dlt.pipeline(pipeline_name="quack_pipeline")
c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset"))
# still cwd
db_path = os.path.abspath(os.path.join(".", "quack_pipeline.duckdb"))
assert c.credentials.database.lower() == db_path.lower()
assert c.credentials._database_path().lower() == db_path.lower()
# we do not keep default duckdb path in the local state
with pytest.raises(KeyError):
p.get_local_state_val("duckdb_database")
Expand All @@ -69,7 +70,7 @@ def test_duckdb_database_path() -> None:
# test special :pipeline: path to create in pipeline folder
c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=":pipeline:"))
db_path = os.path.abspath(os.path.join(p.working_dir, DEFAULT_DUCK_DB_NAME))
assert c.credentials.database.lower() == db_path.lower()
assert c.credentials._database_path().lower() == db_path.lower()
# connect
conn = c.credentials.borrow_conn(read_only=False)
c.credentials.return_conn(conn)
Expand All @@ -80,7 +81,7 @@ def test_duckdb_database_path() -> None:
# provide relative path
db_path = "_storage/test_quack.duckdb"
c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials="duckdb:///_storage/test_quack.duckdb"))
assert c.credentials.database.lower() == os.path.abspath(db_path).lower()
assert c.credentials._database_path().lower() == os.path.abspath(db_path).lower()
conn = c.credentials.borrow_conn(read_only=False)
c.credentials.return_conn(conn)
assert os.path.isfile(db_path)
Expand All @@ -90,7 +91,7 @@ def test_duckdb_database_path() -> None:
db_path = os.path.abspath("_storage/abs_test_quack.duckdb")
c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=f"duckdb:///{db_path}"))
assert os.path.isabs(c.credentials.database)
assert c.credentials.database.lower() == db_path.lower()
assert c.credentials._database_path().lower() == db_path.lower()
conn = c.credentials.borrow_conn(read_only=False)
c.credentials.return_conn(conn)
assert os.path.isfile(db_path)
Expand All @@ -99,7 +100,7 @@ def test_duckdb_database_path() -> None:
# set just path as credentials
db_path = "_storage/path_test_quack.duckdb"
c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path))
assert c.credentials.database.lower() == os.path.abspath(db_path).lower()
assert c.credentials._database_path().lower() == os.path.abspath(db_path).lower()
conn = c.credentials.borrow_conn(read_only=False)
c.credentials.return_conn(conn)
assert os.path.isfile(db_path)
Expand All @@ -108,7 +109,7 @@ def test_duckdb_database_path() -> None:
db_path = os.path.abspath("_storage/abs_path_test_quack.duckdb")
c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path))
assert os.path.isabs(c.credentials.database)
assert c.credentials.database.lower() == db_path.lower()
assert c.credentials._database_path().lower() == db_path.lower()
conn = c.credentials.borrow_conn(read_only=False)
c.credentials.return_conn(conn)
assert os.path.isfile(db_path)
Expand All @@ -128,7 +129,7 @@ def test_keeps_initial_db_path() -> None:
print(p.pipelines_dir)
with p.sql_client() as conn:
# still cwd
assert conn.credentials.database.lower() == os.path.abspath(db_path).lower()
assert conn.credentials._database_path().lower() == os.path.abspath(db_path).lower()
# but it is kept in the local state
assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower()

Expand All @@ -138,7 +139,7 @@ def test_keeps_initial_db_path() -> None:
with p.sql_client() as conn:
# still cwd
assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower()
assert conn.credentials.database.lower() == os.path.abspath(db_path).lower()
assert conn.credentials._database_path().lower() == os.path.abspath(db_path).lower()

# now create a new pipeline
dlt.pipeline(pipeline_name="not_quack", destination="dummy")
Expand All @@ -147,12 +148,12 @@ def test_keeps_initial_db_path() -> None:
assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower()
# new pipeline context took over
# TODO: restore pipeline context on each call
assert conn.credentials.database.lower() != os.path.abspath(db_path).lower()
assert conn.credentials._database_path().lower() != os.path.abspath(db_path).lower()


def test_duckdb_database_delete() -> None:
db_path = "_storage/path_test_quack.duckdb"
p = dlt.pipeline(pipeline_name="quack_pipeline", credentials=db_path, destination="duckdb")
p = dlt.pipeline(pipeline_name="quack_pipeline", destination=duckdb(credentials=DuckDbCredentials(db_path)))
p.run([1, 2, 3], table_name="table", dataset_name="dataset")
# attach the pipeline
p = dlt.attach(pipeline_name="quack_pipeline")
Expand Down
8 changes: 4 additions & 4 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dlt.common.configuration.container import Container
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.destination.reference import DestinationClientDwhConfiguration, JobClientBase, LoadJob, DestinationClientStagingConfiguration, WithStagingDataset, TDestinationReferenceArg
from dlt.common.destination import TLoaderFileFormat
from dlt.common.destination import TLoaderFileFormat, Destination
from dlt.common.data_writers import DataWriter
from dlt.common.schema import TColumnSchema, TTableSchemaColumns, Schema
from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration
Expand Down Expand Up @@ -229,10 +229,10 @@ def yield_client(
) -> Iterator[SqlJobClientBase]:
os.environ.pop("DATASET_NAME", None)
# import destination reference by name
destination = import_module(f"dlt.destinations.impl.{destination_name}")
destination = Destination.from_reference(destination_name)
# create initial config
dest_config: DestinationClientDwhConfiguration = None
dest_config = destination.spec()()
dest_config = destination.spec() # type: ignore[assignment]
dest_config.dataset_name = dataset_name # type: ignore[misc] # TODO: Why is dataset_name final?

if default_config_values is not None:
Expand Down Expand Up @@ -261,7 +261,7 @@ def yield_client(

# lookup for credentials in the section that is destination name
with Container().injectable_context(ConfigSectionContext(sections=("destination", destination_name,))):
with destination.client(schema, dest_config) as client:
with destination.client(schema, dest_config) as client: # type: ignore[assignment]
yield client

@contextlib.contextmanager
Expand Down

0 comments on commit ed55b0d

Please sign in to comment.