-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable async generators as resources #905
Changes from 25 commits
07fd59c
8c8a94f
a07e37f
e4ca5c3
b8396a6
79a42ed
dbea27f
28a6a12
21c6db3
a151428
0211d3d
c087ebf
a7bf8e0
3c047a2
8d81d99
435239d
3787c63
05d0c55
614b80b
fb9c564
4a61e60
9446e29
d830086
b844231
568a2ce
2a168e5
8cf3c3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,7 @@ | |
simulate_func_call, | ||
wrap_compat_transformer, | ||
wrap_resource_gen, | ||
wrap_async_iterator, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -108,7 +109,7 @@ class SourcePipeItem(NamedTuple): | |
Callable[[TDataItems], Iterator[ResolvablePipeItem]], | ||
] | ||
|
||
TPipeNextItemMode = Union[Literal["fifo"], Literal["round_robin"]] | ||
TPipeNextItemMode = Literal["fifo", "round_robin"] | ||
|
||
|
||
class ForkPipe: | ||
|
@@ -321,6 +322,10 @@ def evaluate_gen(self) -> None: | |
# verify if transformer can be called | ||
self._ensure_transform_step(self._gen_idx, gen) | ||
|
||
# wrap async generator | ||
if isinstance(self.gen, AsyncIterator): | ||
self.replace_gen(wrap_async_iterator(self.gen)) | ||
|
||
# evaluate transforms | ||
for step_no, step in enumerate(self._steps): | ||
# print(f"pipe {self.name} step no {step_no} step({step})") | ||
|
@@ -366,9 +371,10 @@ def _wrap_gen(self, *args: Any, **kwargs: Any) -> Any: | |
|
||
def _verify_head_step(self, step: TPipeStep) -> None: | ||
# first element must be Iterable, Iterator or Callable in resource pipe | ||
if not isinstance(step, (Iterable, Iterator)) and not callable(step): | ||
if not isinstance(step, (Iterable, Iterator, AsyncIterator)) and not callable(step): | ||
raise CreatePipeException( | ||
self.name, "A head of a resource pipe must be Iterable, Iterator or a Callable" | ||
self.name, | ||
"A head of a resource pipe must be Iterable, Iterator, AsyncIterator or a Callable", | ||
) | ||
|
||
def _wrap_transform_step_meta(self, step_no: int, step: TPipeStep) -> TPipeStep: | ||
|
@@ -498,20 +504,20 @@ def __init__( | |
max_parallel_items: int, | ||
workers: int, | ||
futures_poll_interval: float, | ||
sources: List[SourcePipeItem], | ||
next_item_mode: TPipeNextItemMode, | ||
) -> None: | ||
self.max_parallel_items = max_parallel_items | ||
self.workers = workers | ||
self.futures_poll_interval = futures_poll_interval | ||
|
||
self._round_robin_index: int = -1 | ||
self._initial_sources_count: int = 0 | ||
self._async_pool: asyncio.AbstractEventLoop = None | ||
self._async_pool_thread: Thread = None | ||
self._thread_pool: ThreadPoolExecutor = None | ||
self._sources: List[SourcePipeItem] = [] | ||
self._sources = sources | ||
self._futures: List[FuturePipeItem] = [] | ||
self._next_item_mode = next_item_mode | ||
self._next_item_mode: TPipeNextItemMode = next_item_mode | ||
self._initial_sources_count = len(sources) | ||
self._current_source_index: int = 0 | ||
|
||
@classmethod | ||
@with_config(spec=PipeIteratorConfiguration) | ||
|
@@ -533,12 +539,10 @@ def from_pipe( | |
pipe.evaluate_gen() | ||
if not isinstance(pipe.gen, Iterator): | ||
raise PipeGenInvalid(pipe.name, pipe.gen) | ||
|
||
# create extractor | ||
extract = cls(max_parallel_items, workers, futures_poll_interval, next_item_mode) | ||
# add as first source | ||
extract._sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) | ||
cls._initial_sources_count = 1 | ||
return extract | ||
sources = [SourcePipeItem(pipe.gen, 0, pipe, None)] | ||
return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) | ||
|
||
@classmethod | ||
@with_config(spec=PipeIteratorConfiguration) | ||
|
@@ -554,7 +558,8 @@ def from_pipes( | |
next_item_mode: TPipeNextItemMode = "fifo", | ||
) -> "PipeIterator": | ||
# print(f"max_parallel_items: {max_parallel_items} workers: {workers}") | ||
extract = cls(max_parallel_items, workers, futures_poll_interval, next_item_mode) | ||
sources: List[SourcePipeItem] = [] | ||
|
||
# clone all pipes before iterating (recursively) as we will fork them (this add steps) and evaluate gens | ||
pipes, _ = PipeIterator.clone_pipes(pipes) | ||
|
||
|
@@ -574,18 +579,16 @@ def _fork_pipeline(pipe: Pipe) -> None: | |
if not isinstance(pipe.gen, Iterator): | ||
raise PipeGenInvalid(pipe.name, pipe.gen) | ||
# add every head as source only once | ||
if not any(i.pipe == pipe for i in extract._sources): | ||
extract._sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) | ||
if not any(i.pipe == pipe for i in sources): | ||
sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) | ||
|
||
# reverse pipes for current mode, as we start processing from the back | ||
if next_item_mode == "fifo": | ||
pipes.reverse() | ||
pipes.reverse() | ||
for pipe in pipes: | ||
_fork_pipeline(pipe) | ||
|
||
extract._initial_sources_count = len(extract._sources) | ||
|
||
return extract | ||
# create extractor | ||
return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) | ||
|
||
def __next__(self) -> PipeItem: | ||
pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None | ||
|
@@ -619,6 +622,16 @@ def __next__(self) -> PipeItem: | |
pipe_item = None | ||
continue | ||
|
||
# handle async iterator items as new source | ||
if isinstance(item, AsyncIterator): | ||
self._sources.append( | ||
SourcePipeItem( | ||
wrap_async_iterator(item), pipe_item.step, pipe_item.pipe, pipe_item.meta | ||
), | ||
) | ||
pipe_item = None | ||
continue | ||
|
||
if isinstance(item, Awaitable) or callable(item): | ||
# do we have a free slot or one of the slots is done? | ||
if len(self._futures) < self.max_parallel_items or self._next_future() >= 0: | ||
|
@@ -689,20 +702,25 @@ def close(self) -> None: | |
def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: | ||
loop.stop() | ||
|
||
# stop all futures | ||
for f, _, _, _ in self._futures: | ||
if not f.done(): | ||
f.cancel() | ||
self._futures.clear() | ||
|
||
# close all generators | ||
for gen, _, _, _ in self._sources: | ||
if inspect.isgenerator(gen): | ||
gen.close() | ||
self._sources.clear() | ||
|
||
# print("stopping loop") | ||
# stop all futures | ||
for f, _, _, _ in self._futures: | ||
if not f.done(): | ||
f.cancel() | ||
|
||
# let tasks cancel | ||
if self._async_pool: | ||
# wait for all async generators to be closed | ||
future = asyncio.run_coroutine_threadsafe( | ||
self._async_pool.shutdown_asyncgens(), self._ensure_async_pool() | ||
) | ||
while not future.done(): | ||
sleep(self.futures_poll_interval) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not 100% certain i implemented this right, but i think i did. we need to shut down all the open gens in this pool. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM! I'm just worried that our current
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rudolfix I am not sure what you mean with "async". there is a test called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean async iterators? |
||
self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) | ||
# print("joining thread") | ||
self._async_pool_thread.join() | ||
|
@@ -712,6 +730,8 @@ def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: | |
self._thread_pool.shutdown(wait=True) | ||
self._thread_pool = None | ||
|
||
self._futures.clear() | ||
|
||
def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: | ||
# lazily create async pool is separate thread | ||
if self._async_pool: | ||
|
@@ -773,91 +793,70 @@ def _resolve_futures(self) -> ResolvablePipeItem: | |
|
||
if future.exception(): | ||
ex = future.exception() | ||
if isinstance(ex, StopAsyncIteration): | ||
return None | ||
if isinstance( | ||
ex, (PipelineException, ExtractorException, DltSourceException, PipeException) | ||
): | ||
raise ex | ||
raise ResourceExtractionError(pipe.name, future, str(ex), "future") from ex | ||
|
||
item = future.result() | ||
if isinstance(item, DataItemWithMeta): | ||
|
||
# we also interpret future items that are None to not be value to be consumed | ||
if item is None: | ||
return None | ||
elif isinstance(item, DataItemWithMeta): | ||
return ResolvablePipeItem(item.data, step, pipe, item.meta) | ||
else: | ||
return ResolvablePipeItem(item, step, pipe, meta) | ||
|
||
def _get_source_item(self) -> ResolvablePipeItem: | ||
if self._next_item_mode == "fifo": | ||
return self._get_source_item_current() | ||
elif self._next_item_mode == "round_robin": | ||
return self._get_source_item_round_robin() | ||
|
||
def _get_source_item_current(self) -> ResolvablePipeItem: | ||
# no more sources to iterate | ||
if len(self._sources) == 0: | ||
return None | ||
try: | ||
# get items from last added iterator, this makes the overall Pipe as close to FIFO as possible | ||
gen, step, pipe, meta = self._sources[-1] | ||
# print(f"got {pipe.name}") | ||
# register current pipe name during the execution of gen | ||
set_current_pipe_name(pipe.name) | ||
item = None | ||
while item is None: | ||
item = next(gen) | ||
# full pipe item may be returned, this is used by ForkPipe step | ||
# to redirect execution of an item to another pipe | ||
if isinstance(item, ResolvablePipeItem): | ||
return item | ||
else: | ||
# keep the item assigned step and pipe when creating resolvable item | ||
if isinstance(item, DataItemWithMeta): | ||
return ResolvablePipeItem(item.data, step, pipe, item.meta) | ||
else: | ||
return ResolvablePipeItem(item, step, pipe, meta) | ||
except StopIteration: | ||
# remove empty iterator and try another source | ||
self._sources.pop() | ||
return self._get_source_item() | ||
except (PipelineException, ExtractorException, DltSourceException, PipeException): | ||
raise | ||
except Exception as ex: | ||
raise ResourceExtractionError(pipe.name, gen, str(ex), "generator") from ex | ||
|
||
def _get_source_item_round_robin(self) -> ResolvablePipeItem: | ||
sources_count = len(self._sources) | ||
# no more sources to iterate | ||
if sources_count == 0: | ||
return None | ||
# if there are currently more sources than added initially, we need to process the new ones first | ||
if sources_count > self._initial_sources_count: | ||
return self._get_source_item_current() | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function now nicely combines fifo and round_robin. in fifo mode it says on the first source and only ventures into the next ones if that returns none. It would be quite easy to switch it back to the old behavior though. I removed this part that switches from round robin to fifo in some cases as it does not really make sense anymore imho if fifo also can switch the source index. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be very easy to keep it by adding this condition to line 820. my worry here is that if let's say we have a resource that feeds item to a transformer and it is itself a generator, we generate million items, and this will produce million of source slots. my take is that we switch to FIFO mode when sources_count - self._initial_sources_count > self.max_parallel_items to exhaust new generators There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I implemented it this way (and corrected the decreasing of initial sources counter along the way). Now there is a diferentiation between fifo and strict_fifo, I think this is necessary to prevent a scenario as you describe it. |
||
# print(f"got {pipe.name}") | ||
# register current pipe name during the execution of gen | ||
item = None | ||
while item is None: | ||
self._round_robin_index = (self._round_robin_index + 1) % sources_count | ||
gen, step, pipe, meta = self._sources[self._round_robin_index] | ||
set_current_pipe_name(pipe.name) | ||
item = next(gen) | ||
# full pipe item may be returned, this is used by ForkPipe step | ||
# to redirect execution of an item to another pipe | ||
if isinstance(item, ResolvablePipeItem): | ||
return item | ||
first_evaluated_index: int = None | ||
# always reset to end of list for fifo mode, also take into account that new sources can be added | ||
# if too many new sources is added we switch to fifo not to exhaust them | ||
if ( | ||
self._next_item_mode == "fifo" | ||
or (sources_count - self._initial_sources_count) >= self.max_parallel_items | ||
): | ||
self._current_source_index = sources_count - 1 | ||
else: | ||
# keep the item assigned step and pipe when creating resolvable item | ||
if isinstance(item, DataItemWithMeta): | ||
return ResolvablePipeItem(item.data, step, pipe, item.meta) | ||
else: | ||
return ResolvablePipeItem(item, step, pipe, meta) | ||
self._current_source_index = (self._current_source_index - 1) % sources_count | ||
while True: | ||
# if we have checked all sources once and all returned None, then we can sleep a bit | ||
if self._current_source_index == first_evaluated_index: | ||
sleep(self.futures_poll_interval) | ||
# get next item from the current source | ||
gen, step, pipe, meta = self._sources[self._current_source_index] | ||
set_current_pipe_name(pipe.name) | ||
if (item := next(gen)) is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo this already will check for not None as |
||
# full pipe item may be returned, this is used by ForkPipe step | ||
# to redirect execution of an item to another pipe | ||
if isinstance(item, ResolvablePipeItem): | ||
return item | ||
else: | ||
# keep the item assigned step and pipe when creating resolvable item | ||
if isinstance(item, DataItemWithMeta): | ||
return ResolvablePipeItem(item.data, step, pipe, item.meta) | ||
else: | ||
return ResolvablePipeItem(item, step, pipe, meta) | ||
# remember the first evaluated index | ||
if first_evaluated_index is None: | ||
first_evaluated_index = self._current_source_index | ||
# always go round robin if None was returned | ||
self._current_source_index = (self._current_source_index - 1) % sources_count | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we also check if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no! check how the modulo operator works for negative numbers :) |
||
except StopIteration: | ||
# remove empty iterator and try another source | ||
self._sources.pop(self._round_robin_index) | ||
# we need to decrease the index to keep the round robin order | ||
self._round_robin_index -= 1 | ||
# since in this case we have popped an initial source, we need to decrease the initial sources count | ||
self._initial_sources_count -= 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is actually a bug i would say.. |
||
return self._get_source_item_round_robin() | ||
self._sources.pop(self._current_source_index) | ||
# decrease initial source count if we popped an initial source | ||
if self._current_source_index < self._initial_sources_count: | ||
self._initial_sources_count -= 1 | ||
return self._get_source_item() | ||
except (PipelineException, ExtractorException, DltSourceException, PipeException): | ||
raise | ||
except Exception as ex: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from copy import deepcopy | ||
import inspect | ||
import asyncio | ||
from typing import ( | ||
AsyncIterable, | ||
AsyncIterator, | ||
|
@@ -24,6 +25,7 @@ | |
pipeline_state, | ||
) | ||
from dlt.common.utils import flatten_list_or_items, get_callable_name, uniq_id | ||
from dlt.extract.utils import wrap_async_iterator | ||
|
||
from dlt.extract.typing import ( | ||
DataItemWithMeta, | ||
|
@@ -123,8 +125,6 @@ def from_data( | |
data = wrap_additional_type(data) | ||
|
||
# several iterable types are not allowed and must be excluded right away | ||
if isinstance(data, (AsyncIterator, AsyncIterable)): | ||
raise InvalidResourceDataTypeAsync(name, data, type(data)) | ||
if isinstance(data, (str, dict)): | ||
raise InvalidResourceDataTypeBasic(name, data, type(data)) | ||
|
||
|
@@ -135,7 +135,7 @@ def from_data( | |
parent_pipe = DltResource._get_parent_pipe(name, data_from) | ||
|
||
# create resource from iterator, iterable or generator function | ||
if isinstance(data, (Iterable, Iterator)) or callable(data): | ||
if isinstance(data, (Iterable, Iterator, AsyncIterable)) or callable(data): | ||
pipe = Pipe.from_data(name, data, parent=parent_pipe) | ||
return cls( | ||
pipe, | ||
|
@@ -306,16 +306,26 @@ def add_limit(self, max_items: int) -> "DltResource": # noqa: A003 | |
|
||
def _gen_wrap(gen: TPipeStep) -> TPipeStep: | ||
"""Wrap a generator to take the first `max_items` records""" | ||
nonlocal max_items | ||
count = 0 | ||
is_async_gen = False | ||
if inspect.isfunction(gen): | ||
gen = gen() | ||
|
||
# wrap async gen already here | ||
if isinstance(gen, AsyncIterator): | ||
gen = wrap_async_iterator(gen) | ||
is_async_gen = True | ||
|
||
try: | ||
for i in gen: # type: ignore # TODO: help me fix this later | ||
yield i | ||
count += 1 | ||
if count == max_items: | ||
return | ||
if i is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this line is needed for the async generator to work properly, it changes the behavior of the limit, but probably that is ok, i am not sure. |
||
count += 1 | ||
# async gen yields awaitable so we must count one awaitable more | ||
# so the previous one is evaluated and yielded. | ||
# new awaitable will be cancelled | ||
if count == max_items + int(is_async_gen): | ||
return | ||
finally: | ||
if inspect.isgenerator(gen): | ||
gen.close() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since it is annotated in the argument should we also remove the annotation here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch! but somehow mypy sees this as string here so I forced the type (or maybe it is VSCode language server)