Skip to content

Commit

Permalink
Fix/1465 fixes snowflake auth credentials (#1489)
Browse files Browse the repository at this point in the history
* correctly handles explicit initial values, still allowing optional args to be resolved

* allows pure authenticator, allows to specify token right in the credentials

* adds base method to generate typed query params in connection string credentials, serializes to str for to_url

* drops resolve() from __init__ and parse_native_value methods

* updates snowflake docs

* runs native value parsing for side effects
  • Loading branch information
rudolfix authored Jun 20, 2024
1 parent d4b0bd0 commit 37f64a1
Show file tree
Hide file tree
Showing 15 changed files with 352 additions and 137 deletions.
17 changes: 14 additions & 3 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,21 @@ def _maybe_parse_native_value(
not isinstance(explicit_value, C_Mapping) or isinstance(explicit_value, BaseConfiguration)
):
try:
# parse the native value anyway because there are configs with side effects
config.parse_native_representation(explicit_value)
default_value = config.__class__()
# parse native value and convert it into dict, extract the diff and use it as exact value
# NOTE: as those are the same dataclasses, the set of keys must be the same
explicit_value = {
k: v
for k, v in config.__class__.from_init_value(explicit_value).items()
if default_value[k] != v
}
except ValueError as v_err:
# provide generic exception
raise InvalidNativeValue(type(config), type(explicit_value), embedded_sections, v_err)
except NotImplementedError:
pass
# explicit value was consumed
explicit_value = None
return explicit_value


Expand Down Expand Up @@ -336,7 +343,11 @@ def _resolve_config_field(
# print(f"{embedded_config} IS RESOLVED with VALUE {value}")
# injected context will be resolved
if value is not None:
_maybe_parse_native_value(embedded_config, value, embedded_sections + (key,))
from_native_explicit = _maybe_parse_native_value(
embedded_config, value, embedded_sections + (key,)
)
if from_native_explicit is not value:
embedded_config.update(from_native_explicit)
value = embedded_config
else:
# only config with sections may look for initial values
Expand Down
57 changes: 29 additions & 28 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
# forward class declaration
_F_BaseConfiguration: Any = type(object)
_F_ContainerInjectableContext: Any = type(object)
_T = TypeVar("_T", bound="BaseConfiguration")
_C = TypeVar("_C", bound="CredentialsConfiguration")
_B = TypeVar("_B", bound="BaseConfiguration")


class NotResolved:
Expand Down Expand Up @@ -289,6 +288,33 @@ class BaseConfiguration(MutableMapping[str, Any]):
"""Typing for dataclass fields"""
__hint_resolvers__: ClassVar[Dict[str, Callable[["BaseConfiguration"], Type[Any]]]] = {}

@classmethod
def from_init_value(cls: Type[_B], init_value: Any = None) -> _B:
"""Initializes credentials from `init_value`
Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials)
a `parse_native_representation` method will be used to parse it. In case of a dict, the credentials object will be updated with key: values of the dict.
Unexpected values in the dict will be ignored.
Credentials will be marked as resolved if all required fields are set resolve() method is successful
"""
# create an instance
self = cls()
self._apply_init_value(init_value)
if not self.is_partial():
# let it fail gracefully
with contextlib.suppress(Exception):
self.resolve()
return self

def _apply_init_value(self, init_value: Any = None) -> None:
if isinstance(init_value, C_Mapping):
self.update(init_value)
elif init_value is not None:
self.parse_native_representation(init_value)
else:
return

def parse_native_representation(self, native_value: Any) -> None:
"""Initialize the configuration fields by parsing the `native_value` which should be a native representation of the configuration
or credentials, for example database connection string or JSON serialized GCP service credentials file.
Expand Down Expand Up @@ -348,7 +374,7 @@ def resolve(self) -> None:
self.call_method_in_mro("on_resolved")
self.__is_resolved__ = True

def copy(self: _T) -> _T:
def copy(self: _B) -> _B:
"""Returns a deep copy of the configuration instance"""
return copy.deepcopy(self)

Expand Down Expand Up @@ -426,38 +452,13 @@ class CredentialsConfiguration(BaseConfiguration):

__section__: ClassVar[str] = "credentials"

@classmethod
def from_init_value(cls: Type[_C], init_value: Any = None) -> _C:
"""Initializes credentials from `init_value`
Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials)
a `parse_native_representation` method will be used to parse it. In case of a dict, the credentials object will be updated with key: values of the dict.
Unexpected values in the dict will be ignored.
Credentials will be marked as resolved if all required fields are set.
"""
# create an instance
self = cls()
self._apply_init_value(init_value)
return self

def to_native_credentials(self) -> Any:
"""Returns native credentials object.
By default calls `to_native_representation` method.
"""
return self.to_native_representation()

def _apply_init_value(self, init_value: Any = None) -> None:
if isinstance(init_value, C_Mapping):
self.update(init_value)
elif init_value is not None:
self.parse_native_representation(init_value)
else:
return
if not self.is_partial():
self.resolve()

def __str__(self) -> str:
"""Get string representation of credentials to be displayed, with all secret parts removed"""
return super().__str__()
Expand Down
25 changes: 22 additions & 3 deletions dlt/common/configuration/specs/connection_string_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ConnectionStringCredentials(CredentialsConfiguration):
username: str = None
host: Optional[str] = None
port: Optional[int] = None
query: Optional[Dict[str, str]] = None
query: Optional[Dict[str, Any]] = None

__config_gen_annotations__: ClassVar[List[str]] = ["port", "password", "host"]

Expand Down Expand Up @@ -44,16 +44,35 @@ def on_resolved(self) -> None:
def to_native_representation(self) -> str:
return self.to_url().render_as_string(hide_password=False)

def get_query(self) -> Dict[str, Any]:
"""Gets query preserving parameter types. Mostly used internally to export connection params"""
return {} if self.query is None else self.query

def to_url(self) -> URL:
"""Creates SQLAlchemy compatible URL object, computes current query via `get_query` and serializes its values to str"""
# circular dependencies here
from dlt.common.configuration.utils import serialize_value

def _serialize_value(v_: Any) -> str:
if v_ is None:
return None
return serialize_value(v_)

# query must be str -> str
query = {k: _serialize_value(v) for k, v in self.get_query().items()}
return URL.create(
self.drivername,
self.username,
self.password,
self.host,
self.port,
self.database,
self.query,
query,
)

def __str__(self) -> str:
return self.to_url().render_as_string(hide_password=True)
url = self.to_url()
# do not display query. it often contains secret values
url = url._replace(query=None)
# we only have control over netloc/path
return url.render_as_string(hide_password=True)
6 changes: 3 additions & 3 deletions dlt/common/configuration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,21 @@ def deserialize_value(key: str, value: Any, hint: Type[TAny]) -> TAny:
raise ConfigValueCannotBeCoercedException(key, value, hint) from exc


def serialize_value(value: Any) -> Any:
def serialize_value(value: Any) -> str:
if value is None:
raise ValueError(value)
# return literal for tuples
if isinstance(value, tuple):
return str(value)
if isinstance(value, BaseConfiguration):
try:
return value.to_native_representation()
return str(value.to_native_representation())
except NotImplementedError:
# no native representation: use dict
value = dict(value)
# coerce type to text which will use json for mapping and sequences
value_dt = py_type_to_sc_type(type(value))
return coerce_value("text", value_dt, value)
return coerce_value("text", value_dt, value) # type: ignore[no-any-return]


def auto_cast(value: str) -> Any:
Expand Down
28 changes: 13 additions & 15 deletions dlt/destinations/impl/clickhouse/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import ClassVar, List, Any, Final, Literal, cast, Optional
from typing import ClassVar, Dict, List, Any, Final, Literal, cast, Optional

from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ConnectionStringCredentials
Expand Down Expand Up @@ -59,23 +59,21 @@ def parse_native_representation(self, native_value: Any) -> None:
self.query.get("send_receive_timeout", self.send_receive_timeout)
)
self.secure = cast(TSecureConnection, int(self.query.get("secure", self.secure)))
if not self.is_partial():
self.resolve()

def to_url(self) -> URL:
url = super().to_url()
url = url.update_query_pairs(
[
("connect_timeout", str(self.connect_timeout)),
("send_receive_timeout", str(self.send_receive_timeout)),
("secure", str(1) if self.secure else str(0)),
def get_query(self) -> Dict[str, Any]:
query = dict(super().get_query())
query.update(
{
"connect_timeout": str(self.connect_timeout),
"send_receive_timeout": str(self.send_receive_timeout),
"secure": 1 if self.secure else 0,
# Toggle experimental settings. These are necessary for certain datatypes and not optional.
("allow_experimental_lightweight_delete", "1"),
# ("allow_experimental_object_type", "1"),
("enable_http_compression", "1"),
]
"allow_experimental_lightweight_delete": 1,
# "allow_experimental_object_type": 1,
"enable_http_compression": 1,
}
)
return url
return query


@configspec
Expand Down
10 changes: 4 additions & 6 deletions dlt/destinations/impl/mssql/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def parse_native_representation(self, native_value: Any) -> None:
self.query = {k.lower(): v for k, v in self.query.items()} # Make case-insensitive.
self.driver = self.query.get("driver", self.driver)
self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout))
if not self.is_partial():
self.resolve()

def on_resolved(self) -> None:
if self.driver not in self.SUPPORTED_DRIVERS:
Expand All @@ -45,10 +43,10 @@ def on_resolved(self) -> None:
)
self.database = self.database.lower()

def to_url(self) -> URL:
url = super().to_url()
url.update_query_pairs([("connect_timeout", str(self.connect_timeout))])
return url
def get_query(self) -> Dict[str, Any]:
query = dict(super().get_query())
query["connect_timeout"] = self.connect_timeout
return query

def on_partial(self) -> None:
self.driver = self._get_driver()
Expand Down
12 changes: 5 additions & 7 deletions dlt/destinations/impl/postgres/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Final, ClassVar, Any, List, TYPE_CHECKING, Union
from typing import Dict, Final, ClassVar, Any, List, TYPE_CHECKING, Union

from dlt.common.libs.sql_alchemy import URL
from dlt.common.configuration import configspec
Expand All @@ -23,13 +23,11 @@ class PostgresCredentials(ConnectionStringCredentials):
def parse_native_representation(self, native_value: Any) -> None:
super().parse_native_representation(native_value)
self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout))
if not self.is_partial():
self.resolve()

def to_url(self) -> URL:
url = super().to_url()
url.update_query_pairs([("connect_timeout", str(self.connect_timeout))])
return url
def get_query(self) -> Dict[str, Any]:
query = dict(super().get_query())
query["connect_timeout"] = self.connect_timeout
return query


@configspec
Expand Down
Loading

0 comments on commit 37f64a1

Please sign in to comment.