Skip to content

Commit 4fcfa28

Browse files
rudolfixburnash
andauthored
RESTClient: implement AuthConfigBase.__bool__ + update docs (#1413)
* Fix AuthConfigBase so its instances always evaluate to True in bool context; change docs to suggest direct inheritance from AuthBase * Add tests * Fix formatting * uses AuthBase as auth type --------- Co-authored-by: Anton Burnashev <[email protected]>
1 parent 3dc1874 commit 4fcfa28

File tree

4 files changed

+61
-13
lines changed

4 files changed

+61
-13
lines changed

dlt/sources/helpers/rest_client/auth.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration):
3838
configurable via env variables or toml files
3939
"""
4040

41-
pass
41+
def __bool__(self) -> bool:
42+
# This is needed to avoid AuthConfigBase-derived classes
43+
# which do not implement CredentialsConfiguration interface
44+
# to be evaluated as False in requests.sessions.Session.prepare_request()
45+
return True
4246

4347

4448
@configspec

dlt/sources/helpers/rest_client/client.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
Any,
77
TypeVar,
88
Iterable,
9+
Union,
910
cast,
1011
)
1112
import copy
1213
from urllib.parse import urlparse
1314
from requests import Session as BaseSession # noqa: I251
1415
from requests import Response, Request
16+
from requests.auth import AuthBase
1517

1618
from dlt.common import jsonpath, logger
1719

@@ -41,7 +43,7 @@ def __init__(
4143
request: Request,
4244
response: Response,
4345
paginator: BasePaginator,
44-
auth: AuthConfigBase,
46+
auth: AuthBase,
4547
):
4648
super().__init__(__iterable)
4749
self.request = request
@@ -57,7 +59,7 @@ class RESTClient:
5759
Args:
5860
base_url (str): The base URL of the API to make requests to.
5961
headers (Optional[Dict[str, str]]): Default headers to include in all requests.
60-
auth (Optional[AuthConfigBase]): Authentication configuration for all requests.
62+
auth (Optional[AuthBase]): Authentication configuration for all requests.
6163
paginator (Optional[BasePaginator]): Default paginator for handling paginated responses.
6264
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for extracting data from responses.
6365
session (BaseSession): HTTP session for making requests.
@@ -69,7 +71,7 @@ def __init__(
6971
self,
7072
base_url: str,
7173
headers: Optional[Dict[str, str]] = None,
72-
auth: Optional[AuthConfigBase] = None,
74+
auth: Optional[AuthBase] = None,
7375
paginator: Optional[BasePaginator] = None,
7476
data_selector: Optional[jsonpath.TJsonPath] = None,
7577
session: BaseSession = None,
@@ -105,7 +107,7 @@ def _create_request(
105107
method: HTTPMethod,
106108
params: Dict[str, Any],
107109
json: Optional[Dict[str, Any]] = None,
108-
auth: Optional[AuthConfigBase] = None,
110+
auth: Optional[AuthBase] = None,
109111
hooks: Optional[Hooks] = None,
110112
) -> Request:
111113
parsed_url = urlparse(path)
@@ -154,7 +156,7 @@ def paginate(
154156
method: HTTPMethodBasic = "GET",
155157
params: Optional[Dict[str, Any]] = None,
156158
json: Optional[Dict[str, Any]] = None,
157-
auth: Optional[AuthConfigBase] = None,
159+
auth: Optional[AuthBase] = None,
158160
paginator: Optional[BasePaginator] = None,
159161
data_selector: Optional[jsonpath.TJsonPath] = None,
160162
hooks: Optional[Hooks] = None,
@@ -166,7 +168,7 @@ def paginate(
166168
method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'.
167169
params (Optional[Dict[str, Any]]): URL parameters for the request.
168170
json (Optional[Dict[str, Any]]): JSON payload for the request.
169-
auth (Optional[AuthConfigBase]): Authentication configuration for the request.
171+
auth (Optional[AuthBase): Authentication configuration for the request.
170172
paginator (Optional[BasePaginator]): Paginator instance for handling
171173
pagination logic.
172174
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for

docs/website/docs/general-usage/http/rest-client.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ The available authentication methods are defined in the `dlt.sources.helpers.res
407407
- [APIKeyAuth](#api-key-authentication)
408408
- [HttpBasicAuth](#http-basic-authentication)
409409

410-
For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthConfigBase` class.
410+
For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library.
411411

412412
### Bearer token authentication
413413

@@ -479,12 +479,12 @@ response = client.get("/protected/resource")
479479

480480
### Implementing custom authentication
481481

482-
You can implement custom authentication by subclassing the `AuthConfigBase` class and implementing the `__call__` method:
482+
You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method:
483483

484484
```py
485-
from dlt.sources.helpers.rest_client.auth import AuthConfigBase
485+
from requests.auth import AuthBase
486486

487-
class CustomAuth(AuthConfigBase):
487+
class CustomAuth(AuthBase):
488488
def __init__(self, token):
489489
self.token = token
490490

tests/sources/helpers/rest_client/test_client.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import pytest
33
from typing import Any, cast
4+
from requests import PreparedRequest, Request
5+
from requests.auth import AuthBase
46
from dlt.common.typing import TSecretStrValue
5-
from dlt.sources.helpers.requests import Response, Request
7+
from dlt.sources.helpers.requests import Response
68
from dlt.sources.helpers.rest_client import RESTClient
79
from dlt.sources.helpers.rest_client.client import Hooks
810
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator
@@ -57,7 +59,6 @@ def test_page_context(self, rest_client: RESTClient) -> None:
5759
for page in rest_client.paginate(
5860
"/posts",
5961
paginator=JSONResponsePaginator(next_url_path="next_page"),
60-
auth=AuthConfigBase(),
6162
):
6263
# response that produced data
6364
assert isinstance(page.response, Response)
@@ -183,3 +184,44 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
183184
)
184185

185186
assert_pagination(list(pages_iter))
187+
188+
def test_custom_auth_success(self, rest_client: RESTClient):
189+
class CustomAuthConfigBase(AuthConfigBase):
190+
def __init__(self, token: str):
191+
self.token = token
192+
193+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
194+
request.headers["Authorization"] = f"Bearer {self.token}"
195+
return request
196+
197+
class CustomAuthAuthBase(AuthBase):
198+
def __init__(self, token: str):
199+
self.token = token
200+
201+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
202+
request.headers["Authorization"] = f"Bearer {self.token}"
203+
return request
204+
205+
auth_list = [
206+
CustomAuthConfigBase("test-token"),
207+
CustomAuthAuthBase("test-token"),
208+
]
209+
210+
for auth in auth_list:
211+
response = rest_client.get(
212+
"/protected/posts/bearer-token",
213+
auth=auth,
214+
)
215+
216+
assert response.status_code == 200
217+
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}
218+
219+
pages_iter = rest_client.paginate(
220+
"/protected/posts/bearer-token",
221+
auth=auth,
222+
)
223+
224+
pages_list = list(pages_iter)
225+
assert_pagination(pages_list)
226+
227+
assert pages_list[0].response.request.headers["Authorization"] == "Bearer test-token"

0 commit comments

Comments
 (0)