From 4a02300947572407dc0f7391d0630432aa2d1d88 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:31:46 +0200 Subject: [PATCH 01/15] runs motherduck init on ci --- .github/workflows/test_destinations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index fed5c99fe1..037f9da3e5 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -28,7 +28,7 @@ env: RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} # Test redshift and filesystem with all buckets # postgres runs again here so we can test on mac/windows - ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\"]" + ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\", \"motherduck\"]" jobs: get_docs_changes: From d3c3f3da9bce12585a0ab98d5c5b89227a652751 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:32:56 +0200 Subject: [PATCH 02/15] fixes edge cases for optional new types, new types of optional types and literal detection --- dlt/common/typing.py | 74 +++++++++++++++++++++++++++++-------- tests/common/test_typing.py | 10 +++++ 2 files changed, 69 insertions(+), 15 deletions(-) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 7490dc6e53..2575dd15a4 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -27,6 +27,7 @@ IO, Iterator, Generator, + NamedTuple, ) from typing_extensions import ( @@ -88,10 +89,6 @@ "A single data item or a list as extracted from the data source" TAnyDateTime = Union[pendulum.DateTime, pendulum.Date, datetime, date, str, float, int] """DateTime represented as pendulum/python object, ISO string or unix timestamp""" - -ConfigValue: None = None -"""value of type None indicating argument that may be injected by config provider""" - TVariantBase = TypeVar("TVariantBase", covariant=True) TVariantRV = Tuple[str, Any] VARIANT_FIELD_FORMAT = "v_%s" @@ -99,6 +96,30 @@ TSortOrder = Literal["asc", "desc"] +class ConfigValueSentinel(NamedTuple): + """Class to create singleton sentinel for config and secret injected value""" + + default_literal: str + default_type: AnyType + + def __str__(self) -> str: + return self.__repr__() + + def __repr__(self) -> str: + if self.default_literal == "dlt.config.value": + inst_ = "ConfigValue" + else: + inst_ = "SecretValue" + return f"{inst_}({self.default_literal}) awaiting injection" + + +ConfigValue: None = ConfigValueSentinel("dlt.config.value", AnyType) # type: ignore[assignment] +"""Config value indicating argument that may be injected by config provider. Evaluates to None when type checking""" + +SecretValue: None = ConfigValueSentinel("dlt.secrets.value", TSecretValue) # type: ignore[assignment] +"""Secret value indicating argument that may be injected by config provider. Evaluates to None when type checking""" + + @runtime_checkable class SupportsVariant(Protocol, Generic[TVariantBase]): """Defines variant type protocol that should be recognized by normalizers @@ -157,6 +178,10 @@ def extract_type_if_modifier(t: Type[Any]) -> Optional[Type[Any]]: return None +def extract_supertype(t: Type[Any]) -> Optional[Type[Any]]: + return getattr(t, "__supertype__", None) # type: ignore[no-any-return] + + def is_union_type(hint: Type[Any]) -> bool: # We need to handle UnionType because with Python>=3.10 # new Optional syntax was introduced which treats Optionals @@ -172,8 +197,8 @@ def is_union_type(hint: Type[Any]) -> bool: if origin is Union or origin is UnionType: return True - if hint := extract_type_if_modifier(hint): - return is_union_type(hint) + if inner_t := extract_type_if_modifier(hint): + return is_union_type(inner_t) return False @@ -184,8 +209,13 @@ def is_optional_type(t: Type[Any]) -> bool: if is_union and type(None) in get_args(t): return True - if t := extract_type_if_modifier(t): - return is_optional_type(t) + if inner_t := extract_type_if_modifier(t): + if is_optional_type(inner_t): + return True + else: + t = inner_t + if super_t := extract_supertype(t): + return is_optional_type(super_t) return False @@ -203,24 +233,37 @@ def extract_union_types(t: Type[Any], no_none: bool = False) -> List[Any]: def is_literal_type(hint: Type[Any]) -> bool: if get_origin(hint) is Literal: return True - if hint := extract_type_if_modifier(hint): - return is_literal_type(hint) + if inner_t := extract_type_if_modifier(hint): + if is_literal_type(inner_t): + return True + else: + hint = inner_t + if super_t := extract_supertype(hint): + return is_literal_type(super_t) + if is_union_type(hint) and is_optional_type(hint): + return is_literal_type(get_args(hint)[0]) + return False def is_newtype_type(t: Type[Any]) -> bool: if hasattr(t, "__supertype__"): return True - if t := extract_type_if_modifier(t): - return is_newtype_type(t) + if inner_t := extract_type_if_modifier(t): + if is_newtype_type(inner_t): + return True + else: + t = inner_t + if is_union_type(t) and is_optional_type(t): + return is_newtype_type(get_args(t)[0]) return False def is_typeddict(t: Type[Any]) -> bool: if isinstance(t, _TypedDict): return True - if t := extract_type_if_modifier(t): - return is_typeddict(t) + if inner_t := extract_type_if_modifier(t): + return is_typeddict(inner_t) return False @@ -257,7 +300,8 @@ def extract_inner_type(hint: Type[Any], preserve_new_types: bool = False) -> Typ """ if maybe_modified := extract_type_if_modifier(hint): return extract_inner_type(maybe_modified, preserve_new_types) - if is_optional_type(hint): + # make sure we deal with optional directly + if is_union_type(hint) and is_optional_type(hint): return extract_inner_type(get_args(hint)[0], preserve_new_types) if is_literal_type(hint): # assume that all literals are of the same type diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index cc319619e6..5bdd308566 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -138,6 +138,9 @@ def test_is_literal() -> None: assert is_literal_type(Final[TTestLi]) is True # type: ignore[arg-type] assert is_literal_type("a") is False # type: ignore[arg-type] assert is_literal_type(List[str]) is False + NT1 = NewType("NT1", Optional[TTestLi]) # type: ignore[valid-newtype] + assert is_literal_type(NT1) is True + assert is_literal_type(NewType("NT2", NT1)) is True def test_optional() -> None: @@ -151,6 +154,11 @@ def test_optional() -> None: assert is_optional_type(Final[Annotated[Union[str, int], None]]) is False # type: ignore[arg-type] assert is_optional_type(Annotated[Union[str, int], type(None)]) is False # type: ignore[arg-type] assert is_optional_type(TOptionalTyDi) is True # type: ignore[arg-type] + NT1 = NewType("NT1", Optional[str]) # type: ignore[valid-newtype] + assert is_optional_type(NT1) is True + assert is_optional_type(ClassVar[NT1]) is True # type: ignore[arg-type] + assert is_optional_type(NewType("NT2", NT1)) is True + assert is_optional_type(NewType("NT2", Annotated[NT1, 1])) is True assert is_optional_type(TTestTyDi) is False assert extract_union_types(TOptionalLi) == [TTestLi, type(None)] # type: ignore[arg-type] assert extract_union_types(TOptionalTyDi) == [TTestTyDi, type(None)] # type: ignore[arg-type] @@ -173,6 +181,7 @@ def test_is_newtype() -> None: assert is_newtype_type(ClassVar[NT1]) is True # type: ignore[arg-type] assert is_newtype_type(TypeVar("TV1", bound=str)) is False # type: ignore[arg-type] assert is_newtype_type(1) is False # type: ignore[arg-type] + assert is_newtype_type(Optional[NT1]) is True # type: ignore[arg-type] def test_is_annotated() -> None: @@ -195,6 +204,7 @@ def test_extract_inner_type() -> None: assert extract_inner_type(NTL2, preserve_new_types=True) is NTL2 l_2 = Literal[NTL2(1.238), NTL2(2.343)] # type: ignore[valid-type] assert extract_inner_type(l_2) is float # type: ignore[arg-type] + assert extract_inner_type(NewType("NT1", Optional[str])) is str def test_get_config_if_union() -> None: From e94ff9571d641680bf6196ae306cfe9adf0752b4 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:33:14 +0200 Subject: [PATCH 03/15] skips streamlit tests if not installed --- tests/helpers/streamlit_tests/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/helpers/streamlit_tests/__init__.py b/tests/helpers/streamlit_tests/__init__.py index e69de29bb2..61132f214d 100644 --- a/tests/helpers/streamlit_tests/__init__.py +++ b/tests/helpers/streamlit_tests/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("streamlit") From 6622fd45ae918f6b4748ff21e0cae39edded0e59 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:41:37 +0200 Subject: [PATCH 04/15] defines singleton sentinels for dlt.config.value and dlt.secrets.value --- dlt/common/configuration/accessors.py | 8 ++------ tests/common/configuration/test_accessors.py | 6 +++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/dlt/common/configuration/accessors.py b/dlt/common/configuration/accessors.py index dfadc97fa3..1b32ae96f4 100644 --- a/dlt/common/configuration/accessors.py +++ b/dlt/common/configuration/accessors.py @@ -1,6 +1,4 @@ import abc -import contextlib -import tomlkit from typing import Any, ClassVar, List, Sequence, Tuple, Type, TypeVar from dlt.common.configuration.container import Container @@ -9,10 +7,8 @@ from dlt.common.configuration.specs import BaseConfiguration, is_base_configuration_inner_hint from dlt.common.configuration.utils import deserialize_value, log_traces, auto_cast from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext -from dlt.common.typing import AnyType, ConfigValue, TSecretValue +from dlt.common.typing import AnyType, ConfigValue, SecretValue, TSecretValue -DLT_SECRETS_VALUE = "secrets.value" -DLT_CONFIG_VALUE = "config.value" TConfigAny = TypeVar("TConfigAny", bound=Any) @@ -129,7 +125,7 @@ def writable_provider(self) -> ConfigProvider: p for p in self._get_providers_from_context() if p.is_writable and p.supports_secrets ) - value: ClassVar[Any] = ConfigValue + value: ClassVar[Any] = SecretValue "A placeholder that tells dlt to replace it with actual secret during the call to a source or resource decorated function." diff --git a/tests/common/configuration/test_accessors.py b/tests/common/configuration/test_accessors.py index 147d56abec..dc8761110f 100644 --- a/tests/common/configuration/test_accessors.py +++ b/tests/common/configuration/test_accessors.py @@ -19,7 +19,7 @@ from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.utils import get_resolved_traces, ResolvedValueTrace from dlt.common.runners.configuration import PoolRunnerConfiguration -from dlt.common.typing import AnyType, TSecretValue +from dlt.common.typing import AnyType, ConfigValue, SecretValue, TSecretValue from tests.utils import preserve_environ @@ -29,8 +29,8 @@ def test_accessor_singletons() -> None: - assert dlt.config.value is None - assert dlt.secrets.value is None + assert dlt.config.value is ConfigValue + assert dlt.secrets.value is SecretValue def test_getter_accessor(toml_providers: ConfigProvidersContext, environment: Any) -> None: From 7d42533849ead4482621fda7e677ea2575b087d1 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:42:52 +0200 Subject: [PATCH 05/15] uses sentinels to detect config and secret values, removes source code unparsing --- dlt/common/reflection/spec.py | 63 ++++++++++++----------------------- 1 file changed, 21 insertions(+), 42 deletions(-) diff --git a/dlt/common/reflection/spec.py b/dlt/common/reflection/spec.py index 5c39199f63..db791c60cd 100644 --- a/dlt/common/reflection/spec.py +++ b/dlt/common/reflection/spec.py @@ -3,11 +3,9 @@ from typing import Dict, List, Tuple, Type, Any, Optional, NewType from inspect import Signature, Parameter -from dlt.common.typing import AnyType, AnyFun, TSecretValue +from dlt.common.typing import AnyType, AnyFun, ConfigValueSentinel, NoneType, TSecretValue from dlt.common.configuration import configspec, is_valid_hint, is_secret_hint from dlt.common.configuration.specs import BaseConfiguration -from dlt.common.configuration.accessors import DLT_CONFIG_VALUE, DLT_SECRETS_VALUE -from dlt.common.reflection.utils import get_func_def_node, get_literal_defaults from dlt.common.utils import get_callable_name # [^.^_]+ splits by . or _ @@ -34,8 +32,8 @@ def spec_from_signature( """Creates a SPEC on base `base1 for a function `f` with signature `sig`. All the arguments in `sig` that are valid SPEC hints and have defaults will be part of the SPEC. - Special markers for required SPEC fields `dlt.secrets.value` and `dlt.config.value` are parsed using - module source code, which is a hack and will not work for modules not imported from a file. + Special default markers for required SPEC fields `dlt.secrets.value` and `dlt.config.value` are sentinel + string values with a type set to Any during typechecking. The sentinels are defined in dlt.common.typing module. The name of a SPEC type is inferred from qualname of `f` and type will refer to `f` module and is unique for a module. NOTE: the SPECS are cached in the module by using name as an id. @@ -52,28 +50,6 @@ def spec_from_signature( MOD_SPEC: Type[BaseConfiguration] = getattr(module, spec_id) return MOD_SPEC, MOD_SPEC.get_resolvable_fields() - # find all the arguments that have following defaults - literal_defaults: Dict[str, str] = None - - def dlt_config_literal_to_type(arg_name: str) -> AnyType: - nonlocal literal_defaults - - if literal_defaults is None: - try: - node = get_func_def_node(f) - literal_defaults = get_literal_defaults(node) - except Exception: - # ignore exception during parsing. it is almost impossible to test all cases of function definitions - literal_defaults = {} - - if arg_name in literal_defaults: - literal_default = literal_defaults[arg_name] - if literal_default.endswith(DLT_CONFIG_VALUE): - return AnyType - if literal_default.endswith(DLT_SECRETS_VALUE): - return TSecretValue - return None - # synthesize configuration from the signature new_fields: Dict[str, Any] = {} sig_base_fields: Dict[str, Any] = {} @@ -87,40 +63,43 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType: ]: field_type = AnyType if p.annotation == Parameter.empty else p.annotation # keep the base fields if sig not annotated - if p.name in base_fields and field_type is AnyType and p.default is None: + if ( + p.name in base_fields + and field_type is AnyType + and isinstance(p.default, (NoneType, ConfigValueSentinel)) + ): sig_base_fields[p.name] = base_fields[p.name] continue # only valid hints and parameters with defaults are eligible if is_valid_hint(field_type) and p.default != Parameter.empty: - # try to get type from default - if field_type is AnyType and p.default is not None: - field_type = type(p.default) - # make type optional if explicit None is provided as default type_from_literal: AnyType = None + # make type optional if explicit None is provided as default if p.default is None: + # optional type + field_type = Optional[field_type] + elif isinstance(p.default, ConfigValueSentinel): # check if the defaults were attributes of the form .config.value or .secrets.value - type_from_literal = dlt_config_literal_to_type(p.name) - if type_from_literal is None: - # optional type - field_type = Optional[field_type] - elif type_from_literal is TSecretValue: + type_from_literal = p.default.default_type + if type_from_literal is TSecretValue: # override type with secret value if secrets.value - # print(f"Param {p.name} is REQUIRED: secrets literal") if not is_secret_hint(field_type): if field_type is AnyType: field_type = TSecretValue else: # generate typed SecretValue field_type = NewType("TSecretValue", field_type) # type: ignore - else: - # keep type mandatory if config.value - # print(f"Param {p.name} is REQUIRED: config literal") - pass + # remove sentinel from default + p = p.replace(default=None) + elif field_type is AnyType: + # try to get type from default + field_type = type(p.default) + if include_defaults or type_from_literal is not None: # set annotations annotations[p.name] = field_type # set field with default value new_fields[p.name] = p.default + # print(f"Param {p.name} is {field_type}: {p.default} due to {include_defaults} or {type_from_literal}") signature_fields = {**sig_base_fields, **new_fields} From e27962f770db2b84c92ed460704888d8a26c1920 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:44:21 +0200 Subject: [PATCH 06/15] converts sentinels if used in configspec to right defaults --- .../configuration/specs/base_configuration.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 006cde8dce..12ff1d8a5d 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -20,7 +20,7 @@ ClassVar, TypeVar, ) -from typing_extensions import get_args, get_origin, dataclass_transform, Annotated, TypeAlias +from typing_extensions import get_args, get_origin, dataclass_transform from functools import wraps if TYPE_CHECKING: @@ -30,6 +30,7 @@ from dlt.common.typing import ( AnyType, + ConfigValueSentinel, TAnyClass, extract_inner_type, is_annotated, @@ -61,7 +62,7 @@ def __bool__(self) -> bool: return self.not_resolved -def is_hint_not_resolved(hint: AnyType) -> bool: +def is_hint_not_resolvable(hint: AnyType) -> bool: """Checks if hint should NOT be resolved. Final and types annotated like >>> Annotated[str, NotResolved()] @@ -102,7 +103,7 @@ def is_valid_hint(hint: Type[Any]) -> bool: # class vars are skipped by dataclass return True - if is_hint_not_resolved(hint): + if is_hint_not_resolvable(hint): # all hints that are not resolved are valid return True @@ -190,7 +191,7 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_")): warnings.warn( f"Missing default value for field {ann} on {cls.__name__}. None assumed. All" - " fields in configspec must have default." + " fields in configspec must have defaults." ) setattr(cls, ann, None) # get all attributes without corresponding annotations @@ -217,6 +218,20 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: # context can have any type if not is_valid_hint(hint) and not is_context: raise ConfigFieldTypeHintNotSupported(att_name, cls, hint) + # replace config / secret sentinels + if isinstance(att_value, ConfigValueSentinel): + if is_secret_hint(att_value.default_type) and not is_secret_hint(hint): + warnings.warn( + f"You indicated {att_name} to be {att_value.default_literal} but type" + " hint is not a secret" + ) + if not is_secret_hint(att_value.default_type) and is_secret_hint(hint): + warnings.warn( + f"You typed {att_name} to be a secret but" + f" {att_value.default_literal} indicates it is not" + ) + setattr(cls, att_name, None) + if isinstance(att_value, BaseConfiguration): # Wrap config defaults in default_factory to work around dataclass # blocking mutable defaults @@ -292,7 +307,7 @@ def _get_resolvable_dataclass_fields(cls) -> Iterator[TDtcField]: """Yields all resolvable dataclass fields in the order they should be resolved""" # Sort dynamic type hint fields last because they depend on other values yield from sorted( - (f for f in cls.__dataclass_fields__.values() if cls.__is_valid_field(f)), + (f for f in cls.__dataclass_fields__.values() if is_valid_configspec_field(f)), key=lambda f: f.name in cls.__hint_resolvers__, ) @@ -350,7 +365,7 @@ def __iter__(self) -> Iterator[str]: """Iterator or valid key names""" return map( lambda field: field.name, - filter(lambda val: self.__is_valid_field(val), self.__dataclass_fields__.values()), + filter(lambda val: is_valid_configspec_field(val), self.__dataclass_fields__.values()), ) def __len__(self) -> int: @@ -366,14 +381,10 @@ def update(self, other: Any = (), /, **kwds: Any) -> None: # helper functions def __has_attr(self, __key: str) -> bool: - return __key in self.__dataclass_fields__ and self.__is_valid_field( + return __key in self.__dataclass_fields__ and is_valid_configspec_field( self.__dataclass_fields__[__key] ) - @staticmethod - def __is_valid_field(field: TDtcField) -> bool: - return not field.name.startswith("__") and field._field_type is dataclasses._FIELD # type: ignore - def call_method_in_mro(config, method_name: str) -> None: # python multi-inheritance is cooperative and this would require that all configurations cooperatively # call each other class_method_name. this is not at all possible as we do not know which configs in the end will @@ -391,6 +402,10 @@ def call_method_in_mro(config, method_name: str) -> None: _F_BaseConfiguration = BaseConfiguration +def is_valid_configspec_field(field: TDtcField) -> bool: + return not field.name.startswith("__") and field._field_type is dataclasses._FIELD # type: ignore + + @configspec class CredentialsConfiguration(BaseConfiguration): """Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets.""" From fd1096d4c4619a7319a80b42d4207d7875205600 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:45:41 +0200 Subject: [PATCH 07/15] adds SPECs to callables as attributes --- dlt/common/configuration/inject.py | 93 +++++++++++++++++------------- 1 file changed, 54 insertions(+), 39 deletions(-) diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index 6699826ec8..c6ec5d4ddc 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -1,11 +1,10 @@ import inspect from functools import wraps -from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload, cast +from typing import Callable, Dict, Type, Any, Optional, Union, Tuple, TypeVar, overload, cast from inspect import Signature, Parameter -from contextlib import nullcontext -from dlt.common.typing import DictStrAny, StrAny, TFun, AnyFun +from dlt.common.typing import DictStrAny, TFun, AnyFun from dlt.common.configuration.resolve import resolve_configuration, inject_section from dlt.common.configuration.specs.base_configuration import BaseConfiguration from dlt.common.configuration.specs.config_section_context import ConfigSectionContext @@ -15,14 +14,16 @@ _LAST_DLT_CONFIG = "_dlt_config" _ORIGINAL_ARGS = "_dlt_orig_args" -# keep a registry of all the decorated functions -_FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {} - TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) def get_fun_spec(f: AnyFun) -> Type[BaseConfiguration]: - return _FUNC_SPECS.get(id(f)) + return getattr(f, "__SPEC__", None) # type: ignore[no-any-return] + + +def set_fun_spec(f: AnyFun, spec: Type[BaseConfiguration]) -> None: + """Assigns a spec to a callable from which it was inferred""" + setattr(f, "__SPEC__", spec) # noqa: B010 @overload @@ -30,7 +31,7 @@ def with_config( func: TFun, /, spec: Type[BaseConfiguration] = None, - sections: Tuple[str, ...] = (), + sections: Union[str, Tuple[str, ...]] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, include_defaults: bool = True, @@ -46,7 +47,7 @@ def with_config( func: None = ..., /, spec: Type[BaseConfiguration] = None, - sections: Tuple[str, ...] = (), + sections: Union[str, Tuple[str, ...]] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, include_defaults: bool = True, @@ -61,7 +62,7 @@ def with_config( func: Optional[AnyFun] = None, /, spec: Type[BaseConfiguration] = None, - sections: Tuple[str, ...] = (), + sections: Union[str, Tuple[str, ...]] = (), sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming, auto_pipeline_section: bool = False, include_defaults: bool = True, @@ -88,17 +89,18 @@ def with_config( Callable[[TFun], TFun]: A decorated function """ - section_f: Callable[[StrAny], str] = None - # section may be a function from function arguments to section - if callable(sections): - section_f = sections - def decorator(f: TFun) -> TFun: SPEC: Type[BaseConfiguration] = None sig: Signature = inspect.signature(f) signature_fields: Dict[str, Any] + # find variadic kwargs to which additional arguments and injection context can be injected kwargs_arg = next( - (p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None + ( + p + for p in sig.parameters.values() + if p.kind == Parameter.VAR_KEYWORD and p.name == "injection_kwargs" + ), + None, ) if spec is None: SPEC, signature_fields = spec_from_signature(f, sig, include_defaults, base=base) @@ -109,7 +111,7 @@ def decorator(f: TFun) -> TFun: # if no signature fields were added we will not wrap `f` for injection if len(signature_fields) == 0: # always register new function - _FUNC_SPECS[id(f)] = SPEC + set_fun_spec(f, SPEC) return f spec_arg: Parameter = None @@ -127,20 +129,23 @@ def decorator(f: TFun) -> TFun: pipeline_name_arg = p pipeline_name_arg_default = None if p.default == Parameter.empty else p.default - def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration: + def resolve_config( + bound_args: inspect.BoundArguments, accept_partial_: bool + ) -> BaseConfiguration: """Resolve arguments using the provided spec""" # bind parameters to signature # for calls containing resolved spec in the kwargs, we do not need to resolve again config: BaseConfiguration = None - # if section derivation function was provided then call it - if section_f: - curr_sections: Tuple[str, ...] = (section_f(bound_args.arguments),) - # sections may be a string - elif isinstance(sections, str): - curr_sections = (sections,) + curr_sections: Union[str, Tuple[str, ...]] = None + # section may be a function from function arguments to section + if callable(sections): + curr_sections = sections(bound_args.arguments) else: curr_sections = sections + # sections may be a string + if isinstance(curr_sections, str): + curr_sections = (curr_sections,) # if one of arguments is spec the use it as initial value if initial_config: @@ -162,11 +167,11 @@ def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration: # this may be called from many threads so section_context is thread affine with inject_section(section_context, lock_context=lock_context_on_injection): - # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") + # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections} in {bound_args.arguments}") return resolve_configuration( config or SPEC(), explicit_value=bound_args.arguments, - accept_partial=accept_partial, + accept_partial=accept_partial_, ) def update_bound_args( @@ -174,6 +179,7 @@ def update_bound_args( ) -> None: # overwrite or add resolved params resolved_params = dict(config) + # print("resolved_params", resolved_params) # overwrite or add resolved params for p in sig.parameters.values(): if p.name in resolved_params: @@ -191,11 +197,18 @@ def update_bound_args( def with_partially_resolved_config(config: Optional[BaseConfiguration] = None) -> Any: # creates a pre-resolved partial of the decorated function - empty_bound_args = sig.bind_partial() if not config: - config = resolve_config(empty_bound_args) - - def wrapped(*args: Any, **kwargs: Any) -> Any: + # TODO: this will not work if correct config is not provided + # esp. in case of parameters in _wrap being ConfigurationBase + # at least we should implement re-resolve with explicit parameters + # so we can merge partial we get here to combine a full config + empty_bound_args = sig.bind_partial() + # TODO: resolve partial here that will be updated in _wrap + config = resolve_config(empty_bound_args, accept_partial_=False) + + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> Any: + # TODO: we should not change the outer config but deepcopy it nonlocal config # Do we need an exception here? @@ -213,27 +226,28 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: # call the function with the pre-resolved config bound_args = sig.bind(*args, **kwargs) + # TODO: update partial config with bound_args (to cover edge cases with embedded configs) update_bound_args(bound_args, config, args, kwargs) return f(*bound_args.args, **bound_args.kwargs) - return wrapped + return _wrap @wraps(f) def _wrap(*args: Any, **kwargs: Any) -> Any: # Resolve config config: BaseConfiguration = None - bound_args = sig.bind(*args, **kwargs) + bound_args = sig.bind_partial(*args, **kwargs) if _LAST_DLT_CONFIG in kwargs: config = last_config(**kwargs) else: - config = resolve_config(bound_args) + config = resolve_config(bound_args, accept_partial_=accept_partial) # call the function with resolved config update_bound_args(bound_args, config, args, kwargs) return f(*bound_args.args, **bound_args.kwargs) # register the spec for a wrapped function - _FUNC_SPECS[id(_wrap)] = SPEC + set_fun_spec(_wrap, SPEC) # add a method to create a pre-resolved partial setattr(_wrap, "__RESOLVED_PARTIAL_FUNC__", with_partially_resolved_config) # noqa: B010 @@ -255,13 +269,14 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return decorator(func) -def last_config(**kwargs: Any) -> Any: - """Get configuration instance used to inject function arguments""" - return kwargs[_LAST_DLT_CONFIG] +def last_config(**injection_kwargs: Any) -> Any: + """Get configuration instance used to inject function kwargs""" + return injection_kwargs[_LAST_DLT_CONFIG] -def get_orig_args(**kwargs: Any) -> Tuple[Tuple[Any], DictStrAny]: - return kwargs[_ORIGINAL_ARGS] # type: ignore +def get_orig_args(**injection_kwargs: Any) -> Tuple[Tuple[Any], DictStrAny]: + """Get original argument with which the injectable function was called""" + return injection_kwargs[_ORIGINAL_ARGS] # type: ignore def create_resolved_partial(f: AnyFun, config: Optional[BaseConfiguration] = None) -> AnyFun: From 65d2aca679c7765711c19ba4314de67a2065191c Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:47:37 +0200 Subject: [PATCH 08/15] simplifies and fixes nested dict update, adds dict clone in utils --- dlt/common/utils.py | 44 +++++++++++----------- dlt/extract/hints.py | 13 ++++--- tests/common/test_utils.py | 77 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 103 insertions(+), 31 deletions(-) diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 1d3020f4dd..cb2ec4c3d9 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -15,6 +15,7 @@ Any, ContextManager, Dict, + MutableMapping, Iterator, Optional, Sequence, @@ -24,17 +25,15 @@ Mapping, List, Union, - Counter, Iterable, ) -from collections.abc import Mapping as C_Mapping from dlt.common.exceptions import DltException, ExceptionTrace, TerminalException from dlt.common.typing import AnyFun, StrAny, DictStrAny, StrStr, TAny, TFun T = TypeVar("T") -TDict = TypeVar("TDict", bound=DictStrAny) +TDict = TypeVar("TDict", bound=MutableMapping[Any, Any]) TKey = TypeVar("TKey") TValue = TypeVar("TValue") @@ -281,35 +280,36 @@ def update_dict_with_prune(dest: DictStrAny, update: StrAny) -> None: del dest[k] -def update_dict_nested(dst: TDict, src: StrAny, keep_dst_values: bool = False) -> TDict: +def update_dict_nested(dst: TDict, src: TDict, copy_src_dicts: bool = False) -> TDict: """Merges `src` into `dst` key wise. Does not recur into lists. Values in `src` overwrite `dst` if both keys exit. - Optionally (`keep_dst_values`) you can keep the `dst` value on conflict + Only `dict` and its subclasses are updated recursively. With `copy_src_dicts`, dict key:values will be deep copied, + otherwise, both dst and src will keep the same references. """ - # based on https://github.com/clarketm/mergedeep/blob/master/mergedeep/mergedeep.py - - def _is_recursive_merge(a: StrAny, b: StrAny) -> bool: - both_mapping = isinstance(a, C_Mapping) and isinstance(b, C_Mapping) - both_counter = isinstance(a, Counter) and isinstance(b, Counter) - return both_mapping and not both_counter for key in src: + src_val = src[key] if key in dst: - if _is_recursive_merge(dst[key], src[key]): + dst_val = dst[key] + if isinstance(src_val, dict) and isinstance(dst_val, dict): # If the key for both `dst` and `src` are both Mapping types (e.g. dict), then recurse. - update_dict_nested(dst[key], src[key], keep_dst_values=keep_dst_values) - elif dst[key] is src[key]: - # If a key exists in both objects and the values are `same`, the value from the `dst` object will be used. - pass - else: - if not keep_dst_values: - # if not keep then overwrite - dst[key] = src[key] + update_dict_nested(dst_val, src_val, copy_src_dicts=copy_src_dicts) + continue + + if copy_src_dicts and isinstance(src_val, dict): + dst[key] = update_dict_nested({}, src_val, True) else: - # If the key exists only in `src`, the value from the `src` object will be used. - dst[key] = src[key] + dst[key] = src_val + return dst +def clone_dict_nested(src: TDict) -> TDict: + """Clones `src` structure descending into nested dicts. Does not descend into mappings that are not dicts ie. specs instances. + Compared to `deepcopy` does not clone any other objects. Uses `update_dict_nested` internally + """ + return update_dict_nested({}, src, copy_src_dicts=True) # type: ignore[return-value] + + def map_nested_in_place(func: AnyFun, _complex: TAny) -> TAny: """Applies `func` to all elements in `_dict` recursively, replacing elements in nested dictionaries and lists in place.""" if isinstance(_complex, tuple): diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 70ff0cc29d..287474c82c 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -24,6 +24,7 @@ new_table, ) from dlt.common.typing import TDataItem +from dlt.common.utils import clone_dict_nested from dlt.common.validation import validate_dict_ignoring_xkeys from dlt.extract.exceptions import ( DataItemRequiredForDynamicTableHints, @@ -318,7 +319,7 @@ def apply_hints( # set properties that can't be passed to make_hints if incremental is not None: - t["incremental"] = None if incremental is Incremental.EMPTY else incremental + t["incremental"] = incremental self._set_hints(t, create_table_variant) @@ -375,11 +376,11 @@ def merge_hints( @staticmethod def _clone_hints(hints_template: TResourceHints) -> TResourceHints: - t_ = copy(hints_template) - t_["columns"] = deepcopy(hints_template["columns"]) - if "schema_contract" in hints_template: - t_["schema_contract"] = deepcopy(hints_template["schema_contract"]) - return t_ + if hints_template is None: + return None + # creates a deep copy of dict structure without actually copying the objects + # deepcopy(hints_template) # + return clone_dict_nested(hints_template) # type: ignore[type-var] @staticmethod def _resolve_hint(item: TDataItem, hint: TTableHintTemplate[Any]) -> Any: diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index 229ce17085..e08c1cdf01 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -1,12 +1,14 @@ +from copy import deepcopy import itertools import inspect import binascii import pytest -from typing import Dict +from typing import Any, Dict from dlt.common.exceptions import PipelineException, TerminalValueError from dlt.common.runners import Venv from dlt.common.utils import ( + clone_dict_nested, graph_find_scc_nodes, flatten_list_of_str_or_dicts, digest128, @@ -293,6 +295,75 @@ def test_nested_dict_merge() -> None: assert update_dict_nested(dict(dict_1), dict_2) == {"a": 2, "b": 2, "c": 4} assert update_dict_nested(dict(dict_2), dict_1) == {"a": 1, "b": 2, "c": 4} - assert update_dict_nested(dict(dict_1), dict_2, keep_dst_values=True) == update_dict_nested( - dict_2, dict_1 + assert update_dict_nested(dict(dict_1), dict_2, copy_src_dicts=True) == {"a": 2, "b": 2, "c": 4} + assert update_dict_nested(dict(dict_2), dict_1, copy_src_dicts=True) == {"a": 1, "b": 2, "c": 4} + dict_1_update = update_dict_nested({}, dict_1) + assert dict_1_update == dict_1 + assert dict_1_update is not dict_1 + dict_1_update = clone_dict_nested(dict_1) + assert dict_1_update == dict_1 + assert dict_1_update is not dict_1 + + dict_1_deep = {"a": 3, "b": dict_1} + dict_1_deep_clone = update_dict_nested({}, dict_1_deep) + assert dict_1_deep_clone == dict_1_deep + # reference got copied + assert dict_1_deep_clone["b"] is dict_1 + # update with copy + dict_1_deep_clone = clone_dict_nested(dict_1_deep) + assert dict_1_deep_clone == dict_1_deep + # reference got copied + assert dict_1_deep_clone["b"] is not dict_1 + + # make sure that that Mappings that are not dicts are atomically copied + from dlt.common.configuration.specs import ConnectionStringCredentials + + dsn = ConnectionStringCredentials("postgres://loader:loader@localhost:5432/dlt_data") + dict_1_mappings: Dict[str, Any] = { + "_tuple": (1, 2), + "_config": {"key": "str", "_dsn": dsn, "_dict": dict_1_deep}, + } + # make a clone + dict_1_mappings_clone = clone_dict_nested(dict_1_mappings) + # values are same + assert dict_1_mappings == dict_1_mappings_clone + # all objects and mappings are copied as reference + assert dict_1_mappings["_tuple"] is dict_1_mappings_clone["_tuple"] + assert dict_1_mappings["_config"]["_dsn"] is dict_1_mappings_clone["_config"]["_dsn"] + # dicts are copied by value + assert dict_1_mappings["_config"] is not dict_1_mappings_clone["_config"] + assert dict_1_mappings["_config"]["_dict"] is not dict_1_mappings_clone["_config"]["_dict"] + assert ( + dict_1_mappings["_config"]["_dict"]["b"] + is not dict_1_mappings_clone["_config"]["_dict"]["b"] + ) + + # make a copy using references + dict_1_mappings_clone = update_dict_nested({}, dict_1_mappings) + assert dict_1_mappings["_config"] is dict_1_mappings_clone["_config"] + assert dict_1_mappings["_config"]["_dict"] is dict_1_mappings_clone["_config"]["_dict"] + assert ( + dict_1_mappings["_config"]["_dict"]["b"] is dict_1_mappings_clone["_config"]["_dict"]["b"] + ) + + # replace a few keys + print(dict_1_mappings) + # this should be non destructive for the dst + deep_clone_dict_1_mappings = deepcopy(dict_1_mappings) + mappings_update = update_dict_nested( + dict_1_mappings, {"_config": {"_dsn": ConnectionStringCredentials(), "_dict": {"a": "X"}}} + ) + # assert deep_clone_dict_1_mappings == dict_1_mappings + # things overwritten + assert dict_1_mappings["_config"]["_dsn"] is mappings_update["_config"]["_dsn"] + # this one is empty + assert mappings_update["_config"]["_dsn"].username is None + assert dict_1_mappings["_config"]["_dsn"].username is None + assert mappings_update["_config"]["_dict"]["a"] == "X" + assert dict_1_mappings["_config"]["_dict"]["a"] == "X" + + # restore original values + mappings_update = update_dict_nested( + mappings_update, {"_config": {"_dsn": dsn, "_dict": {"a": 3}}} ) + assert mappings_update == deep_clone_dict_1_mappings From 02780a61bc39524dad61421fb7d4c77310f4a341 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:48:33 +0200 Subject: [PATCH 09/15] adds several missing config injects tests --- tests/common/configuration/test_inject.py | 250 +++++++++++++++++++--- 1 file changed, 223 insertions(+), 27 deletions(-) diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 1aa52c1919..f0494e9898 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -6,7 +6,10 @@ from dlt.common.configuration.exceptions import ConfigFieldMissingException from dlt.common.configuration.inject import ( + _LAST_DLT_CONFIG, + _ORIGINAL_ARGS, get_fun_spec, + get_orig_args, last_config, with_config, create_resolved_partial, @@ -20,11 +23,16 @@ GcpServiceAccountCredentialsWithoutDefaults, ConnectionStringCredentials, ) -from dlt.common.configuration.specs.base_configuration import configspec, is_secret_hint +from dlt.common.configuration.specs.base_configuration import ( + CredentialsConfiguration, + configspec, + is_secret_hint, + is_valid_configspec_field, +) from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.reflection.spec import _get_spec_name_from_f -from dlt.common.typing import StrAny, TSecretValue, is_newtype_type +from dlt.common.typing import StrAny, TSecretStrValue, TSecretValue, is_newtype_type from tests.utils import preserve_environ from tests.common.configuration.utils import environment, toml_providers @@ -47,8 +55,28 @@ def f_var_env(user=dlt.config.value, path=dlt.config.value): assert path == "explicit path" # user will be injected - f_var_env(None, path="explicit path") - f_var_env(path="explicit path", user=None) + f_var_env(dlt.config.value, path="explicit path") + f_var_env(path="explicit path", user=dlt.secrets.value) + + # none will be passed and trigger config missing + with pytest.raises(ConfigFieldMissingException) as cfg_ex: + f_var_env(None, path="explicit path") + assert "user" in cfg_ex.value.traces + assert cfg_ex.value.traces["user"][0].provider == "ExplicitValues" + + +def test_explicit_none(environment: Any) -> None: + @with_config + def f_var(user: Optional[str] = "default"): + return user + + assert f_var(None) is None + assert f_var() == "default" + assert f_var(dlt.config.value) == "default" + environment["USER"] = "env user" + assert f_var() == "env user" + assert f_var(None) is None + assert f_var(dlt.config.value) == "env user" def test_default_values_are_resolved(environment: Any) -> None: @@ -83,11 +111,64 @@ def f_secret(password=dlt.secrets.value): environment["USER"] = "user" assert f_config() == "user" - assert f_config(None) == "user" + assert f_config(dlt.config.value) == "user" environment["PASSWORD"] = "password" assert f_secret() == "password" - assert f_secret(None) == "password" + assert f_secret(dlt.secrets.value) == "password" + + +def test_dlt_literals_in_spec() -> None: + @configspec + class LiteralsConfiguration(BaseConfiguration): + required_str: str = dlt.config.value + required_int: int = dlt.config.value + required_secret: TSecretStrValue = dlt.secrets.value + credentials: CredentialsConfiguration = dlt.secrets.value + optional_default: float = 1.2 + + fields = { + k: f.default + for k, f in LiteralsConfiguration.__dataclass_fields__.items() + if is_valid_configspec_field(f) + } + # make sure all special values are evaluated to None which indicate required params + assert fields == { + "required_str": None, + "required_int": None, + "required_secret": None, + "credentials": None, + "optional_default": 1.2, + } + c = LiteralsConfiguration() + assert dict(c) == fields + + # instantiate to make sure linter does not complain + c = LiteralsConfiguration("R", 0, TSecretStrValue("A"), ConnectionStringCredentials()) + assert dict(c) == { + "required_str": "R", + "required_int": 0, + "required_secret": TSecretStrValue("A"), + "credentials": ConnectionStringCredentials(), + "optional_default": 1.2, + } + + # this generates warnings + @configspec + class WrongLiteralsConfiguration(BaseConfiguration): + required_int: int = dlt.secrets.value + required_secret: TSecretStrValue = dlt.config.value + credentials: CredentialsConfiguration = dlt.config.value + + +def test_dlt_literals_defaults_none() -> None: + @with_config + def with_optional_none( + level: Optional[int] = dlt.config.value, aux: Optional[str] = dlt.secrets.value + ): + return (level, aux) + + assert with_optional_none() == (None, None) def test_inject_from_argument_section(toml_providers: ConfigProvidersContext) -> None: @@ -104,12 +185,14 @@ def f_credentials(gcp_storage: GcpServiceAccountCredentialsWithoutDefaults = dlt def test_inject_secret_value_secret_type(environment: Any) -> None: @with_config def f_custom_secret_type( - _dict: Dict[str, Any] = dlt.secrets.value, _int: int = dlt.secrets.value, **kwargs: Any + _dict: Dict[str, Any] = dlt.secrets.value, + _int: int = dlt.secrets.value, + **injection_kwargs: Any, ): # secret values were coerced into types assert _dict == {"a": 1} assert _int == 1234 - cfg = last_config(**kwargs) + cfg = last_config(**injection_kwargs) spec: Type[BaseConfiguration] = cfg.__class__ # assert that types are secret for f in ["_dict", "_int"]: @@ -124,6 +207,22 @@ def f_custom_secret_type( f_custom_secret_type() +def test_aux_not_injected_into_kwargs() -> None: + # only kwargs with name injection_kwargs receive aux info + + @configspec + class AuxTest(BaseConfiguration): + aux: str = "INFO" + + @with_config(spec=AuxTest) + def f_no_aux(**kwargs: Any): + assert "aux" not in kwargs + assert _LAST_DLT_CONFIG not in kwargs + assert _ORIGINAL_ARGS not in kwargs + + f_no_aux() + + @pytest.mark.skip("not implemented") def test_inject_with_non_injectable_param() -> None: # one of parameters in signature has not valid hint and is skipped (ie. from_pipe) @@ -356,15 +455,50 @@ def test_inject_with_str_sections() -> None: pass -@pytest.mark.skip("not implemented") -def test_inject_with_func_section() -> None: +def test_inject_with_func_section(environment: Any) -> None: # function to get sections from the arguments is provided - pass + @with_config(sections=lambda args: "dlt_" + args["name"]) # type: ignore[call-overload] + def table_info(name, password=dlt.secrets.value): + return password + + environment["DLT_USERS__PASSWORD"] = "pass" + assert table_info("users") == "pass" + + @with_config(sections=lambda args: ("dlt", args["name"])) # type: ignore[call-overload] + def table_info_2(name, password=dlt.secrets.value): + return password + + environment["DLT__CONTACTS__PASSWORD"] = "pass_x" + assert table_info_2("contacts") == "pass_x" -@pytest.mark.skip("not implemented") -def test_inject_on_class_and_methods() -> None: - pass + +def test_inject_on_class_and_methods(environment: Any) -> None: + environment["AUX"] = "DEBUG" + environment["LEVEL"] = "1" + + class AuxCallReceiver: + @with_config + def __call__(self, level: int = dlt.config.value, aux: str = dlt.config.value) -> Any: + return (level, aux) + + assert AuxCallReceiver()() == (1, "DEBUG") + + class AuxReceiver: + @with_config + def __init__(self, level: int = dlt.config.value, aux: str = dlt.config.value) -> None: + self.level = level + self.aux = aux + + @with_config + def resolve(self, level: int = dlt.config.value, aux: str = dlt.config.value) -> Any: + return (level, aux) + + kl_ = AuxReceiver() + assert kl_.level == 1 + assert kl_.aux == "DEBUG" + + assert kl_.resolve() == (1, "DEBUG") @pytest.mark.skip("not implemented") @@ -374,34 +508,96 @@ def test_set_defaults_for_positional_args() -> None: pass -@pytest.mark.skip("not implemented") def test_inject_spec_remainder_in_kwargs() -> None: # if the wrapped func contains kwargs then all the fields from spec without matching func args must be injected in kwargs - pass + @configspec + class AuxTest(BaseConfiguration): + level: int = None + aux: str = "INFO" + + @with_config(spec=AuxTest) + def f_aux(level, **injection_kwargs: Any): + # level is in args so not added to kwargs + assert level == 1 + assert "level" not in injection_kwargs + # remainder in kwargs + assert injection_kwargs["aux"] == "INFO" + # assert _LAST_DLT_CONFIG not in kwargs + # assert _ORIGINAL_ARGS not in kwargs + + f_aux(1) -@pytest.mark.skip("not implemented") def test_inject_spec_in_kwargs() -> None: - # the resolved spec is injected in kwargs - pass + @configspec + class AuxTest(BaseConfiguration): + aux: str = "INFO" + @with_config(spec=AuxTest) + def f_kw_spec(**injection_kwargs: Any): + c = last_config(**injection_kwargs) + assert c.aux == "INFO" + # no args, no kwargs + assert get_orig_args(**injection_kwargs) == ((), {}) -@pytest.mark.skip("not implemented") -def test_resolved_spec_in_kwargs_pass_through() -> None: + f_kw_spec() + + +def test_resolved_spec_in_kwargs_pass_through(environment: Any) -> None: # if last_config is in kwargs then use it and do not resolve it anew - pass + @configspec + class AuxTest(BaseConfiguration): + aux: str = "INFO" + + @with_config(spec=AuxTest) + def init_cf(aux: str = dlt.config.value, **injection_kwargs: Any): + assert aux == "DEBUG" + return last_config(**injection_kwargs) + + environment["AUX"] = "DEBUG" + c = init_cf() + + @with_config(spec=AuxTest) + def get_cf(aux: str = dlt.config.value, last_config: AuxTest = None): + assert aux == "DEBUG" + assert last_config.aux == "DEBUG" + return last_config + + # this will be ignored, last_config is regarded as resolved + environment["AUX"] = "ERROR" + assert get_cf(last_config=c) is c -@pytest.mark.skip("not implemented") def test_inject_spec_into_argument_with_spec_type() -> None: # if signature contains argument with type of SPEC, it gets injected there - pass + from dlt.destinations.impl.dummy import _configure, DummyClientConfiguration + # _configure has argument of type DummyClientConfiguration that it returns + # this type holds resolved configuration + c = _configure() + assert isinstance(c, DummyClientConfiguration) -@pytest.mark.skip("not implemented") -def test_initial_spec_from_arg_with_spec_type() -> None: + +def test_initial_spec_from_arg_with_spec_type(environment: Any) -> None: # if signature contains argument with type of SPEC, get its value to init SPEC (instead of calling the constructor()) - pass + @configspec + class AuxTest(BaseConfiguration): + level: int = None + aux: str = "INFO" + + @with_config(spec=AuxTest) + def init_cf( + level: int = dlt.config.value, aux: str = dlt.config.value, init_cf: AuxTest = None + ): + assert level == -1 + assert aux == "DEBUG" + # init_cf was used as init but also got resolved + assert init_cf.aux == "DEBUG" + return init_cf + + init_c = AuxTest(level=-1) + environment["AUX"] = "DEBUG" + assert init_cf(init_cf=init_c) is init_c def test_use_most_specific_union_type( From 1ce128f3e46ab52794d62f2de555e48b1f07bcf0 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:50:48 +0200 Subject: [PATCH 10/15] gives precedence to apply_hints when setting incremental, fixes resolve and merge configs, detect EMPTY incremental on resolve --- dlt/extract/incremental/__init__.py | 61 +++++--- tests/extract/test_incremental.py | 222 +++++++++++++++++++++++----- 2 files changed, 226 insertions(+), 57 deletions(-) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index b6ecd2d3db..c30466e9bd 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -6,8 +6,6 @@ import inspect from functools import wraps - -import dlt from dlt.common import logger from dlt.common.exceptions import MissingDependencyException from dlt.common.pendulum import pendulum @@ -111,7 +109,7 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa def __init__( self, - cursor_path: str = dlt.config.value, + cursor_path: str = None, initial_value: Optional[TCursorValue] = None, last_value_func: Optional[LastValueFunc[TCursorValue]] = max, primary_key: Optional[TTableHintTemplate[TColumnNames]] = None, @@ -254,14 +252,22 @@ def on_resolved(self) -> None: def parse_native_representation(self, native_value: Any) -> None: if isinstance(native_value, Incremental): - self.cursor_path = native_value.cursor_path - self.initial_value = native_value.initial_value - self.last_value_func = native_value.last_value_func - self.end_value = native_value.end_value - self.resource_name = native_value.resource_name - self._primary_key = native_value._primary_key - self.allow_external_schedulers = native_value.allow_external_schedulers - self.row_order = native_value.row_order + if self is self.EMPTY: + raise ValueError("Trying to resolve EMPTY Incremental") + if native_value is self.EMPTY: + raise ValueError( + "Do not use EMPTY Incremental as default or explicit values. Pass None to reset" + " an incremental." + ) + merged = self.merge(native_value) + self.cursor_path = merged.cursor_path + self.initial_value = merged.initial_value + self.last_value_func = merged.last_value_func + self.end_value = merged.end_value + self.resource_name = merged.resource_name + self._primary_key = merged._primary_key + self.allow_external_schedulers = merged.allow_external_schedulers + self.row_order = merged.row_order else: # TODO: Maybe check if callable(getattr(native_value, '__lt__', None)) # Passing bare value `incremental=44` gets parsed as initial_value self.initial_value = native_value @@ -440,7 +446,7 @@ def can_close(self) -> bool: def __str__(self) -> str: return ( - f"Incremental at {id(self)} for resource {self.resource_name} with cursor path:" + f"Incremental at 0x{id(self):x} for resource {self.resource_name} with cursor path:" f" {self.cursor_path} initial {self.initial_value} - {self.end_value} lv_func" f" {self.last_value_func}" ) @@ -490,6 +496,8 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: class IncrementalResourceWrapper(ItemTransform[TDataItem]): _incremental: Optional[Incremental[Any]] = None """Keeps the injectable incremental""" + _from_hints: bool = False + """If True, incremental was set explicitly from_hints""" _resource_name: str = None def __init__(self, primary_key: Optional[TTableHintTemplate[TColumnNames]] = None) -> None: @@ -539,8 +547,10 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: if p.name in bound_args.arguments: explicit_value = bound_args.arguments[p.name] if explicit_value is Incremental.EMPTY or p.default is Incremental.EMPTY: - # drop incremental - pass + raise ValueError( + "Do not use EMPTY Incremental as default or explicit values. Pass None to" + " reset an incremental." + ) elif isinstance(explicit_value, Incremental): # Explicit Incremental instance is merged with default # allowing e.g. to only update initial_value/end_value but keeping default cursor_path @@ -573,14 +583,9 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: new_incremental.__orig_class__ = p.annotation # type: ignore # set the incremental only if not yet set or if it was passed explicitly - # NOTE: if new incremental is resolved, it was passed via config injection # NOTE: the _incremental may be also set by applying hints to the resource see `set_template` in `DltResource` - if ( - new_incremental - and p.name in bound_args.arguments - and not new_incremental.is_resolved() - ) or not self._incremental: - self._incremental = new_incremental + if (new_incremental and p.name in bound_args.arguments) or not self._incremental: + self.set_incremental(new_incremental) if not self._incremental.is_resolved(): self._incremental.resolve() # in case of transformers the bind will be called before this wrapper is set: because transformer is called for a first time late in the pipe @@ -593,6 +598,20 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore + @property + def incremental(self) -> Optional[Incremental[Any]]: + return self._incremental + + def set_incremental( + self, incremental: Optional[Incremental[Any]], from_hints: bool = False + ) -> None: + """Sets the incremental. If incremental was set from_hints, it can only be changed in the same manner""" + if self._from_hints and not from_hints: + # do not accept incremental if apply hints were used + return + self._from_hints = from_hints + self._incremental = incremental + @property def allow_external_schedulers(self) -> bool: """Allows the Incremental instance to get its initial and end values from external schedulers like Airflow""" diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 7fb9c39194..69091c2f28 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -12,6 +12,7 @@ import dlt from dlt.common.configuration.container import Container +from dlt.common.configuration.exceptions import InvalidNativeValue from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration from dlt.common.configuration import ConfigurationValueError from dlt.common.pendulum import pendulum, timedelta @@ -332,7 +333,7 @@ def test_optional_arg_from_spec_not_passed(item_type: TestDataItemFormat) -> Non @configspec class SomeDataOverrideConfiguration(BaseConfiguration): - created_at: dlt.sources.incremental = dlt.sources.incremental("created_at", initial_value="2022-02-03T00:00:00Z") # type: ignore[type-arg] + created_at: dlt.sources.incremental = dlt.sources.incremental("updated_at", initial_value="2022-02-03T00:00:00Z") # type: ignore[type-arg] # provide what to inject via spec. the spec contain the default @@ -351,6 +352,7 @@ def some_data_override_config( def test_override_initial_value_from_config(item_type: TestDataItemFormat) -> None: # use the shortest possible config version # os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_OVERRIDE_CONFIG__CREATED_AT__INITIAL_VALUE'] = '2000-02-03T00:00:00Z' + os.environ["CREATED_AT__CURSOR_PATH"] = "created_at" os.environ["CREATED_AT__INITIAL_VALUE"] = "2000-02-03T00:00:00Z" p = dlt.pipeline(pipeline_name=uniq_id()) @@ -598,7 +600,7 @@ def some_data(last_timestamp=dlt.sources.incremental("item.timestamp|modifiedAt" def test_remove_incremental_with_explicit_none() -> None: @dlt.resource(standalone=True) def some_data( - last_timestamp: dlt.sources.incremental[float] = dlt.sources.incremental( + last_timestamp: Optional[dlt.sources.incremental[float]] = dlt.sources.incremental( "id", initial_value=9 ), ): @@ -623,9 +625,9 @@ def some_data_optional( assert last_timestamp is None yield 1 - # we disable incremental by typing the argument as optional - # if not disabled it would fail on "item.timestamp" not found - assert list(some_data_optional(last_timestamp=dlt.sources.incremental.EMPTY)) == [1] + # can't use EMPTY to reset incremental + with pytest.raises(ValueError): + list(some_data_optional(last_timestamp=dlt.sources.incremental.EMPTY)) @dlt.resource(standalone=True) def some_data( @@ -635,8 +637,8 @@ def some_data( yield 1 # we'll get the value error - with pytest.raises(ValueError): - assert list(some_data(last_timestamp=dlt.sources.incremental.EMPTY)) == [1] + with pytest.raises(InvalidNativeValue): + list(some_data(last_timestamp=dlt.sources.incremental.EMPTY)) @pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) @@ -860,20 +862,23 @@ def some_data(last_timestamp=dlt.sources.incremental("ts", primary_key=())): @pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) def test_apply_hints_incremental(item_type: TestDataItemFormat) -> None: - p = dlt.pipeline(pipeline_name=uniq_id()) + p = dlt.pipeline(pipeline_name=uniq_id(), destination="dummy") data = [{"created_at": 1}, {"created_at": 2}, {"created_at": 3}] source_items = data_to_item_format(item_type, data) @dlt.resource def some_data(created_at: Optional[dlt.sources.incremental[int]] = None): # make sure that incremental from apply_hints is here - assert created_at is not None - assert created_at.last_value_func is max + if created_at is not None: + assert created_at.cursor_path == "created_at" + assert created_at.last_value_func is max yield source_items # the incremental wrapper is created for a resource and the incremental value is provided via apply hints r = some_data() assert r is not some_data + assert r.incremental is not None + assert r.incremental.incremental is None r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) if item_type == "pandas": assert list(r)[0].equals(source_items[0]) @@ -881,6 +886,9 @@ def some_data(created_at: Optional[dlt.sources.incremental[int]] = None): assert list(r) == source_items p.extract(r) assert "incremental" in r.state + assert r.incremental.incremental is not None + assert len(r._pipe) == 2 + # no more elements assert list(r) == [] # same thing with explicit None @@ -894,11 +902,20 @@ def some_data(created_at: Optional[dlt.sources.incremental[int]] = None): assert "incremental" in r.state assert list(r) == [] + # remove incremental + r.apply_hints(incremental=dlt.sources.incremental.EMPTY) + assert r.incremental is not None + assert r.incremental.incremental is None + if item_type == "pandas": + assert list(r)[0].equals(source_items[0]) + else: + assert list(r) == source_items + # as above but we provide explicit incremental when creating resource p = p.drop() - r = some_data(created_at=dlt.sources.incremental("created_at", last_value_func=max)) - # explicit has precedence here and hints will be ignored - r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=min)) + r = some_data(created_at=dlt.sources.incremental("created_at", last_value_func=min)) + # hints have precedence, as expected + r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) p.extract(r) assert "incremental" in r.state # max value @@ -927,10 +944,128 @@ def some_data_no_incremental(): # we add incremental as a step p = p.drop() r = some_data_no_incremental() - r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) - assert r.incremental is not None + print(r._pipe) + incr_instance = dlt.sources.incremental("created_at", last_value_func=max) + r.apply_hints(incremental=incr_instance) + print(r._pipe) + assert r.incremental is incr_instance p.extract(r) assert "incremental" in r.state + info = p.normalize() + assert info.row_counts["some_data_no_incremental"] == 3 + # make sure we can override incremental + incr_instance = dlt.sources.incremental("created_at", last_value_func=max, row_order="desc") + r.apply_hints(incremental=incr_instance) + assert r.incremental is incr_instance + p.extract(r) + info = p.normalize() + assert "some_data_no_incremental" not in info.row_counts + # we switch last value func to min + incr_instance = dlt.sources.incremental( + "created_at", last_value_func=min, row_order="desc", primary_key=() + ) + r.apply_hints(incremental=incr_instance) + assert r.incremental is incr_instance + p.extract(r) + info = p.normalize() + # we have three elements due to min function (equal element NOT is eliminated due to primary_key==()) + assert info.row_counts["some_data_no_incremental"] == 3 + + # remove incremental + r.apply_hints(incremental=dlt.sources.incremental.EMPTY) + assert r.incremental is None + + +def test_incremental_wrapper_on_clone_standalone_incremental() -> None: + @dlt.resource(standalone=True) + def standalone_incremental(created_at: Optional[dlt.sources.incremental[int]] = None): + yield [{"created_at": 1}, {"created_at": 2}, {"created_at": 3}] + + s_r_1 = standalone_incremental() + s_r_i_1 = dlt.sources.incremental[int]("created_at") + s_r_2 = standalone_incremental() + s_r_i_2 = dlt.sources.incremental[int]("created_at", initial_value=3) + s_r_i_3 = dlt.sources.incremental[int]("created_at", initial_value=1, last_value_func=min) + s_r_3 = standalone_incremental(created_at=s_r_i_3) + + # different wrappers + assert s_r_1.incremental is not s_r_2.incremental + s_r_1.apply_hints(incremental=s_r_i_1) + s_r_2.apply_hints(incremental=s_r_i_2) + assert s_r_1.incremental.incremental is s_r_i_1 + assert s_r_2.incremental.incremental is s_r_i_2 + + # evaluate s r 3 + assert list(s_r_3) == [{"created_at": 1}] + # incremental is set after evaluation but the instance is different (wrapper is merging instances) + assert s_r_3.incremental.incremental.last_value_func is min + + # standalone resources are bound so clone does not re-wrap + s_r_3_clone = s_r_3._clone() + assert s_r_3_clone.incremental is s_r_3.incremental + assert s_r_3_clone.incremental.incremental is s_r_3.incremental.incremental + + # evaluate others + assert len(list(s_r_1)) == 3 + assert len(list(s_r_2)) == 1 + + +def test_incremental_wrapper_on_clone_standalone_no_incremental() -> None: + @dlt.resource(standalone=True) + def standalone(): + yield [{"created_at": 1}, {"created_at": 2}, {"created_at": 3}] + + s_r_1 = standalone() + s_r_i_1 = dlt.sources.incremental[int]("created_at", row_order="desc") + s_r_2 = standalone() + s_r_i_2 = dlt.sources.incremental[int]("created_at", initial_value=3) + + # clone keeps the incremental step + s_r_1.apply_hints(incremental=s_r_i_1) + assert s_r_1.incremental is s_r_i_1 + + s_r_1_clone = s_r_1._clone() + assert s_r_1_clone.incremental is s_r_i_1 + + assert len(list(s_r_1)) == 3 + s_r_2.apply_hints(incremental=s_r_i_2) + assert len(list(s_r_2)) == 1 + + +def test_incremental_wrapper_on_clone_incremental() -> None: + @dlt.resource + def regular_incremental(created_at: Optional[dlt.sources.incremental[int]] = None): + yield [{"created_at": 1}, {"created_at": 2}, {"created_at": 3}] + + assert regular_incremental.incremental is not None + assert regular_incremental.incremental.incremental is None + + # separate incremental + r_1 = regular_incremental() + assert r_1.args_bound is True + r_2 = regular_incremental.with_name("cloned_regular") + assert r_1.incremental is not None + assert r_2.incremental is not None + assert r_1.incremental is not r_2.incremental is not regular_incremental.incremental + + # evaluate and compare incrementals + assert len(list(r_1)) == 3 + assert len(list(r_2)) == 3 + assert r_1.incremental.incremental is None + assert r_2.incremental.incremental is None + + # now bind some real incrementals + r_3 = regular_incremental(dlt.sources.incremental[int]("created_at", initial_value=3)) + r_4 = regular_incremental( + dlt.sources.incremental[int]("created_at", initial_value=1, last_value_func=min) + ) + r_4_clone = r_4._clone("r_4_clone") + # evaluate + assert len(list(r_3)) == 1 + assert len(list(r_4)) == 1 + assert r_3.incremental.incremental is not r_4.incremental.incremental + # now the clone should share the incremental because it was done after parameters were bound + assert r_4_clone.incremental is r_4.incremental def test_last_value_func_on_dict() -> None: @@ -991,6 +1126,7 @@ def some_data( max_hours: int = 2, tz: str = None, ): + print("some_data", updated_at, dict(updated_at)) data = [ {"updated_at": start_dt + timedelta(hours=hour), "hour": hour} for hour in range(1, max_hours + 1) @@ -1034,6 +1170,8 @@ def some_data( "updated_at", initial_value=pendulum_start_dt, end_value=pendulum_start_dt.add(hours=3) ) ) + print(resource.incremental.incremental, dict(resource.incremental.incremental)) + pipeline = pipeline.drop() extract_info = pipeline.extract(resource) assert ( extract_info.metrics[extract_info.loads_ids[0]][0]["resource_metrics"][ @@ -1621,7 +1759,7 @@ def test_type( r = test_type() list(r) - assert r.incremental._incremental.get_incremental_value_type() is str + assert r.incremental.incremental.get_incremental_value_type() is str # use annotation @dlt.resource @@ -1635,7 +1773,7 @@ def test_type_2( r = test_type_2() list(r) - assert r.incremental._incremental.get_incremental_value_type() is int + assert r.incremental.incremental.get_incremental_value_type() is int # pass in explicit value @dlt.resource @@ -1645,7 +1783,7 @@ def test_type_3(updated_at: dlt.sources.incremental[int]): r = test_type_3(dlt.sources.incremental[float]("updated_at", allow_external_schedulers=True)) list(r) - assert r.incremental._incremental.get_incremental_value_type() is float + assert r.incremental.incremental.get_incremental_value_type() is float # pass explicit value overriding default that is typed @dlt.resource @@ -1657,7 +1795,7 @@ def test_type_4( r = test_type_4(dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True)) list(r) - assert r.incremental._incremental.get_incremental_value_type() is str + assert r.incremental.incremental.get_incremental_value_type() is str # no generic type information @dlt.resource @@ -1669,7 +1807,7 @@ def test_type_5( r = test_type_5(dlt.sources.incremental("updated_at")) list(r) - assert r.incremental._incremental.get_incremental_value_type() is Any + assert r.incremental.incremental.get_incremental_value_type() is Any @pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) @@ -1743,27 +1881,39 @@ def test_type_2( # does not participate os.environ["DLT_START_VALUE"] = "2" - result = data_item_to_list(item_type, list(test_type_2())) - assert len(result) == 3 - - assert test_type_2.incremental.allow_external_schedulers is False - assert test_type_2().incremental.allow_external_schedulers is False - - # allow scheduler in wrapper - r = test_type_2() - r.incremental.allow_external_schedulers = True - result = data_item_to_list(item_type, list(test_type_2())) - assert len(result) == 2 - assert r.incremental.allow_external_schedulers is True - assert r.incremental._incremental.allow_external_schedulers is True + # r = test_type_2() + # result = data_item_to_list(item_type, list(r)) + # assert len(result) == 3 + + # # incremental not bound to the wrapper + # assert test_type_2.incremental.allow_external_schedulers is None + # assert test_type_2().incremental.allow_external_schedulers is None + # # this one is bound + # assert r.incremental.allow_external_schedulers is False + + # # allow scheduler in wrapper + # r = test_type_2() + # r.incremental.allow_external_schedulers = True + # result = data_item_to_list(item_type, list(r)) + # assert len(result) == 2 + # assert r.incremental.allow_external_schedulers is True + # assert r.incremental.incremental.allow_external_schedulers is True # add incremental dynamically @dlt.resource() def test_type_3(): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) r = test_type_3() - r.add_step(dlt.sources.incremental("updated_at")) + r.add_step(dlt.sources.incremental[int]("updated_at")) r.incremental.allow_external_schedulers = True - result = data_item_to_list(item_type, list(test_type_2())) + result = data_item_to_list(item_type, list(r)) assert len(result) == 2 + + # if type of incremental cannot be inferred, external scheduler will be ignored + r = test_type_3() + r.add_step(dlt.sources.incremental("updated_at")) + r.incremental.allow_external_schedulers = True + result = data_item_to_list(item_type, list(r)) + assert len(result) == 3 From 16c1ec100daa117175c13c116001a8b507dcccb9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:52:39 +0200 Subject: [PATCH 11/15] moves wrapping resources in config and incremental wrappers from decorator to resources, rewraps resource on clone to separate sections and incremental instances --- dlt/extract/decorators.py | 108 +++++++++---------------- dlt/extract/resource.py | 132 ++++++++++++++++++++++++++----- tests/extract/test_decorators.py | 45 +++++++++-- tests/extract/test_sources.py | 10 +-- 4 files changed, 195 insertions(+), 100 deletions(-) diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 9c4076cfa7..a7246b6832 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -23,11 +23,13 @@ from dlt.common.configuration import with_config, get_fun_spec, known_sections, configspec from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated +from dlt.common.configuration.inject import set_fun_spec from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.exceptions import ArgumentsOverloadException from dlt.common.pipeline import PipelineContext +from dlt.common.reflection.spec import spec_from_signature from dlt.common.schema.utils import DEFAULT_WRITE_DISPOSITION from dlt.common.source import _SOURCES, SourceInfo from dlt.common.schema.schema import Schema @@ -453,9 +455,7 @@ def resource( TDltResourceImpl instance which may be loaded, iterated or combined with other resources into a pipeline. """ - def make_resource( - _name: str, _section: str, _data: Any, incremental: IncrementalResourceWrapper = None - ) -> TDltResourceImpl: + def make_resource(_name: str, _section: str, _data: Any) -> TDltResourceImpl: table_template = make_hints( table_name, write_disposition=write_disposition or DEFAULT_WRITE_DISPOSITION, @@ -466,14 +466,6 @@ def make_resource( table_format=table_format, ) - # If custom nesting level was specified then - # we need to add it to table hints so that - # later in normalizer dlt/common/normalizers/json/relational.py - # we can override max_nesting level for the given table - if max_table_nesting is not None: - table_template.setdefault("x-normalizer", {}) # type: ignore[typeddict-item] - table_template["x-normalizer"]["max_nesting"] = max_table_nesting # type: ignore[typeddict-item] - resource = _impl_cls.from_data( _data, _name, @@ -481,8 +473,14 @@ def make_resource( table_template, selected, cast(DltResource, data_from), - incremental=incremental, + True, ) + # If custom nesting level was specified then + # we need to add it to table hints so that + # later in normalizer dlt/common/normalizers/json/relational.py + # we can override max_nesting level for the given table + if max_table_nesting is not None: + resource.max_table_nesting = max_table_nesting if parallelized: return resource.parallelize() return resource @@ -502,82 +500,52 @@ def decorator( resource_name = name if name and not callable(name) else get_callable_name(f) - # do not inject config values for inner functions, we assume that they are part of the source - SPEC: Type[BaseConfiguration] = None - # wrap source extraction function in configuration with section func_module = inspect.getmodule(f) source_section = _get_source_section_name(func_module) - - incremental: IncrementalResourceWrapper = None - sig = inspect.signature(f) - if IncrementalResourceWrapper.should_wrap(sig): - incremental = IncrementalResourceWrapper(primary_key) - incr_f = incremental.wrap(sig, f) if incremental else f - - resource_sections = (known_sections.SOURCES, source_section, resource_name) - - # standalone resource will prefer existing section context when resolving config values - # this lets the source to override those values and provide common section for all config values for resources present in that source - # for autogenerated spec do not include defaults - # NOTE: allow full config for standalone, currently some edge cases for incremental does not work - # (removing it via apply hints or explicit call) - conf_f = with_config( - incr_f, - spec=spec, - sections=resource_sections, - sections_merge_style=ConfigSectionContext.resource_merge_style, - include_defaults=spec is not None, # or standalone, - ) is_inner_resource = is_inner_callable(f) - if conf_f != incr_f and is_inner_resource and not standalone: - raise ResourceInnerCallableConfigWrapDisallowed(resource_name, source_section) - # get spec for wrapped function - SPEC = get_fun_spec(conf_f) - # store the standalone resource information + if spec is None: + # autodetect spec + SPEC, resolvable_fields = spec_from_signature( + f, inspect.signature(f), include_defaults=standalone + ) + print(SPEC, resolvable_fields, standalone) + if is_inner_resource and not standalone: + if len(resolvable_fields) > 0: + # prevent required arguments to inner functions that are not standalone + raise ResourceInnerCallableConfigWrapDisallowed(resource_name, source_section) + else: + # empty spec for inner functions - they should not be injected + SPEC = BaseConfiguration + else: + SPEC = spec + # assign spec to "f" + set_fun_spec(f, SPEC) + + # store the non-inner resource information if not is_inner_resource: _SOURCES[f.__qualname__] = SourceInfo(SPEC, f, func_module) if not standalone: # we return a DltResource that is callable and returns dlt resource when called # so it should match the signature - return make_resource(resource_name, source_section, conf_f, incremental) # type: ignore[return-value] - - # wrap the standalone resource - if data_from: - compat_wrapper, skip_args = wrap_compat_transformer, 1 - else: - compat_wrapper, skip_args = wrap_resource_gen, 0 + return make_resource(resource_name, source_section, f) # type: ignore[return-value] - @wraps(incr_f) + @wraps(f) def _wrap(*args: Any, **kwargs: Any) -> TDltResourceImpl: - _, mod_sig, bound_args = simulate_func_call(incr_f, skip_args, *args, **kwargs) + skip_args = 1 if data_from else 0 + _, mod_sig, bound_args = simulate_func_call(f, skip_args, *args, **kwargs) actual_resource_name = name(bound_args.arguments) if callable(name) else resource_name - # wrap again with an actual resource name - conf_f = with_config( - incr_f, - spec=SPEC, - sections=resource_sections[:-1] + (actual_resource_name,), - sections_merge_style=ConfigSectionContext.resource_merge_style, - ) - try: - r = make_resource( - actual_resource_name, - source_section, - compat_wrapper(actual_resource_name, conf_f, sig, *args, **kwargs), - incremental, - ) - except InvalidResourceDataTypeFunctionNotAGenerator: + r = make_resource(actual_resource_name, source_section, f) + # wrap the standalone resource + data_ = r._pipe.bind_gen(*args, **kwargs) + if isinstance(data_, DltResource): # we allow an edge case: resource can return another resource - # actually call the function to see if it contains DltResource - data_ = conf_f(*args, **kwargs) - if not isinstance(data_, DltResource): - raise r = data_ # type: ignore[assignment] # consider transformer arguments bound r._args_bound = True # keep explicit args passed - r._set_explicit_args(conf_f, mod_sig, *args, **kwargs) + r._set_explicit_args(f, mod_sig, *args, **kwargs) return r return _wrap diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index a64d5070b8..eecb570375 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -14,6 +14,7 @@ ) from typing_extensions import TypeVar, Self +from dlt.common.configuration.inject import get_fun_spec, with_config from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.config_section_context import ConfigSectionContext @@ -94,7 +95,6 @@ def __init__( pipe: Pipe, hints: TResourceHints, selected: bool, - incremental: IncrementalResourceWrapper = None, section: str = None, args_bound: bool = False, ) -> None: @@ -103,8 +103,6 @@ def __init__( self._pipe = pipe self._args_bound = args_bound self._explicit_args: DictStrAny = None - if incremental and not self.incremental: - self.add_step(incremental) self.source_name = None super().__init__(hints) @@ -117,8 +115,16 @@ def from_data( hints: TResourceHints = None, selected: bool = True, data_from: Union["DltResource", Pipe] = None, - incremental: IncrementalResourceWrapper = None, + inject_config: bool = False, ) -> Self: + """Creates an instance of DltResource from compatible `data` with a given `name` and `section`. + + Internally (in the most common case) a new instance of Pipe with `name` is created from `data` and + optionally connected to an existing pipe `from_data` to form a transformer (dependent resource). + + If `inject_config` is set to True and data is a callable, the callable is wrapped in incremental and config + injection wrappers. + """ if data is None: raise InvalidResourceDataTypeIsNone(name, data, NoneType) @@ -126,7 +132,10 @@ def from_data( return data # type: ignore[return-value] if isinstance(data, Pipe): - return cls(data, hints, selected, incremental=incremental, section=section) + r_ = cls(data, hints, selected, section=section) + if inject_config: + r_._inject_config() + return r_ if callable(data): name = name or get_callable_name(data) @@ -155,14 +164,16 @@ def from_data( # create resource from iterator, iterable or generator function if isinstance(data, (Iterable, Iterator, AsyncIterable)) or callable(data): pipe = Pipe.from_data(name, data, parent=parent_pipe) - return cls( + r_ = cls( pipe, hints, selected, - incremental=incremental, section=section, args_bound=not callable(data), ) + if inject_config: + r_._inject_config() + return r_ else: # some other data type that is not supported raise InvalidResourceDataType( @@ -226,9 +237,12 @@ def max_table_nesting(self) -> Optional[int]: return max_nesting if isinstance(max_nesting, int) else None @max_table_nesting.setter - def max_table_nesting(self, value: int) -> None: - self._hints.setdefault("x-normalizer", {}) # type: ignore[typeddict-item] - self._hints["x-normalizer"]["max_nesting"] = value # type: ignore[typeddict-item] + def max_table_nesting(self, value: Optional[int]) -> None: + normalizer = self._hints.setdefault("x-normalizer", {}) # type: ignore[typeddict-item] + if value is None: + normalizer.pop("max_nesting", None) + else: + normalizer["max_nesting"] = value def pipe_data_from(self: TDltResourceImpl, data_from: Union[TDltResourceImpl, Pipe]) -> None: """Replaces the parent in the transformer resource pipe from which the data is piped.""" @@ -420,12 +434,26 @@ def _set_hints( incremental = self.incremental # try to late assign incremental if table_schema_template.get("incremental") is not None: - if incremental: - incremental._incremental = table_schema_template["incremental"] - else: + new_incremental = table_schema_template["incremental"] + # remove incremental if empty + if new_incremental is Incremental.EMPTY: + new_incremental = None + + if incremental is not None: + if isinstance(incremental, IncrementalResourceWrapper): + # replace in wrapper + incremental.set_incremental(new_incremental, from_hints=True) + else: + step_no = self._pipe.find(Incremental) + self._pipe.remove_step(step_no) + # re-add the step + incremental = None + + if incremental is None: # if there's no wrapper add incremental as a transform - incremental = table_schema_template["incremental"] # type: ignore - self.add_step(incremental) + incremental = new_incremental # type: ignore + if new_incremental: + self.add_step(new_incremental) if incremental: primary_key = table_schema_template.get("primary_key", incremental.primary_key) @@ -461,6 +489,14 @@ def bind(self: TDltResourceImpl, *args: Any, **kwargs: Any) -> TDltResourceImpl: self._set_explicit_args(orig_gen, None, *args, **kwargs) # type: ignore return self + @property + def args_bound(self) -> bool: + """Returns true if resource the parameters are bound to values. Such resource cannot be further called. + Note that resources are lazily evaluated and arguments are only formally checked. Configuration + was not yet injected as well. + """ + return self._args_bound + @property def explicit_args(self) -> StrAny: """Returns a dictionary of arguments used to parametrize the resource. Does not include defaults and injected args.""" @@ -535,20 +571,80 @@ def _set_explicit_args( except Exception: pass + def _eject_config(self) -> bool: + """Unwraps the pipe generator step from config injection and incremental wrappers by restoring the original step. + + Removes the step with incremental wrapper. Should be used before a subsequent _inject_config is called on the + same pipe to successfully wrap it with new incremental and config injection. + Note that resources with bound arguments cannot be ejected. + + """ + if not self._pipe.is_empty and not self._args_bound: + orig_gen = getattr(self._pipe.gen, "__GEN__", None) + if orig_gen: + step_no = self._pipe.find(IncrementalResourceWrapper) + if step_no >= 0: + self._pipe.remove_step(step_no) + self._pipe.replace_gen(orig_gen) + return True + return False + + def _inject_config(self) -> "DltResource": + """Wraps the pipe generation step in incremental and config injection wrappers and adds pipe step with + Incremental transform. + """ + gen = self._pipe.gen + if not callable(gen): + return self + + incremental: IncrementalResourceWrapper = None + sig = inspect.signature(gen) + if IncrementalResourceWrapper.should_wrap(sig): + incremental = IncrementalResourceWrapper(self._hints.get("primary_key")) + incr_f = incremental.wrap(sig, gen) + self.add_step(incremental) + else: + incr_f = gen + resource_sections = (known_sections.SOURCES, self.section, self.name) + # function should have associated SPEC + spec = get_fun_spec(gen) + # standalone resource will prefer existing section context when resolving config values + # this lets the source to override those values and provide common section for all config values for resources present in that source + # for autogenerated spec do not include defaults + conf_f = with_config( + incr_f, + spec=spec, + sections=resource_sections, + sections_merge_style=ConfigSectionContext.resource_merge_style, + ) + if conf_f != gen: + self._pipe.replace_gen(conf_f) + # storage the original generator to be able to eject config and incremental wrapper + # when resource is cloned + setattr(conf_f, "__GEN__", gen) # noqa: B010 + return self + def _clone( self: TDltResourceImpl, new_name: str = None, with_parent: bool = False ) -> TDltResourceImpl: - """Creates a deep copy of a current resource, optionally renaming the resource. The clone will not be part of the source""" + """Creates a deep copy of a current resource, optionally renaming the resource. The clone will not be part of the source.""" pipe = self._pipe if self._pipe and not self._pipe.is_empty: pipe = pipe._clone(new_name=new_name, with_parent=with_parent) # incremental and parent are already in the pipe (if any) - return self.__class__( + r_ = self.__class__( pipe, - deepcopy(self._hints), + self._clone_hints(self._hints), selected=self.selected, section=self.section, + args_bound=self._args_bound, ) + # try to eject and then inject configuration and incremental wrapper when resource is cloned + # this makes sure that a take config values from a right section and wrapper has a separated + # instance in the pipeline + if r_._eject_config(): + r_._inject_config() + return r_ def _get_config_section_context(self) -> ConfigSectionContext: container = Container() diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index c6a675a8d3..1cf14abe55 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -57,8 +57,10 @@ def resource(): "resource": "resource", "write_disposition": "append", } - assert resource().name == "resource" assert resource._args_bound is False + assert resource.name == "resource" + assert resource().args_bound is True + assert resource().name == "resource" assert resource.incremental is None assert resource.write_disposition == "append" @@ -630,6 +632,18 @@ def inner_resource(initial_id=dlt.config.value): assert "secret" in fields assert "config" in fields + @dlt.resource(standalone=True) + def inner_standalone_resource( + secret=dlt.secrets.value, config=dlt.config.value, opt: str = "A" + ): + yield 1 + + SPEC = get_fun_spec(inner_standalone_resource("TS", "CFG")._pipe.gen) # type: ignore[arg-type] + fields = SPEC.get_resolvable_fields() + # resources marked as standalone always inject full signature + assert len(fields) == 3 + assert {"secret", "config", "opt"} == set(fields.keys()) + @dlt.source def inner_source(secret=dlt.secrets.value, config=dlt.config.value, opt: str = "A"): return standalone_resource @@ -734,6 +748,11 @@ def standalone_signature(init: int, secret_end: int = dlt.secrets.value): yield from range(init, secret_end) +@dlt.resource +def regular_signature(init: int, secret_end: int = dlt.secrets.value): + yield from range(init, secret_end) + + def test_standalone_resource() -> None: # wrapped flag will not create the resource but just simple function wrapper that must be called before use @dlt.resource(standalone=True) @@ -746,7 +765,7 @@ def nice_signature(init: int): assert nice_signature.__doc__ == """Has nice signature""" assert list(nice_signature(7)) == [7, 8, 9] - assert nice_signature(8)._args_bound is True + assert nice_signature(8).args_bound is True with pytest.raises(TypeError): # bound! nice_signature(7)() @@ -800,7 +819,7 @@ def test_standalone_transformer() -> None: bound_tx = standalone_transformer(5, 10) # this is not really true - assert bound_tx._args_bound is True + assert bound_tx.args_bound is True with pytest.raises(TypeError): bound_tx(1) assert isinstance(bound_tx, DltResource) @@ -891,16 +910,28 @@ def rv_resource(uniq_name: str = dlt.config.value): assert conf_ex.value.fields == ["uniq_name"] -def test_resource_rename_credentials_separation(): +def test_standalone_resource_rename_credentials_separation(): os.environ["SOURCES__TEST_DECORATORS__STANDALONE_SIGNATURE__SECRET_END"] = "5" assert list(standalone_signature(1)) == [1, 2, 3, 4] - # config section is not impacted by the rename - # NOTE: probably we should keep it like that - os.environ["SOURCES__TEST_DECORATORS__RENAMED_SIG__SECRET_END"] = "6" + # os.environ["SOURCES__TEST_DECORATORS__RENAMED_SIG__SECRET_END"] = "6" + # assert list(standalone_signature.with_name("renamed_sig")(1)) == [1, 2, 3, 4, 5] + + # bound resource will not allow for reconfig assert list(standalone_signature(1).with_name("renamed_sig")) == [1, 2, 3, 4] +def test_resource_rename_credentials_separation(): + os.environ["SOURCES__TEST_DECORATORS__REGULAR_SIGNATURE__SECRET_END"] = "5" + assert list(regular_signature(1)) == [1, 2, 3, 4] + + os.environ["SOURCES__TEST_DECORATORS__RENAMED_SIG__SECRET_END"] = "6" + assert list(regular_signature.with_name("renamed_sig")(1)) == [1, 2, 3, 4, 5] + + # bound resource will not allow for reconfig + assert list(regular_signature(1).with_name("renamed_sig")) == [1, 2, 3, 4] + + def test_class_source() -> None: class _Source: def __init__(self, elems: int) -> None: diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index 308b65bd37..39e3264aff 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -677,8 +677,8 @@ def test_illegal_double_bind() -> None: def _r1(): yield ["a", "b", "c"] - assert _r1._args_bound is False - assert _r1()._args_bound is True + assert _r1.args_bound is False + assert _r1().args_bound is True with pytest.raises(TypeError) as py_ex: _r1()() @@ -689,14 +689,14 @@ def _r1(): assert "Parametrized resource" in str(py_ex.value) bound_r = dlt.resource([1, 2, 3], name="rx") - assert bound_r._args_bound is True + assert bound_r.args_bound is True with pytest.raises(TypeError): _r1() def _gen(): yield from [1, 2, 3] - assert dlt.resource(_gen())._args_bound is True + assert dlt.resource(_gen()).args_bound is True @dlt.resource @@ -1292,7 +1292,7 @@ def empty_gen(): ) assert empty_r._hints == { "columns": {}, - "incremental": None, + "incremental": Incremental.EMPTY, "validator": None, "write_disposition": "append", "original_columns": {}, From 9a0f8f181a27272f4f600ac7d3d608aa4c1f8123 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:56:10 +0200 Subject: [PATCH 12/15] fixes mssql examples in sql_database docs --- .../docs/dlt-ecosystem/verified-sources/sql_database.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md index 4a80de1bdf..de3e5f4c35 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md @@ -411,22 +411,22 @@ Here, we use the `mysql` and `pymysql` dialects to set up an SSL connection to a **To connect to an `mssql` server using Windows authentication**, include `trusted_connection=yes` in the connection string. ```toml -destination.mssql.credentials="mssql+pyodbc://loader.database.windows.net/dlt_data?trusted_connection=yes&driver=ODBC+Driver+17+for+SQL+Server" +sources.sql_database.credentials="mssql+pyodbc://loader.database.windows.net/dlt_data?trusted_connection=yes&driver=ODBC+Driver+17+for+SQL+Server" ``` **To connect to a local sql server instance running without SSL** pass `encrypt=no` parameter: ```toml -destination.mssql.credentials="mssql+pyodbc://loader:loader@localhost/dlt_data?encrypt=no&driver=ODBC+Driver+17+for+SQL+Server" +sources.sql_database.credentials="mssql+pyodbc://loader:loader@localhost/dlt_data?encrypt=no&driver=ODBC+Driver+17+for+SQL+Server" ``` **To allow self signed SSL certificate** when you are getting `certificate verify failed:unable to get local issuer certificate`: ```toml -destination.mssql.credentials="mssql+pyodbc://loader:loader@localhost/dlt_data?TrustServerCertificate=yes&driver=ODBC+Driver+17+for+SQL+Server" +sources.sql_database.credentials="mssql+pyodbc://loader:loader@localhost/dlt_data?TrustServerCertificate=yes&driver=ODBC+Driver+17+for+SQL+Server" ``` ***To use long strings (>8k) and avoid collation errors**: ```toml -destination.mssql.credentials="mssql+pyodbc://loader:loader@localhost/dlt_data?LongAsMax=yes&driver=ODBC+Driver+17+for+SQL+Server" +sources.sql_database.credentials="mssql+pyodbc://loader:loader@localhost/dlt_data?LongAsMax=yes&driver=ODBC+Driver+17+for+SQL+Server" ``` ## Customizations From d6a8fcfd06d6dedf0108798c7047e5239a6711a3 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Jun 2024 22:57:03 +0200 Subject: [PATCH 13/15] allows to use None as explicit value when resolving config, allows to use sentinels to request injected values --- dlt/common/configuration/resolve.py | 132 ++++++++++-------- dlt/common/destination/reference.py | 3 +- dlt/common/storages/live_schema_storage.py | 5 +- .../impl/destination/configuration.py | 2 +- dlt/extract/extract.py | 8 +- dlt/helpers/dbt/__init__.py | 6 +- dlt/helpers/dbt/configuration.py | 2 +- dlt/helpers/dbt/runner.py | 12 +- dlt/pipeline/__init__.py | 25 ++-- dlt/pipeline/dbt.py | 6 +- dlt/pipeline/pipeline.py | 6 +- .../configuration/test_configuration.py | 42 +++++- .../configuration/test_toml_provider.py | 4 +- tests/common/storages/test_schema_storage.py | 14 +- tests/extract/test_extract.py | 4 +- 15 files changed, 155 insertions(+), 116 deletions(-) diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 9101cfdd9c..68634881da 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -5,6 +5,8 @@ from dlt.common.configuration.providers.provider import ConfigProvider from dlt.common.typing import ( AnyType, + ConfigValue, + SecretValue, StrAny, TSecretValue, get_all_types_of_class_in_union, @@ -20,7 +22,7 @@ is_context_inner_hint, is_base_configuration_inner_hint, is_valid_hint, - is_hint_not_resolved, + is_hint_not_resolvable, ) from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.configuration.specs.exceptions import NativeValueError @@ -189,67 +191,81 @@ def _resolve_config_fields( hint = config.__hint_resolvers__[key](config) # get default and explicit values default_value = getattr(config, key, None) + explicit_none = False traces: List[LookupTrace] = [] if explicit_values: - explicit_value = explicit_values.get(key) + explicit_value = None + if key in explicit_values: + # allow None to be passed in explicit values + # so we are able to reset defaults like in regular function calls + explicit_value = explicit_values[key] + explicit_none = explicit_value is None + # detect dlt.config and dlt.secrets and force injection + if explicit_value in (ConfigValue, SecretValue): + explicit_value = None else: - if is_hint_not_resolved(hint): + if is_hint_not_resolvable(hint): # for final fields default value is like explicit explicit_value = default_value else: explicit_value = None - # if hint is union of configurations, any of them must be resolved - specs_in_union: List[Type[BaseConfiguration]] = [] current_value = None - if is_union_type(hint): - # if union contains a type of explicit value which is not a valid hint, return it as current value - if ( - explicit_value - and not is_valid_hint(type(explicit_value)) - and get_all_types_of_class_in_union(hint, type(explicit_value)) - ): - current_value, traces = explicit_value, [] - else: - specs_in_union = get_all_types_of_class_in_union(hint, BaseConfiguration) - if not current_value: - if len(specs_in_union) > 1: - for idx, alt_spec in enumerate(specs_in_union): - # return first resolved config from an union - try: - current_value, traces = _resolve_config_field( - key, - alt_spec, - default_value, - explicit_value, - config, - config.__section__, - explicit_sections, - embedded_sections, - accept_partial, - ) - break - except ConfigFieldMissingException as cfm_ex: - # add traces from unresolved union spec - # TODO: we should group traces per hint - currently user will see all options tried without the key info - traces.extend(list(itertools.chain(*cfm_ex.traces.values()))) - except InvalidNativeValue: - # if none of specs in union parsed - if idx == len(specs_in_union) - 1: - raise - else: - current_value, traces = _resolve_config_field( - key, - hint, - default_value, - explicit_value, - config, - config.__section__, - explicit_sections, - embedded_sections, - accept_partial, - ) + # explicit none skips resolution + if not explicit_none: + # if hint is union of configurations, any of them must be resolved + specs_in_union: List[Type[BaseConfiguration]] = [] + if is_union_type(hint): + # if union contains a type of explicit value which is not a valid hint, return it as current value + if ( + explicit_value + and not is_valid_hint(type(explicit_value)) + and get_all_types_of_class_in_union(hint, type(explicit_value)) + ): + current_value, traces = explicit_value, [] + else: + specs_in_union = get_all_types_of_class_in_union(hint, BaseConfiguration) + if not current_value: + if len(specs_in_union) > 1: + for idx, alt_spec in enumerate(specs_in_union): + # return first resolved config from an union + try: + current_value, traces = _resolve_config_field( + key, + alt_spec, + default_value, + explicit_value, + config, + config.__section__, + explicit_sections, + embedded_sections, + accept_partial, + ) + break + except ConfigFieldMissingException as cfm_ex: + # add traces from unresolved union spec + # TODO: we should group traces per hint - currently user will see all options tried without the key info + traces.extend(list(itertools.chain(*cfm_ex.traces.values()))) + except InvalidNativeValue: + # if none of specs in union parsed + if idx == len(specs_in_union) - 1: + raise + else: + current_value, traces = _resolve_config_field( + key, + hint, + default_value, + explicit_value, + config, + config.__section__, + explicit_sections, + embedded_sections, + accept_partial, + ) + else: + # set the trace for explicit none + traces = [LookupTrace("ExplicitValues", None, key, None)] # check if hint optional is_optional = is_optional_type(hint) @@ -258,7 +274,7 @@ def _resolve_config_fields( unresolved_fields[key] = traces # set resolved value in config if default_value != current_value: - if not is_hint_not_resolved(hint): + if not is_hint_not_resolvable(hint): # ignore final types setattr(config, key, current_value) @@ -302,15 +318,15 @@ def _resolve_config_field( pass # if inner_hint is BaseConfiguration then resolve it recursively elif is_base_configuration_inner_hint(inner_hint): - if isinstance(value, BaseConfiguration): + if isinstance(default_value, BaseConfiguration): + # if default value was instance of configuration, use it as embedded initial + embedded_config = default_value + default_value = None + elif isinstance(value, BaseConfiguration): # if resolved value is instance of configuration (typically returned by context provider) embedded_config = value default_value = None value = None - elif isinstance(default_value, BaseConfiguration): - # if default value was instance of configuration, use it - embedded_config = default_value - default_value = None else: embedded_config = inner_hint() diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index d4cdfb729d..4919711f58 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -37,7 +37,6 @@ ) from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration -from dlt.common.configuration.accessors import config from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.destination.exceptions import ( IdentifierTooLongException, @@ -624,7 +623,7 @@ def from_reference( return dest def client( - self, schema: Schema, initial_config: TDestinationConfig = config.value + self, schema: Schema, initial_config: TDestinationConfig = None ) -> TDestinationClient: """Returns a configured instance of the destination's job client""" return self.client_class(schema, self.configuration(initial_config)) diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index fb94a21b7a..fd4ecc968e 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,16 +1,13 @@ from typing import Dict, List, cast from dlt.common.schema.schema import Schema -from dlt.common.configuration.accessors import config from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.storages.configuration import SchemaStorageConfiguration class LiveSchemaStorage(SchemaStorage): - def __init__( - self, config: SchemaStorageConfiguration = config.value, makedirs: bool = False - ) -> None: + def __init__(self, config: SchemaStorageConfiguration, makedirs: bool = False) -> None: self.live_schemas: Dict[str, Schema] = {} super().__init__(config, makedirs) diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py index bad7e4e3cc..c3b677058c 100644 --- a/dlt/destinations/impl/destination/configuration.py +++ b/dlt/destinations/impl/destination/configuration.py @@ -22,4 +22,4 @@ class CustomDestinationClientConfiguration(DestinationClientConfiguration): loader_file_format: TLoaderFileFormat = "typed-jsonl" batch_size: int = 10 skip_dlt_columns_and_tables: bool = True - max_table_nesting: int = 0 + max_table_nesting: Optional[int] = 0 diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index d4298f2f6b..aa5bbea09c 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -219,7 +219,7 @@ def _compute_metrics(self, load_id: str, source: DltSource) -> ExtractMetrics: if name == "incremental": # represent incremental as dictionary (it derives from BaseConfiguration) if isinstance(hint, IncrementalResourceWrapper): - hint = hint._incremental + hint = hint.incremental # sometimes internal incremental is not bound if hint: hints[name] = dict(hint) # type: ignore[call-overload] @@ -297,9 +297,8 @@ def _extract_single_source( load_id: str, source: DltSource, *, - max_parallel_items: int = None, - workers: int = None, - futures_poll_interval: float = None, + max_parallel_items: int, + workers: int, ) -> None: schema = source.schema collector = self.collector @@ -319,7 +318,6 @@ def _extract_single_source( source.resources.selected_pipes, max_parallel_items=max_parallel_items, workers=workers, - futures_poll_interval=futures_poll_interval, ) as pipes: left_gens = total_gens = len(pipes._sources) collector.update("Resources", 0, total_gens) diff --git a/dlt/helpers/dbt/__init__.py b/dlt/helpers/dbt/__init__.py index 4801dcd6b9..08d6c23ed1 100644 --- a/dlt/helpers/dbt/__init__.py +++ b/dlt/helpers/dbt/__init__.py @@ -6,7 +6,7 @@ from dlt.common.runners import Venv from dlt.common.destination.reference import DestinationClientDwhConfiguration from dlt.common.configuration.specs import CredentialsWithDefault -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretValue, ConfigValue from dlt.version import get_installed_requirement_string from dlt.helpers.dbt.runner import create_runner, DBTPackageRunner @@ -84,9 +84,9 @@ def package_runner( destination_configuration: DestinationClientDwhConfiguration, working_dir: str, package_location: str, - package_repository_branch: str = None, + package_repository_branch: str = ConfigValue, package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa - auto_full_refresh_when_out_of_sync: bool = None, + auto_full_refresh_when_out_of_sync: bool = ConfigValue, ) -> DBTPackageRunner: default_profile_name = _default_profile_name(destination_configuration) return create_runner( diff --git a/dlt/helpers/dbt/configuration.py b/dlt/helpers/dbt/configuration.py index 70fa4d1ac5..bec0bace3c 100644 --- a/dlt/helpers/dbt/configuration.py +++ b/dlt/helpers/dbt/configuration.py @@ -10,7 +10,7 @@ class DBTRunnerConfiguration(BaseConfiguration): package_location: str = None package_repository_branch: Optional[str] = None - package_repository_ssh_key: TSecretValue = TSecretValue( + package_repository_ssh_key: Optional[TSecretValue] = TSecretValue( "" ) # the default is empty value which will disable custom SSH KEY package_profiles_dir: Optional[str] = None diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index 7b1f79dc77..c68931d7db 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -1,7 +1,7 @@ import os from subprocess import CalledProcessError import giturlparse -from typing import Sequence +from typing import Optional, Sequence import dlt from dlt.common import logger @@ -302,11 +302,11 @@ def create_runner( credentials: DestinationClientDwhConfiguration, working_dir: str, package_location: str = dlt.config.value, - package_repository_branch: str = None, - package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa - package_profiles_dir: str = None, - package_profile_name: str = None, - auto_full_refresh_when_out_of_sync: bool = None, + package_repository_branch: Optional[str] = None, + package_repository_ssh_key: Optional[TSecretValue] = TSecretValue(""), # noqa + package_profiles_dir: Optional[str] = None, + package_profile_name: Optional[str] = None, + auto_full_refresh_when_out_of_sync: bool = True, config: DBTRunnerConfiguration = None, ) -> DBTPackageRunner: """Creates a Python wrapper over `dbt` package present at specified location, that allows to control it (ie. run and test) from Python code. diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index c9e7b5097c..d500788bd1 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -102,11 +102,11 @@ def pipeline( credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, _impl_cls: Type[TPipeline] = Pipeline, # type: ignore[assignment] - **kwargs: Any, + **injection_kwargs: Any, ) -> TPipeline: - ensure_correct_pipeline_kwargs(pipeline, **kwargs) + ensure_correct_pipeline_kwargs(pipeline, **injection_kwargs) # call without arguments returns current pipeline - orig_args = get_orig_args(**kwargs) # original (*args, **kwargs) + orig_args = get_orig_args(**injection_kwargs) # original (*args, **kwargs) # is any of the arguments different from defaults has_arguments = bool(orig_args[0]) or any(orig_args[1].values()) @@ -125,11 +125,12 @@ def pipeline( pipelines_dir = get_dlt_pipelines_dir() destination = Destination.from_reference( - destination or kwargs["destination_type"], destination_name=kwargs["destination_name"] + destination or injection_kwargs["destination_type"], + destination_name=injection_kwargs["destination_name"], ) staging = Destination.from_reference( - staging or kwargs.get("staging_type", None), - destination_name=kwargs.get("staging_name", None), + staging or injection_kwargs.get("staging_type", None), + destination_name=injection_kwargs.get("staging_name", None), ) progress = collector_from_name(progress) @@ -147,8 +148,8 @@ def pipeline( full_refresh, progress, False, - last_config(**kwargs), - kwargs["runtime"], + last_config(**injection_kwargs), + injection_kwargs["runtime"], ) # set it as current pipeline p.activate() @@ -163,10 +164,10 @@ def attach( full_refresh: bool = False, credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, - **kwargs: Any, + **injection_kwargs: Any, ) -> Pipeline: """Attaches to the working folder of `pipeline_name` in `pipelines_dir` or in default directory. Requires that valid pipeline state exists in working folder.""" - ensure_correct_pipeline_kwargs(attach, **kwargs) + ensure_correct_pipeline_kwargs(attach, **injection_kwargs) # if working_dir not provided use temp folder if not pipelines_dir: pipelines_dir = get_dlt_pipelines_dir() @@ -185,8 +186,8 @@ def attach( full_refresh, progress, True, - last_config(**kwargs), - kwargs["runtime"], + last_config(**injection_kwargs), + injection_kwargs["runtime"], ) # set it as current pipeline p.activate() diff --git a/dlt/pipeline/dbt.py b/dlt/pipeline/dbt.py index e647e475ed..ee900005fd 100644 --- a/dlt/pipeline/dbt.py +++ b/dlt/pipeline/dbt.py @@ -3,7 +3,7 @@ from dlt.common.exceptions import VenvNotFound from dlt.common.runners import Venv from dlt.common.schema import Schema -from dlt.common.typing import TSecretValue +from dlt.common.typing import ConfigValue, TSecretValue from dlt.common.schema.utils import normalize_schema_name from dlt.helpers.dbt import ( @@ -52,9 +52,9 @@ def get_venv( def package( pipeline: Pipeline, package_location: str, - package_repository_branch: str = None, + package_repository_branch: str = ConfigValue, package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa - auto_full_refresh_when_out_of_sync: bool = None, + auto_full_refresh_when_out_of_sync: bool = ConfigValue, venv: Venv = None, ) -> DBTPackageRunner: """Creates a Python wrapper over `dbt` package present at specified location, that allows to control it (ie. run and test) from Python code. diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 53770f332d..af37106b54 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -48,7 +48,7 @@ ) from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound -from dlt.common.typing import TFun, TSecretValue, is_optional_type +from dlt.common.typing import ConfigValue, TFun, TSecretValue, is_optional_type from dlt.common.runners import pool_runner as runner from dlt.common.storages import ( LiveSchemaStorage, @@ -400,8 +400,8 @@ def extract( columns: TAnySchemaColumns = None, primary_key: TColumnNames = None, schema: Schema = None, - max_parallel_items: int = None, - workers: int = None, + max_parallel_items: int = ConfigValue, + workers: int = ConfigValue, schema_contract: TSchemaContract = None, ) -> ExtractInfo: """Extracts the `data` and prepare it for the normalization. Does not require destination or credentials to be configured. See `run` method for the arguments' description.""" diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 43ccdf856c..f231b2b6ea 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -17,7 +17,7 @@ from dlt.common import json, pendulum, Decimal, Wei from dlt.common.configuration.providers.provider import ConfigProvider -from dlt.common.configuration.specs.base_configuration import NotResolved, is_hint_not_resolved +from dlt.common.configuration.specs.base_configuration import NotResolved, is_hint_not_resolvable from dlt.common.configuration.specs.gcp_credentials import ( GcpServiceAccountCredentialsWithoutDefaults, ) @@ -311,6 +311,32 @@ def test_explicit_values_false_when_bool() -> None: assert c.heels == "" +def test_explicit_embedded_config(environment: Any) -> None: + instr_explicit = InstrumentedConfiguration(head="h", tube=["tu", "be"], heels="xhe") + + environment["INSTRUMENTED__HEAD"] = "hed" + c = resolve.resolve_configuration( + EmbeddedConfiguration(default="X", sectioned=SectionedConfiguration(password="S")), + explicit_value={"instrumented": instr_explicit}, + ) + + # explicit value will be part of the resolved configuration + assert c.instrumented is instr_explicit + # configuration was injected from env + assert c.instrumented.head == "hed" + + # the same but with resolved + instr_explicit = InstrumentedConfiguration(head="h", tube=["tu", "be"], heels="xhe") + instr_explicit.resolve() + c = resolve.resolve_configuration( + EmbeddedConfiguration(default="X", sectioned=SectionedConfiguration(password="S")), + explicit_value={"instrumented": instr_explicit}, + ) + assert c.instrumented is instr_explicit + # but configuration is not injected + assert c.instrumented.head == "h" + + def test_default_values(environment: Any) -> None: # explicit values override the environment and all else environment["PIPELINE_NAME"] = "env name" @@ -925,12 +951,14 @@ def test_is_valid_hint() -> None: def test_is_not_resolved_hint() -> None: - assert is_hint_not_resolved(Final[ConfigFieldMissingException]) is True - assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved()]) is True - assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(True)]) is True - assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(False)]) is False - assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, "REQ"]) is False - assert is_hint_not_resolved(str) is False + assert is_hint_not_resolvable(Final[ConfigFieldMissingException]) is True + assert is_hint_not_resolvable(Annotated[ConfigFieldMissingException, NotResolved()]) is True + assert is_hint_not_resolvable(Annotated[ConfigFieldMissingException, NotResolved(True)]) is True + assert ( + is_hint_not_resolvable(Annotated[ConfigFieldMissingException, NotResolved(False)]) is False + ) + assert is_hint_not_resolvable(Annotated[ConfigFieldMissingException, "REQ"]) is False + assert is_hint_not_resolvable(str) is False def test_not_resolved_hint() -> None: diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index 4f2219716a..43bad21ece 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -89,7 +89,7 @@ def single_val(port=None): return port # secrets have api.port=1023 and this will be used - assert single_val(None) == 1023 + assert single_val(dlt.secrets.value) == 1023 # env will make it string, also section is optional environment["PORT"] = "UNKNOWN" @@ -110,7 +110,7 @@ def mixed_val( ): return api_type, secret_value, typecheck - _tup = mixed_val(None, None, None) + _tup = mixed_val(dlt.config.value, dlt.secrets.value, dlt.config.value) assert _tup[0] == "REST" assert _tup[1] == "2137" assert isinstance(_tup[2], dict) diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 6cb76fba9d..e97fac8a9e 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -62,14 +62,14 @@ def ie_storage(request) -> SchemaStorage: ) -def init_storage(cls, C: SchemaStorageConfiguration) -> SchemaStorage: +def init_storage(cls, config: SchemaStorageConfiguration) -> SchemaStorage: # use live schema storage for test which must be backward compatible with schema storage - s = cls(C, makedirs=True) - assert C is s.config - if C.export_schema_path: - os.makedirs(C.export_schema_path, exist_ok=True) - if C.import_schema_path: - os.makedirs(C.import_schema_path, exist_ok=True) + s = cls(config, makedirs=True) + assert config is s.config + if config.export_schema_path: + os.makedirs(config.export_schema_path, exist_ok=True) + if config.import_schema_path: + os.makedirs(config.import_schema_path, exist_ok=True) return s diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py index 9620e7fdfb..dc978b997a 100644 --- a/tests/extract/test_extract.py +++ b/tests/extract/test_extract.py @@ -317,7 +317,7 @@ def tx_step(item): def expect_tables(extract_step: Extract, resource: DltResource) -> dlt.Schema: source = DltSource(dlt.Schema("selectables"), "module", [resource(10)]) load_id = extract_step.extract_storage.create_load_package(source.discover_schema()) - extract_step._extract_single_source(load_id, source) + extract_step._extract_single_source(load_id, source, max_parallel_items=5, workers=1) # odd and even tables must be in the source schema assert len(source.schema.data_tables(include_incomplete=True)) == 2 assert "odd_table" in source.schema._schema_tables @@ -340,7 +340,7 @@ def expect_tables(extract_step: Extract, resource: DltResource) -> dlt.Schema: source = source.with_resources(resource.name) source.selected_resources[resource.name].bind(10).select_tables("odd_table") load_id = extract_step.extract_storage.create_load_package(source.discover_schema()) - extract_step._extract_single_source(load_id, source) + extract_step._extract_single_source(load_id, source, max_parallel_items=5, workers=1) assert len(source.schema.data_tables(include_incomplete=True)) == 1 assert "odd_table" in source.schema._schema_tables extract_step.extract_storage.commit_new_load_package(load_id, source.schema) From 8ffa6d833bf0522a3baa2c3076e9aa56d689ccdb Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 2 Jun 2024 12:27:03 +0200 Subject: [PATCH 14/15] adds is_subclass working with type aliases --- dlt/cli/config_toml_writer.py | 4 +- dlt/common/configuration/container.py | 3 +- dlt/common/configuration/providers/context.py | 3 +- dlt/common/configuration/resolve.py | 8 ++-- .../configuration/specs/base_configuration.py | 7 ++-- dlt/common/libs/pydantic.py | 5 ++- dlt/common/typing.py | 41 ++++++++++++++++++- dlt/extract/incremental/__init__.py | 6 +-- tests/common/test_typing.py | 17 ++++++++ 9 files changed, 75 insertions(+), 19 deletions(-) diff --git a/dlt/cli/config_toml_writer.py b/dlt/cli/config_toml_writer.py index 7ff7f735eb..63396e3ebe 100644 --- a/dlt/cli/config_toml_writer.py +++ b/dlt/cli/config_toml_writer.py @@ -11,7 +11,7 @@ extract_inner_hint, ) from dlt.common.data_types import py_type_to_sc_type -from dlt.common.typing import AnyType, is_final_type, is_optional_type +from dlt.common.typing import AnyType, is_final_type, is_optional_type, is_subclass class WritableConfigValue(NamedTuple): @@ -34,7 +34,7 @@ def generate_typed_example(name: str, hint: AnyType) -> Any: if sc_type == "bool": return True if sc_type == "complex": - if issubclass(inner_hint, C_Sequence): + if is_subclass(inner_hint, C_Sequence): return ["a", "b", "c"] else: table = tomlkit.table(False) diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index 441b0e21bc..84d6194966 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -8,6 +8,7 @@ ContainerInjectableContextMangled, ContextDefaultCannotBeCreated, ) +from dlt.common.typing import is_subclass TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableContext) @@ -56,7 +57,7 @@ def __init__(self) -> None: def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: # return existing config object or create it from spec - if not issubclass(spec, ContainerInjectableContext): + if not is_subclass(spec, ContainerInjectableContext): raise KeyError(f"{spec.__name__} is not a context") context, item = self._thread_getitem(spec) diff --git a/dlt/common/configuration/providers/context.py b/dlt/common/configuration/providers/context.py index c6c1aac644..de2290540c 100644 --- a/dlt/common/configuration/providers/context.py +++ b/dlt/common/configuration/providers/context.py @@ -3,6 +3,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs import ContainerInjectableContext +from dlt.common.typing import is_subclass from .provider import ConfigProvider @@ -24,7 +25,7 @@ def get_value( # only context is a valid hint with contextlib.suppress(KeyError, TypeError): - if issubclass(hint, ContainerInjectableContext): + if is_subclass(hint, ContainerInjectableContext): # contexts without defaults will raise ContextDefaultCannotBeCreated return self.container[hint], hint.__name__ diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 68634881da..5b01c2b65b 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -5,12 +5,12 @@ from dlt.common.configuration.providers.provider import ConfigProvider from dlt.common.typing import ( AnyType, - ConfigValue, - SecretValue, + ConfigValueSentinel, StrAny, TSecretValue, get_all_types_of_class_in_union, is_optional_type, + is_subclass, is_union_type, ) @@ -89,7 +89,7 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur raise return first_credentials else: - assert issubclass(hint, CredentialsConfiguration) + assert is_subclass(hint, CredentialsConfiguration) return hint.from_init_value(initial_value) # type: ignore @@ -202,7 +202,7 @@ def _resolve_config_fields( explicit_value = explicit_values[key] explicit_none = explicit_value is None # detect dlt.config and dlt.secrets and force injection - if explicit_value in (ConfigValue, SecretValue): + if isinstance(explicit_value, ConfigValueSentinel): explicit_value = None else: if is_hint_not_resolvable(hint): diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 12ff1d8a5d..7d2fbc0035 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -36,6 +36,7 @@ is_annotated, is_final_type, is_optional_type, + is_subclass, is_union_type, ) from dlt.common.data_types import py_type_to_sc_type @@ -81,15 +82,15 @@ def is_hint_not_resolvable(hint: AnyType) -> bool: def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool: - return inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration) + return is_subclass(inner_hint, BaseConfiguration) def is_context_inner_hint(inner_hint: Type[Any]) -> bool: - return inspect.isclass(inner_hint) and issubclass(inner_hint, ContainerInjectableContext) + return is_subclass(inner_hint, ContainerInjectableContext) def is_credentials_inner_hint(inner_hint: Type[Any]) -> bool: - return inspect.isclass(inner_hint) and issubclass(inner_hint, CredentialsConfiguration) + return is_subclass(inner_hint, CredentialsConfiguration) def get_config_if_union_hint(hint: Type[Any]) -> Type[Any]: diff --git a/dlt/common/libs/pydantic.py b/dlt/common/libs/pydantic.py index c4bf994cb9..e6af064b8f 100644 --- a/dlt/common/libs/pydantic.py +++ b/dlt/common/libs/pydantic.py @@ -28,6 +28,7 @@ extract_inner_type, is_list_generic_type, is_dict_generic_type, + is_subclass, is_union_type, ) @@ -124,7 +125,7 @@ def pydantic_to_table_schema_columns( try: data_type = py_type_to_sc_type(inner_type) except TypeError: - if issubclass(inner_type, BaseModel): + if is_subclass(inner_type, BaseModel): data_type = "complex" is_inner_type_pydantic_model = True else: @@ -250,7 +251,7 @@ def _process_annotation(t_: Type[Any]) -> Type[Any]: elif is_union_type(t_): u_t_s = tuple(_process_annotation(u_t) for u_t in extract_union_types(t_)) return Union[u_t_s] # type: ignore[return-value] - elif inspect.isclass(t_) and issubclass(t_, BaseModel): + elif is_subclass(t_, BaseModel): # types must be same before and after processing if id(t_) in _child_models: return _child_models[id(t_)] diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 2575dd15a4..e108267d06 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,8 +1,9 @@ -from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence +from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence, Callable as C_Callable from datetime import datetime, date # noqa: I251 import inspect import os from re import Pattern as _REPattern +import sys from types import FunctionType, MethodType, ModuleType from typing import ( ForwardRef, @@ -50,6 +51,16 @@ # in versions of Python>=3.10. UnionType = Never +if sys.version_info[:3] >= (3, 9, 0): + from typing import _SpecialGenericAlias, _GenericAlias # type: ignore[attr-defined] + from types import GenericAlias # type: ignore[attr-defined] + + typingGenericAlias: Tuple[Any, ...] = (_GenericAlias, _SpecialGenericAlias, GenericAlias) +else: + from typing import _GenericAlias # type: ignore[attr-defined] + + typingGenericAlias = (_GenericAlias,) + from dlt.common.pendulum import timedelta, pendulum if TYPE_CHECKING: @@ -317,10 +328,36 @@ def get_all_types_of_class_in_union(hint: Type[Any], cls: Type[TAny]) -> List[Ty return [ t for t in get_args(hint) - if not is_typeddict(t) and inspect.isclass(t) and (issubclass(t, cls) or issubclass(cls, t)) + if not is_typeddict(t) and (is_subclass(t, cls) or is_subclass(cls, t)) ] +def is_generic_alias(tp: Type[Any]) -> bool: + """Tests if type is a generic alias ie. List[str]""" + return isinstance(tp, typingGenericAlias) and tp.__origin__ not in ( + Union, + tuple, + ClassVar, + C_Callable, + ) + + +def is_subclass(subclass: Type[Any], cls: Type[Any]) -> bool: + """Return whether 'cls' is a derived from another class or is the same class. + + Will handle generic types by comparing their origins. + """ + if is_generic_alias(subclass): + subclass = get_origin(subclass) + if is_generic_alias(cls): + cls = get_origin(cls) + + print(subclass, cls, inspect.isclass(subclass), inspect.isclass(cls)) + if inspect.isclass(subclass) and inspect.isclass(cls): + return issubclass(subclass, cls) + return False + + def get_generic_type_argument_from_instance( instance: Any, sample_value: Optional[Any] ) -> Type[Any]: diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index c30466e9bd..bcb6b1cc9a 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -18,6 +18,7 @@ extract_inner_type, get_generic_type_argument_from_instance, is_optional_type, + is_subclass, ) from dlt.common.schema.typing import TColumnNames from dlt.common.configuration import configspec, ConfigurationValueError @@ -524,10 +525,7 @@ def get_incremental_arg(sig: inspect.Signature) -> Optional[inspect.Parameter]: incremental_param: Optional[inspect.Parameter] = None for p in sig.parameters.values(): annotation = extract_inner_type(p.annotation) - annotation = get_origin(annotation) or annotation - if (inspect.isclass(annotation) and issubclass(annotation, Incremental)) or isinstance( - p.default, Incremental - ): + if is_subclass(annotation, Incremental) or isinstance(p.default, Incremental): incremental_param = p break return incremental_param diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index 5bdd308566..53aa29dd78 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -34,6 +34,7 @@ is_literal_type, is_newtype_type, is_optional_type, + is_subclass, is_typeddict, is_union_type, is_annotated, @@ -233,3 +234,19 @@ def test_extract_annotated_inner_type() -> None: assert extract_inner_type(Annotated[Optional[MyDataclass], "meta"]) is MyDataclass # type: ignore[arg-type] assert extract_inner_type(Annotated[MyDataclass, Optional]) is MyDataclass # type: ignore[arg-type] assert extract_inner_type(Annotated[MyDataclass, "random metadata string"]) is MyDataclass # type: ignore[arg-type] + + +def test_is_subclass() -> None: + from dlt.extract import Incremental + + assert is_subclass(Incremental, BaseConfiguration) is True + assert is_subclass(Incremental[float], Incremental[int]) is True + assert is_subclass(BaseConfiguration, Incremental[int]) is False + assert is_subclass(list, Sequence) is True + assert is_subclass(list, Sequence[str]) is True + # unions, new types, literals etc. will always produce False + assert is_subclass(list, Optional[list]) is False # type: ignore[arg-type] + assert is_subclass(Optional[list], list) is False # type: ignore[arg-type] + assert is_subclass(list, TTestLi) is False # type: ignore[arg-type] + assert is_subclass(TTestLi, TTestLi) is False # type: ignore[arg-type] + assert is_subclass(list, NewType("LT", list)) is False From d2f35dab636388b92c4814f0033404e4d5f8a754 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 2 Jun 2024 16:21:27 +0200 Subject: [PATCH 15/15] tests configspecs with generics --- dlt/common/configuration/resolve.py | 4 +- dlt/common/typing.py | 13 +-- .../configuration/test_configuration.py | 81 ++++++++++++++++++- tests/common/test_typing.py | 28 ++++++- tests/extract/test_incremental.py | 41 ++++++++++ 5 files changed, 154 insertions(+), 13 deletions(-) diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 5b01c2b65b..5ca08a8a66 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -221,7 +221,9 @@ def _resolve_config_fields( if ( explicit_value and not is_valid_hint(type(explicit_value)) - and get_all_types_of_class_in_union(hint, type(explicit_value)) + and get_all_types_of_class_in_union( + hint, type(explicit_value), with_superclass=True + ) ): current_value, traces = explicit_value, [] else: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index e108267d06..f0b2b7dcb3 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -323,16 +323,18 @@ def extract_inner_type(hint: Type[Any], preserve_new_types: bool = False) -> Typ return hint -def get_all_types_of_class_in_union(hint: Type[Any], cls: Type[TAny]) -> List[Type[TAny]]: - # hint is an Union that contains classes, return all classes that are a subclass or superclass of cls +def get_all_types_of_class_in_union( + hint: Any, cls: TAny, with_superclass: bool = False +) -> List[TAny]: + """if `hint` is an Union that contains classes, return all classes that are a subclass or (optionally) superclass of cls""" return [ t for t in get_args(hint) - if not is_typeddict(t) and (is_subclass(t, cls) or is_subclass(cls, t)) + if not is_typeddict(t) and (is_subclass(t, cls) or is_subclass(cls, t) and with_superclass) ] -def is_generic_alias(tp: Type[Any]) -> bool: +def is_generic_alias(tp: Any) -> bool: """Tests if type is a generic alias ie. List[str]""" return isinstance(tp, typingGenericAlias) and tp.__origin__ not in ( Union, @@ -342,7 +344,7 @@ def is_generic_alias(tp: Type[Any]) -> bool: ) -def is_subclass(subclass: Type[Any], cls: Type[Any]) -> bool: +def is_subclass(subclass: Any, cls: Any) -> bool: """Return whether 'cls' is a derived from another class or is the same class. Will handle generic types by comparing their origins. @@ -352,7 +354,6 @@ def is_subclass(subclass: Type[Any], cls: Type[Any]) -> bool: if is_generic_alias(cls): cls = get_origin(cls) - print(subclass, cls, inspect.isclass(subclass), inspect.isclass(cls)) if inspect.isclass(subclass) and inspect.isclass(cls): return issubclass(subclass, cls) return False diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index f231b2b6ea..d0a56f526e 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -5,6 +5,7 @@ Any, Dict, Final, + Generic, List, Mapping, MutableMapping, @@ -13,7 +14,7 @@ Type, Union, ) -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeVar from dlt.common import json, pendulum, Decimal, Wei from dlt.common.configuration.providers.provider import ConfigProvider @@ -22,7 +23,14 @@ GcpServiceAccountCredentialsWithoutDefaults, ) from dlt.common.utils import custom_environ, get_exception_trace, get_exception_trace_chain -from dlt.common.typing import AnyType, DictStrAny, StrAny, TSecretValue, extract_inner_type +from dlt.common.typing import ( + AnyType, + ConfigValue, + DictStrAny, + StrAny, + TSecretValue, + extract_inner_type, +) from dlt.common.configuration.exceptions import ( ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, @@ -1338,3 +1346,72 @@ class EmbeddedConfigurationWithDefaults(BaseConfiguration): c_resolved = resolve.resolve_configuration(c_instance) assert c_resolved.is_resolved() assert c_resolved.conn_str.is_resolved() + + +def test_configuration_with_generic(environment: Dict[str, str]) -> None: + TColumn = TypeVar("TColumn", bound=str) + + @configspec + class IncrementalConfiguration(BaseConfiguration, Generic[TColumn]): + # TODO: support generics field + column: str = ConfigValue + + @configspec + class SourceConfiguration(BaseConfiguration): + name: str = ConfigValue + incremental: IncrementalConfiguration[str] = ConfigValue + + # resolve incremental + environment["COLUMN"] = "column" + c = resolve.resolve_configuration(IncrementalConfiguration[str]()) + assert c.column == "column" + + # resolve embedded config with generic + environment["INCREMENTAL__COLUMN"] = "column_i" + c2 = resolve.resolve_configuration(SourceConfiguration(name="name")) + assert c2.incremental.column == "column_i" + + # put incremental in union + @configspec + class SourceUnionConfiguration(BaseConfiguration): + name: str = ConfigValue + incremental_union: Optional[IncrementalConfiguration[str]] = ConfigValue + + c3 = resolve.resolve_configuration(SourceUnionConfiguration(name="name")) + assert c3.incremental_union is None + environment["INCREMENTAL_UNION__COLUMN"] = "column_u" + c3 = resolve.resolve_configuration(SourceUnionConfiguration(name="name")) + assert c3.incremental_union.column == "column_u" + + class Sentinel: + pass + + class SubSentinel(Sentinel): + pass + + @configspec + class SourceWideUnionConfiguration(BaseConfiguration): + name: str = ConfigValue + incremental_w_union: Union[IncrementalConfiguration[str], str, Sentinel] = ConfigValue + incremental_sub: Optional[Union[IncrementalConfiguration[str], str, SubSentinel]] = None + + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(SourceWideUnionConfiguration(name="name")) + + # use explicit sentinel + sentinel = Sentinel() + c4 = resolve.resolve_configuration( + SourceWideUnionConfiguration(name="name"), explicit_value={"incremental_w_union": sentinel} + ) + assert c4.incremental_w_union is sentinel + + # instantiate incremental + environment["INCREMENTAL_W_UNION__COLUMN"] = "column_w_u" + c4 = resolve.resolve_configuration(SourceWideUnionConfiguration(name="name")) + assert c4.incremental_w_union.column == "column_w_u" # type: ignore[union-attr] + + # sentinel (of super class type) also works for hint of subclass type + c4 = resolve.resolve_configuration( + SourceWideUnionConfiguration(name="name"), explicit_value={"incremental_sub": sentinel} + ) + assert c4.incremental_sub is sentinel diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index 53aa29dd78..3a9e320040 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -29,6 +29,7 @@ StrAny, extract_inner_type, extract_union_types, + get_all_types_of_class_in_union, is_dict_generic_type, is_list_generic_type, is_literal_type, @@ -245,8 +246,27 @@ def test_is_subclass() -> None: assert is_subclass(list, Sequence) is True assert is_subclass(list, Sequence[str]) is True # unions, new types, literals etc. will always produce False - assert is_subclass(list, Optional[list]) is False # type: ignore[arg-type] - assert is_subclass(Optional[list], list) is False # type: ignore[arg-type] - assert is_subclass(list, TTestLi) is False # type: ignore[arg-type] - assert is_subclass(TTestLi, TTestLi) is False # type: ignore[arg-type] + assert is_subclass(list, Optional[list]) is False + assert is_subclass(Optional[list], list) is False + assert is_subclass(list, TTestLi) is False + assert is_subclass(TTestLi, TTestLi) is False assert is_subclass(list, NewType("LT", list)) is False + + +def test_get_all_types_of_class_in_union() -> None: + from dlt.extract import Incremental + + # optional is an union + assert get_all_types_of_class_in_union(Optional[str], str) == [str] + # both classes and type aliases are recognized + assert get_all_types_of_class_in_union(Optional[Incremental], BaseConfiguration) == [ + Incremental + ] + assert get_all_types_of_class_in_union(Optional[Incremental[float]], BaseConfiguration) == [ + Incremental[float] + ] + # by default superclasses are not recognized + assert get_all_types_of_class_in_union(Union[BaseConfiguration, str], Incremental[float]) == [] + assert get_all_types_of_class_in_union( + Union[BaseConfiguration, str], Incremental[float], with_superclass=True + ) == [BaseConfiguration] diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 69091c2f28..675f44bb14 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -1,5 +1,6 @@ import os import asyncio +import inspect import random from time import sleep from typing import Optional, Any @@ -25,6 +26,7 @@ from dlt.extract.exceptions import InvalidStepFunctionArguments from dlt.extract.resource import DltResource from dlt.sources.helpers.transform import take_first +from dlt.extract.incremental import IncrementalResourceWrapper, Incremental from dlt.extract.incremental.exceptions import ( IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, @@ -40,6 +42,45 @@ ) +def test_detect_incremental_arg() -> None: + def incr_1(incremental: dlt.sources.incremental): # type: ignore[type-arg] + pass + + assert ( + IncrementalResourceWrapper.get_incremental_arg(inspect.signature(incr_1)).name + == "incremental" + ) + + def incr_2(incremental: Incremental[str]): + pass + + assert ( + IncrementalResourceWrapper.get_incremental_arg(inspect.signature(incr_2)).name + == "incremental" + ) + + def incr_3(incremental=dlt.sources.incremental[str]("updated_at")): # noqa + pass + + assert ( + IncrementalResourceWrapper.get_incremental_arg(inspect.signature(incr_3)).name + == "incremental" + ) + + def incr_4(incremental=Incremental[str]("updated_at")): # noqa + pass + + assert ( + IncrementalResourceWrapper.get_incremental_arg(inspect.signature(incr_4)).name + == "incremental" + ) + + def incr_5(incremental: IncrementalResourceWrapper): + pass + + assert IncrementalResourceWrapper.get_incremental_arg(inspect.signature(incr_5)) is None + + @pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) def test_single_items_last_value_state_is_updated(item_type: TestDataItemFormat) -> None: data = [