Skip to content

Commit

Permalink
fix: ensure custom session can be provided to rest client
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed May 22, 2024
1 parent b1e0f77 commit 64c9538
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
20 changes: 9 additions & 11 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def __init__(
self.auth = auth

if session:
self._validate_session_raise_for_status(session)
self.session = session
self.session = _warn_session_raise_for_status(session)
else:
self.session = Client(raise_for_status=False).session

Expand All @@ -90,15 +89,6 @@ def __init__(

self.data_selector = data_selector

def _validate_session_raise_for_status(self, session: BaseSession) -> None:
# dlt.sources.helpers.requests.session.Session
# has raise_for_status=True by default
if getattr(self.session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. "
"This may cause unexpected behavior."
)

def _create_request(
self,
path: str,
Expand Down Expand Up @@ -296,3 +286,11 @@ def detect_paginator(self, response: Response, data: Any) -> BasePaginator:
" instance of the paginator as some settings may not be guessed correctly."
)
return paginator


def _warn_session_raise_for_status(session: BaseSession) -> BaseSession:
if getattr(session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. This may cause unexpected behavior."
)
return session
16 changes: 15 additions & 1 deletion tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import pytest
from typing import Any, cast
from dlt.common import logger
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.requests import Client, Response, Request
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator
Expand Down Expand Up @@ -183,3 +184,16 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
)

assert_pagination(list(pages_iter))

def test_custom_session_client(self, mocker):
mocked_warning = mocker.patch.object(logger, "warning")
RESTClient(
base_url="https://api.example.com",
headers={"Accept": "application/json"},
session=Client(raise_for_status=True).session,
)
assert (
mocked_warning.call_args[0][0]
== "The session provided has raise_for_status enabled. This may cause unexpected"
" behavior."
)

0 comments on commit 64c9538

Please sign in to comment.