Skip to content

Commit

Permalink
Add initial version of rest_client.paginate()
Browse files Browse the repository at this point in the history
  • Loading branch information
burnash committed Mar 25, 2024
1 parent 7377c23 commit a44c9b4
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 10 deletions.
45 changes: 45 additions & 0 deletions dlt/sources/helpers/rest_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,46 @@
from typing import Optional, Dict, Iterator, Union, Any

from dlt.common import jsonpath

from .client import RESTClient # noqa: F401
from .client import PageData
from .auth import AuthConfigBase
from .paginators import BasePaginator
from .typing import HTTPMethodBasic, Hooks


def paginate(
url: str,
method: HTTPMethodBasic = "GET",
headers: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: AuthConfigBase = None,
paginator: Union[str, BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
hooks: Optional[Hooks] = None,
) -> Iterator[PageData[Any]]:
"""
Paginate over a REST API endpoint.
Args:
url: URL to paginate over.
**kwargs: Keyword arguments to pass to `RESTClient.paginate`.
Returns:
Iterator[Page]: Iterator over pages.
"""
client = RESTClient(
base_url=url,
headers=headers,
)
return client.paginate(
path="",
method=method,
params=params,
json=json,
auth=auth,
paginator=paginator,
data_selector=data_selector,
hooks=hooks,
)
7 changes: 1 addition & 6 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
Any,
TypeVar,
Iterable,
Union,
Callable,
cast,
)
import copy
Expand All @@ -20,7 +18,7 @@
from dlt.sources.helpers.requests.retry import Client
from dlt.sources.helpers.requests import Response, Request

from .typing import HTTPMethodBasic, HTTPMethod
from .typing import HTTPMethodBasic, HTTPMethod, Hooks
from .paginators import BasePaginator
from .auth import AuthConfigBase
from .detector import PaginatorFactory, find_records
Expand All @@ -30,9 +28,6 @@


_T = TypeVar("_T")
HookFunction = Callable[[Response, Any, Any], None]
HookEvent = Union[HookFunction, List[HookFunction]]
Hooks = Dict[str, HookEvent]


class PageData(List[_T]):
Expand Down
8 changes: 8 additions & 0 deletions dlt/sources/helpers/rest_client/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from typing import (
List,
Dict,
Union,
Literal,
Callable,
Any,
)
from dlt.sources.helpers.requests import Response


HTTPMethodBasic = Literal["GET", "POST"]
HTTPMethodExtended = Literal["PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
HTTPMethod = Union[HTTPMethodBasic, HTTPMethodExtended]
HookFunction = Callable[[Response, Any, Any], None]
HookEvent = Union[HookFunction, List[HookFunction]]
Hooks = Dict[str, HookEvent]
16 changes: 13 additions & 3 deletions dlt/sources/helpers/rest_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from functools import reduce
from operator import getitem
from typing import Any, Sequence, Union, Tuple
from typing import Tuple

from dlt.common import logger
from dlt.extract.source import DltSource


def join_url(base_url: str, path: str) -> str:
if base_url is None:
raise ValueError("Base URL must be provided or set to an empty string.")

if base_url == "":
return path

if path == "":
return base_url

# Normalize the base URL
base_url = base_url.rstrip("/")
if not base_url.endswith("/"):
base_url += "/"

return base_url + path.lstrip("/")


Expand Down
12 changes: 11 additions & 1 deletion tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, cast
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client import RESTClient, paginate
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator

Expand Down Expand Up @@ -171,3 +171,13 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
)

self._assert_pagination(list(pages_iter))

def test_paginate_function(self, rest_client: RESTClient):
pages_iter = paginate(
"https://api.example.com/posts",
paginator=JSONResponsePaginator(next_url_path="next_page"),
)

pages = list(pages_iter)

self._assert_pagination(pages)
90 changes: 90 additions & 0 deletions tests/sources/helpers/rest_client/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
from dlt.sources.helpers.rest_client.utils import join_url


@pytest.mark.parametrize(
"base_url, path, expected",
[
# Normal cases
(
"http://example.com",
"path/to/resource",
"http://example.com/path/to/resource",
),
(
"http://example.com/",
"/path/to/resource",
"http://example.com/path/to/resource",
),
(
"http://example.com/",
"path/to/resource",
"http://example.com/path/to/resource",
),
(
"http://example.com",
"//path/to/resource",
"http://example.com/path/to/resource",
),
(
"http://example.com///",
"//path/to/resource",
"http://example.com/path/to/resource",
),
# Trailing and leading slashes
("http://example.com/", "/", "http://example.com/"),
("http://example.com", "/", "http://example.com/"),
("http://example.com/", "///", "http://example.com/"),
("http://example.com", "///", "http://example.com/"),
("/", "path/to/resource", "/path/to/resource"),
("/", "/path/to/resource", "/path/to/resource"),
# Empty strings
("", "", ""),
(
"",
"http://example.com/path/to/resource",
"http://example.com/path/to/resource",
),
("", "path/to/resource", "path/to/resource"),
("http://example.com", "", "http://example.com"),
# Query parameters and fragments
(
"http://example.com",
"path/to/resource?query=123",
"http://example.com/path/to/resource?query=123",
),
(
"http://example.com/",
"path/to/resource#fragment",
"http://example.com/path/to/resource#fragment",
),
# Special characters in the path
(
"http://example.com",
"/path/to/resource with spaces",
"http://example.com/path/to/resource with spaces",
),
("http://example.com", "/path/with/中文", "http://example.com/path/with/中文"),
# Protocols and subdomains
("https://sub.example.com", "path", "https://sub.example.com/path"),
("ftp://example.com", "/path", "ftp://example.com/path"),
# Missing protocol in base_url
("example.com", "path", "example.com/path"),
],
)
def test_join_url(base_url, path, expected):
assert join_url(base_url, path) == expected


@pytest.mark.parametrize(
"base_url, path, exception",
[
(None, "path", ValueError),
("http://example.com", None, AttributeError),
(123, "path", AttributeError),
("http://example.com", 123, AttributeError),
],
)
def test_join_url_invalid_input_types(base_url, path, exception):
with pytest.raises(exception):
join_url(base_url, path)

0 comments on commit a44c9b4

Please sign in to comment.