Skip to content

Commit

Permalink
RESTClient: implement AuthConfigBase.__bool__ + update docs (#1413)
Browse files Browse the repository at this point in the history
* Fix AuthConfigBase so its instances always evaluate to True in bool context; change docs to suggest direct inheritance from AuthBase

* Add tests

* Fix formatting

* uses AuthBase as auth type

---------

Co-authored-by: Anton Burnashev <[email protected]>
  • Loading branch information
rudolfix and burnash authored May 27, 2024
1 parent 3dc1874 commit 4fcfa28
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
6 changes: 5 additions & 1 deletion dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration):
configurable via env variables or toml files
"""

pass
def __bool__(self) -> bool:
# This is needed to avoid AuthConfigBase-derived classes
# which do not implement CredentialsConfiguration interface
# to be evaluated as False in requests.sessions.Session.prepare_request()
return True


@configspec
Expand Down
14 changes: 8 additions & 6 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
Any,
TypeVar,
Iterable,
Union,
cast,
)
import copy
from urllib.parse import urlparse
from requests import Session as BaseSession # noqa: I251
from requests import Response, Request
from requests.auth import AuthBase

from dlt.common import jsonpath, logger

Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(
request: Request,
response: Response,
paginator: BasePaginator,
auth: AuthConfigBase,
auth: AuthBase,
):
super().__init__(__iterable)
self.request = request
Expand All @@ -57,7 +59,7 @@ class RESTClient:
Args:
base_url (str): The base URL of the API to make requests to.
headers (Optional[Dict[str, str]]): Default headers to include in all requests.
auth (Optional[AuthConfigBase]): Authentication configuration for all requests.
auth (Optional[AuthBase]): Authentication configuration for all requests.
paginator (Optional[BasePaginator]): Default paginator for handling paginated responses.
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for extracting data from responses.
session (BaseSession): HTTP session for making requests.
Expand All @@ -69,7 +71,7 @@ def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
session: BaseSession = None,
Expand Down Expand Up @@ -105,7 +107,7 @@ def _create_request(
method: HTTPMethod,
params: Dict[str, Any],
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
) -> Request:
parsed_url = urlparse(path)
Expand Down Expand Up @@ -154,7 +156,7 @@ def paginate(
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -166,7 +168,7 @@ def paginate(
method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'.
params (Optional[Dict[str, Any]]): URL parameters for the request.
json (Optional[Dict[str, Any]]): JSON payload for the request.
auth (Optional[AuthConfigBase]): Authentication configuration for the request.
auth (Optional[AuthBase): Authentication configuration for the request.
paginator (Optional[BasePaginator]): Paginator instance for handling
pagination logic.
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for
Expand Down
8 changes: 4 additions & 4 deletions docs/website/docs/general-usage/http/rest-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ The available authentication methods are defined in the `dlt.sources.helpers.res
- [APIKeyAuth](#api-key-authentication)
- [HttpBasicAuth](#http-basic-authentication)

For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthConfigBase` class.
For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library.

### Bearer token authentication

Expand Down Expand Up @@ -479,12 +479,12 @@ response = client.get("/protected/resource")

### Implementing custom authentication

You can implement custom authentication by subclassing the `AuthConfigBase` class and implementing the `__call__` method:
You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method:

```py
from dlt.sources.helpers.rest_client.auth import AuthConfigBase
from requests.auth import AuthBase

class CustomAuth(AuthConfigBase):
class CustomAuth(AuthBase):
def __init__(self, token):
self.token = token

Expand Down
46 changes: 44 additions & 2 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import pytest
from typing import Any, cast
from requests import PreparedRequest, Request
from requests.auth import AuthBase
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.requests import Response
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator
Expand Down Expand Up @@ -57,7 +59,6 @@ def test_page_context(self, rest_client: RESTClient) -> None:
for page in rest_client.paginate(
"/posts",
paginator=JSONResponsePaginator(next_url_path="next_page"),
auth=AuthConfigBase(),
):
# response that produced data
assert isinstance(page.response, Response)
Expand Down Expand Up @@ -183,3 +184,44 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
)

assert_pagination(list(pages_iter))

def test_custom_auth_success(self, rest_client: RESTClient):
class CustomAuthConfigBase(AuthConfigBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

class CustomAuthAuthBase(AuthBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

auth_list = [
CustomAuthConfigBase("test-token"),
CustomAuthAuthBase("test-token"),
]

for auth in auth_list:
response = rest_client.get(
"/protected/posts/bearer-token",
auth=auth,
)

assert response.status_code == 200
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}

pages_iter = rest_client.paginate(
"/protected/posts/bearer-token",
auth=auth,
)

pages_list = list(pages_iter)
assert_pagination(pages_list)

assert pages_list[0].response.request.headers["Authorization"] == "Bearer test-token"

0 comments on commit 4fcfa28

Please sign in to comment.