diff --git a/dlt/extract/items.py b/dlt/extract/items.py index 888787e6b7..21738abcef 100644 --- a/dlt/extract/items.py +++ b/dlt/extract/items.py @@ -238,3 +238,23 @@ class ValidateItem(ItemTransform[TDataItem]): def bind(self, pipe: SupportsPipe) -> ItemTransform[TDataItem]: self.table_name = pipe.name return self + + +class LimitItem(ItemTransform[TDataItem]): + placement_affinity: ClassVar[float] = 1.1 # stick to end right behind incremental + + def __init__(self, max_items: int) -> None: + self.max_items = max_items if max_items is not None else -1 + + def bind(self, pipe: SupportsPipe) -> "LimitItem": + self.gen = pipe.gen + self.count = 0 + return self + + def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: + if self.count == self.max_items: + if inspect.isgenerator(self.gen): + self.gen.close() + return None + self.count += 1 + return item diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 02b52c4623..0e24de9558 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -39,6 +39,7 @@ wrap_compat_transformer, wrap_resource_gen, wrap_async_iterator, + wrap_iterator, ) @@ -279,6 +280,10 @@ def evaluate_gen(self) -> None: if isinstance(self.gen, AsyncIterator): self.replace_gen(wrap_async_iterator(self.gen)) + # we also wrap iterators to make them stoppable + if isinstance(self.gen, Iterator): + self.replace_gen(wrap_iterator(self.gen)) + # evaluate transforms for step_no, step in enumerate(self._steps): # print(f"pipe {self.name} step no {step_no} step({step})") diff --git a/dlt/extract/pipe_iterator.py b/dlt/extract/pipe_iterator.py index 465040f9f4..5fa62ffa63 100644 --- a/dlt/extract/pipe_iterator.py +++ b/dlt/extract/pipe_iterator.py @@ -24,7 +24,11 @@ ) from dlt.common.configuration.container import Container from dlt.common.exceptions import PipelineException -from dlt.common.pipeline import unset_current_pipe_name, set_current_pipe_name +from dlt.common.pipeline import ( + unset_current_pipe_name, + set_current_pipe_name, + get_current_pipe_name, +) from dlt.common.utils import get_callable_name from dlt.extract.exceptions import ( @@ -38,7 +42,7 @@ ) from dlt.extract.pipe import Pipe from dlt.extract.items import DataItemWithMeta, PipeItem, ResolvablePipeItem, SourcePipeItem -from dlt.extract.utils import wrap_async_iterator +from dlt.extract.utils import wrap_async_iterator, wrap_iterator from dlt.extract.concurrency import FuturesPool TPipeNextItemMode = Literal["fifo", "round_robin"] @@ -179,10 +183,12 @@ def __next__(self) -> PipeItem: item = pipe_item.item # if item is iterator, then add it as a new source + # we wrap it to make it stoppable if isinstance(item, Iterator): - # print(f"adding iterable {item}") self._sources.append( - SourcePipeItem(item, pipe_item.step, pipe_item.pipe, pipe_item.meta) + SourcePipeItem( + wrap_iterator(item), pipe_item.step, pipe_item.pipe, pipe_item.meta + ) ) pipe_item = None continue @@ -291,7 +297,6 @@ def _get_source_item(self) -> ResolvablePipeItem: first_evaluated_index = self._current_source_index # always go round robin if None was returned or item is to be run as future self._current_source_index = (self._current_source_index - 1) % sources_count - except StopIteration: # remove empty iterator and try another source self._sources.pop(self._current_source_index) diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 42e3905162..689e0a91f8 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -41,6 +41,7 @@ MapItem, YieldMapItem, ValidateItem, + LimitItem, ) from dlt.extract.pipe_iterator import ManagedPipeIterator from dlt.extract.pipe import Pipe, TPipeStep @@ -363,56 +364,14 @@ def add_limit(self: TDltResourceImpl, max_items: int) -> TDltResourceImpl: # no "DltResource": returns self """ - # make sure max_items is a number, to allow "None" as value for unlimited - if max_items is None: - max_items = -1 - - def _gen_wrap(gen: TPipeStep) -> TPipeStep: - """Wrap a generator to take the first `max_items` records""" - - # zero items should produce empty generator - if max_items == 0: - return - - count = 0 - is_async_gen = False - if callable(gen): - gen = gen() # type: ignore - - # wrap async gen already here - if isinstance(gen, AsyncIterator): - gen = wrap_async_iterator(gen) - is_async_gen = True - - try: - for i in gen: # type: ignore # TODO: help me fix this later - yield i - if i is not None: - count += 1 - # async gen yields awaitable so we must count one awaitable more - # so the previous one is evaluated and yielded. - # new awaitable will be cancelled - if count == max_items + int(is_async_gen): - return - finally: - if inspect.isgenerator(gen): - gen.close() - return - - # transformers should be limited by their input, so we only limit non-transformers - if not self.is_transformer: - gen = self._pipe.gen - # wrap gen directly - if inspect.isgenerator(gen): - self._pipe.replace_gen(_gen_wrap(gen)) - else: - # keep function as function to not evaluate generators before pipe starts - self._pipe.replace_gen(partial(_gen_wrap, gen)) - else: + if self.is_transformer: logger.warning( f"Setting add_limit to a transformer {self.name} has no effect. Set the limit on" " the top level resource." ) + else: + self.add_step(LimitItem(max_items)) + return self def parallelize(self: TDltResourceImpl) -> TDltResourceImpl: diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index 68570d0995..0bcd13155e 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -183,6 +183,17 @@ def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> in return meta_arg +def wrap_iterator(gen: Iterator[TDataItems]) -> Iterator[TDataItems]: + """Wraps an iterator into a generator""" + if inspect.isgenerator(gen): + return gen + + def wrapped_gen() -> Iterator[TDataItems]: + yield from gen + + return wrapped_gen() + + def wrap_async_iterator( gen: AsyncIterator[TDataItems], ) -> Generator[Awaitable[TDataItems], None, None]: