Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: ensure custom session can be provided to rest client #1396

Merged
merged 10 commits into from
May 27, 2024
2 changes: 1 addition & 1 deletion dlt/sources/helpers/requests/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _make_session(self) -> Session:
session.mount("http://", self._adapter)
session.mount("https://", self._adapter)
retry = _make_retry(**self._retry_kwargs)
session.request = retry.wraps(session.request) # type: ignore[method-assign]
session.send = retry.wraps(session.send) # type: ignore[method-assign]
return session

@property
Expand Down
20 changes: 9 additions & 11 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def __init__(
self.auth = auth

if session:
self._validate_session_raise_for_status(session)
self.session = session
self.session = _warn_session_raise_for_status(session)
z3z1ma marked this conversation as resolved.
Show resolved Hide resolved
else:
self.session = Client(raise_for_status=False).session

Expand All @@ -90,15 +89,6 @@ def __init__(

self.data_selector = data_selector

def _validate_session_raise_for_status(self, session: BaseSession) -> None:
# dlt.sources.helpers.requests.session.Session
# has raise_for_status=True by default
z3z1ma marked this conversation as resolved.
Show resolved Hide resolved
if getattr(self.session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. "
"This may cause unexpected behavior."
)

def _create_request(
self,
path: str,
Expand Down Expand Up @@ -296,3 +286,11 @@ def detect_paginator(self, response: Response, data: Any) -> BasePaginator:
" instance of the paginator as some settings may not be guessed correctly."
)
return paginator


def _warn_session_raise_for_status(session: BaseSession) -> BaseSession:
burnash marked this conversation as resolved.
Show resolved Hide resolved
if getattr(session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. This may cause unexpected behavior."
)
return session
16 changes: 15 additions & 1 deletion tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import pytest
from typing import Any, cast
from dlt.common import logger
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.requests import Client, Response, Request
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator
Expand Down Expand Up @@ -183,3 +184,16 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
)

assert_pagination(list(pages_iter))

def test_custom_session_client(self, mocker):
mocked_warning = mocker.patch.object(logger, "warning")
RESTClient(
base_url="https://api.example.com",
headers={"Accept": "application/json"},
session=Client(raise_for_status=True).session,
)
assert (
mocked_warning.call_args[0][0]
== "The session provided has raise_for_status enabled. This may cause unexpected"
" behavior."
)
80 changes: 44 additions & 36 deletions tests/sources/helpers/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def mock_sleep() -> Iterator[mock.MagicMock]:


def test_default_session_retry_settings() -> None:
retry: Retrying = Client().session.request.retry # type: ignore
retry: Retrying = Client().session.send.retry # type: ignore
assert retry.stop.max_attempt_number == 5 # type: ignore
assert isinstance(retry.retry, retry_any)
retries = retry.retry.retries
Expand All @@ -51,7 +51,7 @@ def custom_retry_cond(response, exception):
respect_retry_after_header=False,
).session

retry: Retrying = session.request.retry # type: ignore
retry: Retrying = session.send.retry # type: ignore
assert retry.stop.max_attempt_number == 14 # type: ignore
assert isinstance(retry.retry, retry_any)
retries = retry.retry.retries
Expand All @@ -63,11 +63,12 @@ def custom_retry_cond(response, exception):
def test_retry_on_status_all_fails(mock_sleep: mock.MagicMock) -> None:
session = Client().session
url = "https://example.com/data"
m = requests_mock.Adapter()
session.mount("https://", m)
m.register_uri("GET", url, status_code=503)

with requests_mock.mock(session=session) as m:
m.get(url, status_code=503)
with pytest.raises(requests.HTTPError):
session.get(url)
with pytest.raises(requests.HTTPError):
session.get(url)

assert m.call_count == RunConfiguration.request_max_attempts

Expand All @@ -76,16 +77,17 @@ def test_retry_on_status_success_after_2(mock_sleep: mock.MagicMock) -> None:
"""Test successful request after 2 retries"""
session = Client().session
url = "https://example.com/data"
m = requests_mock.Adapter()
session.mount("https://", m)

responses = [
dict(text="error", status_code=503),
dict(text="error", status_code=503),
dict(text="error", status_code=200),
]

with requests_mock.mock(session=session) as m:
m.get(url, responses)
resp = session.get(url)
m.register_uri("GET", url, responses)
resp = session.get(url)

assert resp.status_code == 200
assert m.call_count == 3
Expand All @@ -94,30 +96,32 @@ def test_retry_on_status_success_after_2(mock_sleep: mock.MagicMock) -> None:
def test_retry_on_status_without_raise_for_status(mock_sleep: mock.MagicMock) -> None:
url = "https://example.com/data"
session = Client(raise_for_status=False).session
m = requests_mock.Adapter()
session.mount("https://", m)

with requests_mock.mock(session=session) as m:
m.get(url, status_code=503)
response = session.get(url)
assert response.status_code == 503
m.register_uri("GET", url, status_code=503)
response = session.get(url)
assert response.status_code == 503

assert m.call_count == RunConfiguration.request_max_attempts


def test_hooks_with_raise_for_statue() -> None:
url = "https://example.com/data"
session = Client(raise_for_status=True).session
m = requests_mock.Adapter()
session.mount("https://", m)

def _no_content(resp: requests.Response, *args, **kwargs) -> requests.Response:
resp.status_code = 204
resp._content = b"[]"
return resp

with requests_mock.mock(session=session) as m:
m.get(url, status_code=503)
response = session.get(url, hooks={"response": _no_content})
# we simulate empty response
assert response.status_code == 204
assert response.json() == []
m.register_uri("GET", url, status_code=503)
response = session.get(url, hooks={"response": _no_content})
# we simulate empty response
assert response.status_code == 204
assert response.json() == []

assert m.call_count == 1

Expand All @@ -130,12 +134,13 @@ def test_retry_on_exception_all_fails(
exception_class: Type[Exception], mock_sleep: mock.MagicMock
) -> None:
session = Client().session
m = requests_mock.Adapter()
session.mount("https://", m)
url = "https://example.com/data"

with requests_mock.mock(session=session) as m:
m.get(url, exc=exception_class)
with pytest.raises(exception_class):
session.get(url)
m.register_uri("GET", url, exc=exception_class)
with pytest.raises(exception_class):
session.get(url)

assert m.call_count == RunConfiguration.request_max_attempts

Expand All @@ -145,12 +150,13 @@ def retry_on(response: requests.Response, exception: BaseException) -> bool:
return response.text == "error"

session = Client(retry_condition=retry_on).session
m = requests_mock.Adapter()
session.mount("https://", m)
url = "https://example.com/data"

with requests_mock.mock(session=session) as m:
m.get(url, text="error")
response = session.get(url)
assert response.content == b"error"
m.register_uri("GET", url, text="error")
response = session.get(url)
assert response.content == b"error"

assert m.call_count == RunConfiguration.request_max_attempts

Expand All @@ -160,12 +166,12 @@ def retry_on(response: requests.Response, exception: BaseException) -> bool:
return response.text == "error"

session = Client(retry_condition=retry_on).session
m = requests_mock.Adapter()
session.mount("https://", m)
url = "https://example.com/data"
responses = [dict(text="error"), dict(text="error"), dict(text="success")]

with requests_mock.mock(session=session) as m:
m.get(url, responses)
resp = session.get(url)
m.register_uri("GET", url, [dict(text="error"), dict(text="error"), dict(text="success")])
resp = session.get(url)

assert resp.text == "success"
assert m.call_count == 3
Expand All @@ -174,14 +180,16 @@ def retry_on(response: requests.Response, exception: BaseException) -> bool:
def test_wait_retry_after_int(mock_sleep: mock.MagicMock) -> None:
session = Client(request_backoff_factor=0).session
url = "https://example.com/data"
m = requests_mock.Adapter()
session.mount("https://", m)
m.register_uri("GET", url, text="error")
responses = [
dict(text="error", headers={"retry-after": "4"}, status_code=429),
dict(text="success"),
]

with requests_mock.mock(session=session) as m:
m.get(url, responses)
session.get(url)
m.register_uri("GET", url, responses)
session.get(url)

mock_sleep.assert_called_once()
assert 4 <= mock_sleep.call_args[0][0] <= 5 # Adds jitter up to 1s
Expand All @@ -206,7 +214,7 @@ def test_init_default_client(existing_session: bool) -> None:

session = default_client.session
assert session.timeout == cfg["RUNTIME__REQUEST_TIMEOUT"]
retry = session.request.retry # type: ignore[attr-defined]
retry = session.send.retry # type: ignore[attr-defined]
assert retry.wait.multiplier == cfg["RUNTIME__REQUEST_BACKOFF_FACTOR"]
assert retry.stop.max_attempt_number == cfg["RUNTIME__REQUEST_MAX_ATTEMPTS"]
assert retry.wait.max == cfg["RUNTIME__REQUEST_MAX_RETRY_DELAY"]
Expand All @@ -226,7 +234,7 @@ def test_client_instance_with_config(existing_session: bool) -> None:

session = client.session
assert session.timeout == cfg["RUNTIME__REQUEST_TIMEOUT"]
retry = session.request.retry # type: ignore[attr-defined]
retry = session.send.retry # type: ignore[attr-defined]
assert retry.wait.multiplier == cfg["RUNTIME__REQUEST_BACKOFF_FACTOR"]
assert retry.stop.max_attempt_number == cfg["RUNTIME__REQUEST_MAX_ATTEMPTS"]
assert retry.wait.max == cfg["RUNTIME__REQUEST_MAX_RETRY_DELAY"]
Loading