Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allows to run parallel pipelines in separate threads #813

Merged
merged 15 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 96 additions & 17 deletions dlt/common/configuration/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import contextmanager
from typing import Dict, Iterator, Type, TypeVar
import re
import threading
from typing import ClassVar, Dict, Iterator, Tuple, Type, TypeVar

from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext
from dlt.common.configuration.exceptions import (
Expand All @@ -16,20 +18,33 @@ class Container:
Injection context is identified by its type and available via dict indexer. The common pattern is to instantiate default context value
if it is not yet present in container.

By default, the context is thread-affine so it is visible only from the thread that originally set it. This behavior may be changed
in particular context type (spec).

The indexer is settable and allows to explicitly set the value. This is required by for context that needs to be explicitly instantiated.

The `injectable_context` allows to set a context with a `with` keyword and then restore the previous one after it gets out of scope.

"""

_INSTANCE: "Container" = None
_INSTANCE: ClassVar["Container"] = None
_LOCK: ClassVar[threading.Lock] = threading.Lock()
_MAIN_THREAD_ID: ClassVar[int] = threading.get_ident()
"""A main thread id to which get item will fallback for contexts without default"""

contexts: Dict[Type[ContainerInjectableContext], ContainerInjectableContext]
thread_contexts: Dict[int, Dict[Type[ContainerInjectableContext], ContainerInjectableContext]]
"""A thread aware mapping of injection context """
main_context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext]
"""Injection context for the main thread"""

def __new__(cls: Type["Container"]) -> "Container":
if not cls._INSTANCE:
cls._INSTANCE = super().__new__(cls)
cls._INSTANCE.contexts = {}
cls._INSTANCE.thread_contexts = {}
cls._INSTANCE.main_context = cls._INSTANCE.thread_contexts[
Container._MAIN_THREAD_ID
] = {}

return cls._INSTANCE

def __init__(self) -> None:
Expand All @@ -40,48 +55,112 @@ def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration:
if not issubclass(spec, ContainerInjectableContext):
raise KeyError(f"{spec.__name__} is not a context")

item = self.contexts.get(spec)
context, item = self._thread_getitem(spec)
if item is None:
if spec.can_create_default:
item = spec()
self.contexts[spec] = item
self._thread_setitem(context, spec, item)
item.add_extras()
else:
raise ContextDefaultCannotBeCreated(spec)

return item # type: ignore
return item # type: ignore[return-value]

def __setitem__(self, spec: Type[TConfiguration], value: TConfiguration) -> None:
# value passed to container must be final
value.resolve()
# put it into context
self.contexts[spec] = value
self._thread_setitem(self._thread_context(spec), spec, value)

def __delitem__(self, spec: Type[TConfiguration]) -> None:
del self.contexts[spec]
context = self._thread_context(spec)
self._thread_delitem(context, spec)

def __contains__(self, spec: Type[TConfiguration]) -> bool:
return spec in self.contexts
context = self._thread_context(spec)
return spec in context

def _thread_context(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you use pythons thread local context to do all this? https://docs.python.org/3/library/threading.html#thread-local-data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I know it but when you look at the code, there are exceptions to that behavior.

  1. some type of context are available globally (I use main thread id)
  2. there's a special treatment of the executor thread pool. I use a context of a thread that started a pool, not the current thread

so yeah I could use local() but there are exceptions so I'd need to keep more dictionaries. or you can force the thread id for local()?

self, spec: Type[TConfiguration]
) -> Dict[Type[ContainerInjectableContext], ContainerInjectableContext]:
if spec.global_affinity:
context = self.main_context
else:
# thread pool names used in dlt contain originating thread id. use this id over pool id
if m := re.match(r"dlt-pool-(\d+)-", threading.currentThread().getName()):
thread_id = int(m.group(1))
else:
thread_id = threading.get_ident()
# return main context for main thread
if thread_id == Container._MAIN_THREAD_ID:
return self.main_context
# we may add a new empty thread context so lock here
with Container._LOCK:
context = self.thread_contexts.get(thread_id)
if context is None:
context = self.thread_contexts[thread_id] = {}
return context

def _thread_getitem(
self, spec: Type[TConfiguration]
) -> Tuple[
Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
ContainerInjectableContext,
]:
# with Container._LOCK:
context = self._thread_context(spec)
item = context.get(spec)
# if item is None and not spec.thread_affinity and context is not self.main_context:
# item = self.main_context.get(spec)
return context, item

def _thread_setitem(
self,
context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
spec: Type[ContainerInjectableContext],
value: TConfiguration,
) -> None:
# with Container._LOCK:
context[spec] = value
# set the global context if spec is not thread affine
# if not spec.thread_affinity and context is not self.main_context:
# self.main_context[spec] = value

def _thread_delitem(
self,
context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext],
spec: Type[ContainerInjectableContext],
) -> None:
del context[spec]

@contextmanager
def injectable_context(self, config: TConfiguration) -> Iterator[TConfiguration]:
"""A context manager that will insert `config` into the container and restore the previous value when it gets out of scope."""
config.resolve()
spec = type(config)
previous_config: ContainerInjectableContext = None
if spec in self.contexts:
previous_config = self.contexts[spec]
context, previous_config = self._thread_getitem(spec)

# set new config and yield context
self._thread_setitem(context, spec, config)
try:
self[spec] = config
yield config
finally:
# before setting the previous config for given spec, check if there was no overlapping modification
if self.contexts[spec] is config:
context, current_config = self._thread_getitem(spec)
if current_config is config:
# config is injected for spec so restore previous
if previous_config is None:
del self.contexts[spec]
self._thread_delitem(context, spec)
else:
self.contexts[spec] = previous_config
self._thread_setitem(context, spec, previous_config)
else:
# value was modified in the meantime and not restored
raise ContainerInjectableContextMangled(spec, self.contexts[spec], config)
raise ContainerInjectableContextMangled(spec, context[spec], config)

@staticmethod
def thread_pool_prefix() -> str:
"""Creates a container friendly pool prefix that contains starting thread id. Container implementation will automatically use it
for any thread-affine contexts instead of using id of the pool thread
"""
return f"dlt-pool-{threading.get_ident()}-"
19 changes: 8 additions & 11 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import threading
from functools import wraps
from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload
from inspect import Signature, Parameter
Expand All @@ -15,7 +14,6 @@
_ORIGINAL_ARGS = "_dlt_orig_args"
# keep a registry of all the decorated functions
_FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {}
_RESOLVE_LOCK = threading.Lock()

TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration)

Expand Down Expand Up @@ -146,15 +144,14 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
sections=curr_sections,
merge_style=sections_merge_style,
)
# this may be called from many threads so make sure context is not mangled
with _RESOLVE_LOCK:
with inject_section(section_context):
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}")
config = resolve_configuration(
config or SPEC(),
explicit_value=bound_args.arguments,
accept_partial=accept_partial,
)
# this may be called from many threads so section_context is thread affine
with inject_section(section_context):
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}")
config = resolve_configuration(
config or SPEC(),
explicit_value=bound_args.arguments,
accept_partial=accept_partial,
)
resolved_params = dict(config)
# overwrite or add resolved params
for p in sig.parameters.values():
Expand Down
12 changes: 8 additions & 4 deletions dlt/common/configuration/providers/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,14 @@ def set_value(self, key: str, value: Any, pipeline_name: str, *sections: str) ->
if k not in master:
master[k] = tomlkit.table()
master = master[k] # type: ignore
if isinstance(value, dict) and isinstance(master.get(key), dict):
update_dict_nested(master[key], value) # type: ignore
else:
master[key] = value
if isinstance(value, dict):
# remove none values, TODO: we need recursive None removal
value = {k: v for k, v in value.items() if v is not None}
# if target is also dict then merge recursively
if isinstance(master.get(key), dict):
update_dict_nested(master[key], value) # type: ignore
return
master[key] = value

@property
def supports_sections(self) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ class ContainerInjectableContext(BaseConfiguration):

can_create_default: ClassVar[bool] = True
"""If True, `Container` is allowed to create default context instance, if none exists"""
global_affinity: ClassVar[bool] = False
"""If True, `Container` will create context that will be visible in any thread. If False, per thread context is created"""

def add_extras(self) -> None:
"""Called right after context was added to the container. Benefits mostly the config provider injection context which adds extra providers using the initial ones."""
Expand Down
5 changes: 4 additions & 1 deletion dlt/common/configuration/specs/config_providers_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import io
from typing import List
from typing import ClassVar, List

from dlt.common.configuration.exceptions import DuplicateConfigProviderException
from dlt.common.configuration.providers import (
ConfigProvider,
Expand Down Expand Up @@ -34,6 +35,8 @@ class ConfigProvidersConfiguration(BaseConfiguration):
class ConfigProvidersContext(ContainerInjectableContext):
"""Injectable list of providers used by the configuration `resolve` module"""

global_affinity: ClassVar[bool] = True

providers: List[ConfigProvider]
context_provider: ConfigProvider

Expand Down
4 changes: 1 addition & 3 deletions dlt/common/configuration/specs/run_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ class RunConfiguration(BaseConfiguration):
slack_incoming_hook: Optional[TSecretStrValue] = None
dlthub_telemetry: bool = True # enable or disable dlthub telemetry
dlthub_telemetry_segment_write_key: str = "a1F2gc6cNYw2plyAt02sZouZcsRjG7TD"
log_format: str = (
"{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}"
)
log_format: str = "{asctime}|[{levelname:<21}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}"
log_level: str = "WARNING"
request_timeout: float = 60
"""Timeout for http requests"""
Expand Down
6 changes: 4 additions & 2 deletions dlt/common/data_writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dlt.common.data_writers.writers import DataWriter, TLoaderFileFormat
from dlt.common.data_writers.buffered import BufferedDataWriter
from dlt.common.data_writers.writers import DataWriter, DataWriterMetrics, TLoaderFileFormat
from dlt.common.data_writers.buffered import BufferedDataWriter, new_file_id
from dlt.common.data_writers.escape import (
escape_redshift_literal,
escape_redshift_identifier,
Expand All @@ -8,8 +8,10 @@

__all__ = [
"DataWriter",
"DataWriterMetrics",
"TLoaderFileFormat",
"BufferedDataWriter",
"new_file_id",
"escape_redshift_literal",
"escape_redshift_identifier",
"escape_bigquery_identifier",
Expand Down
31 changes: 24 additions & 7 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
import gzip
from typing import List, IO, Any, Optional, Type, TypeVar, Generic

from dlt.common.utils import uniq_id
from dlt.common.typing import TDataItem, TDataItems
from dlt.common.data_writers import TLoaderFileFormat
from dlt.common.data_writers.exceptions import (
BufferedDataWriterClosed,
DestinationCapabilitiesRequired,
InvalidFileNameTemplateException,
)
from dlt.common.data_writers.writers import DataWriter
from dlt.common.data_writers.writers import DataWriter, DataWriterMetrics
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.configuration import with_config, known_sections, configspec
from dlt.common.configuration.specs import BaseConfiguration
from dlt.common.destination import DestinationCapabilitiesContext

from dlt.common.utils import uniq_id

TWriter = TypeVar("TWriter", bound=DataWriter)


def new_file_id() -> str:
"""Creates new file id which is globally unique within table_name scope"""
return uniq_id(5)


class BufferedDataWriter(Generic[TWriter]):
@configspec
class BufferedDataWriterConfiguration(BaseConfiguration):
Expand Down Expand Up @@ -49,7 +53,7 @@ def __init__(
self._caps = _caps
# validate if template has correct placeholders
self.file_name_template = file_name_template
self.closed_files: List[str] = [] # all fully processed files
self.closed_files: List[DataWriterMetrics] = [] # all fully processed files
# buffered items must be less than max items in file
self.buffer_max_items = min(buffer_max_items, file_max_items or buffer_max_items)
self.file_max_bytes = file_max_bytes
Expand Down Expand Up @@ -121,10 +125,20 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int
return new_rows_count

def write_empty_file(self, columns: TTableSchemaColumns) -> None:
"""Writes empty file: only header and footer without actual items"""
if columns is not None:
self._current_columns = dict(columns)
self._flush_items(allow_empty_file=True)

def import_file(self, file_path: str, metrics: DataWriterMetrics) -> None:
# TODO: we should separate file storage from other storages. this creates circular deps
from dlt.common.storages import FileStorage

self._rotate_file()
FileStorage.link_hard_with_fallback(file_path, self._file_name)
self.closed_files.append(metrics._replace(file_path=self._file_name))
self._file_name = None

def close(self) -> None:
self._ensure_open()
self._flush_and_close_file()
Expand All @@ -143,7 +157,7 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb
def _rotate_file(self) -> None:
self._flush_and_close_file()
self._file_name = (
self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension
self.file_name_template % new_file_id() + "." + self._file_format_spec.file_extension
)

def _flush_items(self, allow_empty_file: bool = False) -> None:
Expand Down Expand Up @@ -171,9 +185,12 @@ def _flush_and_close_file(self) -> None:
if self._writer:
# write the footer of a file
self._writer.write_footer()
self._file.close()
self._file.flush()
# add file written to the list so we can commit all the files later
self.closed_files.append(self._file_name)
self.closed_files.append(
DataWriterMetrics(self._file_name, self._writer.items_count, self._file.tell())
)
self._file.close()
self._writer = None
self._file = None
self._file_name = None
Expand Down
8 changes: 7 additions & 1 deletion dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, NamedTuple

from dlt.common import json
from dlt.common.configuration import configspec, known_sections, with_config
Expand All @@ -23,6 +23,12 @@ class TFileFormatSpec:
supports_compression: bool = False


class DataWriterMetrics(NamedTuple):
file_path: str
items_count: int
file_size: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe column count? but that is not really important tbh.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to add elapsed time (start stop). Column count is not known at this moment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not during extract. but it is known during normalize. you can however get the column count from the relevant schema...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes elapsed would be cool too!



class DataWriter(abc.ABC):
def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> None:
self._f = f
Expand Down
Loading