Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 24, 2024
1 parent e4ca5c3 commit b8396a6
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 40 deletions.
54 changes: 32 additions & 22 deletions dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class SourcePipeItem(NamedTuple):
Callable[[TDataItems], Iterator[ResolvablePipeItem]],
]

TPipeNextItemMode = Union[Literal["fifo"], Literal["round_robin"]]
TPipeNextItemMode = Union[Literal["auto"], Literal["fifo"], Literal["round_robin"]]


class ForkPipe:
Expand Down Expand Up @@ -144,6 +144,7 @@ def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = No
self._gen_idx = 0
self._steps: List[TPipeStep] = []
self.parent = parent
self.generates_awaitables = False
# add the steps, this will check and mod transformations
if steps:
for step in steps:
Expand Down Expand Up @@ -325,6 +326,7 @@ def evaluate_gen(self) -> None:
# wrap async generator
if inspect.isasyncgen(self.gen):
self.replace_gen(wrap_async_generator(self.gen))
self.generates_awaitables = True

# evaluate transforms
for step_no, step in enumerate(self._steps):
Expand Down Expand Up @@ -495,7 +497,7 @@ class PipeIteratorConfiguration(BaseConfiguration):
workers: int = 5
futures_poll_interval: float = 0.01
copy_on_fork: bool = False
next_item_mode: str = "fifo"
next_item_mode: str = "auto"

__section__ = "extract"

Expand All @@ -504,6 +506,7 @@ def __init__(
max_parallel_items: int,
workers: int,
futures_poll_interval: float,
sources: List[SourcePipeItem],
next_item_mode: TPipeNextItemMode,
) -> None:
self.max_parallel_items = max_parallel_items
Expand All @@ -515,8 +518,20 @@ def __init__(
self._async_pool: asyncio.AbstractEventLoop = None
self._async_pool_thread: Thread = None
self._thread_pool: ThreadPoolExecutor = None
self._sources: List[SourcePipeItem] = []
self._sources = sources
self._initial_sources_count = len(sources)
self._futures: List[FuturePipeItem] = []

# evaluate next item mode, switch to round robin if we have any async generators
if next_item_mode == "auto":
next_item_mode = (
"round_robin" if any(s.pipe.generates_awaitables for s in self._sources) else "fifo"
)

# we process fifo backwards
if next_item_mode == "fifo":
self._sources.reverse()

self._next_item_mode = next_item_mode

@classmethod
Expand All @@ -528,7 +543,7 @@ def from_pipe(
max_parallel_items: int = 20,
workers: int = 5,
futures_poll_interval: float = 0.01,
next_item_mode: TPipeNextItemMode = "fifo",
next_item_mode: TPipeNextItemMode = "auto",
) -> "PipeIterator":
# join all dependent pipes
if pipe.parent:
Expand All @@ -539,12 +554,10 @@ def from_pipe(
pipe.evaluate_gen()
if not isinstance(pipe.gen, Iterator):
raise PipeGenInvalid(pipe.name, pipe.gen)

# create extractor
extract = cls(max_parallel_items, workers, futures_poll_interval, next_item_mode)
# add as first source
extract._sources.append(SourcePipeItem(pipe.gen, 0, pipe, None))
cls._initial_sources_count = 1
return extract
sources = [SourcePipeItem(pipe.gen, 0, pipe, None)]
return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode)

@classmethod
@with_config(spec=PipeIteratorConfiguration)
Expand All @@ -557,10 +570,11 @@ def from_pipes(
workers: int = 5,
futures_poll_interval: float = 0.01,
copy_on_fork: bool = False,
next_item_mode: TPipeNextItemMode = "fifo",
next_item_mode: TPipeNextItemMode = "auto",
) -> "PipeIterator":
# print(f"max_parallel_items: {max_parallel_items} workers: {workers}")
extract = cls(max_parallel_items, workers, futures_poll_interval, next_item_mode)
sources: List[SourcePipeItem] = []

# clone all pipes before iterating (recursively) as we will fork them (this add steps) and evaluate gens
pipes, _ = PipeIterator.clone_pipes(pipes)

Expand All @@ -580,18 +594,14 @@ def _fork_pipeline(pipe: Pipe) -> None:
if not isinstance(pipe.gen, Iterator):
raise PipeGenInvalid(pipe.name, pipe.gen)
# add every head as source only once
if not any(i.pipe == pipe for i in extract._sources):
extract._sources.append(SourcePipeItem(pipe.gen, 0, pipe, None))
if not any(i.pipe == pipe for i in sources):
sources.append(SourcePipeItem(pipe.gen, 0, pipe, None))

# reverse pipes for current mode, as we start processing from the back
if next_item_mode == "fifo":
pipes.reverse()
for pipe in pipes:
_fork_pipeline(pipe)

extract._initial_sources_count = len(extract._sources)

return extract
# create extractor
return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode)

def __next__(self) -> PipeItem:
pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None
Expand Down Expand Up @@ -805,7 +815,7 @@ def _resolve_futures(self) -> ResolvablePipeItem:

def _get_source_item(self) -> ResolvablePipeItem:
if self._next_item_mode == "fifo":
return self._get_source_item_current()
return self._get_source_item_fifo()
elif self._next_item_mode == "round_robin":
return self._get_source_item_round_robin()

Expand All @@ -830,7 +840,7 @@ def _get_next_item_from_generator(
else:
return ResolvablePipeItem(item, step, pipe, meta)

def _get_source_item_current(self) -> ResolvablePipeItem:
def _get_source_item_fifo(self) -> ResolvablePipeItem:
# no more sources to iterate
if len(self._sources) == 0:
return None
Expand Down Expand Up @@ -860,7 +870,7 @@ def _get_source_item_round_robin(self) -> ResolvablePipeItem:
return None
# if there are currently more sources than added initially, we need to process the new ones first
if sources_count > self._initial_sources_count:
return self._get_source_item_current()
return self._get_source_item_fifo()
try:
# print(f"got {pipe.name}")
# register current pipe name during the execution of gen
Expand Down
45 changes: 29 additions & 16 deletions dlt/extract/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import inspect
import makefun
from typing import Optional, Tuple, Union, List, Any, Sequence, cast, Iterator
import asyncio
from typing import (
Optional,
Tuple,
Union,
List,
Any,
Sequence,
cast,
AsyncGenerator,
Awaitable,
Generator,
)
from collections.abc import Mapping as C_Mapping

from dlt.common.exceptions import MissingDependencyException
Expand Down Expand Up @@ -119,26 +131,27 @@ def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> in
return meta_arg


def wrap_async_generator(gen: Any) -> Any:
"""Wraps an async generator into a list of awaitables"""
is_running = False
def wrap_async_generator(
gen: AsyncGenerator[TDataItems, None]
) -> Generator[Awaitable[TDataItems], None, None]:
"""Wraps an async generatqor into a list of awaitables"""
exhausted = False
lock = asyncio.Lock()

# creates an awaitable that will return the next item from the async generator
async def run() -> TDataItems:
nonlocal is_running, exhausted
try:
return await gen.__anext__()
except StopAsyncIteration:
exhausted = True
raise
finally:
is_running = False

# it is best to use the round robin strategy here if multiple async generators are used in resources
async with lock:
try:
return await gen.__anext__()
except StopAsyncIteration:
nonlocal exhausted
exhausted = True
raise

# this generator yields None while the async generator is not exhauste
while not exhausted:
while is_running:
while lock.locked():
yield None
is_running = True
yield run()


Expand Down
5 changes: 3 additions & 2 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,7 +1716,7 @@ async def async_transformer(item):
assert {r[0] for r in rows} == {"at", "bt", "ct"}


@pytest.mark.parametrize("next_item_mode", ["fifo", "round_robin"])
@pytest.mark.parametrize("next_item_mode", ["auto", "fifo", "round_robin"])
def test_parallel_async_generators(next_item_mode: str) -> None:
os.environ["EXTRACT__NEXT_ITEM_MODE"] = next_item_mode
execution_order = []
Expand Down Expand Up @@ -1756,8 +1756,9 @@ def source():
assert len(rows) == 3
assert {r[0] for r in rows} == {"e", "f", "g"}

# auto mode will switch to round robin if we have awaitables
assert (
execution_order == ["one", "two", "one", "two", "one", "two"]
if next_item_mode == "round_robin"
if next_item_mode in ["auto", "round_robin"]
else ["one", "one", "one", "two", "two", "two"]
)

0 comments on commit b8396a6

Please sign in to comment.