Skip to content

Commit

Permalink
fix async error test
Browse files Browse the repository at this point in the history
test async iterator
  • Loading branch information
sh-rp committed Jan 26, 2024
1 parent 435239d commit 3787c63
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
3 changes: 1 addition & 2 deletions dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def evaluate_gen(self) -> None:
self._ensure_transform_step(self._gen_idx, gen)

# wrap async generator
if inspect.isasyncgen(self.gen):
if isinstance(self.gen, AsyncIterator):
self.replace_gen(wrap_async_iterator(self.gen))

# evaluate transforms
Expand Down Expand Up @@ -636,7 +636,6 @@ def __next__(self) -> PipeItem:
if len(self._futures) < self.max_parallel_items or self._next_future() >= 0:
# check if Awaitable first - awaitable can also be a callable
if isinstance(item, Awaitable):
print("schedule")
future = asyncio.run_coroutine_threadsafe(item, self._ensure_async_pool())
elif callable(item):
future = self._ensure_thread_pool().submit(item)
Expand Down
5 changes: 3 additions & 2 deletions tests/extract/test_extract_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,10 +698,11 @@ async def long_gen():
try:
close_pipe_yielding = True
for i in range(0, 10000):
asyncio.sleep(0.01)
await asyncio.sleep(0.01)
yield i
close_pipe_yielding = False
except GeneratorExit:
# we have a different exception here
except asyncio.CancelledError:
close_pipe_got_exit = True

def raise_gen(item: int):
Expand Down
29 changes: 29 additions & 0 deletions tests/pipeline/test_resources_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,35 @@
from functools import wraps


def test_async_iterator_resource() -> None:
# define an asynchronous iterator
class AsyncIterator:
def __init__(self):
self.counter = 0

def __aiter__(self):
return self

# return the next awaitable
async def __anext__(self):
# check for no further items
if self.counter >= 5:
raise StopAsyncIteration
# increment the counter
self.counter += 1
# simulate work
await asyncio.sleep(0.1)
# return the counter value
return {"i": self.counter}

pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True)
pipeline_1.run(AsyncIterator, table_name="async")
with pipeline_1.sql_client() as c:
with c.execute_query("SELECT * FROM async") as cur:
rows = list(cur.fetchall())
assert [r[0] for r in rows] == [1, 2, 3, 4, 5]


#
# async generators resource tests
#
Expand Down

0 comments on commit 3787c63

Please sign in to comment.