diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 6feee2a812..9f12adf3a7 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -400,12 +400,16 @@ class PipelineContext(ContainerInjectableContext): _deferred_pipeline: Callable[[], SupportsPipeline] _pipeline: SupportsPipeline - can_create_default: ClassVar[bool] = False + can_create_default: ClassVar[bool] = True def pipeline(self) -> SupportsPipeline: """Creates or returns exiting pipeline""" if not self._pipeline: # delayed pipeline creation + assert self._deferred_pipeline is not None, ( + "Deferred pipeline creation function not provided to PipelineContext. Are you" + " calling dlt.pipeline() from another thread?" + ) self.activate(self._deferred_pipeline()) return self._pipeline @@ -425,7 +429,7 @@ def deactivate(self) -> None: self._pipeline._set_context(False) self._pipeline = None - def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline]) -> None: + def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> None: """Initialize the context with a function returning the Pipeline object to allow creation on first use""" self._deferred_pipeline = deferred_pipeline diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 85c654b46c..d3b8725cc7 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -725,7 +725,7 @@ def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: target=start_background_loop, args=(self._async_pool,), daemon=True, - name="DltFuturesThread", + name=Container.thread_pool_prefix() + "futures", ) self._async_pool_thread.start() @@ -737,7 +737,9 @@ def _ensure_thread_pool(self) -> ThreadPoolExecutor: if self._thread_pool: return self._thread_pool - self._thread_pool = ThreadPoolExecutor(self.workers) + self._thread_pool = ThreadPoolExecutor( + self.workers, thread_name_prefix=Container.thread_pool_prefix() + "threads" + ) return self._thread_pool def __enter__(self) -> "PipeIterator": diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 3ebd4f53a2..187fd878a9 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1,7 +1,11 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor import itertools import logging import os -from typing import Any, Any, cast +from time import sleep +from typing import Any, Tuple, cast +import threading from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest @@ -21,7 +25,7 @@ PipelineStateNotAvailable, UnknownDestinationModule, ) -from dlt.common.pipeline import PipelineContext +from dlt.common.pipeline import LoadInfo, PipelineContext from dlt.common.runtime.collector import LogCollector from dlt.common.schema.utils import new_column, new_table from dlt.common.utils import uniq_id @@ -809,12 +813,13 @@ def github_repo_events_table_meta(page): @dlt.resource -def _get_shuffled_events(): - with open( - "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" - ) as f: - issues = json.load(f) - yield issues +def _get_shuffled_events(repeat: int = 1): + for _ in range(repeat): + with open( + "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + ) as f: + issues = json.load(f) + yield issues @pytest.mark.parametrize("github_resource", (github_repo_events_table_meta, github_repo_events)) @@ -1388,3 +1393,107 @@ def test_remove_pending_packages() -> None: assert pipeline.has_pending_data pipeline.drop_pending_packages() assert pipeline.has_pending_data is False + + +def test_parallel_threads_pipeline() -> None: + init_lock = threading.Lock() + extract_ev = threading.Event() + sem = threading.Semaphore(0) + normalize_ev = threading.Event() + load_ev = threading.Event() + + def _run_pipeline(pipeline_name: str) -> Tuple[LoadInfo, PipelineContext]: + # rotate the files frequently so we have parallel normalize and load + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "10" + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "10" + + @dlt.transformer( + name="github_repo_events", + write_disposition="append", + table_name=lambda i: i["type"], + ) + def github_repo_events(page): + yield page + + @dlt.transformer + async def slow(item): + await asyncio.sleep(0.1) + return item + + @dlt.transformer + @dlt.defer + def slow_func(item): + sleep(0.1) + return item + + # @dlt.resource + # def slow(): + # for i in range(10): + # yield slowly_rotate(i) + + # make sure that only one pipeline is created + with init_lock: + pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") + context = Container()[PipelineContext] + sem.release() + # start every step at the same moment to increase chances of any race conditions to happen + extract_ev.wait() + context_2 = Container()[PipelineContext] + try: + # generate github events, push them through futures and thread pools and then dispatch to separate tables + pipeline.extract(_get_shuffled_events(repeat=2) | slow | slow_func | github_repo_events) + finally: + sem.release() + normalize_ev.wait() + try: + pipeline.normalize(workers=4) + finally: + sem.release() + load_ev.wait() + info = pipeline.load() + + # info = pipeline.run(slow()) + + assert context is context_2 + return info, context + + with ThreadPoolExecutor(max_workers=4) as pool: + f_1 = pool.submit(_run_pipeline, "pipeline_1") + f_2 = pool.submit(_run_pipeline, "pipeline_2") + + sem.acquire() + sem.acquire() + if f_1.done(): + raise f_1.exception() + if f_2.done(): + raise f_2.exception() + extract_ev.set() + sem.acquire() + sem.acquire() + if f_1.done(): + raise f_1.exception() + if f_2.done(): + raise f_2.exception() + normalize_ev.set() + sem.acquire() + sem.acquire() + if f_1.done(): + raise f_1.exception() + if f_2.done(): + raise f_2.exception() + load_ev.set() + + info_1, context_1 = f_1.result() + info_2, context_2 = f_2.result() + + print("EXIT") + print(info_1) + print(info_2) + + counts_1 = context_1.pipeline().last_trace.last_normalize_info # type: ignore + assert counts_1.row_counts["push_event"] == 16 + counts_2 = context_2.pipeline().last_trace.last_normalize_info # type: ignore + assert counts_2.row_counts["push_event"] == 16 + + assert context_1.pipeline().pipeline_name == "pipeline_1" + assert context_2.pipeline().pipeline_name == "pipeline_2"