Skip to content

Commit

Permalink
uses AuthBase as auth type
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed May 24, 2024
1 parent 4124673 commit d3fd1ed
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
14 changes: 8 additions & 6 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
Any,
TypeVar,
Iterable,
Union,
cast,
)
import copy
from urllib.parse import urlparse
from requests import Session as BaseSession # noqa: I251
from requests import Response, Request
from requests.auth import AuthBase

from dlt.common import jsonpath, logger

Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(
request: Request,
response: Response,
paginator: BasePaginator,
auth: AuthConfigBase,
auth: AuthBase,
):
super().__init__(__iterable)
self.request = request
Expand All @@ -57,7 +59,7 @@ class RESTClient:
Args:
base_url (str): The base URL of the API to make requests to.
headers (Optional[Dict[str, str]]): Default headers to include in all requests.
auth (Optional[AuthConfigBase]): Authentication configuration for all requests.
auth (Optional[AuthBase]): Authentication configuration for all requests.
paginator (Optional[BasePaginator]): Default paginator for handling paginated responses.
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for extracting data from responses.
session (BaseSession): HTTP session for making requests.
Expand All @@ -69,7 +71,7 @@ def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
session: BaseSession = None,
Expand Down Expand Up @@ -105,7 +107,7 @@ def _create_request(
method: HTTPMethod,
params: Dict[str, Any],
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
) -> Request:
parsed_url = urlparse(path)
Expand Down Expand Up @@ -154,7 +156,7 @@ def paginate(
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -166,7 +168,7 @@ def paginate(
method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'.
params (Optional[Dict[str, Any]]): URL parameters for the request.
json (Optional[Dict[str, Any]]): JSON payload for the request.
auth (Optional[AuthConfigBase]): Authentication configuration for the request.
auth (Optional[AuthBase): Authentication configuration for the request.
paginator (Optional[BasePaginator]): Paginator instance for handling
pagination logic.
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def resource(url: str):
# dlt bigquery custom destination
# we can use the dlt provided credentials class
# to retrieve the gcp credentials from the secrets
@dlt.destination(name="bigquery", loader_file_format="parquet", batch_size=0, naming_convention="snake_case")
@dlt.destination(
name="bigquery", loader_file_format="parquet", batch_size=0, naming_convention="snake_case"
)
def bigquery_insert(
items, table=BIGQUERY_TABLE_ID, credentials: GcpServiceAccountCredentials = dlt.secrets.value
) -> None:
Expand Down
7 changes: 4 additions & 3 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import pytest
from typing import Any, cast
from requests import PreparedRequest, Request
from requests.auth import AuthBase
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.requests import Response
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator
Expand Down Expand Up @@ -189,15 +190,15 @@ class CustomAuthConfigBase(AuthConfigBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: Request) -> Request:
def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

class CustomAuthAuthBase(AuthBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: Request) -> Request:
def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

Expand Down

0 comments on commit d3fd1ed

Please sign in to comment.