Skip to content

Commit

Permalink
RESTClient: add support for relative next URLs in LinkPaginators (#1163)
Browse files Browse the repository at this point in the history
* Extend `mock_api_server()` to support relative next urls
* Enhance BaseNextUrlPaginator to support relative next URLs in pagination
burnash authored Apr 2, 2024
1 parent ea22515 commit ecb5aa0
Showing 4 changed files with 209 additions and 42 deletions.
7 changes: 7 additions & 0 deletions dlt/sources/helpers/rest_client/paginators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from urllib.parse import urlparse, urljoin

from dlt.sources.helpers.requests import Response, Request
from dlt.common import jsonpath
@@ -102,6 +103,12 @@ def update_request(self, request: Request) -> None:

class BaseNextUrlPaginator(BasePaginator):
def update_request(self, request: Request) -> None:
# Handle relative URLs
if self.next_reference:
parsed_url = urlparse(self.next_reference)
if not parsed_url.scheme:
self.next_reference = urljoin(request.url, self.next_reference)

request.url = self.next_reference


88 changes: 60 additions & 28 deletions tests/sources/helpers/rest_client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import NamedTuple, Callable, Pattern, List, TYPE_CHECKING
from typing import NamedTuple, Callable, Pattern, List, Union, TYPE_CHECKING
import base64

from urllib.parse import urlsplit, urlunsplit
@@ -10,17 +10,21 @@
from dlt.common import json

if TYPE_CHECKING:
RequestCallback = Callable[[requests_mock.Request, requests_mock.Context], str]
RequestCallback = Callable[
[requests_mock.Request, requests_mock.Context], Union[str, dict, list]
]
ResponseSerializer = Callable[[requests_mock.Request, requests_mock.Context], str]
else:
RequestCallback = Callable
ResponseSerializer = Callable

MOCK_BASE_URL = "https://api.example.com"


class Route(NamedTuple):
method: str
pattern: Pattern[str]
callback: RequestCallback
callback: ResponseSerializer


class APIRouter:
@@ -32,8 +36,17 @@ def _add_route(
self, method: str, pattern: str, func: RequestCallback
) -> RequestCallback:
compiled_pattern = re.compile(f"{self.base_url}{pattern}")
self.routes.append(Route(method, compiled_pattern, func))
return func

def serialize_response(request, context):
result = func(request, context)

if isinstance(result, dict) or isinstance(result, list):
return json.dumps(result)

return result

self.routes.append(Route(method, compiled_pattern, serialize_response))
return serialize_response

def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]:
def decorator(func: RequestCallback) -> RequestCallback:
@@ -59,9 +72,17 @@ def register_routes(self, mocker: requests_mock.Mocker) -> None:
router = APIRouter(MOCK_BASE_URL)


def serialize_page(records, page_number, total_pages, base_url, records_key="data"):
def serialize_page(
records,
page_number,
total_pages,
request_url,
records_key="data",
use_absolute_url=True,
):
"""Serialize a page of records into a dict with pagination metadata."""
if records_key is None:
return json.dumps(records)
return records

response = {
records_key: records,
@@ -72,11 +93,15 @@ def serialize_page(records, page_number, total_pages, base_url, records_key="dat
if page_number < total_pages:
next_page = page_number + 1

scheme, netloc, path, _, _ = urlsplit(base_url)
next_page = urlunsplit([scheme, netloc, path, f"page={next_page}", ""])
response["next_page"] = next_page
scheme, netloc, path, _, _ = urlsplit(request_url)
if use_absolute_url:
next_page_url = urlunsplit([scheme, netloc, path, f"page={next_page}", ""])
else:
next_page_url = f"{path}?page={next_page}"

return json.dumps(response)
response["next_page"] = next_page_url

return response


def generate_posts(count=100):
@@ -91,15 +116,22 @@ def get_page_number(qs, key="page", default=1):
return int(qs.get(key, [default])[0])


def paginate_response(request, records, page_size=10, records_key="data"):
def paginate_response(
request, records, page_size=10, records_key="data", use_absolute_url=True
):
page_number = get_page_number(request.qs)
total_records = len(records)
total_pages = (total_records + page_size - 1) // page_size
start_index = (page_number - 1) * 10
end_index = start_index + 10
records_slice = records[start_index:end_index]
return serialize_page(
records_slice, page_number, total_pages, request.url, records_key
records_slice,
page_number,
total_pages,
request.url,
records_key,
use_absolute_url,
)


@@ -115,6 +147,10 @@ def posts_no_key(request, context):
def posts(request, context):
return paginate_response(request, generate_posts())

@router.get(r"/posts_relative_next_url(\?page=\d+)?$")
def posts_relative_next_url(request, context):
return paginate_response(request, generate_posts(), use_absolute_url=False)

@router.get(r"/posts/(\d+)/comments")
def post_comments(request, context):
post_id = int(request.url.split("/")[-2])
@@ -123,17 +159,17 @@ def post_comments(request, context):
@router.get(r"/posts/\d+$")
def post_detail(request, context):
post_id = request.url.split("/")[-1]
return json.dumps({"id": post_id, "body": f"Post body {post_id}"})
return {"id": post_id, "body": f"Post body {post_id}"}

@router.get(r"/posts/\d+/some_details_404")
def post_detail_404(request, context):
"""Return 404 for post with id > 0. Used to test ignoring 404 errors."""
post_id = int(request.url.split("/")[-2])
if post_id < 1:
return json.dumps({"id": post_id, "body": f"Post body {post_id}"})
return {"id": post_id, "body": f"Post body {post_id}"}
else:
context.status_code = 404
return json.dumps({"error": "Post not found"})
return {"error": "Post not found"}

@router.get(r"/posts_under_a_different_key$")
def posts_with_results_key(request, context):
@@ -149,15 +185,15 @@ def protected_basic_auth(request, context):
if auth == f"Basic {creds_base64}":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.get("/protected/posts/bearer-token")
def protected_bearer_token(request, context):
auth = request.headers.get("Authorization")
if auth == "Bearer test-token":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.get("/protected/posts/bearer-token-plain-text-error")
def protected_bearer_token_plain_text_erorr(request, context):
@@ -173,31 +209,27 @@ def protected_api_key(request, context):
if api_key == "test-api-key":
return paginate_response(request, generate_posts())
context.status_code = 401
return json.dumps({"error": "Unauthorized"})
return {"error": "Unauthorized"}

@router.post("/oauth/token")
def oauth_token(request, context):
return json.dumps(
{
"access_token": "test-token",
"expires_in": 3600,
}
)
return {"access_token": "test-token", "expires_in": 3600}

@router.post("/auth/refresh")
def refresh_token(request, context):
body = request.json()
if body.get("refresh_token") == "valid-refresh-token":
return json.dumps({"access_token": "new-valid-token"})
return {"access_token": "new-valid-token"}
context.status_code = 401
return json.dumps({"error": "Invalid refresh token"})
return {"error": "Invalid refresh token"}

router.register_routes(m)

yield m


def assert_pagination(pages, expected_start=0, page_size=10):
def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10):
assert len(pages) == total_pages
for i, page in enumerate(pages):
assert page == [
{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)
17 changes: 17 additions & 0 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -74,6 +74,23 @@ def test_default_paginator(self, rest_client: RESTClient):

assert_pagination(pages)

def test_excplicit_paginator(self, rest_client: RESTClient):
pages_iter = rest_client.paginate(
"/posts", paginator=JSONResponsePaginator(next_url_path="next_page")
)
pages = list(pages_iter)

assert_pagination(pages)

def test_excplicit_paginator_relative_next_url(self, rest_client: RESTClient):
pages_iter = rest_client.paginate(
"/posts_relative_next_url",
paginator=JSONResponsePaginator(next_url_path="next_page"),
)
pages = list(pages_iter)

assert_pagination(pages)

def test_paginate_with_hooks(self, rest_client: RESTClient):
def response_hook(response: Response, *args: Any, **kwargs: Any) -> None:
if response.status_code == 404:
139 changes: 125 additions & 14 deletions tests/sources/helpers/rest_client/test_paginators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from unittest.mock import Mock

from requests.models import Response
import pytest

from requests.models import Response, Request

from dlt.sources.helpers.rest_client.paginators import (
SinglePagePaginator,
@@ -29,21 +30,131 @@ def test_update_state_without_next(self):


class TestJSONResponsePaginator:
def test_update_state_with_next(self):
paginator = JSONResponsePaginator()
response = Mock(
Response, json=lambda: {"next": "http://example.com/next", "results": []}
)
@pytest.mark.parametrize(
"test_case",
[
# Test with empty next_url_path, e.g. auto-detect
{
"next_url_path": None,
"response_json": {"next": "http://example.com/next", "results": []},
"expected": {
"next_reference": "http://example.com/next",
"has_next_page": True,
},
},
# Test with explicit next_url_path
{
"next_url_path": "next_page",
"response_json": {
"next_page": "http://example.com/next",
"results": [],
},
"expected": {
"next_reference": "http://example.com/next",
"has_next_page": True,
},
},
# Test with nested next_url_path
{
"next_url_path": "next_page.url",
"response_json": {
"next_page": {"url": "http://example.com/next"},
"results": [],
},
"expected": {
"next_reference": "http://example.com/next",
"has_next_page": True,
},
},
# Test without next_page
{
"next_url_path": None,
"response_json": {"results": []},
"expected": {
"next_reference": None,
"has_next_page": False,
},
},
],
)
def test_update_state(self, test_case):
next_url_path = test_case["next_url_path"]

if next_url_path is None:
paginator = JSONResponsePaginator()
else:
paginator = JSONResponsePaginator(next_url_path=next_url_path)
response = Mock(Response, json=lambda: test_case["response_json"])
paginator.update_state(response)
assert paginator.next_reference == "http://example.com/next"
assert paginator.has_next_page is True
assert paginator.next_reference == test_case["expected"]["next_reference"]
assert paginator.has_next_page == test_case["expected"]["has_next_page"]

def test_update_state_without_next(self):
# Test update_request from BaseNextUrlPaginator
@pytest.mark.parametrize(
"test_case",
[
# Test with absolute URL
{
"next_reference": "http://example.com/api/resource?page=2",
"request_url": "http://example.com/api/resource",
"expected": "http://example.com/api/resource?page=2",
},
# Test with relative URL
{
"next_reference": "/api/resource?page=2",
"request_url": "http://example.com/api/resource",
"expected": "http://example.com/api/resource?page=2",
},
# Test with more nested path
{
"next_reference": "/api/resource/subresource?page=3&sort=desc",
"request_url": "http://example.com/api/resource/subresource",
"expected": "http://example.com/api/resource/subresource?page=3&sort=desc",
},
# Test with 'page' in path
{
"next_reference": "/api/page/4/items?filter=active",
"request_url": "http://example.com/api/page/3/items",
"expected": "http://example.com/api/page/4/items?filter=active",
},
# Test with complex query parameters
{
"next_reference": "/api/resource?page=3&category=books&sort=author",
"request_url": "http://example.com/api/resource?page=2",
"expected": "http://example.com/api/resource?page=3&category=books&sort=author",
},
# Test with URL having port number
{
"next_reference": "/api/resource?page=2",
"request_url": "http://example.com:8080/api/resource",
"expected": "http://example.com:8080/api/resource?page=2",
},
# Test with HTTPS protocol
{
"next_reference": "https://secure.example.com/api/resource?page=2",
"request_url": "https://secure.example.com/api/resource",
"expected": "https://secure.example.com/api/resource?page=2",
},
# Test with encoded characters in URL
{
"next_reference": "/api/resource?page=2&query=%E3%81%82",
"request_url": "http://example.com/api/resource",
"expected": "http://example.com/api/resource?page=2&query=%E3%81%82",
},
# Test with missing 'page' parameter in next_reference
{
"next_reference": "/api/resource?sort=asc",
"request_url": "http://example.com/api/resource?page=1",
"expected": "http://example.com/api/resource?sort=asc",
},
],
)
def test_update_request(self, test_case):
paginator = JSONResponsePaginator()
response = Mock(Response, json=lambda: {"results": []})
paginator.update_state(response)
assert paginator.next_reference is None
assert paginator.has_next_page is False
paginator.next_reference = test_case["next_reference"]
request = Mock(Request, url=test_case["request_url"])
paginator.update_request(request)
assert request.url == test_case["expected"]


class TestSinglePagePaginator:

0 comments on commit ecb5aa0

Please sign in to comment.