Skip to content

Commit b8fb7fd

Browse files
authored
Add RESTClient: (#1141)
* Add RESTClient and tests * Add PyJWT * Add initial version of `rest_client.paginate()` * Export `rest_client.paginate` to `helpers.requests` module * Fix the typing error * Use dlt.common.json * Add dependency checks for PyJWT and cryptography in auth module * Remove unused imports and check_connection function from rest_client utils * Refactor pagination assertion into a standalone function * Move `paginate` function test to new file `test_requests_paginate.py` * Remove PyJWT from deps * Remove explicit initializers and meta fields from configspec classes * Implement lazy loading for jwt and cryptography in auth * Set username default to None * Add PyJWT to dev dependencies
1 parent 28434b6 commit b8fb7fd

19 files changed

+1857
-33
lines changed

dlt/sources/helpers/requests/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from requests.exceptions import ChunkedEncodingError
1616
from dlt.sources.helpers.requests.retry import Client
1717
from dlt.sources.helpers.requests.session import Session
18+
from dlt.sources.helpers.rest_client import paginate
1819
from dlt.common.configuration.specs import RunConfiguration
1920

2021
client = Client()
2122

22-
get, post, put, patch, delete, options, head, request = (
23+
get, post, put, patch, delete, options, head, request, paginate = (
2324
client.get,
2425
client.post,
2526
client.put,
@@ -28,6 +29,7 @@
2829
client.options,
2930
client.head,
3031
client.request,
32+
paginate,
3133
)
3234

3335

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Optional, Dict, Iterator, Union, Any
2+
3+
from dlt.common import jsonpath
4+
5+
from .client import RESTClient # noqa: F401
6+
from .client import PageData
7+
from .auth import AuthConfigBase
8+
from .paginators import BasePaginator
9+
from .typing import HTTPMethodBasic, Hooks
10+
11+
12+
def paginate(
13+
url: str,
14+
method: HTTPMethodBasic = "GET",
15+
headers: Optional[Dict[str, str]] = None,
16+
params: Optional[Dict[str, Any]] = None,
17+
json: Optional[Dict[str, Any]] = None,
18+
auth: AuthConfigBase = None,
19+
paginator: Optional[BasePaginator] = None,
20+
data_selector: Optional[jsonpath.TJsonPath] = None,
21+
hooks: Optional[Hooks] = None,
22+
) -> Iterator[PageData[Any]]:
23+
"""
24+
Paginate over a REST API endpoint.
25+
26+
Args:
27+
url: URL to paginate over.
28+
**kwargs: Keyword arguments to pass to `RESTClient.paginate`.
29+
30+
Returns:
31+
Iterator[Page]: Iterator over pages.
32+
"""
33+
client = RESTClient(
34+
base_url=url,
35+
headers=headers,
36+
)
37+
return client.paginate(
38+
path="",
39+
method=method,
40+
params=params,
41+
json=json,
42+
auth=auth,
43+
paginator=paginator,
44+
data_selector=data_selector,
45+
hooks=hooks,
46+
)
+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from base64 import b64encode
2+
import math
3+
from typing import (
4+
List,
5+
Dict,
6+
Final,
7+
Literal,
8+
Optional,
9+
Union,
10+
Any,
11+
cast,
12+
Iterable,
13+
TYPE_CHECKING,
14+
)
15+
from dlt.sources.helpers import requests
16+
from requests.auth import AuthBase
17+
from requests import PreparedRequest # noqa: I251
18+
import pendulum
19+
20+
from dlt.common.exceptions import MissingDependencyException
21+
22+
from dlt.common import logger
23+
from dlt.common.configuration.specs.base_configuration import configspec
24+
from dlt.common.configuration.specs import CredentialsConfiguration
25+
from dlt.common.configuration.specs.exceptions import NativeValueError
26+
from dlt.common.typing import TSecretStrValue
27+
28+
if TYPE_CHECKING:
29+
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
30+
else:
31+
PrivateKeyTypes = Any
32+
33+
TApiKeyLocation = Literal[
34+
"header", "cookie", "query", "param"
35+
] # Alias for scheme "in" field
36+
37+
38+
class AuthConfigBase(AuthBase, CredentialsConfiguration):
39+
"""Authenticator base which is both `requests` friendly AuthBase and dlt SPEC
40+
configurable via env variables or toml files
41+
"""
42+
43+
pass
44+
45+
46+
@configspec
47+
class BearerTokenAuth(AuthConfigBase):
48+
token: TSecretStrValue = None
49+
50+
def parse_native_representation(self, value: Any) -> None:
51+
if isinstance(value, str):
52+
self.token = cast(TSecretStrValue, value)
53+
else:
54+
raise NativeValueError(
55+
type(self),
56+
value,
57+
f"BearerTokenAuth token must be a string, got {type(value)}",
58+
)
59+
60+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
61+
request.headers["Authorization"] = f"Bearer {self.token}"
62+
return request
63+
64+
65+
@configspec
66+
class APIKeyAuth(AuthConfigBase):
67+
name: str = "Authorization"
68+
api_key: TSecretStrValue = None
69+
location: TApiKeyLocation = "header"
70+
71+
def parse_native_representation(self, value: Any) -> None:
72+
if isinstance(value, str):
73+
self.api_key = cast(TSecretStrValue, value)
74+
else:
75+
raise NativeValueError(
76+
type(self),
77+
value,
78+
f"APIKeyAuth api_key must be a string, got {type(value)}",
79+
)
80+
81+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
82+
if self.location == "header":
83+
request.headers[self.name] = self.api_key
84+
elif self.location in ["query", "param"]:
85+
request.prepare_url(request.url, {self.name: self.api_key})
86+
elif self.location == "cookie":
87+
raise NotImplementedError()
88+
return request
89+
90+
91+
@configspec
92+
class HttpBasicAuth(AuthConfigBase):
93+
username: str = None
94+
password: TSecretStrValue = None
95+
96+
def parse_native_representation(self, value: Any) -> None:
97+
if isinstance(value, Iterable) and not isinstance(value, str):
98+
value = list(value)
99+
if len(value) == 2:
100+
self.username, self.password = value
101+
return
102+
raise NativeValueError(
103+
type(self),
104+
value,
105+
f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}",
106+
)
107+
108+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
109+
encoded = b64encode(f"{self.username}:{self.password}".encode()).decode()
110+
request.headers["Authorization"] = f"Basic {encoded}"
111+
return request
112+
113+
114+
@configspec
115+
class OAuth2AuthBase(AuthConfigBase):
116+
"""Base class for oauth2 authenticators. requires access_token"""
117+
118+
# TODO: Separate class for flows (implicit, authorization_code, client_credentials, etc)
119+
access_token: TSecretStrValue = None
120+
121+
def parse_native_representation(self, value: Any) -> None:
122+
if isinstance(value, str):
123+
self.access_token = cast(TSecretStrValue, value)
124+
else:
125+
raise NativeValueError(
126+
type(self),
127+
value,
128+
f"OAuth2AuthBase access_token must be a string, got {type(value)}",
129+
)
130+
131+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
132+
request.headers["Authorization"] = f"Bearer {self.access_token}"
133+
return request
134+
135+
136+
@configspec
137+
class OAuthJWTAuth(BearerTokenAuth):
138+
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""
139+
140+
format: Final[Literal["JWT"]] = "JWT" # noqa: A003
141+
client_id: str = None
142+
private_key: TSecretStrValue = None
143+
auth_endpoint: str = None
144+
scopes: Optional[Union[str, List[str]]] = None
145+
headers: Optional[Dict[str, str]] = None
146+
private_key_passphrase: Optional[TSecretStrValue] = None
147+
default_token_expiration: int = 3600
148+
149+
def __post_init__(self) -> None:
150+
self.scopes = (
151+
self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
152+
)
153+
self.token = None
154+
self.token_expiry: Optional[pendulum.DateTime] = None
155+
156+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
157+
if self.token is None or self.is_token_expired():
158+
self.obtain_token()
159+
r.headers["Authorization"] = f"Bearer {self.token}"
160+
return r
161+
162+
def is_token_expired(self) -> bool:
163+
return not self.token_expiry or pendulum.now() >= self.token_expiry
164+
165+
def obtain_token(self) -> None:
166+
try:
167+
import jwt
168+
except ModuleNotFoundError:
169+
raise MissingDependencyException("dlt OAuth helpers", ["PyJWT"])
170+
171+
payload = self.create_jwt_payload()
172+
data = {
173+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
174+
"assertion": jwt.encode(
175+
payload, self.load_private_key(), algorithm="RS256"
176+
),
177+
}
178+
179+
logger.debug(f"Obtaining token from {self.auth_endpoint}")
180+
181+
response = requests.post(self.auth_endpoint, headers=self.headers, data=data)
182+
response.raise_for_status()
183+
184+
token_response = response.json()
185+
self.token = token_response["access_token"]
186+
self.token_expiry = pendulum.now().add(
187+
seconds=token_response.get("expires_in", self.default_token_expiration)
188+
)
189+
190+
def create_jwt_payload(self) -> Dict[str, Union[str, int]]:
191+
now = pendulum.now()
192+
return {
193+
"iss": self.client_id,
194+
"sub": self.client_id,
195+
"aud": self.auth_endpoint,
196+
"exp": math.floor((now.add(hours=1)).timestamp()),
197+
"iat": math.floor(now.timestamp()),
198+
"scope": cast(str, self.scopes),
199+
}
200+
201+
def load_private_key(self) -> "PrivateKeyTypes":
202+
try:
203+
from cryptography.hazmat.backends import default_backend
204+
from cryptography.hazmat.primitives import serialization
205+
except ModuleNotFoundError:
206+
raise MissingDependencyException("dlt OAuth helpers", ["cryptography"])
207+
208+
private_key_bytes = self.private_key.encode("utf-8")
209+
return serialization.load_pem_private_key(
210+
private_key_bytes,
211+
password=self.private_key_passphrase.encode("utf-8")
212+
if self.private_key_passphrase
213+
else None,
214+
backend=default_backend(),
215+
)

0 commit comments

Comments
 (0)