diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index c410d18dd9..57f2121f18 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -1,5 +1,7 @@ from contextlib import contextmanager -from typing import Dict, Iterator, Type, TypeVar +import re +import threading +from typing import ClassVar, Dict, Iterator, Tuple, Type, TypeVar from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext from dlt.common.configuration.exceptions import ( @@ -16,20 +18,33 @@ class Container: Injection context is identified by its type and available via dict indexer. The common pattern is to instantiate default context value if it is not yet present in container. + By default, the context is thread-affine so it is visible only from the thread that originally set it. This behavior may be changed + in particular context type (spec). + The indexer is settable and allows to explicitly set the value. This is required by for context that needs to be explicitly instantiated. The `injectable_context` allows to set a context with a `with` keyword and then restore the previous one after it gets out of scope. """ - _INSTANCE: "Container" = None + _INSTANCE: ClassVar["Container"] = None + _LOCK: ClassVar[threading.Lock] = threading.Lock() + _MAIN_THREAD_ID: ClassVar[int] = threading.get_ident() + """A main thread id to which get item will fallback for contexts without default""" - contexts: Dict[Type[ContainerInjectableContext], ContainerInjectableContext] + thread_contexts: Dict[int, Dict[Type[ContainerInjectableContext], ContainerInjectableContext]] + """A thread aware mapping of injection context """ + main_context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext] + """Injection context for the main thread""" def __new__(cls: Type["Container"]) -> "Container": if not cls._INSTANCE: cls._INSTANCE = super().__new__(cls) - cls._INSTANCE.contexts = {} + cls._INSTANCE.thread_contexts = {} + cls._INSTANCE.main_context = cls._INSTANCE.thread_contexts[ + Container._MAIN_THREAD_ID + ] = {} + return cls._INSTANCE def __init__(self) -> None: @@ -40,48 +55,112 @@ def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: if not issubclass(spec, ContainerInjectableContext): raise KeyError(f"{spec.__name__} is not a context") - item = self.contexts.get(spec) + context, item = self._thread_getitem(spec) if item is None: if spec.can_create_default: item = spec() - self.contexts[spec] = item + self._thread_setitem(context, spec, item) item.add_extras() else: raise ContextDefaultCannotBeCreated(spec) - return item # type: ignore + return item # type: ignore[return-value] def __setitem__(self, spec: Type[TConfiguration], value: TConfiguration) -> None: # value passed to container must be final value.resolve() # put it into context - self.contexts[spec] = value + self._thread_setitem(self._thread_context(spec), spec, value) def __delitem__(self, spec: Type[TConfiguration]) -> None: - del self.contexts[spec] + context = self._thread_context(spec) + self._thread_delitem(context, spec) def __contains__(self, spec: Type[TConfiguration]) -> bool: - return spec in self.contexts + context = self._thread_context(spec) + return spec in context + + def _thread_context( + self, spec: Type[TConfiguration] + ) -> Dict[Type[ContainerInjectableContext], ContainerInjectableContext]: + if spec.global_affinity: + context = self.main_context + else: + # thread pool names used in dlt contain originating thread id. use this id over pool id + if m := re.match(r"dlt-pool-(\d+)-", threading.currentThread().getName()): + thread_id = int(m.group(1)) + else: + thread_id = threading.get_ident() + # return main context for main thread + if thread_id == Container._MAIN_THREAD_ID: + return self.main_context + # we may add a new empty thread context so lock here + with Container._LOCK: + context = self.thread_contexts.get(thread_id) + if context is None: + context = self.thread_contexts[thread_id] = {} + return context + + def _thread_getitem( + self, spec: Type[TConfiguration] + ) -> Tuple[ + Dict[Type[ContainerInjectableContext], ContainerInjectableContext], + ContainerInjectableContext, + ]: + # with Container._LOCK: + context = self._thread_context(spec) + item = context.get(spec) + # if item is None and not spec.thread_affinity and context is not self.main_context: + # item = self.main_context.get(spec) + return context, item + + def _thread_setitem( + self, + context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext], + spec: Type[ContainerInjectableContext], + value: TConfiguration, + ) -> None: + # with Container._LOCK: + context[spec] = value + # set the global context if spec is not thread affine + # if not spec.thread_affinity and context is not self.main_context: + # self.main_context[spec] = value + + def _thread_delitem( + self, + context: Dict[Type[ContainerInjectableContext], ContainerInjectableContext], + spec: Type[ContainerInjectableContext], + ) -> None: + del context[spec] @contextmanager def injectable_context(self, config: TConfiguration) -> Iterator[TConfiguration]: """A context manager that will insert `config` into the container and restore the previous value when it gets out of scope.""" + config.resolve() spec = type(config) previous_config: ContainerInjectableContext = None - if spec in self.contexts: - previous_config = self.contexts[spec] + context, previous_config = self._thread_getitem(spec) + # set new config and yield context + self._thread_setitem(context, spec, config) try: - self[spec] = config yield config finally: # before setting the previous config for given spec, check if there was no overlapping modification - if self.contexts[spec] is config: + context, current_config = self._thread_getitem(spec) + if current_config is config: # config is injected for spec so restore previous if previous_config is None: - del self.contexts[spec] + self._thread_delitem(context, spec) else: - self.contexts[spec] = previous_config + self._thread_setitem(context, spec, previous_config) else: # value was modified in the meantime and not restored - raise ContainerInjectableContextMangled(spec, self.contexts[spec], config) + raise ContainerInjectableContextMangled(spec, context[spec], config) + + @staticmethod + def thread_pool_prefix() -> str: + """Creates a container friendly pool prefix that contains starting thread id. Container implementation will automatically use it + for any thread-affine contexts instead of using id of the pool thread + """ + return f"dlt-pool-{threading.get_ident()}-" diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index 6478c3258c..a22f299ae8 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -1,5 +1,4 @@ import inspect -import threading from functools import wraps from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload from inspect import Signature, Parameter @@ -15,7 +14,6 @@ _ORIGINAL_ARGS = "_dlt_orig_args" # keep a registry of all the decorated functions _FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {} -_RESOLVE_LOCK = threading.Lock() TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) @@ -146,15 +144,14 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: sections=curr_sections, merge_style=sections_merge_style, ) - # this may be called from many threads so make sure context is not mangled - with _RESOLVE_LOCK: - with inject_section(section_context): - # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") - config = resolve_configuration( - config or SPEC(), - explicit_value=bound_args.arguments, - accept_partial=accept_partial, - ) + # this may be called from many threads so section_context is thread affine + with inject_section(section_context): + # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") + config = resolve_configuration( + config or SPEC(), + explicit_value=bound_args.arguments, + accept_partial=accept_partial, + ) resolved_params = dict(config) # overwrite or add resolved params for p in sig.parameters.values(): diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index 3c4fa2c145..7c856e8c27 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -72,10 +72,14 @@ def set_value(self, key: str, value: Any, pipeline_name: str, *sections: str) -> if k not in master: master[k] = tomlkit.table() master = master[k] # type: ignore - if isinstance(value, dict) and isinstance(master.get(key), dict): - update_dict_nested(master[key], value) # type: ignore - else: - master[key] = value + if isinstance(value, dict): + # remove none values, TODO: we need recursive None removal + value = {k: v for k, v in value.items() if v is not None} + # if target is also dict then merge recursively + if isinstance(master.get(key), dict): + update_dict_nested(master[key], value) # type: ignore + return + master[key] = value @property def supports_sections(self) -> bool: diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index c96ac1c4fc..f526ec0841 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -395,6 +395,8 @@ class ContainerInjectableContext(BaseConfiguration): can_create_default: ClassVar[bool] = True """If True, `Container` is allowed to create default context instance, if none exists""" + global_affinity: ClassVar[bool] = False + """If True, `Container` will create context that will be visible in any thread. If False, per thread context is created""" def add_extras(self) -> None: """Called right after context was added to the container. Benefits mostly the config provider injection context which adds extra providers using the initial ones.""" diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 0c852edfa5..860e7414de 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -1,6 +1,7 @@ import contextlib import io -from typing import List +from typing import ClassVar, List + from dlt.common.configuration.exceptions import DuplicateConfigProviderException from dlt.common.configuration.providers import ( ConfigProvider, @@ -34,6 +35,8 @@ class ConfigProvidersConfiguration(BaseConfiguration): class ConfigProvidersContext(ContainerInjectableContext): """Injectable list of providers used by the configuration `resolve` module""" + global_affinity: ClassVar[bool] = True + providers: List[ConfigProvider] context_provider: ConfigProvider diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index 78cca1fbad..54ce46ceba 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -16,9 +16,7 @@ class RunConfiguration(BaseConfiguration): slack_incoming_hook: Optional[TSecretStrValue] = None dlthub_telemetry: bool = True # enable or disable dlthub telemetry dlthub_telemetry_segment_write_key: str = "a1F2gc6cNYw2plyAt02sZouZcsRjG7TD" - log_format: str = ( - "{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}" - ) + log_format: str = "{asctime}|[{levelname:<21}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}" log_level: str = "WARNING" request_timeout: float = 60 """Timeout for http requests""" diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py index fefe2d6486..04c5d04328 100644 --- a/dlt/common/data_writers/__init__.py +++ b/dlt/common/data_writers/__init__.py @@ -1,5 +1,5 @@ -from dlt.common.data_writers.writers import DataWriter, TLoaderFileFormat -from dlt.common.data_writers.buffered import BufferedDataWriter +from dlt.common.data_writers.writers import DataWriter, DataWriterMetrics, TLoaderFileFormat +from dlt.common.data_writers.buffered import BufferedDataWriter, new_file_id from dlt.common.data_writers.escape import ( escape_redshift_literal, escape_redshift_identifier, @@ -8,8 +8,10 @@ __all__ = [ "DataWriter", + "DataWriterMetrics", "TLoaderFileFormat", "BufferedDataWriter", + "new_file_id", "escape_redshift_literal", "escape_redshift_identifier", "escape_bigquery_identifier", diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 1c95daa979..d8ba8b9075 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,7 +1,6 @@ import gzip from typing import List, IO, Any, Optional, Type, TypeVar, Generic -from dlt.common.utils import uniq_id from dlt.common.typing import TDataItem, TDataItems from dlt.common.data_writers import TLoaderFileFormat from dlt.common.data_writers.exceptions import ( @@ -9,16 +8,21 @@ DestinationCapabilitiesRequired, InvalidFileNameTemplateException, ) -from dlt.common.data_writers.writers import DataWriter +from dlt.common.data_writers.writers import DataWriter, DataWriterMetrics from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.configuration import with_config, known_sections, configspec from dlt.common.configuration.specs import BaseConfiguration from dlt.common.destination import DestinationCapabilitiesContext - +from dlt.common.utils import uniq_id TWriter = TypeVar("TWriter", bound=DataWriter) +def new_file_id() -> str: + """Creates new file id which is globally unique within table_name scope""" + return uniq_id(5) + + class BufferedDataWriter(Generic[TWriter]): @configspec class BufferedDataWriterConfiguration(BaseConfiguration): @@ -49,7 +53,7 @@ def __init__( self._caps = _caps # validate if template has correct placeholders self.file_name_template = file_name_template - self.closed_files: List[str] = [] # all fully processed files + self.closed_files: List[DataWriterMetrics] = [] # all fully processed files # buffered items must be less than max items in file self.buffer_max_items = min(buffer_max_items, file_max_items or buffer_max_items) self.file_max_bytes = file_max_bytes @@ -121,10 +125,20 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int return new_rows_count def write_empty_file(self, columns: TTableSchemaColumns) -> None: + """Writes empty file: only header and footer without actual items""" if columns is not None: self._current_columns = dict(columns) self._flush_items(allow_empty_file=True) + def import_file(self, file_path: str, metrics: DataWriterMetrics) -> None: + # TODO: we should separate file storage from other storages. this creates circular deps + from dlt.common.storages import FileStorage + + self._rotate_file() + FileStorage.link_hard_with_fallback(file_path, self._file_name) + self.closed_files.append(metrics._replace(file_path=self._file_name)) + self._file_name = None + def close(self) -> None: self._ensure_open() self._flush_and_close_file() @@ -143,7 +157,7 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb def _rotate_file(self) -> None: self._flush_and_close_file() self._file_name = ( - self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension + self.file_name_template % new_file_id() + "." + self._file_format_spec.file_extension ) def _flush_items(self, allow_empty_file: bool = False) -> None: @@ -171,9 +185,12 @@ def _flush_and_close_file(self) -> None: if self._writer: # write the footer of a file self._writer.write_footer() - self._file.close() + self._file.flush() # add file written to the list so we can commit all the files later - self.closed_files.append(self._file_name) + self.closed_files.append( + DataWriterMetrics(self._file_name, self._writer.items_count, self._file.tell()) + ) + self._file.close() self._writer = None self._file = None self._file_name = None diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 2801656dc3..b0030951a8 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -1,6 +1,6 @@ import abc from dataclasses import dataclass -from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union +from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, NamedTuple from dlt.common import json from dlt.common.configuration import configspec, known_sections, with_config @@ -23,6 +23,12 @@ class TFileFormatSpec: supports_compression: bool = False +class DataWriterMetrics(NamedTuple): + file_path: str + items_count: int + file_size: int + + class DataWriter(abc.ABC): def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> None: self._f = f diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 6665feff5d..07b8871a85 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -227,7 +227,7 @@ def file_name(self) -> str: return self._file_name def job_id(self) -> str: - """The job id that is derived from the file name""" + """The job id that is derived from the file name and does not changes during job lifecycle""" return self._parsed_file_name.job_id() def job_file_info(self) -> ParsedLoadJobFileName: diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 6feee2a812..9f12adf3a7 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -400,12 +400,16 @@ class PipelineContext(ContainerInjectableContext): _deferred_pipeline: Callable[[], SupportsPipeline] _pipeline: SupportsPipeline - can_create_default: ClassVar[bool] = False + can_create_default: ClassVar[bool] = True def pipeline(self) -> SupportsPipeline: """Creates or returns exiting pipeline""" if not self._pipeline: # delayed pipeline creation + assert self._deferred_pipeline is not None, ( + "Deferred pipeline creation function not provided to PipelineContext. Are you" + " calling dlt.pipeline() from another thread?" + ) self.activate(self._deferred_pipeline()) return self._pipeline @@ -425,7 +429,7 @@ def deactivate(self) -> None: self._pipeline._set_context(False) self._pipeline = None - def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline]) -> None: + def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> None: """Initialize the context with a function returning the Pipeline object to allow creation on first use""" self._deferred_pipeline = deferred_pipeline diff --git a/dlt/common/runners/configuration.py b/dlt/common/runners/configuration.py index 6953c72cf1..c5de2353f4 100644 --- a/dlt/common/runners/configuration.py +++ b/dlt/common/runners/configuration.py @@ -8,10 +8,21 @@ @configspec class PoolRunnerConfiguration(BaseConfiguration): - pool_type: TPoolType = None # type of pool to run, must be set in derived configs - workers: Optional[int] = None # how many threads/processes in the pool - run_sleep: float = 0.1 # how long to sleep between runs with workload, seconds + pool_type: TPoolType = None + """type of pool to run, must be set in derived configs""" + start_method: Optional[str] = None + """start method for the pool (typically process). None is system default""" + workers: Optional[int] = None + """# how many threads/processes in the pool""" + run_sleep: float = 0.1 + """how long to sleep between runs with workload, seconds""" if TYPE_CHECKING: - def __init__(self, pool_type: TPoolType = None, workers: int = None) -> None: ... + def __init__( + self, + pool_type: TPoolType = None, + start_method: str = None, + workers: int = None, + run_sleep: float = 0.1, + ) -> None: ... diff --git a/dlt/common/runners/pool_runner.py b/dlt/common/runners/pool_runner.py index 31a809dc9c..491c74cd18 100644 --- a/dlt/common/runners/pool_runner.py +++ b/dlt/common/runners/pool_runner.py @@ -5,6 +5,7 @@ from typing_extensions import ParamSpec from dlt.common import logger, sleep +from dlt.common.configuration.container import Container from dlt.common.runtime import init from dlt.common.runners.runnable import Runnable, TExecutor from dlt.common.runners.configuration import PoolRunnerConfiguration @@ -38,19 +39,22 @@ def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Futur def create_pool(config: PoolRunnerConfiguration) -> Executor: if config.pool_type == "process": # if not fork method, provide initializer for logs and configuration - if multiprocessing.get_start_method() != "fork" and init._INITIALIZED: + start_method = config.start_method or multiprocessing.get_start_method() + if start_method != "fork" and init._INITIALIZED: return ProcessPoolExecutor( max_workers=config.workers, initializer=init.initialize_runtime, initargs=(init._RUN_CONFIGURATION,), - mp_context=multiprocessing.get_context(), + mp_context=multiprocessing.get_context(method=start_method), ) else: return ProcessPoolExecutor( max_workers=config.workers, mp_context=multiprocessing.get_context() ) elif config.pool_type == "thread": - return ThreadPoolExecutor(max_workers=config.workers) + return ThreadPoolExecutor( + max_workers=config.workers, thread_name_prefix=Container.thread_pool_prefix() + ) # no pool - single threaded return NullExecutor() diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 04e5302794..a338c7086f 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -3,9 +3,10 @@ from abc import ABC, abstractmethod from dlt.common import logger +from dlt.common.destination import TLoaderFileFormat from dlt.common.schema import TTableSchemaColumns from dlt.common.typing import StrAny, TDataItems -from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter, DataWriter +from dlt.common.data_writers import BufferedDataWriter, DataWriter, DataWriterMetrics class DataItemStorage(ABC): @@ -39,12 +40,25 @@ def write_data_item( # write item(s) return writer.write_data_item(item, columns) - def write_empty_file( + def write_empty_items_file( self, load_id: str, schema_name: str, table_name: str, columns: TTableSchemaColumns ) -> None: + """Writes empty file: only header and footer without actual items""" writer = self.get_writer(load_id, schema_name, table_name) writer.write_empty_file(columns) + def import_items_file( + self, + load_id: str, + schema_name: str, + table_name: str, + file_path: str, + metrics: DataWriterMetrics, + ) -> None: + """Imports external file from `file_path`. Requires external metrics to be passed as internal data writer is not used.""" + writer = self.get_writer(load_id, schema_name, table_name) + writer.import_file(file_path, metrics) + def close_writers(self, load_id: str) -> None: # flush and close all files for name, writer in self.buffered_writers.items(): @@ -55,8 +69,8 @@ def close_writers(self, load_id: str) -> None: ) writer.close() - def closed_files(self) -> List[str]: - files: List[str] = [] + def closed_files(self) -> List[DataWriterMetrics]: + files: List[DataWriterMetrics] = [] for writer in self.buffered_writers.values(): files.extend(writer.closed_files) diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index 30f372e692..22d6dfaf79 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -2,7 +2,7 @@ from typing import Iterable from dlt.common.exceptions import DltException, TerminalValueError -from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.destination import TLoaderFileFormat class StorageException(DltException): diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 2fa0ad6713..889a7c64db 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -19,7 +19,7 @@ ) from dlt.common import pendulum, json -from dlt.common.data_writers.writers import DataWriter +from dlt.common.data_writers import DataWriter, new_file_id from dlt.common.destination import TLoaderFileFormat from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaTables @@ -36,14 +36,28 @@ class ParsedLoadJobFileName(NamedTuple): + """Represents a file name of a job in load package. The file name contains name of a table, number of times the job was retired, extension + and a 5 bytes random string to make job file name unique. + The job id does not contain retry count and is immutable during loading of the data + """ + table_name: str file_id: str retry_count: int file_format: TLoaderFileFormat def job_id(self) -> str: + """Unique identifier of the job""" + return f"{self.table_name}.{self.file_id}.{self.file_format}" + + def file_name(self) -> str: + """A name of the file with the data to be loaded""" return f"{self.table_name}.{self.file_id}.{int(self.retry_count)}.{self.file_format}" + def with_retry(self) -> "ParsedLoadJobFileName": + """Returns a job with increased retry count""" + return self._replace(retry_count=self.retry_count + 1) + @staticmethod def parse(file_name: str) -> "ParsedLoadJobFileName": p = Path(file_name) @@ -55,6 +69,10 @@ def parse(file_name: str) -> "ParsedLoadJobFileName": parts[0], parts[1], int(parts[2]), cast(TLoaderFileFormat, parts[3]) ) + @staticmethod + def new_file_id() -> str: + return new_file_id() + def __str__(self) -> str: return self.job_id() @@ -288,19 +306,14 @@ def fail_job(self, load_id: str, file_name: str, failed_message: Optional[str]) def retry_job(self, load_id: str, file_name: str) -> str: # when retrying job we must increase the retry count source_fn = ParsedLoadJobFileName.parse(file_name) - dest_fn = ParsedLoadJobFileName( - source_fn.table_name, - source_fn.file_id, - source_fn.retry_count + 1, - source_fn.file_format, - ) + dest_fn = source_fn.with_retry() # move it directly to new file name return self._move_job( load_id, PackageStorage.STARTED_JOBS_FOLDER, PackageStorage.NEW_JOBS_FOLDER, file_name, - dest_fn.job_id(), + dest_fn.file_name(), ) def complete_job(self, load_id: str, file_name: str) -> str: diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index bd4819462b..fa4f5f0419 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -120,8 +120,8 @@ def state(self) -> TLoadJobState: else: return "running" - def job_id(self) -> str: - return BigQueryLoadJob.get_job_id_from_file_path(super().job_id()) + def bigquery_job_id(self) -> str: + return BigQueryLoadJob.get_job_id_from_file_path(super().file_name()) def exception(self) -> str: exception: str = json.dumps( diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 5ed9e9ce2f..d97a098669 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -37,7 +37,9 @@ def from_table_chain( """ params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) # type: ignore top_table = table_chain[0] - file_info = ParsedLoadJobFileName(top_table["name"], uniq_id()[:10], 0, "sql") + file_info = ParsedLoadJobFileName( + top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "sql" + ) try: # Remove line breaks from multiline statements and write one SQL statement per line in output file # to support clients that need to execute one statement at a time (i.e. snowflake) @@ -45,14 +47,14 @@ def from_table_chain( " ".join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client, params) ] - job = cls(file_info.job_id(), "running") + job = cls(file_info.file_name(), "running") job._save_text_file("\n".join(sql)) except Exception: # return failed job tables_str = yaml.dump( table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False ) - job = cls(file_info.job_id(), "failed", pretty_format_exception()) + job = cls(file_info.file_name(), "failed", pretty_format_exception()) job._save_text_file("\n".join([cls.failed_text, tables_str])) return job diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 9dd9cca5b7..3bd7ad9bbb 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -234,7 +234,7 @@ def _extract_single_source( for table in tables_by_resources[resource.name]: # we only need to write empty files for the top tables if not table.get("parent", None): - extractors["puae-jsonl"].write_empty_file(table["name"]) + extractors["puae-jsonl"].write_empty_items_file(table["name"]) if left_gens > 0: # go to 100% diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index 1f90ef5dc8..f16688a515 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -92,9 +92,9 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No # table has name or other hints depending on data items self._write_to_dynamic_table(resource, items) - def write_empty_file(self, table_name: str) -> None: + def write_empty_items_file(self, table_name: str) -> None: table_name = self.naming.normalize_table_identifier(table_name) - self.storage.write_empty_file(self.load_id, self.schema.name, table_name, None) + self.storage.write_empty_items_file(self.load_id, self.schema.name, table_name, None) def _get_static_table_name(self, resource: DltResource, meta: Any) -> Optional[str]: if resource._table_name_hint_fun: diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 85c654b46c..d3b8725cc7 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -725,7 +725,7 @@ def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: target=start_background_loop, args=(self._async_pool,), daemon=True, - name="DltFuturesThread", + name=Container.thread_pool_prefix() + "futures", ) self._async_pool_thread.start() @@ -737,7 +737,9 @@ def _ensure_thread_pool(self) -> ThreadPoolExecutor: if self._thread_pool: return self._thread_pool - self._thread_pool = ThreadPoolExecutor(self.workers) + self._thread_pool = ThreadPoolExecutor( + self.workers, thread_name_prefix=Container.thread_pool_prefix() + "threads" + ) return self._thread_pool def __enter__(self) -> "PipeIterator": diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 5312b24c80..9247fecfd3 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -1,18 +1,16 @@ -import os from typing import List, Dict, Set, Tuple, Any from abc import abstractmethod from dlt.common import json, logger +from dlt.common.data_writers import DataWriterMetrics from dlt.common.json import custom_pua_decode, may_have_pua from dlt.common.runtime import signals from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict from dlt.common.storages import ( NormalizeStorage, LoadStorage, - FileStorage, - PackageStorage, - ParsedLoadJobFileName, ) +from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import DictStrAny, TDataItem from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.utils import RowCounts, merge_row_counts, increase_row_count @@ -337,19 +335,19 @@ def __call__( with self.normalize_storage.extracted_packages.storage.open_file( extracted_items_file, "rb" ) as f: - items_count = get_row_count(f) + file_metrics = DataWriterMetrics(extracted_items_file, get_row_count(f), f.tell()) parts = ParsedLoadJobFileName.parse(extracted_items_file) - new_file_name = PackageStorage.build_job_file_name( - parts.table_name, parts.file_id, loader_file_format=self.load_storage.loader_file_format - ) - target_file_path = self.load_storage.new_packages.storage.make_full_path( - self.load_storage.new_packages.get_job_file_path( - self.load_id, PackageStorage.NEW_JOBS_FOLDER, new_file_name - ) - ) - FileStorage.link_hard_with_fallback( + self.load_storage.import_items_file( + self.load_id, + self.schema.name, + parts.table_name, self.normalize_storage.extracted_packages.storage.make_full_path(extracted_items_file), - target_file_path, + file_metrics, + ) + + return ( + base_schema_update, + file_metrics.items_count, + {root_table_name: file_metrics.items_count}, ) - return base_schema_update, items_count, {root_table_name: items_count} diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index e03278acbf..7b1b06a0ec 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -1,12 +1,13 @@ import os import datetime # noqa: 251 -from typing import Callable, List, Dict, Sequence, Tuple, Set, Optional +from typing import Callable, List, Dict, NamedTuple, Sequence, Tuple, Set, Optional from concurrent.futures import Future, Executor from dlt.common import logger, sleep from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.configuration.container import Container +from dlt.common.data_writers import DataWriterMetrics from dlt.common.destination import TLoaderFileFormat from dlt.common.runners import TRunMetrics, Runnable, NullExecutor from dlt.common.runtime import signals @@ -40,14 +41,18 @@ ItemsNormalizer, ) -# normalize worker wrapping function (map_parallel, map_single) return type -TMapFuncRV = Tuple[Sequence[TSchemaUpdate], RowCounts] + +class TWorkerRV(NamedTuple): + schema_updates: List[TSchemaUpdate] + total_items: int + file_metrics: List[DataWriterMetrics] + row_counts: RowCounts + + # normalize worker wrapping function signature TMapFuncType = Callable[ - [Schema, str, Sequence[str]], TMapFuncRV + [Schema, str, Sequence[str]], TWorkerRV ] # input parameters: (schema name, load_id, list of files to process) -# tuple returned by the worker -TWorkerRV = Tuple[List[TSchemaUpdate], int, List[str], RowCounts] class Normalize(Runnable[Executor], WithStepInfo[NormalizeMetrics, NormalizeInfo]): @@ -101,9 +106,10 @@ def w_normalize_files( schema_updates: List[TSchemaUpdate] = [] total_items = 0 row_counts: RowCounts = {} - load_storages: Dict[TLoaderFileFormat, LoadStorage] = {} + item_normalizers: Dict[TLoaderFileFormat, ItemsNormalizer] = {} - def _get_load_storage(file_format: TLoaderFileFormat) -> LoadStorage: + def _create_load_storage(file_format: TLoaderFileFormat) -> LoadStorage: + """Creates a load storage for particular file_format""" # TODO: capabilities.supported_*_formats can be None, it should have defaults supported_formats = destination_caps.supported_loader_file_formats or [] if file_format == "parquet": @@ -123,34 +129,21 @@ def _get_load_storage(file_format: TLoaderFileFormat) -> LoadStorage: destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format ) - if storage := load_storages.get(file_format): - return storage - storage = load_storages[file_format] = LoadStorage( - False, file_format, supported_formats, loader_storage_config - ) - return storage + return LoadStorage(False, file_format, supported_formats, loader_storage_config) # process all files with data items and write to buffered item storage with Container().injectable_context(destination_caps): schema = Schema.from_stored_schema(stored_schema) - load_storage = _get_load_storage( - destination_caps.preferred_loader_file_format - ) # Default load storage, used for empty tables when no data normalize_storage = NormalizeStorage(False, normalize_storage_config) - item_normalizers: Dict[TLoaderFileFormat, ItemsNormalizer] = {} - - def _get_items_normalizer( - file_format: TLoaderFileFormat, - ) -> Tuple[ItemsNormalizer, LoadStorage]: - load_storage = _get_load_storage(file_format) + def _get_items_normalizer(file_format: TLoaderFileFormat) -> ItemsNormalizer: if file_format in item_normalizers: - return item_normalizers[file_format], load_storage + return item_normalizers[file_format] klass = ParquetItemsNormalizer if file_format == "parquet" else JsonLItemsNormalizer norm = item_normalizers[file_format] = klass( - load_storage, normalize_storage, schema, load_id, config + _create_load_storage(file_format), normalize_storage, schema, load_id, config ) - return norm, load_storage + return norm try: root_tables: Set[str] = set() @@ -164,13 +157,11 @@ def _get_items_normalizer( parsed_file_name.table_name ) root_tables.add(root_table_name) + normalizer = _get_items_normalizer(parsed_file_name.file_format) logger.debug( f"Processing extracted items in {extracted_items_file} in load_id" f" {load_id} with table name {root_table_name} and schema {schema.name}" ) - - file_format = parsed_file_name.file_format - normalizer, load_storage = _get_items_normalizer(file_format) partial_updates, items_count, r_counts = normalizer( extracted_items_file, root_table_name ) @@ -186,12 +177,18 @@ def _get_items_normalizer( # make sure base tables are all covered increase_row_count(row_counts, root_table_name, 0) # write empty jobs for tables without items if table exists in schema - for table_name in root_tables - populated_root_tables: - if table_name not in schema.tables: - continue - logger.debug(f"Writing empty job for table {table_name}") - columns = schema.get_table_columns(table_name) - load_storage.write_empty_file(load_id, schema.name, table_name, columns) + empty_tables = root_tables - populated_root_tables + if empty_tables: + # get first normalizer created + normalizer = list(item_normalizers.values())[0] + for table_name in root_tables - populated_root_tables: + if table_name not in schema.tables: + continue + logger.debug(f"Writing empty job for table {table_name}") + columns = schema.get_table_columns(table_name) + normalizer.load_storage.write_empty_items_file( + load_id, schema.name, table_name, columns + ) except Exception: # TODO: raise a wrapper exception with job_id, load_id, line_no and schema name logger.exception( @@ -199,10 +196,14 @@ def _get_items_normalizer( ) raise finally: - load_storage.close_writers(load_id) + for normalizer in item_normalizers.values(): + normalizer.load_storage.close_writers(load_id) logger.info(f"Processed total {total_items} items in {len(extracted_items_files)} files") - return schema_updates, total_items, load_storage.closed_files(), row_counts + writer_metrics: List[DataWriterMetrics] = [] + for normalizer in item_normalizers.values(): + writer_metrics.extend(normalizer.load_storage.closed_files()) + return TWorkerRV(schema_updates, total_items, writer_metrics, row_counts) def update_table(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: for schema_update in schema_updates: @@ -231,7 +232,7 @@ def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[st l_idx = idx + 1 return chunk_files - def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: + def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: workers: int = getattr(self.pool, "_max_workers", 1) chunk_files = self.group_worker_files(files, workers) schema_dict: TStoredSchema = schema.to_dict() @@ -246,11 +247,8 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM ) for files in chunk_files ] - row_counts: RowCounts = {} - # return stats - schema_updates: List[TSchemaUpdate] = [] - + summary = TWorkerRV([], 0, [], {}) # push all tasks to queue tasks = [ (self.pool.submit(Normalize.w_normalize_files, *params), params) @@ -269,20 +267,20 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM try: # gather schema from all manifests, validate consistency and combine self.update_table(schema, result[0]) - schema_updates.extend(result[0]) + summary.schema_updates.extend(result.schema_updates) # update metrics - self.collector.update("Files", len(result[2])) - self.collector.update("Items", result[1]) + self.collector.update("Files", len(result.file_metrics)) + self.collector.update("Items", result.total_items) # merge row counts - merge_row_counts(row_counts, result[3]) + merge_row_counts(summary.row_counts, result.row_counts) except CannotCoerceColumnException as exc: # schema conflicts resulting from parallel executing logger.warning( f"Parallel schema update conflict, retrying task ({str(exc)}" ) # delete all files produced by the task - for file in result[2]: - os.remove(file) + for metrics in result.file_metrics: + os.remove(metrics.file_path) # schedule the task again schema_dict = schema.to_dict() # TODO: it's time for a named tuple @@ -293,10 +291,11 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TM tasks.append((retry_pending, params)) # remove finished tasks tasks.remove(task) + logger.debug(f"{len(tasks)} tasks still remaining for {load_id}...") - return schema_updates, row_counts + return summary - def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: + def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: result = Normalize.w_normalize_files( self.config, self.normalize_storage.config, @@ -305,16 +304,16 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMap load_id, files, ) - self.update_table(schema, result[0]) - self.collector.update("Files", len(result[2])) - self.collector.update("Items", result[1]) - return result[0], result[3] + self.update_table(schema, result.schema_updates) + self.collector.update("Files", len(result.file_metrics)) + self.collector.update("Items", result.total_items) + return result def spool_files( self, load_id: str, schema: Schema, map_f: TMapFuncType, files: Sequence[str] ) -> None: # process files in parallel or in single thread, depending on map_f - schema_updates, row_counts = map_f(schema, load_id, files) + schema_updates, _, _, row_counts = map_f(schema, load_id, files) # remove normalizer specific info for table in schema.tables.values(): table.pop("x-normalizer", None) # type: ignore[typeddict-item] @@ -381,7 +380,7 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: if len(schema_files) == 0: # delete empty package self.normalize_storage.extracted_packages.delete_package(load_id) - logger.dlt_version_info(f"Empty package {load_id} processed") + logger.info(f"Empty package {load_id} processed") continue with self.collector(f"Normalize {schema.name} in {load_id}"): self.collector.update("Files", 0, len(schema_files)) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index b70a0057c7..e33f0890cb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -578,6 +578,7 @@ def run( LoadInfo: Information on loaded data including the list of package ids and failed job statuses. Please not that `dlt` will not raise if a single job terminally fails. Such information is provided via LoadInfo. """ signals.raise_if_signalled() + self.activate() self._set_destinations(destination=destination, staging=staging) self._set_dataset_name(dataset_name) diff --git a/docs/website/docs/reference/performance.md b/docs/website/docs/reference/performance.md index 98773c2ef4..f1a405684f 100644 --- a/docs/website/docs/reference/performance.md +++ b/docs/website/docs/reference/performance.md @@ -136,6 +136,7 @@ PROGRESS=log python pipeline_script.py ``` ## Parallelism +You can create pipelines that extract, normalize and load data in parallel. ### Extract You can extract data concurrently if you write your pipelines to yield callables or awaitables that can be then evaluated in a thread or futures pool respectively. @@ -249,6 +250,17 @@ The default is to not parallelize normalization and to perform it in the main pr Normalization is CPU bound and can easily saturate all your cores. Never allow `dlt` to use all cores on your local machine. ::: +:::caution +The default method of spawning a process pool on Linux is **fork**. If you are using threads in your code (or libraries that use threads), +you should rather switch to **spawn**. Process forking does not respawn the threads and may destroy the critical sections in your code. Even logging +with Python loggers from multiple threads may lock the `normalize` step. Here's how you switch to **spawn**: +```toml +[normalize] +workers=3 +start_method="spawn" +``` +::: + ### Load The **load** stage uses a thread pool for parallelization. Loading is input/output bound. `dlt` avoids any processing of the content of the load package produced by the normalizer. By default loading happens in 20 threads, each loading a single file. @@ -314,9 +326,11 @@ if __name__ == "__main__" or "PYTEST_CURRENT_TEST" in os.environ: pipeline = dlt.pipeline("parallel_load", destination="duckdb", full_refresh=True) pipeline.extract(read_table(1000000)) + load_id = pipeline.list_extracted_load_packages()[0] + extracted_package = pipeline.get_load_package_info(load_id) # we should have 11 files (10 pieces for `table` and 1 for state) - extracted_files = pipeline.list_extracted_resources() - print(extracted_files) + extracted_jobs = extracted_package.jobs["new_jobs"] + print([str(job.job_file_info) for job in extracted_jobs]) # normalize and print counts print(pipeline.normalize(loader_file_format="jsonl")) # print jobs in load package (10 + 1 as above) @@ -360,6 +374,79 @@ the schema, that should be a problem though as long as your data does not create should be accessed serially to avoid losing details on parallel runs. +## Running several pipelines in parallel in single process +You can run several pipeline instances in parallel from a single process by placing them in +separate threads. The most straightforward way is to use `ThreadPoolExecutor` and `asyncio` to execute pipeline methods. + + +```py +import asyncio +import dlt +from time import sleep +from concurrent.futures import ThreadPoolExecutor + +# create both futures and thread parallel resources + +def async_table(): + async def _gen(idx): + await asyncio.sleep(0.1) + return {"async_gen": idx} + + # just yield futures in a loop + for idx_ in range(10): + yield _gen(idx_) + +def defer_table(): + @dlt.defer + def _gen(idx): + sleep(0.1) + return {"thread_gen": idx} + + # just yield futures in a loop + for idx_ in range(5): + yield _gen(idx_) + +def _run_pipeline(pipeline, gen_): + # run the pipeline in a thread, also instantiate generators here! + # Python does not let you use generators across threads + return pipeline.run(gen_()) + +# declare pipelines in main thread then run them "async" +pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True) +pipeline_2 = dlt.pipeline("pipeline_2", destination="duckdb", full_refresh=True) + +async def _run_async(): + loop = asyncio.get_running_loop() + # from Python 3.9 you do not need explicit pool. loop.to_thread will suffice + with ThreadPoolExecutor() as executor: + results = await asyncio.gather( + loop.run_in_executor(executor, _run_pipeline, pipeline_1, async_table), + loop.run_in_executor(executor, _run_pipeline, pipeline_2, defer_table), + ) + # result contains two LoadInfo instances + results[0].raise_on_failed_jobs() + results[1].raise_on_failed_jobs() + +# load data +asyncio.run(_run_async()) +# activate pipelines before they are used +pipeline_1.activate() +# assert load_data_table_counts(pipeline_1) == {"async_table": 10} +pipeline_2.activate() +# assert load_data_table_counts(pipeline_2) == {"defer_table": 5} +``` + + +:::tip +Please note the following: +1. Do not run pipelines with the same name and working dir in parallel. State synchronization will not +work in that case. +2. When running in multiple threads and using [parallel normalize step](#normalize) , use **spawn** +process start method. +3. If you created the `Pipeline` object in the worker thread and you use it from another (ie. main thread) +call `pipeline.activate()` to inject the right context into current thread. +::: + ## Resources extraction, `fifo` vs. `round robin` When extracting from resources, you have two options to determine what the order of queries to your diff --git a/docs/website/docs/reference/performance_snippets/performance-snippets.py b/docs/website/docs/reference/performance_snippets/performance-snippets.py index 84bf36adb0..a6ad2f2618 100644 --- a/docs/website/docs/reference/performance_snippets/performance-snippets.py +++ b/docs/website/docs/reference/performance_snippets/performance-snippets.py @@ -111,5 +111,64 @@ def database_cursor_chunked(): assert len(list(database_cursor_chunked())) == 10000 +def parallel_pipelines_asyncio_snippet() -> None: + # @@@DLT_SNIPPET_START parallel_pipelines + import asyncio + import dlt + from time import sleep + from concurrent.futures import ThreadPoolExecutor + + # create both futures and thread parallel resources + + def async_table(): + async def _gen(idx): + await asyncio.sleep(0.1) + return {"async_gen": idx} + + # just yield futures in a loop + for idx_ in range(10): + yield _gen(idx_) + + def defer_table(): + @dlt.defer + def _gen(idx): + sleep(0.1) + return {"thread_gen": idx} + + # just yield futures in a loop + for idx_ in range(5): + yield _gen(idx_) + + def _run_pipeline(pipeline, gen_): + # run the pipeline in a thread, also instantiate generators here! + # Python does not let you use generators across threads + return pipeline.run(gen_()) + + # declare pipelines in main thread then run them "async" + pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True) + pipeline_2 = dlt.pipeline("pipeline_2", destination="duckdb", full_refresh=True) + + async def _run_async(): + loop = asyncio.get_running_loop() + # from Python 3.9 you do not need explicit pool. loop.to_thread will suffice + with ThreadPoolExecutor() as executor: + results = await asyncio.gather( + loop.run_in_executor(executor, _run_pipeline, pipeline_1, async_table), + loop.run_in_executor(executor, _run_pipeline, pipeline_2, defer_table), + ) + # result contains two LoadInfo instances + results[0].raise_on_failed_jobs() + results[1].raise_on_failed_jobs() + + # load data + asyncio.run(_run_async()) + # activate pipelines before they are used + pipeline_1.activate() + # assert load_data_table_counts(pipeline_1) == {"async_table": 10} + pipeline_2.activate() + # assert load_data_table_counts(pipeline_2) == {"defer_table": 5} + # @@@DLT_SNIPPET_END parallel_pipelines + + def test_toml_snippets() -> None: parse_toml_file("./toml-snippets.toml") diff --git a/pyproject.toml b/pyproject.toml index e503edeea0..29522c4827 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.4.1a0" +version = "0.4.1a1" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Ty Dunn "] diff --git a/tests/common/configuration/test_accessors.py b/tests/common/configuration/test_accessors.py index 4fda3b27a9..147d56abec 100644 --- a/tests/common/configuration/test_accessors.py +++ b/tests/common/configuration/test_accessors.py @@ -189,7 +189,12 @@ def test_setter(toml_providers: ConfigProvidersContext, environment: Any) -> Non # mod the config and use it to resolve the configuration dlt.config["pool"] = {"pool_type": "process", "workers": 21} c = resolve_configuration(PoolRunnerConfiguration(), sections=("pool",)) - assert dict(c) == {"pool_type": "process", "workers": 21, "run_sleep": 0.1} + assert dict(c) == { + "pool_type": "process", + "start_method": None, + "workers": 21, + "run_sleep": 0.1, + } def test_secrets_separation(toml_providers: ConfigProvidersContext) -> None: diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index e6091b3c70..81d49432d7 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -558,9 +558,7 @@ class _SecretCredentials(RunConfiguration): "slack_incoming_hook": None, "dlthub_telemetry": True, "dlthub_telemetry_segment_write_key": "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB", - "log_format": ( - "{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}" - ), + "log_format": "{asctime}|[{levelname:<21}]|{process}|{thread}|{name}|{filename}|{funcName}:{lineno}|{message}", "log_level": "WARNING", "request_timeout": 60, "request_max_attempts": 5, diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index 21c8de5782..9521f5960d 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -1,5 +1,7 @@ +from concurrent.futures import ThreadPoolExecutor import pytest -from typing import Any, ClassVar, Literal, Optional, Iterator, TYPE_CHECKING +import threading +from typing import Any, ClassVar, Literal, Optional, Iterator, Type, TYPE_CHECKING from dlt.common.configuration import configspec from dlt.common.configuration.providers.context import ContextProvider @@ -38,6 +40,11 @@ class NoDefaultInjectableContext(ContainerInjectableContext): can_create_default: ClassVar[bool] = False +@configspec +class GlobalTestContext(InjectableTestContext): + global_affinity: ClassVar[bool] = True + + @configspec class EmbeddedWithNoDefaultInjectableContext(BaseConfiguration): injected: NoDefaultInjectableContext @@ -60,25 +67,26 @@ def container() -> Iterator[Container]: def test_singleton(container: Container) -> None: # keep the old configurations list - container_configurations = container.contexts + container_configurations = container.thread_contexts singleton = Container() # make sure it is the same object assert container is singleton # that holds the same configurations dictionary - assert container_configurations is singleton.contexts + assert container_configurations is singleton.thread_contexts -def test_container_items(container: Container) -> None: +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_items(container: Container, spec: Type[InjectableTestContext]) -> None: # will add InjectableTestContext instance to container - container[InjectableTestContext] - assert InjectableTestContext in container - del container[InjectableTestContext] - assert InjectableTestContext not in container - container[InjectableTestContext] = InjectableTestContext(current_value="S") - assert container[InjectableTestContext].current_value == "S" - container[InjectableTestContext] = InjectableTestContext(current_value="SS") - assert container[InjectableTestContext].current_value == "SS" + container[spec] + assert spec in container + del container[spec] + assert spec not in container + container[spec] = spec(current_value="S") + assert container[spec].current_value == "S" + container[spec] = spec(current_value="SS") + assert container[spec].current_value == "SS" def test_get_default_injectable_config(container: Container) -> None: @@ -96,7 +104,10 @@ def test_raise_on_no_default_value(container: Container) -> None: assert container[NoDefaultInjectableContext] is injected -def test_container_injectable_context(container: Container) -> None: +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_injectable_context( + container: Container, spec: Type[InjectableTestContext] +) -> None: with container.injectable_context(InjectableTestContext()) as current_config: assert current_config.current_value is None current_config.current_value = "TEST" @@ -106,43 +117,131 @@ def test_container_injectable_context(container: Container) -> None: assert InjectableTestContext not in container -def test_container_injectable_context_restore(container: Container) -> None: +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_injectable_context_restore( + container: Container, spec: Type[InjectableTestContext] +) -> None: # this will create InjectableTestConfiguration - original = container[InjectableTestContext] + original = container[spec] original.current_value = "ORIGINAL" - with container.injectable_context(InjectableTestContext()) as current_config: + with container.injectable_context(spec()) as current_config: current_config.current_value = "TEST" # nested context is supported - with container.injectable_context(InjectableTestContext()) as inner_config: + with container.injectable_context(spec()) as inner_config: assert inner_config.current_value is None - assert container[InjectableTestContext] is inner_config - assert container[InjectableTestContext] is current_config + assert container[spec] is inner_config + assert container[spec] is current_config - assert container[InjectableTestContext] is original - assert container[InjectableTestContext].current_value == "ORIGINAL" + assert container[spec] is original + assert container[spec].current_value == "ORIGINAL" -def test_container_injectable_context_mangled(container: Container) -> None: - original = container[InjectableTestContext] +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_injectable_context_mangled( + container: Container, spec: Type[InjectableTestContext] +) -> None: + original = container[spec] original.current_value = "ORIGINAL" - context = InjectableTestContext() + context = spec() with pytest.raises(ContainerInjectableContextMangled) as py_ex: with container.injectable_context(context) as current_config: current_config.current_value = "TEST" # overwrite the config in container - container[InjectableTestContext] = InjectableTestContext() - assert py_ex.value.spec == InjectableTestContext + container[spec] = spec() + assert py_ex.value.spec == spec assert py_ex.value.expected_config == context -def test_container_provider(container: Container) -> None: +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_thread_affinity(container: Container, spec: Type[InjectableTestContext]) -> None: + event = threading.Semaphore(0) + thread_item: InjectableTestContext = None + + def _thread() -> None: + container[spec] = spec(current_value="THREAD") + event.release() + event.acquire() + nonlocal thread_item + thread_item = container[spec] + event.release() + + threading.Thread(target=_thread, daemon=True).start() + event.acquire() + # it may be or separate copy (InjectableTestContext) or single copy (GlobalTestContext) + main_item = container[spec] + main_item.current_value = "MAIN" + event.release() + main_item = container[spec] + event.release() + if spec is GlobalTestContext: + # just one context is kept globally + assert main_item is thread_item + # MAIN was set after thread + assert thread_item.current_value == "MAIN" + else: + assert main_item is not thread_item + assert main_item.current_value == "MAIN" + assert thread_item.current_value == "THREAD" + + +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_pool_affinity(container: Container, spec: Type[InjectableTestContext]) -> None: + event = threading.Semaphore(0) + thread_item: InjectableTestContext = None + + def _thread() -> None: + container[spec] = spec(current_value="THREAD") + event.release() + event.acquire() + nonlocal thread_item + thread_item = container[spec] + event.release() + + threading.Thread(target=_thread, daemon=True, name=Container.thread_pool_prefix()).start() + event.acquire() + # it may be or separate copy (InjectableTestContext) or single copy (GlobalTestContext) + main_item = container[spec] + main_item.current_value = "MAIN" + event.release() + main_item = container[spec] + event.release() + + # just one context is kept globally - Container user pool thread name to get the starting thread id + # and uses it to retrieve context + assert main_item is thread_item + # MAIN was set after thread + assert thread_item.current_value == "MAIN" + + +def test_thread_pool_affinity(container: Container) -> None: + def _context() -> InjectableTestContext: + return container[InjectableTestContext] + + main_item = container[InjectableTestContext] = InjectableTestContext(current_value="MAIN") + + with ThreadPoolExecutor(thread_name_prefix=container.thread_pool_prefix()) as p: + future = p.submit(_context) + item = future.result() + + assert item is main_item + + # create non affine pool + with ThreadPoolExecutor() as p: + future = p.submit(_context) + item = future.result() + + assert item is not main_item + + +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_provider(container: Container, spec: Type[InjectableTestContext]) -> None: provider = ContextProvider() # default value will be created - v, k = provider.get_value("n/a", InjectableTestContext, None) - assert isinstance(v, InjectableTestContext) - assert k == "InjectableTestContext" - assert InjectableTestContext in container + v, k = provider.get_value("n/a", spec, None) + assert isinstance(v, spec) + assert k == spec.__name__ + assert spec in container # provider does not create default value in Container v, k = provider.get_value("n/a", NoDefaultInjectableContext, None) @@ -157,7 +256,7 @@ def test_container_provider(container: Container) -> None: # must assert if sections are provided with pytest.raises(AssertionError): - provider.get_value("n/a", InjectableTestContext, None, "ns1") + provider.get_value("n/a", spec, None, "ns1") # type hints that are not classes literal = Literal["a"] @@ -176,7 +275,10 @@ def test_container_provider_embedded_inject(container: Container, environment: A assert C.injected is injected -def test_container_provider_embedded_no_default(container: Container) -> None: +@pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) +def test_container_provider_embedded_no_default( + container: Container, spec: Type[InjectableTestContext] +) -> None: with container.injectable_context(NoDefaultInjectableContext()): resolve_configuration(EmbeddedWithNoDefaultInjectableContext()) # default cannot be created so fails diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index db5333f610..fcec881521 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -371,7 +371,10 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: pool = PoolRunnerConfiguration(pool_type="none", workers=10) provider.set_value("runner_config", dict(pool), "new_pipeline") # print(provider._toml["new_pipeline"]["runner_config"].as_string()) - assert provider._toml["new_pipeline"]["runner_config"] == dict(pool) # type: ignore[index] + expected_pool = dict(pool) + # None is removed + expected_pool.pop("start_method") + assert provider._toml["new_pipeline"]["runner_config"] == expected_pool # type: ignore[index] # dict creates only shallow dict so embedded credentials will fail creds = WithCredentialsConfiguration() diff --git a/tests/common/data_writers/test_buffered_writer.py b/tests/common/data_writers/test_buffered_writer.py index 5832341fb2..1cd6a84552 100644 --- a/tests/common/data_writers/test_buffered_writer.py +++ b/tests/common/data_writers/test_buffered_writer.py @@ -55,8 +55,10 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert writer._file is None # writer is closed and data was written assert len(writer.closed_files) == 1 + assert writer.closed_files[0].items_count == 9 + assert writer.closed_files[0].file_size > 0 # check the content, mind that we swapped the columns - with FileStorage.open_zipsafe_ro(writer.closed_files[0], "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: content = f.readlines() assert "col2,col1" in content[0] assert "NULL,0" in content[2] @@ -108,9 +110,12 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert len(writer.closed_files) == 2 assert writer._buffered_items == [] # the last file must contain text value of the column3 - with FileStorage.open_zipsafe_ro(writer.closed_files[-1], "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro(writer.closed_files[-1].file_path, "r", encoding="utf-8") as f: content = f.readlines() assert "(col3_value" in content[-1] + # check metrics + assert writer.closed_files[0].items_count == 11 + assert writer.closed_files[1].items_count == 22 @pytest.mark.parametrize( @@ -140,7 +145,7 @@ def c2_doc(count: int) -> Iterator[DictStrAny]: # only the initial 15 items written assert writer._writer.items_count == 15 # all written - with FileStorage.open_zipsafe_ro(writer.closed_files[-1], "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro(writer.closed_files[-1].file_path, "r", encoding="utf-8") as f: content = f.readlines() assert content[-1] == '{"col1":1,"col2":3}\n' @@ -168,3 +173,13 @@ def test_writer_optional_schema(disable_compression: bool) -> None: with get_writer(_format="jsonl", disable_compression=disable_compression) as writer: writer.write_data_item([{"col1": 1}], None) writer.write_data_item([{"col1": 1}], None) + + +# @pytest.mark.parametrize( +# "disable_compression", [True, False], ids=["no_compression", "compression"] +# ) +# def test_write_empty_file() -> None: +# pass + + +# def test_import_file() diff --git a/tests/common/runners/test_runners.py b/tests/common/runners/test_runners.py index 2b81c2ea54..3b56b64156 100644 --- a/tests/common/runners/test_runners.py +++ b/tests/common/runners/test_runners.py @@ -129,7 +129,7 @@ def test_runnable_with_runner() -> None: @pytest.mark.parametrize("method", ALL_METHODS) -def test_pool_runner_process_methods(method) -> None: +def test_pool_runner_process_methods_forced(method) -> None: multiprocessing.set_start_method(method, force=True) r = _TestRunnableWorker(4) # make sure signals and logging is initialized @@ -139,3 +139,15 @@ def test_pool_runner_process_methods(method) -> None: runs_count = runner.run_pool(configure(ProcessPoolConfiguration), r) assert runs_count == 1 assert [v[0] for v in r.rv] == list(range(4)) + + +@pytest.mark.parametrize("method", ALL_METHODS) +def test_pool_runner_process_methods_configured(method) -> None: + r = _TestRunnableWorker(4) + # make sure signals and logging is initialized + C = resolve_configuration(RunConfiguration()) + initialize_runtime(C) + + runs_count = runner.run_pool(ProcessPoolConfiguration(start_method=method), r) + assert runs_count == 1 + assert [v[0] for v in r.rv] == list(range(4)) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index 6afbf0910f..f671ddcf32 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -93,7 +93,7 @@ def test_retry_job(load_storage: LoadStorage) -> None: def test_build_parse_job_path(load_storage: LoadStorage) -> None: - file_id = uniq_id(5) + file_id = ParsedLoadJobFileName.new_file_id() f_n_t = ParsedLoadJobFileName("test_table", file_id, 0, "jsonl") job_f_n = PackageStorage.build_job_file_name( f_n_t.table_name, file_id, 0, loader_file_format=load_storage.loader_file_format diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index 92d4950624..b1c19114fe 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -51,7 +51,7 @@ def test_parquet_writer_schema_evolution_with_big_buffer() -> None: {"col1": c1, "col2": c2, "col3": c3, "col4": c4}, ) - with open(writer.closed_files[0], "rb") as f: + with open(writer.closed_files[0].file_path, "rb") as f: table = pq.read_table(f) assert table.column("col1").to_pylist() == [1, 1] assert table.column("col2").to_pylist() == [2, 2] @@ -78,11 +78,11 @@ def test_parquet_writer_schema_evolution_with_small_buffer() -> None: assert len(writer.closed_files) == 2 - with open(writer.closed_files[0], "rb") as f: + with open(writer.closed_files[0].file_path, "rb") as f: table = pq.read_table(f) assert len(table.schema) == 3 - with open(writer.closed_files[1], "rb") as f: + with open(writer.closed_files[1].file_path, "rb") as f: table = pq.read_table(f) assert len(table.schema) == 4 @@ -108,7 +108,7 @@ def test_parquet_writer_json_serialization() -> None: [{"col1": 1, "col2": 2, "col3": []}], {"col1": c1, "col2": c2, "col3": c3} ) - with open(writer.closed_files[0], "rb") as f: + with open(writer.closed_files[0].file_path, "rb") as f: table = pq.read_table(f) assert table.column("col1").to_pylist() == [1, 1, 1, 1] assert table.column("col2").to_pylist() == [2, 2, 2, 2] @@ -140,7 +140,7 @@ def test_parquet_writer_all_data_fields() -> None: microsecond=int(str(data["col11_precision"].microsecond)[:3] + "000") # type: ignore[attr-defined] ) - with open(writer.closed_files[0], "rb") as f: + with open(writer.closed_files[0].file_path, "rb") as f: table = pq.read_table(f) for key, value in data.items(): # what we have is pandas Timezone which is naive @@ -168,7 +168,7 @@ def test_parquet_writer_items_file_rotation() -> None: writer.write_data_item([{"col1": i}], columns) assert len(writer.closed_files) == 10 - with open(writer.closed_files[4], "rb") as f: + with open(writer.closed_files[4].file_path, "rb") as f: table = pq.read_table(f) assert table.column("col1").to_pylist() == list(range(40, 50)) @@ -183,7 +183,7 @@ def test_parquet_writer_size_file_rotation() -> None: writer.write_data_item([{"col1": i}], columns) assert len(writer.closed_files) == 25 - with open(writer.closed_files[4], "rb") as f: + with open(writer.closed_files[4].file_path, "rb") as f: table = pq.read_table(f) assert table.column("col1").to_pylist() == list(range(16, 20)) diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index bdf3347580..7436023f03 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -119,7 +119,7 @@ def test_get_completed_table_chain_single_job_per_table() -> None: load_id, schema, top_job_table, - "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl", + "event_user.839c6e6b514e427687586ccc65bf133f.jsonl", ) ) == 1 diff --git a/tests/load/utils.py b/tests/load/utils.py index e99886b1b1..6811ca59a6 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -358,8 +358,11 @@ def expect_load_file( status="completed", ) -> LoadJob: file_name = ParsedLoadJobFileName( - table_name, uniq_id(), 0, client.capabilities.preferred_loader_file_format - ).job_id() + table_name, + ParsedLoadJobFileName.new_file_id(), + 0, + client.capabilities.preferred_loader_file_format, + ).file_name() file_storage.save(file_name, query.encode("utf-8")) table = client.get_load_table(table_name) job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 3ebd4f53a2..438a42c4d3 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1,7 +1,11 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor import itertools import logging import os -from typing import Any, Any, cast +from time import sleep +from typing import Any, Tuple, cast +import threading from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest @@ -21,9 +25,10 @@ PipelineStateNotAvailable, UnknownDestinationModule, ) -from dlt.common.pipeline import PipelineContext +from dlt.common.pipeline import LoadInfo, PipelineContext from dlt.common.runtime.collector import LogCollector from dlt.common.schema.utils import new_column, new_table +from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id from dlt.common.schema import Schema @@ -39,7 +44,12 @@ from tests.common.configuration.utils import environment from tests.utils import TEST_STORAGE_ROOT from tests.extract.utils import expect_extracted_file -from tests.pipeline.utils import assert_load_info, airtable_emojis, many_delayed +from tests.pipeline.utils import ( + assert_load_info, + airtable_emojis, + load_data_table_counts, + many_delayed, +) def test_default_pipeline() -> None: @@ -809,12 +819,13 @@ def github_repo_events_table_meta(page): @dlt.resource -def _get_shuffled_events(): - with open( - "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" - ) as f: - issues = json.load(f) - yield issues +def _get_shuffled_events(repeat: int = 1): + for _ in range(repeat): + with open( + "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + ) as f: + issues = json.load(f) + yield issues @pytest.mark.parametrize("github_resource", (github_repo_events_table_meta, github_repo_events)) @@ -1388,3 +1399,197 @@ def test_remove_pending_packages() -> None: assert pipeline.has_pending_data pipeline.drop_pending_packages() assert pipeline.has_pending_data is False + + +@pytest.mark.parametrize("workers", (1, 4), ids=("1 norm worker", "4 norm workers")) +def test_parallel_pipelines_threads(workers: int) -> None: + # critical section to control pipeline steps + init_lock = threading.Lock() + extract_ev = threading.Event() + normalize_ev = threading.Event() + load_ev = threading.Event() + # control main thread + sem = threading.Semaphore(0) + + # rotate the files frequently so we have parallel normalize and load + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "10" + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "10" + + # force spawn process pool + os.environ["NORMALIZE__START_METHOD"] = "spawn" + + page_repeats = 1 + + # set the extra per pipeline + os.environ["PIPELINE_1__EXTRA"] = "CFG_P_1" + os.environ["PIPELINE_2__EXTRA"] = "CFG_P_2" + + def _run_pipeline(pipeline_name: str) -> Tuple[LoadInfo, PipelineContext, DictStrAny]: + try: + + @dlt.transformer( + name="github_repo_events", + write_disposition="append", + table_name=lambda i: i["type"], + ) + def github_repo_events(page, extra): + # test setting the resource state + dlt.current.resource_state()["extra"] = extra + yield page + + @dlt.transformer + async def slow(items): + await asyncio.sleep(0.1) + return items + + @dlt.transformer + @dlt.defer + def slow_func(items, extra): + # sdd configurable extra to each element + sleep(0.1) + return map(lambda item: {**item, **{"extra": extra}}, items) + + @dlt.source + def github(extra: str = dlt.config.value): + # generate github events, push them through futures and thread pools and then dispatch to separate tables + return ( + _get_shuffled_events(repeat=page_repeats) + | slow + | slow_func(extra) + | github_repo_events(extra) + ) + + # make sure that only one pipeline is created + with init_lock: + pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") + context = Container()[PipelineContext] + finally: + sem.release() + # start every step at the same moment to increase chances of any race conditions to happen + extract_ev.wait() + context_2 = Container()[PipelineContext] + try: + pipeline.extract(github()) + finally: + sem.release() + normalize_ev.wait() + try: + pipeline.normalize(workers=workers) + finally: + sem.release() + load_ev.wait() + info = pipeline.load() + + # get counts in the thread + counts = load_data_table_counts(pipeline) + + assert context is context_2 + return info, context, counts + + with ThreadPoolExecutor(max_workers=4) as pool: + f_1 = pool.submit(_run_pipeline, "pipeline_1") + f_2 = pool.submit(_run_pipeline, "pipeline_2") + + sem.acquire() + sem.acquire() + if f_1.done(): + raise f_1.exception() + if f_2.done(): + raise f_2.exception() + extract_ev.set() + sem.acquire() + sem.acquire() + if f_1.done(): + raise f_1.exception() + if f_2.done(): + raise f_2.exception() + normalize_ev.set() + sem.acquire() + sem.acquire() + if f_1.done(): + raise f_1.exception() + if f_2.done(): + raise f_2.exception() + load_ev.set() + + info_1, context_1, counts_1 = f_1.result() + info_2, context_2, counts_2 = f_2.result() + + assert_load_info(info_1) + assert_load_info(info_2) + + pipeline_1: dlt.Pipeline = context_1.pipeline() # type: ignore + pipeline_2: dlt.Pipeline = context_2.pipeline() # type: ignore + + n_counts_1 = pipeline_1.last_trace.last_normalize_info + assert n_counts_1.row_counts["push_event"] == 8 * page_repeats == counts_1["push_event"] + n_counts_2 = pipeline_2.last_trace.last_normalize_info + assert n_counts_2.row_counts["push_event"] == 8 * page_repeats == counts_2["push_event"] + + assert pipeline_1.pipeline_name == "pipeline_1" + assert pipeline_2.pipeline_name == "pipeline_2" + + # check if resource state has extra + assert pipeline_1.state["sources"]["github"]["resources"]["github_repo_events"] == { + "extra": "CFG_P_1" + } + assert pipeline_2.state["sources"]["github"]["resources"]["github_repo_events"] == { + "extra": "CFG_P_2" + } + + # make sure we can still access data + pipeline_1.activate() # activate pipeline to access inner duckdb + assert load_data_table_counts(pipeline_1) == counts_1 + pipeline_2.activate() + assert load_data_table_counts(pipeline_2) == counts_2 + + +@pytest.mark.parametrize("workers", (1, 4), ids=("1 norm worker", "4 norm workers")) +def test_parallel_pipelines_async(workers: int) -> None: + os.environ["NORMALIZE__WORKERS"] = str(workers) + + # create both futures and thread parallel resources + + def async_table(): + async def _gen(idx): + await asyncio.sleep(0.1) + return {"async_gen": idx} + + # just yield futures in a loop + for idx_ in range(10): + yield _gen(idx_) + + def defer_table(): + @dlt.defer + def _gen(idx): + sleep(0.1) + return {"thread_gen": idx} + + # just yield futures in a loop + for idx_ in range(5): + yield _gen(idx_) + + def _run_pipeline(pipeline, gen_) -> LoadInfo: + # run the pipeline in a thread, also instantiate generators here! + # Python does not let you use generators across instances + return pipeline.run(gen_()) + + # declare pipelines in main thread then run them "async" + pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True) + pipeline_2 = dlt.pipeline("pipeline_2", destination="duckdb", full_refresh=True) + + async def _run_async(): + loop = asyncio.get_running_loop() + with ThreadPoolExecutor() as executor: + results = await asyncio.gather( + loop.run_in_executor(executor, _run_pipeline, pipeline_1, async_table), + loop.run_in_executor(executor, _run_pipeline, pipeline_2, defer_table), + ) + assert_load_info(results[0]) + assert_load_info(results[1]) + + asyncio.run(_run_async()) + pipeline_1.activate() # activate pipeline 1 to access inner duckdb + assert load_data_table_counts(pipeline_1) == {"async_table": 10} + pipeline_2.activate() # activate pipeline 2 to access inner duckdb + assert load_data_table_counts(pipeline_2) == {"defer_table": 5}