Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes config injection edge cases #1430

Merged
merged 16 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_destinations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ env:
RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }}
# Test redshift and filesystem with all buckets
# postgres runs again here so we can test on mac/windows
ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\"]"
ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\", \"motherduck\"]"

jobs:
get_docs_changes:
Expand Down
8 changes: 4 additions & 4 deletions dlt/cli/config_toml_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from tomlkit.container import Container as TOMLContainer
from collections.abc import Sequence as C_Sequence

from dlt.common.configuration.specs.base_configuration import is_hint_not_resolved
from dlt.common.configuration.specs.base_configuration import is_hint_not_resolvable
from dlt.common.pendulum import pendulum
from dlt.common.configuration.specs import (
BaseConfiguration,
is_base_configuration_inner_hint,
extract_inner_hint,
)
from dlt.common.data_types import py_type_to_sc_type
from dlt.common.typing import AnyType, is_optional_type
from dlt.common.typing import AnyType, is_optional_type, is_subclass


class WritableConfigValue(NamedTuple):
Expand All @@ -35,7 +35,7 @@ def generate_typed_example(name: str, hint: AnyType) -> Any:
if sc_type == "bool":
return True
if sc_type == "complex":
if issubclass(inner_hint, C_Sequence):
if is_subclass(inner_hint, C_Sequence):
return ["a", "b", "c"]
else:
table = tomlkit.table(False)
Expand Down Expand Up @@ -65,7 +65,7 @@ def write_value(
return
# do not dump nor resolvable and optional fields if they are not of special interest
if (
is_hint_not_resolved(hint) or is_optional_type(hint) or default_value is not None
is_hint_not_resolvable(hint) or is_optional_type(hint) or default_value is not None
) and not is_default_of_interest:
return
# get the inner hint to generate cool examples
Expand Down
8 changes: 2 additions & 6 deletions dlt/common/configuration/accessors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import abc
import contextlib
import tomlkit
from typing import Any, ClassVar, List, Sequence, Tuple, Type, TypeVar

from dlt.common.configuration.container import Container
Expand All @@ -9,10 +7,8 @@
from dlt.common.configuration.specs import BaseConfiguration, is_base_configuration_inner_hint
from dlt.common.configuration.utils import deserialize_value, log_traces, auto_cast
from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext
from dlt.common.typing import AnyType, ConfigValue, TSecretValue
from dlt.common.typing import AnyType, ConfigValue, SecretValue, TSecretValue

DLT_SECRETS_VALUE = "secrets.value"
DLT_CONFIG_VALUE = "config.value"
TConfigAny = TypeVar("TConfigAny", bound=Any)


Expand Down Expand Up @@ -129,7 +125,7 @@ def writable_provider(self) -> ConfigProvider:
p for p in self._get_providers_from_context() if p.is_writable and p.supports_secrets
)

value: ClassVar[Any] = ConfigValue
value: ClassVar[Any] = SecretValue
"A placeholder that tells dlt to replace it with actual secret during the call to a source or resource decorated function."


Expand Down
3 changes: 2 additions & 1 deletion dlt/common/configuration/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ContainerInjectableContextMangled,
ContextDefaultCannotBeCreated,
)
from dlt.common.typing import is_subclass

TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableContext)

Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self) -> None:

def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration:
# return existing config object or create it from spec
if not issubclass(spec, ContainerInjectableContext):
if not is_subclass(spec, ContainerInjectableContext):
raise KeyError(f"{spec.__name__} is not a context")

context, item = self._thread_getitem(spec)
Expand Down
93 changes: 54 additions & 39 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import inspect

from functools import wraps
from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload, cast
from typing import Callable, Dict, Type, Any, Optional, Union, Tuple, TypeVar, overload, cast
from inspect import Signature, Parameter
from contextlib import nullcontext

from dlt.common.typing import DictStrAny, StrAny, TFun, AnyFun
from dlt.common.typing import DictStrAny, TFun, AnyFun
from dlt.common.configuration.resolve import resolve_configuration, inject_section
from dlt.common.configuration.specs.base_configuration import BaseConfiguration
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
Expand All @@ -15,22 +14,24 @@

_LAST_DLT_CONFIG = "_dlt_config"
_ORIGINAL_ARGS = "_dlt_orig_args"
# keep a registry of all the decorated functions
_FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {}

TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration)


def get_fun_spec(f: AnyFun) -> Type[BaseConfiguration]:
return _FUNC_SPECS.get(id(f))
return getattr(f, "__SPEC__", None) # type: ignore[no-any-return]


def set_fun_spec(f: AnyFun, spec: Type[BaseConfiguration]) -> None:
"""Assigns a spec to a callable from which it was inferred"""
setattr(f, "__SPEC__", spec) # noqa: B010


@overload
def with_config(
func: TFun,
/,
spec: Type[BaseConfiguration] = None,
sections: Tuple[str, ...] = (),
sections: Union[str, Tuple[str, ...]] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
include_defaults: bool = True,
Expand All @@ -46,7 +47,7 @@ def with_config(
func: None = ...,
/,
spec: Type[BaseConfiguration] = None,
sections: Tuple[str, ...] = (),
sections: Union[str, Tuple[str, ...]] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
include_defaults: bool = True,
Expand All @@ -61,7 +62,7 @@ def with_config(
func: Optional[AnyFun] = None,
/,
spec: Type[BaseConfiguration] = None,
sections: Tuple[str, ...] = (),
sections: Union[str, Tuple[str, ...]] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
include_defaults: bool = True,
Expand All @@ -88,17 +89,18 @@ def with_config(
Callable[[TFun], TFun]: A decorated function
"""

section_f: Callable[[StrAny], str] = None
# section may be a function from function arguments to section
if callable(sections):
section_f = sections

def decorator(f: TFun) -> TFun:
SPEC: Type[BaseConfiguration] = None
sig: Signature = inspect.signature(f)
signature_fields: Dict[str, Any]
# find variadic kwargs to which additional arguments and injection context can be injected
kwargs_arg = next(
(p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None
(
p
for p in sig.parameters.values()
if p.kind == Parameter.VAR_KEYWORD and p.name == "injection_kwargs"
),
None,
)
if spec is None:
SPEC, signature_fields = spec_from_signature(f, sig, include_defaults, base=base)
Expand All @@ -109,7 +111,7 @@ def decorator(f: TFun) -> TFun:
# if no signature fields were added we will not wrap `f` for injection
if len(signature_fields) == 0:
# always register new function
_FUNC_SPECS[id(f)] = SPEC
set_fun_spec(f, SPEC)
return f

spec_arg: Parameter = None
Expand All @@ -127,20 +129,23 @@ def decorator(f: TFun) -> TFun:
pipeline_name_arg = p
pipeline_name_arg_default = None if p.default == Parameter.empty else p.default

def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration:
def resolve_config(
bound_args: inspect.BoundArguments, accept_partial_: bool
) -> BaseConfiguration:
"""Resolve arguments using the provided spec"""
# bind parameters to signature
# for calls containing resolved spec in the kwargs, we do not need to resolve again
config: BaseConfiguration = None

# if section derivation function was provided then call it
if section_f:
curr_sections: Tuple[str, ...] = (section_f(bound_args.arguments),)
# sections may be a string
elif isinstance(sections, str):
curr_sections = (sections,)
curr_sections: Union[str, Tuple[str, ...]] = None
# section may be a function from function arguments to section
if callable(sections):
curr_sections = sections(bound_args.arguments)
else:
curr_sections = sections
# sections may be a string
if isinstance(curr_sections, str):
curr_sections = (curr_sections,)

# if one of arguments is spec the use it as initial value
if initial_config:
Expand All @@ -162,18 +167,19 @@ def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration:

# this may be called from many threads so section_context is thread affine
with inject_section(section_context, lock_context=lock_context_on_injection):
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}")
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections} in {bound_args.arguments}")
return resolve_configuration(
config or SPEC(),
explicit_value=bound_args.arguments,
accept_partial=accept_partial,
accept_partial=accept_partial_,
)

def update_bound_args(
bound_args: inspect.BoundArguments, config: BaseConfiguration, args: Any, kwargs: Any
) -> None:
# overwrite or add resolved params
resolved_params = dict(config)
# print("resolved_params", resolved_params)
# overwrite or add resolved params
for p in sig.parameters.values():
if p.name in resolved_params:
Expand All @@ -191,11 +197,18 @@ def update_bound_args(

def with_partially_resolved_config(config: Optional[BaseConfiguration] = None) -> Any:
# creates a pre-resolved partial of the decorated function
empty_bound_args = sig.bind_partial()
if not config:
config = resolve_config(empty_bound_args)

def wrapped(*args: Any, **kwargs: Any) -> Any:
# TODO: this will not work if correct config is not provided
# esp. in case of parameters in _wrap being ConfigurationBase
# at least we should implement re-resolve with explicit parameters
# so we can merge partial we get here to combine a full config
empty_bound_args = sig.bind_partial()
# TODO: resolve partial here that will be updated in _wrap
config = resolve_config(empty_bound_args, accept_partial_=False)

@wraps(f)
def _wrap(*args: Any, **kwargs: Any) -> Any:
# TODO: we should not change the outer config but deepcopy it
nonlocal config

# Do we need an exception here?
Expand All @@ -213,27 +226,28 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:

# call the function with the pre-resolved config
bound_args = sig.bind(*args, **kwargs)
# TODO: update partial config with bound_args (to cover edge cases with embedded configs)
update_bound_args(bound_args, config, args, kwargs)
return f(*bound_args.args, **bound_args.kwargs)

return wrapped
return _wrap

@wraps(f)
def _wrap(*args: Any, **kwargs: Any) -> Any:
# Resolve config
config: BaseConfiguration = None
bound_args = sig.bind(*args, **kwargs)
bound_args = sig.bind_partial(*args, **kwargs)
if _LAST_DLT_CONFIG in kwargs:
config = last_config(**kwargs)
else:
config = resolve_config(bound_args)
config = resolve_config(bound_args, accept_partial_=accept_partial)

# call the function with resolved config
update_bound_args(bound_args, config, args, kwargs)
return f(*bound_args.args, **bound_args.kwargs)

# register the spec for a wrapped function
_FUNC_SPECS[id(_wrap)] = SPEC
set_fun_spec(_wrap, SPEC)

# add a method to create a pre-resolved partial
setattr(_wrap, "__RESOLVED_PARTIAL_FUNC__", with_partially_resolved_config) # noqa: B010
Expand All @@ -255,13 +269,14 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
return decorator(func)


def last_config(**kwargs: Any) -> Any:
"""Get configuration instance used to inject function arguments"""
return kwargs[_LAST_DLT_CONFIG]
def last_config(**injection_kwargs: Any) -> Any:
"""Get configuration instance used to inject function kwargs"""
return injection_kwargs[_LAST_DLT_CONFIG]


def get_orig_args(**kwargs: Any) -> Tuple[Tuple[Any], DictStrAny]:
return kwargs[_ORIGINAL_ARGS] # type: ignore
def get_orig_args(**injection_kwargs: Any) -> Tuple[Tuple[Any], DictStrAny]:
"""Get original argument with which the injectable function was called"""
return injection_kwargs[_ORIGINAL_ARGS] # type: ignore


def create_resolved_partial(f: AnyFun, config: Optional[BaseConfiguration] = None) -> AnyFun:
Expand Down
3 changes: 2 additions & 1 deletion dlt/common/configuration/providers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dlt.common.configuration.container import Container
from dlt.common.configuration.specs import ContainerInjectableContext
from dlt.common.typing import is_subclass

from .provider import ConfigProvider

Expand All @@ -24,7 +25,7 @@ def get_value(

# only context is a valid hint
with contextlib.suppress(KeyError, TypeError):
if issubclass(hint, ContainerInjectableContext):
if is_subclass(hint, ContainerInjectableContext):
# contexts without defaults will raise ContextDefaultCannotBeCreated
return self.container[hint], hint.__name__

Expand Down
Loading
Loading