Skip to content

Commit

Permalink
tests running parallel pipelines in thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Dec 9, 2023
1 parent 5f4a489 commit 2847c5b
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 12 deletions.
8 changes: 6 additions & 2 deletions dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,16 @@ class PipelineContext(ContainerInjectableContext):
_deferred_pipeline: Callable[[], SupportsPipeline]
_pipeline: SupportsPipeline

can_create_default: ClassVar[bool] = False
can_create_default: ClassVar[bool] = True

def pipeline(self) -> SupportsPipeline:
"""Creates or returns exiting pipeline"""
if not self._pipeline:
# delayed pipeline creation
assert self._deferred_pipeline is not None, (
"Deferred pipeline creation function not provided to PipelineContext. Are you"
" calling dlt.pipeline() from another thread?"
)
self.activate(self._deferred_pipeline())
return self._pipeline

Expand All @@ -425,7 +429,7 @@ def deactivate(self) -> None:
self._pipeline._set_context(False)
self._pipeline = None

def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline]) -> None:
def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> None:
"""Initialize the context with a function returning the Pipeline object to allow creation on first use"""
self._deferred_pipeline = deferred_pipeline

Expand Down
6 changes: 4 additions & 2 deletions dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
target=start_background_loop,
args=(self._async_pool,),
daemon=True,
name="DltFuturesThread",
name=Container.thread_pool_prefix() + "futures",
)
self._async_pool_thread.start()

Expand All @@ -737,7 +737,9 @@ def _ensure_thread_pool(self) -> ThreadPoolExecutor:
if self._thread_pool:
return self._thread_pool

self._thread_pool = ThreadPoolExecutor(self.workers)
self._thread_pool = ThreadPoolExecutor(
self.workers, thread_name_prefix=Container.thread_pool_prefix() + "threads"
)
return self._thread_pool

def __enter__(self) -> "PipeIterator":
Expand Down
125 changes: 117 additions & 8 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import itertools
import logging
import os
from typing import Any, Any, cast
from time import sleep
from typing import Any, Tuple, cast
import threading
from tenacity import retry_if_exception, Retrying, stop_after_attempt

import pytest
Expand All @@ -21,7 +25,7 @@
PipelineStateNotAvailable,
UnknownDestinationModule,
)
from dlt.common.pipeline import PipelineContext
from dlt.common.pipeline import LoadInfo, PipelineContext
from dlt.common.runtime.collector import LogCollector
from dlt.common.schema.utils import new_column, new_table
from dlt.common.utils import uniq_id
Expand Down Expand Up @@ -809,12 +813,13 @@ def github_repo_events_table_meta(page):


@dlt.resource
def _get_shuffled_events():
with open(
"tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8"
) as f:
issues = json.load(f)
yield issues
def _get_shuffled_events(repeat: int = 1):
for _ in range(repeat):
with open(
"tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8"
) as f:
issues = json.load(f)
yield issues


@pytest.mark.parametrize("github_resource", (github_repo_events_table_meta, github_repo_events))
Expand Down Expand Up @@ -1388,3 +1393,107 @@ def test_remove_pending_packages() -> None:
assert pipeline.has_pending_data
pipeline.drop_pending_packages()
assert pipeline.has_pending_data is False


def test_parallel_threads_pipeline() -> None:
init_lock = threading.Lock()
extract_ev = threading.Event()
sem = threading.Semaphore(0)
normalize_ev = threading.Event()
load_ev = threading.Event()

def _run_pipeline(pipeline_name: str) -> Tuple[LoadInfo, PipelineContext]:
# rotate the files frequently so we have parallel normalize and load
os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "10"
os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "10"

@dlt.transformer(
name="github_repo_events",
write_disposition="append",
table_name=lambda i: i["type"],
)
def github_repo_events(page):
yield page

@dlt.transformer
async def slow(item):
await asyncio.sleep(0.1)
return item

@dlt.transformer
@dlt.defer
def slow_func(item):
sleep(0.1)
return item

# @dlt.resource
# def slow():
# for i in range(10):
# yield slowly_rotate(i)

# make sure that only one pipeline is created
with init_lock:
pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb")
context = Container()[PipelineContext]
sem.release()
# start every step at the same moment to increase chances of any race conditions to happen
extract_ev.wait()
context_2 = Container()[PipelineContext]
try:
# generate github events, push them through futures and thread pools and then dispatch to separate tables
pipeline.extract(_get_shuffled_events(repeat=2) | slow | slow_func | github_repo_events)
finally:
sem.release()
normalize_ev.wait()
try:
pipeline.normalize(workers=4)
finally:
sem.release()
load_ev.wait()
info = pipeline.load()

# info = pipeline.run(slow())

assert context is context_2
return info, context

with ThreadPoolExecutor(max_workers=4) as pool:
f_1 = pool.submit(_run_pipeline, "pipeline_1")
f_2 = pool.submit(_run_pipeline, "pipeline_2")

sem.acquire()
sem.acquire()
if f_1.done():
raise f_1.exception()
if f_2.done():
raise f_2.exception()
extract_ev.set()
sem.acquire()
sem.acquire()
if f_1.done():
raise f_1.exception()
if f_2.done():
raise f_2.exception()
normalize_ev.set()
sem.acquire()
sem.acquire()
if f_1.done():
raise f_1.exception()
if f_2.done():
raise f_2.exception()
load_ev.set()

info_1, context_1 = f_1.result()
info_2, context_2 = f_2.result()

print("EXIT")
print(info_1)
print(info_2)

counts_1 = context_1.pipeline().last_trace.last_normalize_info # type: ignore
assert counts_1.row_counts["push_event"] == 16
counts_2 = context_2.pipeline().last_trace.last_normalize_info # type: ignore
assert counts_2.row_counts["push_event"] == 16

assert context_1.pipeline().pipeline_name == "pipeline_1"
assert context_2.pipeline().pipeline_name == "pipeline_2"

0 comments on commit 2847c5b

Please sign in to comment.