diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 37c0de3db1..020c63a195 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -38,7 +38,11 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration): configurable via env variables or toml files """ - pass + def __bool__(self) -> bool: + # This is needed to avoid AuthConfigBase-derived classes + # which do not implement CredentialsConfiguration interface + # to be evaluated as False in requests.sessions.Session.prepare_request() + return True @configspec diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index b4b62fa849..7d1145a890 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -6,12 +6,14 @@ Any, TypeVar, Iterable, + Union, cast, ) import copy from urllib.parse import urlparse from requests import Session as BaseSession # noqa: I251 from requests import Response, Request +from requests.auth import AuthBase from dlt.common import jsonpath, logger @@ -41,7 +43,7 @@ def __init__( request: Request, response: Response, paginator: BasePaginator, - auth: AuthConfigBase, + auth: AuthBase, ): super().__init__(__iterable) self.request = request @@ -57,7 +59,7 @@ class RESTClient: Args: base_url (str): The base URL of the API to make requests to. headers (Optional[Dict[str, str]]): Default headers to include in all requests. - auth (Optional[AuthConfigBase]): Authentication configuration for all requests. + auth (Optional[AuthBase]): Authentication configuration for all requests. paginator (Optional[BasePaginator]): Default paginator for handling paginated responses. data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for extracting data from responses. session (BaseSession): HTTP session for making requests. @@ -69,7 +71,7 @@ def __init__( self, base_url: str, headers: Optional[Dict[str, str]] = None, - auth: Optional[AuthConfigBase] = None, + auth: Optional[AuthBase] = None, paginator: Optional[BasePaginator] = None, data_selector: Optional[jsonpath.TJsonPath] = None, session: BaseSession = None, @@ -105,7 +107,7 @@ def _create_request( method: HTTPMethod, params: Dict[str, Any], json: Optional[Dict[str, Any]] = None, - auth: Optional[AuthConfigBase] = None, + auth: Optional[AuthBase] = None, hooks: Optional[Hooks] = None, ) -> Request: parsed_url = urlparse(path) @@ -154,7 +156,7 @@ def paginate( method: HTTPMethodBasic = "GET", params: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, - auth: Optional[AuthConfigBase] = None, + auth: Optional[AuthBase] = None, paginator: Optional[BasePaginator] = None, data_selector: Optional[jsonpath.TJsonPath] = None, hooks: Optional[Hooks] = None, @@ -166,7 +168,7 @@ def paginate( method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'. params (Optional[Dict[str, Any]]): URL parameters for the request. json (Optional[Dict[str, Any]]): JSON payload for the request. - auth (Optional[AuthConfigBase]): Authentication configuration for the request. + auth (Optional[AuthBase): Authentication configuration for the request. paginator (Optional[BasePaginator]): Paginator instance for handling pagination logic. data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for diff --git a/docs/website/docs/general-usage/http/rest-client.md b/docs/website/docs/general-usage/http/rest-client.md index 19cc95bf78..1093428b0f 100644 --- a/docs/website/docs/general-usage/http/rest-client.md +++ b/docs/website/docs/general-usage/http/rest-client.md @@ -407,7 +407,7 @@ The available authentication methods are defined in the `dlt.sources.helpers.res - [APIKeyAuth](#api-key-authentication) - [HttpBasicAuth](#http-basic-authentication) -For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthConfigBase` class. +For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library. ### Bearer token authentication @@ -479,12 +479,12 @@ response = client.get("/protected/resource") ### Implementing custom authentication -You can implement custom authentication by subclassing the `AuthConfigBase` class and implementing the `__call__` method: +You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method: ```py -from dlt.sources.helpers.rest_client.auth import AuthConfigBase +from requests.auth import AuthBase -class CustomAuth(AuthConfigBase): +class CustomAuth(AuthBase): def __init__(self, token): self.token = token diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 50defa8edb..7f03c6d167 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -1,8 +1,10 @@ import os import pytest from typing import Any, cast +from requests import PreparedRequest, Request +from requests.auth import AuthBase from dlt.common.typing import TSecretStrValue -from dlt.sources.helpers.requests import Response, Request +from dlt.sources.helpers.requests import Response from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator @@ -57,7 +59,6 @@ def test_page_context(self, rest_client: RESTClient) -> None: for page in rest_client.paginate( "/posts", paginator=JSONResponsePaginator(next_url_path="next_page"), - auth=AuthConfigBase(), ): # response that produced data assert isinstance(page.response, Response) @@ -183,3 +184,44 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): ) assert_pagination(list(pages_iter)) + + def test_custom_auth_success(self, rest_client: RESTClient): + class CustomAuthConfigBase(AuthConfigBase): + def __init__(self, token: str): + self.token = token + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + request.headers["Authorization"] = f"Bearer {self.token}" + return request + + class CustomAuthAuthBase(AuthBase): + def __init__(self, token: str): + self.token = token + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + request.headers["Authorization"] = f"Bearer {self.token}" + return request + + auth_list = [ + CustomAuthConfigBase("test-token"), + CustomAuthAuthBase("test-token"), + ] + + for auth in auth_list: + response = rest_client.get( + "/protected/posts/bearer-token", + auth=auth, + ) + + assert response.status_code == 200 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + pages_iter = rest_client.paginate( + "/protected/posts/bearer-token", + auth=auth, + ) + + pages_list = list(pages_iter) + assert_pagination(pages_list) + + assert pages_list[0].response.request.headers["Authorization"] == "Bearer test-token"