Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/1465 fixes snowflake auth credentials #1489

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,21 @@ def _maybe_parse_native_value(
not isinstance(explicit_value, C_Mapping) or isinstance(explicit_value, BaseConfiguration)
):
try:
# parse the native value anyway because there are configs with side effects
config.parse_native_representation(explicit_value)
default_value = config.__class__()
# parse native value and convert it into dict, extract the diff and use it as exact value
# NOTE: as those are the same dataclasses, the set of keys must be the same
explicit_value = {
k: v
for k, v in config.__class__.from_init_value(explicit_value).items()
if default_value[k] != v
}
except ValueError as v_err:
# provide generic exception
raise InvalidNativeValue(type(config), type(explicit_value), embedded_sections, v_err)
except NotImplementedError:
pass
# explicit value was consumed
explicit_value = None
return explicit_value


Expand Down Expand Up @@ -336,7 +343,11 @@ def _resolve_config_field(
# print(f"{embedded_config} IS RESOLVED with VALUE {value}")
# injected context will be resolved
if value is not None:
_maybe_parse_native_value(embedded_config, value, embedded_sections + (key,))
from_native_explicit = _maybe_parse_native_value(
embedded_config, value, embedded_sections + (key,)
)
if from_native_explicit is not value:
embedded_config.update(from_native_explicit)
value = embedded_config
else:
# only config with sections may look for initial values
Expand Down
57 changes: 29 additions & 28 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
# forward class declaration
_F_BaseConfiguration: Any = type(object)
_F_ContainerInjectableContext: Any = type(object)
_T = TypeVar("_T", bound="BaseConfiguration")
_C = TypeVar("_C", bound="CredentialsConfiguration")
_B = TypeVar("_B", bound="BaseConfiguration")


class NotResolved:
Expand Down Expand Up @@ -289,6 +288,33 @@ class BaseConfiguration(MutableMapping[str, Any]):
"""Typing for dataclass fields"""
__hint_resolvers__: ClassVar[Dict[str, Callable[["BaseConfiguration"], Type[Any]]]] = {}

@classmethod
def from_init_value(cls: Type[_B], init_value: Any = None) -> _B:
"""Initializes credentials from `init_value`

Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials)
a `parse_native_representation` method will be used to parse it. In case of a dict, the credentials object will be updated with key: values of the dict.
Unexpected values in the dict will be ignored.

Credentials will be marked as resolved if all required fields are set resolve() method is successful
"""
# create an instance
self = cls()
self._apply_init_value(init_value)
if not self.is_partial():
# let it fail gracefully
with contextlib.suppress(Exception):
self.resolve()
return self

def _apply_init_value(self, init_value: Any = None) -> None:
if isinstance(init_value, C_Mapping):
self.update(init_value)
elif init_value is not None:
self.parse_native_representation(init_value)
else:
return

def parse_native_representation(self, native_value: Any) -> None:
"""Initialize the configuration fields by parsing the `native_value` which should be a native representation of the configuration
or credentials, for example database connection string or JSON serialized GCP service credentials file.
Expand Down Expand Up @@ -348,7 +374,7 @@ def resolve(self) -> None:
self.call_method_in_mro("on_resolved")
self.__is_resolved__ = True

def copy(self: _T) -> _T:
def copy(self: _B) -> _B:
"""Returns a deep copy of the configuration instance"""
return copy.deepcopy(self)

Expand Down Expand Up @@ -426,38 +452,13 @@ class CredentialsConfiguration(BaseConfiguration):

__section__: ClassVar[str] = "credentials"

@classmethod
def from_init_value(cls: Type[_C], init_value: Any = None) -> _C:
"""Initializes credentials from `init_value`

Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials)
a `parse_native_representation` method will be used to parse it. In case of a dict, the credentials object will be updated with key: values of the dict.
Unexpected values in the dict will be ignored.

Credentials will be marked as resolved if all required fields are set.
"""
# create an instance
self = cls()
self._apply_init_value(init_value)
return self

def to_native_credentials(self) -> Any:
"""Returns native credentials object.

By default calls `to_native_representation` method.
"""
return self.to_native_representation()

def _apply_init_value(self, init_value: Any = None) -> None:
if isinstance(init_value, C_Mapping):
self.update(init_value)
elif init_value is not None:
self.parse_native_representation(init_value)
else:
return
if not self.is_partial():
self.resolve()

def __str__(self) -> str:
"""Get string representation of credentials to be displayed, with all secret parts removed"""
return super().__str__()
Expand Down
25 changes: 22 additions & 3 deletions dlt/common/configuration/specs/connection_string_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ConnectionStringCredentials(CredentialsConfiguration):
username: str = None
host: Optional[str] = None
port: Optional[int] = None
query: Optional[Dict[str, str]] = None
query: Optional[Dict[str, Any]] = None

__config_gen_annotations__: ClassVar[List[str]] = ["port", "password", "host"]

Expand Down Expand Up @@ -44,16 +44,35 @@ def on_resolved(self) -> None:
def to_native_representation(self) -> str:
return self.to_url().render_as_string(hide_password=False)

def get_query(self) -> Dict[str, Any]:
"""Gets query preserving parameter types. Mostly used internally to export connection params"""
return {} if self.query is None else self.query

def to_url(self) -> URL:
"""Creates SQLAlchemy compatible URL object, computes current query via `get_query` and serializes its values to str"""
# circular dependencies here
from dlt.common.configuration.utils import serialize_value

def _serialize_value(v_: Any) -> str:
if v_ is None:
return None
return serialize_value(v_)

# query must be str -> str
query = {k: _serialize_value(v) for k, v in self.get_query().items()}
return URL.create(
self.drivername,
self.username,
self.password,
self.host,
self.port,
self.database,
self.query,
query,
)

def __str__(self) -> str:
return self.to_url().render_as_string(hide_password=True)
url = self.to_url()
# do not display query. it often contains secret values
url = url._replace(query=None)
# we only have control over netloc/path
return url.render_as_string(hide_password=True)
6 changes: 3 additions & 3 deletions dlt/common/configuration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,21 @@ def deserialize_value(key: str, value: Any, hint: Type[TAny]) -> TAny:
raise ConfigValueCannotBeCoercedException(key, value, hint) from exc


def serialize_value(value: Any) -> Any:
def serialize_value(value: Any) -> str:
if value is None:
raise ValueError(value)
# return literal for tuples
if isinstance(value, tuple):
return str(value)
if isinstance(value, BaseConfiguration):
try:
return value.to_native_representation()
return str(value.to_native_representation())
except NotImplementedError:
# no native representation: use dict
value = dict(value)
# coerce type to text which will use json for mapping and sequences
value_dt = py_type_to_sc_type(type(value))
return coerce_value("text", value_dt, value)
return coerce_value("text", value_dt, value) # type: ignore[no-any-return]


def auto_cast(value: str) -> Any:
Expand Down
28 changes: 13 additions & 15 deletions dlt/destinations/impl/clickhouse/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import ClassVar, List, Any, Final, Literal, cast, Optional
from typing import ClassVar, Dict, List, Any, Final, Literal, cast, Optional

from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ConnectionStringCredentials
Expand Down Expand Up @@ -59,23 +59,21 @@ def parse_native_representation(self, native_value: Any) -> None:
self.query.get("send_receive_timeout", self.send_receive_timeout)
)
self.secure = cast(TSecureConnection, int(self.query.get("secure", self.secure)))
if not self.is_partial():
self.resolve()

def to_url(self) -> URL:
url = super().to_url()
url = url.update_query_pairs(
[
("connect_timeout", str(self.connect_timeout)),
("send_receive_timeout", str(self.send_receive_timeout)),
("secure", str(1) if self.secure else str(0)),
def get_query(self) -> Dict[str, Any]:
query = dict(super().get_query())
query.update(
{
"connect_timeout": str(self.connect_timeout),
"send_receive_timeout": str(self.send_receive_timeout),
"secure": 1 if self.secure else 0,
# Toggle experimental settings. These are necessary for certain datatypes and not optional.
("allow_experimental_lightweight_delete", "1"),
# ("allow_experimental_object_type", "1"),
("enable_http_compression", "1"),
]
"allow_experimental_lightweight_delete": 1,
# "allow_experimental_object_type": 1,
"enable_http_compression": 1,
}
)
return url
return query


@configspec
Expand Down
10 changes: 4 additions & 6 deletions dlt/destinations/impl/mssql/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def parse_native_representation(self, native_value: Any) -> None:
self.query = {k.lower(): v for k, v in self.query.items()} # Make case-insensitive.
self.driver = self.query.get("driver", self.driver)
self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout))
if not self.is_partial():
self.resolve()

def on_resolved(self) -> None:
if self.driver not in self.SUPPORTED_DRIVERS:
Expand All @@ -45,10 +43,10 @@ def on_resolved(self) -> None:
)
self.database = self.database.lower()

def to_url(self) -> URL:
url = super().to_url()
url.update_query_pairs([("connect_timeout", str(self.connect_timeout))])
return url
def get_query(self) -> Dict[str, Any]:
query = dict(super().get_query())
query["connect_timeout"] = self.connect_timeout
return query

def on_partial(self) -> None:
self.driver = self._get_driver()
Expand Down
12 changes: 5 additions & 7 deletions dlt/destinations/impl/postgres/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Final, ClassVar, Any, List, TYPE_CHECKING, Union
from typing import Dict, Final, ClassVar, Any, List, TYPE_CHECKING, Union

from dlt.common.libs.sql_alchemy import URL
from dlt.common.configuration import configspec
Expand All @@ -23,13 +23,11 @@ class PostgresCredentials(ConnectionStringCredentials):
def parse_native_representation(self, native_value: Any) -> None:
super().parse_native_representation(native_value)
self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout))
if not self.is_partial():
self.resolve()

def to_url(self) -> URL:
url = super().to_url()
url.update_query_pairs([("connect_timeout", str(self.connect_timeout))])
return url
def get_query(self) -> Dict[str, Any]:
query = dict(super().get_query())
query["connect_timeout"] = self.connect_timeout
return query


@configspec
Expand Down
Loading
Loading