diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 5edc52f836..1deea591a5 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -109,7 +109,7 @@ class SourcePipeItem(NamedTuple): Callable[[TDataItems], Iterator[ResolvablePipeItem]], ] -TPipeNextItemMode = Union[Literal["fifo"], Literal["round_robin"]] +TPipeNextItemMode = Union[Literal["auto"], Literal["fifo"], Literal["round_robin"]] class ForkPipe: @@ -144,6 +144,7 @@ def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = No self._gen_idx = 0 self._steps: List[TPipeStep] = [] self.parent = parent + self.generates_awaitables = False # add the steps, this will check and mod transformations if steps: for step in steps: @@ -325,6 +326,7 @@ def evaluate_gen(self) -> None: # wrap async generator if inspect.isasyncgen(self.gen): self.replace_gen(wrap_async_generator(self.gen)) + self.generates_awaitables = True # evaluate transforms for step_no, step in enumerate(self._steps): @@ -495,7 +497,7 @@ class PipeIteratorConfiguration(BaseConfiguration): workers: int = 5 futures_poll_interval: float = 0.01 copy_on_fork: bool = False - next_item_mode: str = "fifo" + next_item_mode: str = "auto" __section__ = "extract" @@ -504,6 +506,7 @@ def __init__( max_parallel_items: int, workers: int, futures_poll_interval: float, + sources: List[SourcePipeItem], next_item_mode: TPipeNextItemMode, ) -> None: self.max_parallel_items = max_parallel_items @@ -515,8 +518,20 @@ def __init__( self._async_pool: asyncio.AbstractEventLoop = None self._async_pool_thread: Thread = None self._thread_pool: ThreadPoolExecutor = None - self._sources: List[SourcePipeItem] = [] + self._sources = sources + self._initial_sources_count = len(sources) self._futures: List[FuturePipeItem] = [] + + # evaluate next item mode, switch to round robin if we have any async generators + if next_item_mode == "auto": + next_item_mode = ( + "round_robin" if any(s.pipe.generates_awaitables for s in self._sources) else "fifo" + ) + + # we process fifo backwards + if next_item_mode == "fifo": + self._sources.reverse() + self._next_item_mode = next_item_mode @classmethod @@ -528,7 +543,7 @@ def from_pipe( max_parallel_items: int = 20, workers: int = 5, futures_poll_interval: float = 0.01, - next_item_mode: TPipeNextItemMode = "fifo", + next_item_mode: TPipeNextItemMode = "auto", ) -> "PipeIterator": # join all dependent pipes if pipe.parent: @@ -539,12 +554,10 @@ def from_pipe( pipe.evaluate_gen() if not isinstance(pipe.gen, Iterator): raise PipeGenInvalid(pipe.name, pipe.gen) + # create extractor - extract = cls(max_parallel_items, workers, futures_poll_interval, next_item_mode) - # add as first source - extract._sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) - cls._initial_sources_count = 1 - return extract + sources = [SourcePipeItem(pipe.gen, 0, pipe, None)] + return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) @classmethod @with_config(spec=PipeIteratorConfiguration) @@ -557,10 +570,11 @@ def from_pipes( workers: int = 5, futures_poll_interval: float = 0.01, copy_on_fork: bool = False, - next_item_mode: TPipeNextItemMode = "fifo", + next_item_mode: TPipeNextItemMode = "auto", ) -> "PipeIterator": # print(f"max_parallel_items: {max_parallel_items} workers: {workers}") - extract = cls(max_parallel_items, workers, futures_poll_interval, next_item_mode) + sources: List[SourcePipeItem] = [] + # clone all pipes before iterating (recursively) as we will fork them (this add steps) and evaluate gens pipes, _ = PipeIterator.clone_pipes(pipes) @@ -580,18 +594,14 @@ def _fork_pipeline(pipe: Pipe) -> None: if not isinstance(pipe.gen, Iterator): raise PipeGenInvalid(pipe.name, pipe.gen) # add every head as source only once - if not any(i.pipe == pipe for i in extract._sources): - extract._sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) + if not any(i.pipe == pipe for i in sources): + sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) - # reverse pipes for current mode, as we start processing from the back - if next_item_mode == "fifo": - pipes.reverse() for pipe in pipes: _fork_pipeline(pipe) - extract._initial_sources_count = len(extract._sources) - - return extract + # create extractor + return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) def __next__(self) -> PipeItem: pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None @@ -805,7 +815,7 @@ def _resolve_futures(self) -> ResolvablePipeItem: def _get_source_item(self) -> ResolvablePipeItem: if self._next_item_mode == "fifo": - return self._get_source_item_current() + return self._get_source_item_fifo() elif self._next_item_mode == "round_robin": return self._get_source_item_round_robin() @@ -830,7 +840,7 @@ def _get_next_item_from_generator( else: return ResolvablePipeItem(item, step, pipe, meta) - def _get_source_item_current(self) -> ResolvablePipeItem: + def _get_source_item_fifo(self) -> ResolvablePipeItem: # no more sources to iterate if len(self._sources) == 0: return None @@ -860,7 +870,7 @@ def _get_source_item_round_robin(self) -> ResolvablePipeItem: return None # if there are currently more sources than added initially, we need to process the new ones first if sources_count > self._initial_sources_count: - return self._get_source_item_current() + return self._get_source_item_fifo() try: # print(f"got {pipe.name}") # register current pipe name during the execution of gen diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index 5720ee239e..b1b0180d3c 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -1,6 +1,18 @@ import inspect import makefun -from typing import Optional, Tuple, Union, List, Any, Sequence, cast, Iterator +import asyncio +from typing import ( + Optional, + Tuple, + Union, + List, + Any, + Sequence, + cast, + AsyncGenerator, + Awaitable, + Generator, +) from collections.abc import Mapping as C_Mapping from dlt.common.exceptions import MissingDependencyException @@ -119,26 +131,27 @@ def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> in return meta_arg -def wrap_async_generator(gen: Any) -> Any: - """Wraps an async generator into a list of awaitables""" - is_running = False +def wrap_async_generator( + gen: AsyncGenerator[TDataItems, None] +) -> Generator[Awaitable[TDataItems], None, None]: + """Wraps an async generatqor into a list of awaitables""" exhausted = False + lock = asyncio.Lock() + # creates an awaitable that will return the next item from the async generator async def run() -> TDataItems: - nonlocal is_running, exhausted - try: - return await gen.__anext__() - except StopAsyncIteration: - exhausted = True - raise - finally: - is_running = False - - # it is best to use the round robin strategy here if multiple async generators are used in resources + async with lock: + try: + return await gen.__anext__() + except StopAsyncIteration: + nonlocal exhausted + exhausted = True + raise + + # this generator yields None while the async generator is not exhauste while not exhausted: - while is_running: + while lock.locked(): yield None - is_running = True yield run() diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index ce3ef1b8d4..aecad7e037 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1716,7 +1716,7 @@ async def async_transformer(item): assert {r[0] for r in rows} == {"at", "bt", "ct"} -@pytest.mark.parametrize("next_item_mode", ["fifo", "round_robin"]) +@pytest.mark.parametrize("next_item_mode", ["auto", "fifo", "round_robin"]) def test_parallel_async_generators(next_item_mode: str) -> None: os.environ["EXTRACT__NEXT_ITEM_MODE"] = next_item_mode execution_order = [] @@ -1756,8 +1756,9 @@ def source(): assert len(rows) == 3 assert {r[0] for r in rows} == {"e", "f", "g"} + # auto mode will switch to round robin if we have awaitables assert ( execution_order == ["one", "two", "one", "two", "one", "two"] - if next_item_mode == "round_robin" + if next_item_mode in ["auto", "round_robin"] else ["one", "one", "one", "two", "two", "two"] )