Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 22, 2024
1 parent d52558b commit 07fd59c
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 31 deletions.
53 changes: 30 additions & 23 deletions dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
simulate_func_call,
wrap_compat_transformer,
wrap_resource_gen,
wrap_async_generator,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -321,6 +322,9 @@ def evaluate_gen(self) -> None:
# verify if transformer can be called
self._ensure_transform_step(self._gen_idx, gen)

# ensure that asyn gens are wrapped
self.replace_gen(wrap_async_generator(self.gen))

# evaluate transforms
for step_no, step in enumerate(self._steps):
# print(f"pipe {self.name} step no {step_no} step({step})")
Expand Down Expand Up @@ -366,7 +370,7 @@ def _wrap_gen(self, *args: Any, **kwargs: Any) -> Any:

def _verify_head_step(self, step: TPipeStep) -> None:
# first element must be Iterable, Iterator or Callable in resource pipe
if not isinstance(step, (Iterable, Iterator)) and not callable(step):
if not isinstance(step, (Iterable, Iterator, AsyncIterator)) and not callable(step):
raise CreatePipeException(
self.name, "A head of a resource pipe must be Iterable, Iterator or a Callable"
)
Expand Down Expand Up @@ -791,6 +795,27 @@ def _get_source_item(self) -> ResolvablePipeItem:
elif self._next_item_mode == "round_robin":
return self._get_source_item_round_robin()

def _get_next_item_from_generator(
self,
gen: Any,
step: int,
pipe: Pipe,
meta: Any,
) -> ResolvablePipeItem:
item: ResolvablePipeItem = next(gen)
if not item:
return item
# full pipe item may be returned, this is used by ForkPipe step
# to redirect execution of an item to another pipe
if isinstance(item, ResolvablePipeItem):
return item
else:
# keep the item assigned step and pipe when creating resolvable item
if isinstance(item, DataItemWithMeta):
return ResolvablePipeItem(item.data, step, pipe, item.meta)
else:
return ResolvablePipeItem(item, step, pipe, meta)

def _get_source_item_current(self) -> ResolvablePipeItem:
# no more sources to iterate
if len(self._sources) == 0:
Expand All @@ -803,17 +828,8 @@ def _get_source_item_current(self) -> ResolvablePipeItem:
set_current_pipe_name(pipe.name)
item = None
while item is None:
item = next(gen)
# full pipe item may be returned, this is used by ForkPipe step
# to redirect execution of an item to another pipe
if isinstance(item, ResolvablePipeItem):
return item
else:
# keep the item assigned step and pipe when creating resolvable item
if isinstance(item, DataItemWithMeta):
return ResolvablePipeItem(item.data, step, pipe, item.meta)
else:
return ResolvablePipeItem(item, step, pipe, meta)
item = self._get_next_item_from_generator(gen, step, pipe, meta)
return item
except StopIteration:
# remove empty iterator and try another source
self._sources.pop()
Expand All @@ -839,17 +855,8 @@ def _get_source_item_round_robin(self) -> ResolvablePipeItem:
self._round_robin_index = (self._round_robin_index + 1) % sources_count
gen, step, pipe, meta = self._sources[self._round_robin_index]
set_current_pipe_name(pipe.name)
item = next(gen)
# full pipe item may be returned, this is used by ForkPipe step
# to redirect execution of an item to another pipe
if isinstance(item, ResolvablePipeItem):
return item
else:
# keep the item assigned step and pipe when creating resolvable item
if isinstance(item, DataItemWithMeta):
return ResolvablePipeItem(item.data, step, pipe, item.meta)
else:
return ResolvablePipeItem(item, step, pipe, meta)
item = self._get_next_item_from_generator(gen, step, pipe, meta)
return item
except StopIteration:
# remove empty iterator and try another source
self._sources.pop(self._round_robin_index)
Expand Down
4 changes: 1 addition & 3 deletions dlt/extract/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def from_data(
data = wrap_additional_type(data)

# several iterable types are not allowed and must be excluded right away
if isinstance(data, (AsyncIterator, AsyncIterable)):
raise InvalidResourceDataTypeAsync(name, data, type(data))
if isinstance(data, (str, dict)):
raise InvalidResourceDataTypeBasic(name, data, type(data))

Expand All @@ -135,7 +133,7 @@ def from_data(
parent_pipe = DltResource._get_parent_pipe(name, data_from)

# create resource from iterator, iterable or generator function
if isinstance(data, (Iterable, Iterator)) or callable(data):
if isinstance(data, (Iterable, Iterator, AsyncIterable)) or callable(data):
pipe = Pipe.from_data(name, data, parent=parent_pipe)
return cls(
pipe,
Expand Down
32 changes: 29 additions & 3 deletions dlt/extract/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import makefun
from typing import Optional, Tuple, Union, List, Any, Sequence, cast
from typing import Optional, Tuple, Union, List, Any, Sequence, cast, Iterator
from collections.abc import Mapping as C_Mapping

from dlt.common.exceptions import MissingDependencyException
Expand Down Expand Up @@ -119,6 +119,28 @@ def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> in
return meta_arg


def wrap_async_generator(wrapped) -> Any:
"""Wraps an async generator into a list with one awaitable"""
if inspect.isasyncgen(wrapped):

async def run() -> List[TDataItem]:
result: List[TDataItem] = []
try:
item: TDataItems = None
while item := await wrapped.__anext__():
if isinstance(item, Iterator):
result.extend(item)
else:
result.append(item)
except StopAsyncIteration:
pass
return result

yield run()
else:
return wrapped


def wrap_compat_transformer(
name: str, f: AnyFun, sig: inspect.Signature, *args: Any, **kwargs: Any
) -> AnyFun:
Expand All @@ -142,8 +164,12 @@ def wrap_resource_gen(
name: str, f: AnyFun, sig: inspect.Signature, *args: Any, **kwargs: Any
) -> AnyFun:
"""Wraps a generator or generator function so it is evaluated on extraction"""
if inspect.isgeneratorfunction(inspect.unwrap(f)) or inspect.isgenerator(f):
# always wrap generators and generator functions. evaluate only at runtime!

if (
inspect.isgeneratorfunction(inspect.unwrap(f))
or inspect.isgenerator(f)
or inspect.isasyncgenfunction(f)
):

def _partial() -> Any:
# print(f"_PARTIAL: {args} {kwargs}")
Expand Down
22 changes: 22 additions & 0 deletions pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

import dlt, asyncio

# @dlt.resource(table_name="hello")
async def async_gen_resource(idx):
for l in ["a", "b", "c"] * 3:
await asyncio.sleep(0.1)
yield {"async_gen": idx, "letter": l}

pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True)
pipeline_1.run(
async_gen_resource(10), table_name="hello"
)
with pipeline_1.sql_client() as c:
with c.execute_query("SELECT * FROM hello") as cur:
rows = list(cur.fetchall())
for r in rows:
print(r)

# pipeline_1.run(
# async_gen_resource(11)
# )
24 changes: 24 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

import asyncio, inspect
from typing import Awaitable
from asyncio import Future

async def async_gen_resource():
for l in ["a", "b", "c"]:
# await asyncio.sleep(0.1)
yield {"async_gen": 1, "letter": l}


async def run() -> None:
gen = async_gen_resource()
result = []
try:
while item := await gen.__anext__():
result.append(item)#
except StopAsyncIteration:
return [result]


if __name__ == "__main__":
loop = asyncio.get_event_loop()
print(loop.run_until_complete(run()))
4 changes: 2 additions & 2 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,7 @@ def api_fetch(page_num):
assert pipeline.last_trace.last_normalize_info.row_counts["product"] == 12


@pytest.mark.skip("skipped until async generators are implemented")
# @pytest.mark.skip("skipped until async generators are implemented")
def test_async_generator() -> None:
def async_inner_table():
async def _gen(idx):
Expand All @@ -1656,4 +1656,4 @@ async def async_gen_resource(idx):

pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True)
pipeline_1.run(async_gen_resource(10))
pipeline_1.run(async_gen_table(11))
pipeline_1.run(async_gen_table(11), table_name="async_gen_table")

0 comments on commit 07fd59c

Please sign in to comment.