Skip to content

Commit

Permalink
fix add_limit behavior in edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Mar 5, 2024
1 parent 9410bc4 commit 5df7482
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
1 change: 0 additions & 1 deletion dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from functools import wraps



import dlt
from dlt.common.exceptions import MissingDependencyException
from dlt.common import pendulum, logger
Expand Down
9 changes: 9 additions & 0 deletions dlt/extract/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,17 @@ def add_limit(self, max_items: int) -> "DltResource": # noqa: A003
"DltResource": returns self
"""

# make sure max_items is a number, to allow "None" as value for unlimited
if max_items == None:
max_items = -1

def _gen_wrap(gen: TPipeStep) -> TPipeStep:
"""Wrap a generator to take the first `max_items` records"""

# zero items should produce empty generator
if max_items == 0:
return

count = 0
is_async_gen = False
if inspect.isfunction(gen):
Expand Down
26 changes: 26 additions & 0 deletions tests/extract/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Iterator

import pytest
import asyncio

import dlt
from dlt.common.configuration.container import Container
Expand Down Expand Up @@ -789,6 +790,31 @@ def test_limit_infinite_counter() -> None:
assert list(r) == list(range(10))


@pytest.mark.parametrize("limit", (None, -1, 0, 10))
def test_limit_edge_cases(limit: int) -> None:
r = dlt.resource(range(20), name="infinity").add_limit(limit)

@dlt.resource()
async def r_async():
for i in range(20):
await asyncio.sleep(0.01)
yield i

sync_list = list(r)
async_list = list(r_async().add_limit(limit))

# check the expected results
assert sync_list == async_list
if limit == 10:
assert sync_list == list(range(10))
elif limit in [None, -1]:
assert sync_list == list(range(20))
elif limit == 0:
assert sync_list == []
else:
assert False


def test_limit_source() -> None:
def mul_c(item):
yield from "A" * (item + 2)
Expand Down

0 comments on commit 5df7482

Please sign in to comment.