From 5df748296c4e697de3622f21d5b397572b48a47d Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 5 Mar 2024 12:31:05 +0100 Subject: [PATCH] fix add_limit behavior in edge cases --- dlt/extract/incremental/__init__.py | 1 - dlt/extract/resource.py | 9 +++++++++ tests/extract/test_sources.py | 26 ++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index d1a5a05c34..24495ccb19 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -7,7 +7,6 @@ from functools import wraps - import dlt from dlt.common.exceptions import MissingDependencyException from dlt.common import pendulum, logger diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index e5b83b853b..a6ef9c03d9 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -313,8 +313,17 @@ def add_limit(self, max_items: int) -> "DltResource": # noqa: A003 "DltResource": returns self """ + # make sure max_items is a number, to allow "None" as value for unlimited + if max_items == None: + max_items = -1 + def _gen_wrap(gen: TPipeStep) -> TPipeStep: """Wrap a generator to take the first `max_items` records""" + + # zero items should produce empty generator + if max_items == 0: + return + count = 0 is_async_gen = False if inspect.isfunction(gen): diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index a94cf680fa..088b88a10e 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -2,6 +2,7 @@ from typing import Iterator import pytest +import asyncio import dlt from dlt.common.configuration.container import Container @@ -789,6 +790,31 @@ def test_limit_infinite_counter() -> None: assert list(r) == list(range(10)) +@pytest.mark.parametrize("limit", (None, -1, 0, 10)) +def test_limit_edge_cases(limit: int) -> None: + r = dlt.resource(range(20), name="infinity").add_limit(limit) + + @dlt.resource() + async def r_async(): + for i in range(20): + await asyncio.sleep(0.01) + yield i + + sync_list = list(r) + async_list = list(r_async().add_limit(limit)) + + # check the expected results + assert sync_list == async_list + if limit == 10: + assert sync_list == list(range(10)) + elif limit in [None, -1]: + assert sync_list == list(range(20)) + elif limit == 0: + assert sync_list == [] + else: + assert False + + def test_limit_source() -> None: def mul_c(item): yield from "A" * (item + 2)