From a44c9b4a0716970bf4b242d7252245bbef7eca88 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 25 Mar 2024 10:59:29 +0300 Subject: [PATCH] Add initial version of `rest_client.paginate()` --- dlt/sources/helpers/rest_client/__init__.py | 45 ++++++++++ dlt/sources/helpers/rest_client/client.py | 7 +- dlt/sources/helpers/rest_client/typing.py | 8 ++ dlt/sources/helpers/rest_client/utils.py | 16 +++- .../helpers/rest_client/test_client.py | 12 ++- .../sources/helpers/rest_client/test_utils.py | 90 +++++++++++++++++++ 6 files changed, 168 insertions(+), 10 deletions(-) create mode 100644 tests/sources/helpers/rest_client/test_utils.py diff --git a/dlt/sources/helpers/rest_client/__init__.py b/dlt/sources/helpers/rest_client/__init__.py index 3264ea4aae..fd5d558018 100644 --- a/dlt/sources/helpers/rest_client/__init__.py +++ b/dlt/sources/helpers/rest_client/__init__.py @@ -1 +1,46 @@ +from typing import Optional, Dict, Iterator, Union, Any + +from dlt.common import jsonpath + from .client import RESTClient # noqa: F401 +from .client import PageData +from .auth import AuthConfigBase +from .paginators import BasePaginator +from .typing import HTTPMethodBasic, Hooks + + +def paginate( + url: str, + method: HTTPMethodBasic = "GET", + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + auth: AuthConfigBase = None, + paginator: Union[str, BasePaginator] = None, + data_selector: Optional[jsonpath.TJsonPath] = None, + hooks: Optional[Hooks] = None, +) -> Iterator[PageData[Any]]: + """ + Paginate over a REST API endpoint. + + Args: + url: URL to paginate over. + **kwargs: Keyword arguments to pass to `RESTClient.paginate`. + + Returns: + Iterator[Page]: Iterator over pages. + """ + client = RESTClient( + base_url=url, + headers=headers, + ) + return client.paginate( + path="", + method=method, + params=params, + json=json, + auth=auth, + paginator=paginator, + data_selector=data_selector, + hooks=hooks, + ) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 12e22c072d..4b5625eebe 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -6,8 +6,6 @@ Any, TypeVar, Iterable, - Union, - Callable, cast, ) import copy @@ -20,7 +18,7 @@ from dlt.sources.helpers.requests.retry import Client from dlt.sources.helpers.requests import Response, Request -from .typing import HTTPMethodBasic, HTTPMethod +from .typing import HTTPMethodBasic, HTTPMethod, Hooks from .paginators import BasePaginator from .auth import AuthConfigBase from .detector import PaginatorFactory, find_records @@ -30,9 +28,6 @@ _T = TypeVar("_T") -HookFunction = Callable[[Response, Any, Any], None] -HookEvent = Union[HookFunction, List[HookFunction]] -Hooks = Dict[str, HookEvent] class PageData(List[_T]): diff --git a/dlt/sources/helpers/rest_client/typing.py b/dlt/sources/helpers/rest_client/typing.py index dad9842071..626aee4877 100644 --- a/dlt/sources/helpers/rest_client/typing.py +++ b/dlt/sources/helpers/rest_client/typing.py @@ -1,9 +1,17 @@ from typing import ( + List, + Dict, Union, Literal, + Callable, + Any, ) +from dlt.sources.helpers.requests import Response HTTPMethodBasic = Literal["GET", "POST"] HTTPMethodExtended = Literal["PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] HTTPMethod = Union[HTTPMethodBasic, HTTPMethodExtended] +HookFunction = Callable[[Response, Any, Any], None] +HookEvent = Union[HookFunction, List[HookFunction]] +Hooks = Dict[str, HookEvent] diff --git a/dlt/sources/helpers/rest_client/utils.py b/dlt/sources/helpers/rest_client/utils.py index 6001fad4ec..8732437a88 100644 --- a/dlt/sources/helpers/rest_client/utils.py +++ b/dlt/sources/helpers/rest_client/utils.py @@ -1,14 +1,24 @@ -from functools import reduce -from operator import getitem -from typing import Any, Sequence, Union, Tuple +from typing import Tuple from dlt.common import logger from dlt.extract.source import DltSource def join_url(base_url: str, path: str) -> str: + if base_url is None: + raise ValueError("Base URL must be provided or set to an empty string.") + + if base_url == "": + return path + + if path == "": + return base_url + + # Normalize the base URL + base_url = base_url.rstrip("/") if not base_url.endswith("/"): base_url += "/" + return base_url + path.lstrip("/") diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 9984dabb06..17d445042c 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -3,7 +3,7 @@ from typing import Any, cast from dlt.common.typing import TSecretStrValue from dlt.sources.helpers.requests import Response, Request -from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client import RESTClient, paginate from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator @@ -171,3 +171,13 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): ) self._assert_pagination(list(pages_iter)) + + def test_paginate_function(self, rest_client: RESTClient): + pages_iter = paginate( + "https://api.example.com/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + ) + + pages = list(pages_iter) + + self._assert_pagination(pages) diff --git a/tests/sources/helpers/rest_client/test_utils.py b/tests/sources/helpers/rest_client/test_utils.py new file mode 100644 index 0000000000..0de9729a42 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_utils.py @@ -0,0 +1,90 @@ +import pytest +from dlt.sources.helpers.rest_client.utils import join_url + + +@pytest.mark.parametrize( + "base_url, path, expected", + [ + # Normal cases + ( + "http://example.com", + "path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com/", + "/path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com/", + "path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com", + "//path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com///", + "//path/to/resource", + "http://example.com/path/to/resource", + ), + # Trailing and leading slashes + ("http://example.com/", "/", "http://example.com/"), + ("http://example.com", "/", "http://example.com/"), + ("http://example.com/", "///", "http://example.com/"), + ("http://example.com", "///", "http://example.com/"), + ("/", "path/to/resource", "/path/to/resource"), + ("/", "/path/to/resource", "/path/to/resource"), + # Empty strings + ("", "", ""), + ( + "", + "http://example.com/path/to/resource", + "http://example.com/path/to/resource", + ), + ("", "path/to/resource", "path/to/resource"), + ("http://example.com", "", "http://example.com"), + # Query parameters and fragments + ( + "http://example.com", + "path/to/resource?query=123", + "http://example.com/path/to/resource?query=123", + ), + ( + "http://example.com/", + "path/to/resource#fragment", + "http://example.com/path/to/resource#fragment", + ), + # Special characters in the path + ( + "http://example.com", + "/path/to/resource with spaces", + "http://example.com/path/to/resource with spaces", + ), + ("http://example.com", "/path/with/中文", "http://example.com/path/with/中文"), + # Protocols and subdomains + ("https://sub.example.com", "path", "https://sub.example.com/path"), + ("ftp://example.com", "/path", "ftp://example.com/path"), + # Missing protocol in base_url + ("example.com", "path", "example.com/path"), + ], +) +def test_join_url(base_url, path, expected): + assert join_url(base_url, path) == expected + + +@pytest.mark.parametrize( + "base_url, path, exception", + [ + (None, "path", ValueError), + ("http://example.com", None, AttributeError), + (123, "path", AttributeError), + ("http://example.com", 123, AttributeError), + ], +) +def test_join_url_invalid_input_types(base_url, path, exception): + with pytest.raises(exception): + join_url(base_url, path)