Skip to content

Commit

Permalink
makes all injection contexts thread affine, except config providers
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Dec 9, 2023
1 parent 5102f77 commit 5f4a489
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 62 deletions.
110 changes: 93 additions & 17 deletions dlt/common/configuration/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import contextmanager
from typing import Dict, Iterator, Type, TypeVar
import re
import threading
from typing import ClassVar, Dict, Iterator, Tuple, Type, TypeVar

from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext
from dlt.common.configuration.exceptions import (
Expand All @@ -16,20 +18,33 @@ class Container:
Injection context is identified by its type and available via dict indexer. The common pattern is to instantiate default context value
if it is not yet present in container.
By default, the context is thread-affine so it is visible only from the thread that originally set it. This behavior may be changed
in particular context type (spec).
The indexer is settable and allows to explicitly set the value. This is required by for context that needs to be explicitly instantiated.
The `injectable_context` allows to set a context with a `with` keyword and then restore the previous one after it gets out of scope.
"""

_INSTANCE: "Container" = None
_INSTANCE: ClassVar["Container"] = None
_LOCK: ClassVar[threading.Lock] = threading.Lock()
_MAIN_THREAD_ID: ClassVar[int] = threading.get_ident()
"""A main thread id to which get item will fallback for contexts without default"""

contexts: Dict[Type[ContainerInjectableContext], ContainerInjectableContext]
thread_contexts: Dict[int, Dict[Type[ContainerInjectableContext], ContainerInjectableContext]]
"""A thread aware mapping of injection context """
main_context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext]
"""Injection context for the main thread"""

def __new__(cls: Type["Container"]) -> "Container":
if not cls._INSTANCE:
cls._INSTANCE = super().__new__(cls)
cls._INSTANCE.contexts = {}
cls._INSTANCE.thread_contexts = {}
cls._INSTANCE.main_context = cls._INSTANCE.thread_contexts[
Container._MAIN_THREAD_ID
] = {}

return cls._INSTANCE

def __init__(self) -> None:
Expand All @@ -40,48 +55,109 @@ def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration:
if not issubclass(spec, ContainerInjectableContext):
raise KeyError(f"{spec.__name__} is not a context")

item = self.contexts.get(spec)
context, item = self._thread_getitem(spec)
if item is None:
if spec.can_create_default:
item = spec()
self.contexts[spec] = item
self._thread_setitem(context, spec, item)
item.add_extras()
else:
raise ContextDefaultCannotBeCreated(spec)

return item # type: ignore
return item # type: ignore[return-value]

def __setitem__(self, spec: Type[TConfiguration], value: TConfiguration) -> None:
# value passed to container must be final
value.resolve()
# put it into context
self.contexts[spec] = value
self._thread_setitem(self._thread_context(spec), spec, value)

def __delitem__(self, spec: Type[TConfiguration]) -> None:
del self.contexts[spec]
context = self._thread_context(spec)
self._thread_delitem(context, spec)

def __contains__(self, spec: Type[TConfiguration]) -> bool:
return spec in self.contexts
context = self._thread_context(spec)
return spec in context

def _thread_context(
self, spec: Type[TConfiguration]
) -> Dict[Type[ContainerInjectableContext], ContainerInjectableContext]:
if spec.global_affinity:
context = self.main_context
else:
# thread pool names used in dlt contain originating thread id. use this id over pool id
if m := re.match(r"dlt-pool-(\d+)-", threading.currentThread().getName()):
thread_id = int(m.group(1))
else:
thread_id = threading.get_ident()
# we may add a new empty thread context so lock here
with Container._LOCK:
context = self.thread_contexts.get(thread_id)
if context is None:
context = self.thread_contexts[thread_id] = {}
return context

def _thread_getitem(
self, spec: Type[TConfiguration]
) -> Tuple[
Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
ContainerInjectableContext,
]:
# with Container._LOCK:
context = self._thread_context(spec)
item = context.get(spec)
# if item is None and not spec.thread_affinity and context is not self.main_context:
# item = self.main_context.get(spec)
return context, item

def _thread_setitem(
self,
context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
spec: Type[ContainerInjectableContext],
value: TConfiguration,
) -> None:
# with Container._LOCK:
context[spec] = value
# set the global context if spec is not thread affine
# if not spec.thread_affinity and context is not self.main_context:
# self.main_context[spec] = value

def _thread_delitem(
self,
context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
spec: Type[ContainerInjectableContext],
) -> None:
del context[spec]

@contextmanager
def injectable_context(self, config: TConfiguration) -> Iterator[TConfiguration]:
"""A context manager that will insert `config` into the container and restore the previous value when it gets out of scope."""
config.resolve()
spec = type(config)
previous_config: ContainerInjectableContext = None
if spec in self.contexts:
previous_config = self.contexts[spec]
context, previous_config = self._thread_getitem(spec)

# set new config and yield context
self._thread_setitem(context, spec, config)
try:
self[spec] = config
yield config
finally:
# before setting the previous config for given spec, check if there was no overlapping modification
if self.contexts[spec] is config:
context, current_config = self._thread_getitem(spec)
if current_config is config:
# config is injected for spec so restore previous
if previous_config is None:
del self.contexts[spec]
self._thread_delitem(context, spec)
else:
self.contexts[spec] = previous_config
self._thread_setitem(context, spec, previous_config)
else:
# value was modified in the meantime and not restored
raise ContainerInjectableContextMangled(spec, self.contexts[spec], config)
raise ContainerInjectableContextMangled(spec, context[spec], config)

@staticmethod
def thread_pool_prefix() -> str:
"""Creates a container friendly pool prefix that contains starting thread id. Container implementation will automatically use it
for any thread-affine contexts instead of using id of the pool thread
"""
return f"dlt-pool-{threading.get_ident()}-"
18 changes: 8 additions & 10 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_ORIGINAL_ARGS = "_dlt_orig_args"
# keep a registry of all the decorated functions
_FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {}
_RESOLVE_LOCK = threading.Lock()

TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration)

Expand Down Expand Up @@ -146,15 +145,14 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
sections=curr_sections,
merge_style=sections_merge_style,
)
# this may be called from many threads so make sure context is not mangled
with _RESOLVE_LOCK:
with inject_section(section_context):
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}")
config = resolve_configuration(
config or SPEC(),
explicit_value=bound_args.arguments,
accept_partial=accept_partial,
)
# this may be called from many threads so section_context is thread affine
with inject_section(section_context):
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}")
config = resolve_configuration(
config or SPEC(),
explicit_value=bound_args.arguments,
accept_partial=accept_partial,
)
resolved_params = dict(config)
# overwrite or add resolved params
for p in sig.parameters.values():
Expand Down
2 changes: 2 additions & 0 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ class ContainerInjectableContext(BaseConfiguration):

can_create_default: ClassVar[bool] = True
"""If True, `Container` is allowed to create default context instance, if none exists"""
global_affinity: ClassVar[bool] = False
"""If True, `Container` will create context that will be visible in any thread. If False, per thread context is created"""

def add_extras(self) -> None:
"""Called right after context was added to the container. Benefits mostly the config provider injection context which adds extra providers using the initial ones."""
Expand Down
5 changes: 4 additions & 1 deletion dlt/common/configuration/specs/config_providers_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import io
from typing import List
from typing import ClassVar, List

from dlt.common.configuration.exceptions import DuplicateConfigProviderException
from dlt.common.configuration.providers import (
ConfigProvider,
Expand Down Expand Up @@ -34,6 +35,8 @@ class ConfigProvidersConfiguration(BaseConfiguration):
class ConfigProvidersContext(ContainerInjectableContext):
"""Injectable list of providers used by the configuration `resolve` module"""

global_affinity: ClassVar[bool] = True

providers: List[ConfigProvider]
context_provider: ConfigProvider

Expand Down
5 changes: 4 additions & 1 deletion dlt/common/runners/pool_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing_extensions import ParamSpec

from dlt.common import logger, sleep
from dlt.common.configuration.container import Container
from dlt.common.runtime import init
from dlt.common.runners.runnable import Runnable, TExecutor
from dlt.common.runners.configuration import PoolRunnerConfiguration
Expand Down Expand Up @@ -50,7 +51,9 @@ def create_pool(config: PoolRunnerConfiguration) -> Executor:
max_workers=config.workers, mp_context=multiprocessing.get_context()
)
elif config.pool_type == "thread":
return ThreadPoolExecutor(max_workers=config.workers)
return ThreadPoolExecutor(
max_workers=config.workers, thread_name_prefix=Container.thread_pool_prefix()
)
# no pool - single threaded
return NullExecutor()

Expand Down
Loading

0 comments on commit 5f4a489

Please sign in to comment.