Skip to content

Commit

Permalink
enable nested generator and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 23, 2024
1 parent 07fd59c commit 8c8a94f
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 37 deletions.
17 changes: 14 additions & 3 deletions dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,9 @@ def evaluate_gen(self) -> None:
# verify if transformer can be called
self._ensure_transform_step(self._gen_idx, gen)

# ensure that asyn gens are wrapped
self.replace_gen(wrap_async_generator(self.gen))
# wrap async generator
if inspect.isasyncgen(self.gen):
self.replace_gen(wrap_async_generator(self.gen))

# evaluate transforms
for step_no, step in enumerate(self._steps):
Expand Down Expand Up @@ -623,6 +624,16 @@ def __next__(self) -> PipeItem:
pipe_item = None
continue

# handle async iterator items as new source
if inspect.isasyncgen(item):
self._sources.append(
SourcePipeItem(
wrap_async_generator(item), pipe_item.step, pipe_item.pipe, pipe_item.meta
)
)
pipe_item = None
continue

if isinstance(item, Awaitable) or callable(item):
# do we have a free slot or one of the slots is done?
if len(self._futures) < self.max_parallel_items or self._next_future() >= 0:
Expand Down Expand Up @@ -803,7 +814,7 @@ def _get_next_item_from_generator(
meta: Any,
) -> ResolvablePipeItem:
item: ResolvablePipeItem = next(gen)
if not item:
if item is None:
return item
# full pipe item may be returned, this is used by ForkPipe step
# to redirect execution of an item to another pipe
Expand Down
35 changes: 16 additions & 19 deletions dlt/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,26 +119,23 @@ def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> in
return meta_arg


def wrap_async_generator(wrapped) -> Any:
def wrap_async_generator(f: Any) -> Any:
"""Wraps an async generator into a list with one awaitable"""
if inspect.isasyncgen(wrapped):

async def run() -> List[TDataItem]:
result: List[TDataItem] = []
try:
item: TDataItems = None
while item := await wrapped.__anext__():
if isinstance(item, Iterator):
result.extend(item)
else:
result.append(item)
except StopAsyncIteration:
pass
return result

yield run()
else:
return wrapped

async def run() -> List[TDataItem]:
result: List[TDataItem] = []
try:
item: TDataItems = None
while item := await f.__anext__():
if isinstance(item, Iterator):
result.extend(item)
else:
result.append(item)
except StopAsyncIteration:
pass
return result

yield run()


def wrap_compat_transformer(
Expand Down
2 changes: 1 addition & 1 deletion pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import dlt, asyncio

# @dlt.resource(table_name="hello")
@dlt.resource(table_name="hello")
async def async_gen_resource(idx):
for l in ["a", "b", "c"] * 3:
await asyncio.sleep(0.1)
Expand Down
85 changes: 71 additions & 14 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,29 +1631,86 @@ def api_fetch(page_num):
assert pipeline.last_trace.last_normalize_info.row_counts["product"] == 12


# @pytest.mark.skip("skipped until async generators are implemented")
def test_async_generator() -> None:
#
# async generators resource tests
#
def test_async_generator_resource() -> None:
async def async_gen_table():
for l_ in ["a", "b", "c"]:
await asyncio.sleep(0.1)
yield {"letter": l_}

@dlt.resource
async def async_gen_resource():
for l_ in ["d", "e", "f"]:
await asyncio.sleep(0.1)
yield {"letter": l_}

pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True)

# pure async function
pipeline_1.run(async_gen_table(), table_name="async")
with pipeline_1.sql_client() as c:
with c.execute_query("SELECT * FROM async") as cur:
rows = list(cur.fetchall())
assert [r[0] for r in rows] == ["a", "b", "c"]

# async resource
pipeline_1.run(async_gen_resource(), table_name="async")
with pipeline_1.sql_client() as c:
with c.execute_query("SELECT * FROM async") as cur:
rows = list(cur.fetchall())
assert [r[0] for r in rows] == ["a", "b", "c", "d", "e", "f"]


def test_async_generator_nested() -> None:
def async_inner_table():
async def _gen(idx):
for l_ in ["a", "b", "c"]:
await asyncio.sleep(1)
await asyncio.sleep(0.1)
yield {"async_gen": idx, "letter": l_}

# just yield futures in a loop
for idx_ in range(10):
for idx_ in range(3):
yield _gen(idx_)

async def async_gen_table(idx):
for l_ in ["a", "b", "c"]:
await asyncio.sleep(1)
yield {"async_gen": idx, "letter": l_}

pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True)
pipeline_1.run(async_inner_table(), table_name="async")
with pipeline_1.sql_client() as c:
with c.execute_query("SELECT * FROM async") as cur:
rows = list(cur.fetchall())
assert [(r[0], r[1]) for r in rows] == [
(0, "a"),
(0, "b"),
(0, "c"),
(1, "a"),
(1, "b"),
(1, "c"),
(2, "a"),
(2, "b"),
(2, "c"),
]


def test_async_generator_transformer() -> None:
@dlt.resource
async def async_gen_resource(idx):
async def async_resource():
for l_ in ["a", "b", "c"]:
await asyncio.sleep(1)
yield {"async_gen": idx, "letter": l_}
await asyncio.sleep(0.1)
yield {"letter": l_}

@dlt.transformer(data_from=async_resource)
async def async_transformer(items):
for item in items:
await asyncio.sleep(0.1)
yield {
"letter": item["letter"] + "t",
}

pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True)
pipeline_1.run(async_gen_resource(10))
pipeline_1.run(async_gen_table(11), table_name="async_gen_table")
pipeline_1.run(async_transformer(), table_name="async")

with pipeline_1.sql_client() as c:
with c.execute_query("SELECT * FROM async") as cur:
rows = list(cur.fetchall())
assert [r[0] for r in rows] == ["at", "bt", "ct"]

0 comments on commit 8c8a94f

Please sign in to comment.