Skip to content

Commit

Permalink
adds experiment for parallelizing regular resources
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 24, 2024
1 parent 28a6a12 commit 21c6db3
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 6 deletions.
8 changes: 6 additions & 2 deletions dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def __next__(self) -> PipeItem:
if len(self._futures) < self.max_parallel_items or self._next_future() >= 0:
# check if Awaitable first - awaitable can also be a callable
if isinstance(item, Awaitable):
print("schedule")
future = asyncio.run_coroutine_threadsafe(item, self._ensure_async_pool())
elif callable(item):
future = self._ensure_thread_pool().submit(item)
Expand Down Expand Up @@ -800,7 +801,11 @@ def _resolve_futures(self) -> ResolvablePipeItem:
raise ResourceExtractionError(pipe.name, future, str(ex), "future") from ex

item = future.result()
if isinstance(item, DataItemWithMeta):

# we also interpret future items that are None to not be value to be consumed
if item is None:
return None
elif isinstance(item, DataItemWithMeta):
return ResolvablePipeItem(item.data, step, pipe, item.meta)
else:
return ResolvablePipeItem(item, step, pipe, meta)
Expand All @@ -816,7 +821,6 @@ def _get_source_item(self) -> ResolvablePipeItem:
self._current_source_index = -1
first_evaluated_index = -1
while True:
print(self._current_source_index)
self._current_source_index = (self._current_source_index + 1) % sources_count
# if we have checked all sources once and all returned None, then we can sleep a bit
if self._current_source_index == first_evaluated_index:
Expand Down
2 changes: 1 addition & 1 deletion dlt/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def run() -> TDataItems:
exhausted = True
raise

# this generator yields None while the async generator is not exhauste
# this generator yields None while the async generator is not exhausted
try:
while not exhausted:
while lock.locked():
Expand Down
101 changes: 98 additions & 3 deletions tests/pipeline/test_resources_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import dlt, asyncio, pytest, os
from typing import Any

import dlt, asyncio, pytest, os, threading, inspect, time
from functools import wraps


#
Expand Down Expand Up @@ -96,14 +99,14 @@ def test_parallel_async_generators(next_item_mode: str, resource_mode: str) -> N
execution_order = []

@dlt.resource(table_name="table1")
async def sync_resource1():
def sync_resource1():
for l_ in ["a", "b", "c"]:
nonlocal execution_order
execution_order.append("one")
yield {"letter": l_}

@dlt.resource(table_name="table2")
async def sync_resource2():
def sync_resource2():
for l_ in ["e", "f", "g"]:
nonlocal execution_order
execution_order.append("two")
Expand Down Expand Up @@ -183,3 +186,95 @@ async def async_resource1():
with c.execute_query("SELECT * FROM table1") as cur:
rows = list(cur.fetchall())
assert len(rows) == 13


@pytest.mark.parametrize("parallelized", [True, False])
def test_async_decorator_experiment(parallelized) -> None:
os.environ["EXTRACT__NEXT_ITEM_MODE"] = "fifo"
execution_order = []
threads = set()

def parallelize(f) -> Any:
exhausted = False
lock = threading.Lock()

"""converts regular itarable to generator of functions that can be run in parallel in the pipe"""
@wraps(f)
def _wrap(*args: Any, **kwargs: Any) -> Any:
gen = f(*args, **kwargs)
# unpack generator
if inspect.isfunction(gen):
gen = gen()
# if we have an async gen, no further action is needed
if inspect.isasyncgen(gen):
raise Exception("Already async gen")

# get next item from generator
def _gen():
nonlocal exhausted
with lock:
# await asyncio.sleep(0.1)
try:
return next(gen)
# on stop iteration mark as exhausted
except StopIteration:
exhausted = True
return None
try:
while not exhausted:
while lock.locked():
yield None
yield _gen
except GeneratorExit:
# clean up inner generator
gen.close()

return _wrap

@parallelize
def resource1():
for l_ in ["a", "b", "c"]:
time.sleep(0.1)
nonlocal execution_order
execution_order.append("one")
threads.add(threading.get_ident())
yield {"letter": l_}

@parallelize
def resource2():
time.sleep(0.05)
for l_ in ["e", "f", "g"]:
time.sleep(0.1)
nonlocal execution_order
execution_order.append("two")
threads.add(threading.get_ident())
yield {"letter": l_}

@dlt.source
def source():
if parallelized:
return [resource1(), resource2()]
else: # return unwrapped resources
return [resource1.__wrapped__(), resource2.__wrapped__()]

pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True)
pipeline_1.run(source())

# all records should be here
with pipeline_1.sql_client() as c:
with c.execute_query("SELECT * FROM resource1") as cur:
rows = list(cur.fetchall())
assert len(rows) == 3
assert {r[0] for r in rows} == {"a", "b", "c"}

with c.execute_query("SELECT * FROM resource2") as cur:
rows = list(cur.fetchall())
assert len(rows) == 3
assert {r[0] for r in rows} == {"e", "f", "g"}

if parallelized:
assert len(threads) > 1
assert execution_order == ["one", "two", "one", "two", "one", "two"]
else:
assert execution_order == ["one", "one", "one", "two", "two", "two"]
assert len(threads) == 1

0 comments on commit 21c6db3

Please sign in to comment.