diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index b4b62fa849..b10c0b4049 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -80,8 +80,7 @@ def __init__( self.auth = auth if session: - self._validate_session_raise_for_status(session) - self.session = session + self.session = _warn_session_raise_for_status(session) else: self.session = Client(raise_for_status=False).session @@ -90,15 +89,6 @@ def __init__( self.data_selector = data_selector - def _validate_session_raise_for_status(self, session: BaseSession) -> None: - # dlt.sources.helpers.requests.session.Session - # has raise_for_status=True by default - if getattr(self.session, "raise_for_status", False): - logger.warning( - "The session provided has raise_for_status enabled. " - "This may cause unexpected behavior." - ) - def _create_request( self, path: str, @@ -296,3 +286,11 @@ def detect_paginator(self, response: Response, data: Any) -> BasePaginator: " instance of the paginator as some settings may not be guessed correctly." ) return paginator + + +def _warn_session_raise_for_status(session: BaseSession) -> BaseSession: + if getattr(session, "raise_for_status", False): + logger.warning( + "The session provided has raise_for_status enabled. This may cause unexpected behavior." + ) + return session diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 50defa8edb..d94685fa5c 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -1,8 +1,9 @@ import os import pytest from typing import Any, cast +from dlt.common import logger from dlt.common.typing import TSecretStrValue -from dlt.sources.helpers.requests import Response, Request +from dlt.sources.helpers.requests import Client, Response, Request from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator @@ -183,3 +184,16 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): ) assert_pagination(list(pages_iter)) + + def test_custom_session_client(self, mocker): + mocked_warning = mocker.patch.object(logger, "warning") + RESTClient( + base_url="https://api.example.com", + headers={"Accept": "application/json"}, + session=Client(raise_for_status=True).session, + ) + assert ( + mocked_warning.call_args[0][0] + == "The session provided has raise_for_status enabled. This may cause unexpected" + " behavior." + )