Skip to content

Commit

Permalink
regular initializers for credentials (#1142)
Browse files Browse the repository at this point in the history
* removes all dlt dependencies from logger

* uses dataclass_transform to generate init methods for configspec, warnings on inconsistent settings

* changes all configspecs to conform with new init methods, drops special init for credentials

* fixes setting native value None
  • Loading branch information
rudolfix authored Mar 25, 2024
1 parent cf3ac9f commit e0774cc
Show file tree
Hide file tree
Showing 97 changed files with 570 additions and 764 deletions.
2 changes: 1 addition & 1 deletion dlt/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dlt.common import logger
from dlt.common.arithmetics import Decimal
from dlt.common.wei import Wei
from dlt.common.pendulum import pendulum
from dlt.common.json import json
from dlt.common.runtime.signals import sleep
from dlt.common.runtime import logger

__all__ = ["Decimal", "Wei", "pendulum", "json", "sleep", "logger"]
4 changes: 2 additions & 2 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur
first_credentials: CredentialsConfiguration = None
for idx, spec in enumerate(specs_in_union):
try:
credentials = spec(initial_value)
credentials = spec.from_init_value(initial_value)
if credentials.is_resolved():
return credentials
# keep first credentials in the union to return in case all of the match but not resolve
Expand All @@ -88,7 +88,7 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur
return first_credentials
else:
assert issubclass(hint, CredentialsConfiguration)
return hint(initial_value) # type: ignore
return hint.from_init_value(initial_value) # type: ignore


def inject_section(
Expand Down
6 changes: 3 additions & 3 deletions dlt/common/configuration/specs/api_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

@configspec
class OAuth2Credentials(CredentialsConfiguration):
client_id: str
client_secret: TSecretValue
refresh_token: Optional[TSecretValue]
client_id: str = None
client_secret: TSecretValue = None
refresh_token: Optional[TSecretValue] = None
scopes: Optional[List[str]] = None

token: Optional[TSecretValue] = None
Expand Down
6 changes: 6 additions & 0 deletions dlt/common/configuration/specs/aws_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,9 @@ def parse_native_representation(self, native_value: Any) -> None:
self.__is_resolved__ = True
except Exception:
raise InvalidBoto3Session(self.__class__, native_value)

@classmethod
def from_session(cls, botocore_session: Any) -> "AwsCredentials":
self = cls()
self.parse_native_representation(botocore_session)
return self
80 changes: 52 additions & 28 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import contextlib
import dataclasses
import warnings

from collections.abc import Mapping as C_Mapping
from typing import (
Expand All @@ -19,7 +20,7 @@
ClassVar,
TypeVar,
)
from typing_extensions import get_args, get_origin
from typing_extensions import get_args, get_origin, dataclass_transform
from functools import wraps

if TYPE_CHECKING:
Expand All @@ -44,6 +45,7 @@
_F_BaseConfiguration: Any = type(object)
_F_ContainerInjectableContext: Any = type(object)
_T = TypeVar("_T", bound="BaseConfiguration")
_C = TypeVar("_C", bound="CredentialsConfiguration")


def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool:
Expand Down Expand Up @@ -106,18 +108,26 @@ def is_secret_hint(hint: Type[Any]) -> bool:


@overload
def configspec(cls: Type[TAnyClass]) -> Type[TAnyClass]: ...
def configspec(cls: Type[TAnyClass], init: bool = True) -> Type[TAnyClass]: ...


@overload
def configspec(cls: None = ...) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ...
def configspec(
cls: None = ..., init: bool = True
) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ...


@dataclass_transform(eq_default=False, field_specifiers=(dataclasses.Field, dataclasses.field))
def configspec(
cls: Optional[Type[Any]] = None,
cls: Optional[Type[Any]] = None, init: bool = True
) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]:
"""Converts (via derivation) any decorated class to a Python dataclass that may be used as a spec to resolve configurations
__init__ method is synthesized by default. `init` flag is ignored if the decorated class implements custom __init__ as well as
when any of base classes has no synthesized __init__
All fields must have default values. This decorator will add `None` default values that miss one.
In comparison the Python dataclass, a spec implements full dictionary interface for its attributes, allows instance creation from ie. strings
or other types (parsing, deserialization) and control over configuration resolution process. See `BaseConfiguration` and CredentialsConfiguration` for
more information.
Expand All @@ -142,6 +152,10 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]:
# get all annotations without corresponding attributes and set them to None
for ann in cls.__annotations__:
if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_")):
warnings.warn(
f"Missing default value for field {ann} on {cls.__name__}. None assumed. All"
" fields in configspec must have default."
)
setattr(cls, ann, None)
# get all attributes without corresponding annotations
for att_name, att_value in list(cls.__dict__.items()):
Expand Down Expand Up @@ -177,17 +191,18 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def]

# We don't want to overwrite user's __init__ method
# Create dataclass init only when not defined in the class
# (never put init on BaseConfiguration itself)
try:
is_base = cls is BaseConfiguration
except NameError:
is_base = True
init = False
base_params = getattr(cls, "__dataclass_params__", None)
if not is_base and (base_params and base_params.init or cls.__init__ is object.__init__):
init = True
# NOTE: any class without synthesized __init__ breaks the creation chain
has_default_init = super(cls, cls).__init__ == cls.__init__ # type: ignore[misc]
base_params = getattr(cls, "__dataclass_params__", None) # cls.__init__ is object.__init__
synth_init = init and ((not base_params or base_params.init) and has_default_init)
if synth_init != init and has_default_init:
warnings.warn(
f"__init__ method will not be generated on {cls.__name__} because bas class didn't"
" synthesize __init__. Please correct `init` flag in confispec decorator. You are"
" probably receiving incorrect __init__ signature for type checking"
)
# do not generate repr as it may contain secret values
return dataclasses.dataclass(cls, init=init, eq=False, repr=False) # type: ignore
return dataclasses.dataclass(cls, init=synth_init, eq=False, repr=False) # type: ignore

# called with parenthesis
if cls is None:
Expand All @@ -198,12 +213,14 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def]

@configspec
class BaseConfiguration(MutableMapping[str, Any]):
__is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False)
__is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False, compare=False)
"""True when all config fields were resolved and have a specified value type"""
__section__: str = dataclasses.field(default=None, init=False, repr=False)
"""Obligatory section used by config providers when searching for keys, always present in the search path"""
__exception__: Exception = dataclasses.field(default=None, init=False, repr=False)
__exception__: Exception = dataclasses.field(
default=None, init=False, repr=False, compare=False
)
"""Holds the exception that prevented the full resolution"""
__section__: ClassVar[str] = None
"""Obligatory section used by config providers when searching for keys, always present in the search path"""
__config_gen_annotations__: ClassVar[List[str]] = []
"""Additional annotations for config generator, currently holds a list of fields of interest that have defaults"""
__dataclass_fields__: ClassVar[Dict[str, TDtcField]]
Expand Down Expand Up @@ -342,9 +359,10 @@ def call_method_in_mro(config, method_name: str) -> None:
class CredentialsConfiguration(BaseConfiguration):
"""Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets."""

__section__: str = "credentials"
__section__: ClassVar[str] = "credentials"

def __init__(self, init_value: Any = None) -> None:
@classmethod
def from_init_value(cls: Type[_C], init_value: Any = None) -> _C:
"""Initializes credentials from `init_value`
Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials)
Expand All @@ -353,14 +371,10 @@ def __init__(self, init_value: Any = None) -> None:
Credentials will be marked as resolved if all required fields are set.
"""
if init_value is None:
return
elif isinstance(init_value, C_Mapping):
self.update(init_value)
else:
self.parse_native_representation(init_value)
if not self.is_partial():
self.resolve()
# create an instance
self = cls()
self._apply_init_value(init_value)
return self

def to_native_credentials(self) -> Any:
"""Returns native credentials object.
Expand All @@ -369,6 +383,16 @@ def to_native_credentials(self) -> Any:
"""
return self.to_native_representation()

def _apply_init_value(self, init_value: Any = None) -> None:
if isinstance(init_value, C_Mapping):
self.update(init_value)
elif init_value is not None:
self.parse_native_representation(init_value)
else:
return
if not self.is_partial():
self.resolve()

def __str__(self) -> str:
"""Get string representation of credentials to be displayed, with all secret parts removed"""
return super().__str__()
Expand Down
11 changes: 8 additions & 3 deletions dlt/common/configuration/specs/config_providers_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import dataclasses
import io
from typing import ClassVar, List

Expand Down Expand Up @@ -28,7 +29,7 @@ class ConfigProvidersConfiguration(BaseConfiguration):
only_toml_fragments: bool = True

# always look in providers
__section__ = known_sections.PROVIDERS
__section__: ClassVar[str] = known_sections.PROVIDERS


@configspec
Expand All @@ -37,8 +38,12 @@ class ConfigProvidersContext(ContainerInjectableContext):

global_affinity: ClassVar[bool] = True

providers: List[ConfigProvider]
context_provider: ConfigProvider
providers: List[ConfigProvider] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)
context_provider: ConfigProvider = dataclasses.field(
default=None, init=False, repr=False, compare=False
)

def __init__(self) -> None:
super().__init__()
Expand Down
12 changes: 1 addition & 11 deletions dlt/common/configuration/specs/config_section_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class ConfigSectionContext(ContainerInjectableContext):
TMergeFunc = Callable[["ConfigSectionContext", "ConfigSectionContext"], None]

pipeline_name: Optional[str]
pipeline_name: Optional[str] = None
sections: Tuple[str, ...] = ()
merge_style: TMergeFunc = None
source_state_key: str = None
Expand Down Expand Up @@ -70,13 +70,3 @@ def __str__(self) -> str:
super().__str__()
+ f": {self.pipeline_name} {self.sections}@{self.merge_style} state['{self.source_state_key}']"
)

if TYPE_CHECKING:
# provide __init__ signature when type checking
def __init__(
self,
pipeline_name: str = None,
sections: Tuple[str, ...] = (),
merge_style: TMergeFunc = None,
source_state_key: str = None,
) -> None: ...
12 changes: 9 additions & 3 deletions dlt/common/configuration/specs/connection_string_credentials.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Any, ClassVar, Dict, List, Optional
import dataclasses
from typing import Any, ClassVar, Dict, List, Optional, Union

from dlt.common.libs.sql_alchemy import URL, make_url
from dlt.common.configuration.specs.exceptions import InvalidConnectionString

from dlt.common.typing import TSecretValue
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec


@configspec
class ConnectionStringCredentials(CredentialsConfiguration):
drivername: str = None
drivername: str = dataclasses.field(default=None, init=False, repr=False, compare=False)
database: str = None
password: Optional[TSecretValue] = None
username: str = None
Expand All @@ -18,6 +19,11 @@ class ConnectionStringCredentials(CredentialsConfiguration):

__config_gen_annotations__: ClassVar[List[str]] = ["port", "password", "host"]

def __init__(self, connection_string: Union[str, Dict[str, Any]] = None) -> None:
"""Initializes the credentials from SQLAlchemy like connection string or from dict holding connection string elements"""
super().__init__()
self._apply_init_value(connection_string)

def parse_native_representation(self, native_value: Any) -> None:
if not isinstance(native_value, str):
raise InvalidConnectionString(self.__class__, native_value, self.drivername)
Expand Down
23 changes: 16 additions & 7 deletions dlt/common/configuration/specs/gcp_credentials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import sys
from typing import Any, Final, List, Tuple, Union, Dict
from typing import Any, ClassVar, Final, List, Tuple, Union, Dict

from dlt.common import json, pendulum
from dlt.common.configuration.specs.api_credentials import OAuth2Credentials
Expand All @@ -22,8 +23,12 @@

@configspec
class GcpCredentials(CredentialsConfiguration):
token_uri: Final[str] = "https://oauth2.googleapis.com/token"
auth_uri: Final[str] = "https://accounts.google.com/o/oauth2/auth"
token_uri: Final[str] = dataclasses.field(
default="https://oauth2.googleapis.com/token", init=False, repr=False, compare=False
)
auth_uri: Final[str] = dataclasses.field(
default="https://accounts.google.com/o/oauth2/auth", init=False, repr=False, compare=False
)

project_id: str = None

Expand Down Expand Up @@ -69,7 +74,9 @@ def to_gcs_credentials(self) -> Dict[str, Any]:
class GcpServiceAccountCredentialsWithoutDefaults(GcpCredentials):
private_key: TSecretValue = None
client_email: str = None
type: Final[str] = "service_account" # noqa: A003
type: Final[str] = dataclasses.field( # noqa: A003
default="service_account", init=False, repr=False, compare=False
)

def parse_native_representation(self, native_value: Any) -> None:
"""Accepts ServiceAccountCredentials as native value. In other case reverts to serialized services.json"""
Expand Down Expand Up @@ -121,8 +128,10 @@ def __str__(self) -> str:
@configspec
class GcpOAuthCredentialsWithoutDefaults(GcpCredentials, OAuth2Credentials):
# only desktop app supported
refresh_token: TSecretValue
client_type: Final[str] = "installed"
refresh_token: TSecretValue = None
client_type: Final[str] = dataclasses.field(
default="installed", init=False, repr=False, compare=False
)

def parse_native_representation(self, native_value: Any) -> None:
"""Accepts Google OAuth2 credentials as native value. In other case reverts to serialized oauth client secret json"""
Expand Down Expand Up @@ -237,7 +246,7 @@ def __str__(self) -> str:

@configspec
class GcpDefaultCredentials(CredentialsWithDefault, GcpCredentials):
_LAST_FAILED_DEFAULT: float = 0.0
_LAST_FAILED_DEFAULT: ClassVar[float] = 0.0

def parse_native_representation(self, native_value: Any) -> None:
"""Accepts google credentials as native value"""
Expand Down
3 changes: 3 additions & 0 deletions dlt/common/configuration/specs/known_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
EXTRACT = "extract"
"""extract stage of the pipeline"""

SCHEMA = "schema"
"""schema configuration, ie. normalizers"""

PROVIDERS = "providers"
"""secrets and config providers"""

Expand Down
4 changes: 2 additions & 2 deletions dlt/common/configuration/specs/run_configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import binascii
from os.path import isfile, join
from pathlib import Path
from typing import Any, Optional, Tuple, IO
from typing import Any, ClassVar, Optional, IO
from dlt.common.typing import TSecretStrValue

from dlt.common.utils import encoding_for_mode, main_module_file_path, reveal_pseudo_secret
Expand Down Expand Up @@ -30,7 +30,7 @@ class RunConfiguration(BaseConfiguration):
"""Platform connection"""
dlthub_dsn: Optional[TSecretStrValue] = None

__section__ = "runtime"
__section__: ClassVar[str] = "runtime"

def on_resolved(self) -> None:
# generate pipeline name from the entry point script name
Expand Down
4 changes: 2 additions & 2 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gzip
import time
from typing import List, IO, Any, Optional, Type, TypeVar, Generic
from typing import ClassVar, List, IO, Any, Optional, Type, TypeVar, Generic

from dlt.common.typing import TDataItem, TDataItems
from dlt.common.data_writers import TLoaderFileFormat
Expand Down Expand Up @@ -33,7 +33,7 @@ class BufferedDataWriterConfiguration(BaseConfiguration):
disable_compression: bool = False
_caps: Optional[DestinationCapabilitiesContext] = None

__section__ = known_sections.DATA_WRITER
__section__: ClassVar[str] = known_sections.DATA_WRITER

@with_config(spec=BufferedDataWriterConfiguration)
def __init__(
Expand Down
Loading

0 comments on commit e0774cc

Please sign in to comment.