Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reformat with Black #1179

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
else:
PrivateKeyTypes = Any

TApiKeyLocation = Literal[
"header", "cookie", "query", "param"
] # Alias for scheme "in" field
TApiKeyLocation = Literal["header", "cookie", "query", "param"] # Alias for scheme "in" field


class AuthConfigBase(AuthBase, CredentialsConfiguration):
Expand Down Expand Up @@ -102,7 +100,8 @@ def parse_native_representation(self, value: Any) -> None:
raise NativeValueError(
type(self),
value,
f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}",
"HttpBasicAuth username and password must be a tuple of two strings, got"
f" {type(value)}",
)

def __call__(self, request: PreparedRequest) -> PreparedRequest:
Expand Down Expand Up @@ -147,9 +146,7 @@ class OAuthJWTAuth(BearerTokenAuth):
default_token_expiration: int = 3600

def __post_init__(self) -> None:
self.scopes = (
self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
)
self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None

Expand All @@ -171,9 +168,7 @@ def obtain_token(self) -> None:
payload = self.create_jwt_payload()
data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": jwt.encode(
payload, self.load_private_key(), algorithm="RS256"
),
"assertion": jwt.encode(payload, self.load_private_key(), algorithm="RS256"),
}

logger.debug(f"Obtaining token from {self.auth_endpoint}")
Expand Down Expand Up @@ -208,8 +203,8 @@ def load_private_key(self) -> "PrivateKeyTypes":
private_key_bytes = self.private_key.encode("utf-8")
return serialization.load_pem_private_key(
private_key_bytes,
password=self.private_key_passphrase.encode("utf-8")
if self.private_key_passphrase
else None,
password=(
self.private_key_passphrase.encode("utf-8") if self.private_key_passphrase else None
),
backend=default_backend(),
)
24 changes: 6 additions & 18 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,18 @@ def _send_request(self, request: Request) -> Response:

return self.session.send(prepared_request)

def request(
self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any
) -> Response:
def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> Response:
prepared_request = self._create_request(
path=path,
method=method,
**kwargs,
)
return self._send_request(prepared_request)

def get(
self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Response:
def get(self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response:
return self.request(path, method="GET", params=params, **kwargs)

def post(
self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Response:
def post(self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response:
return self.request(path, method="POST", json=json, **kwargs)

def paginate(
Expand Down Expand Up @@ -224,16 +218,12 @@ def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None:
paginator.update_request(request)

# yield data with context
yield PageData(
data, request=request, response=response, paginator=paginator, auth=auth
)
yield PageData(data, request=request, response=response, paginator=paginator, auth=auth)

if not paginator.has_next_page:
break

def extract_response(
self, response: Response, data_selector: jsonpath.TJsonPath
) -> List[Any]:
def extract_response(self, response: Response, data_selector: jsonpath.TJsonPath) -> List[Any]:
if data_selector:
# we should compile data_selector
data: Any = jsonpath.find_values(data_selector, response.json())
Expand All @@ -257,8 +247,6 @@ def detect_paginator(self, response: Response) -> BasePaginator:
"""
paginator = self.pagination_factory.create_paginator(response)
if paginator is None:
raise ValueError(
f"No suitable paginator found for the response at {response.url}"
)
raise ValueError(f"No suitable paginator found for the response at {response.url}")
logger.info(f"Detected paginator: {paginator.__class__.__name__}")
return paginator
7 changes: 2 additions & 5 deletions dlt/sources/helpers/rest_client/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def find_records(
return next(
list_info[2]
for list_info in lists
if list_info[1] in RECORD_KEY_PATTERNS
and list_info[1] not in NON_RECORD_KEY_PATTERNS
if list_info[1] in RECORD_KEY_PATTERNS and list_info[1] not in NON_RECORD_KEY_PATTERNS
)
except StopIteration:
# return the least nested element
Expand Down Expand Up @@ -142,9 +141,7 @@ def single_page_detector(response: Response) -> Optional[SinglePagePaginator]:


class PaginatorFactory:
def __init__(
self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None
):
def __init__(self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None):
if detectors is None:
detectors = [
header_links_detector,
Expand Down
4 changes: 1 addition & 3 deletions dlt/sources/helpers/rest_client/paginators.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def update_state(self, response: Response) -> None:
total = values[0] if values else None

if total is None:
raise ValueError(
f"Total count not found in response for {self.__class__.__name__}"
)
raise ValueError(f"Total count not found in response for {self.__class__.__name__}")

try:
total = int(total)
Expand Down
16 changes: 4 additions & 12 deletions tests/sources/helpers/rest_client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __init__(self, base_url: str):
self.routes: List[Route] = []
self.base_url = base_url

def _add_route(
self, method: str, pattern: str, func: RequestCallback
) -> RequestCallback:
def _add_route(self, method: str, pattern: str, func: RequestCallback) -> RequestCallback:
compiled_pattern = re.compile(f"{self.base_url}{pattern}")

def serialize_response(request, context):
Expand Down Expand Up @@ -116,9 +114,7 @@ def get_page_number(qs, key="page", default=1):
return int(qs.get(key, [default])[0])


def paginate_response(
request, records, page_size=10, records_key="data", use_absolute_url=True
):
def paginate_response(request, records, page_size=10, records_key="data", use_absolute_url=True):
page_number = get_page_number(request.qs)
total_records = len(records)
total_pages = (total_records + page_size - 1) // page_size
Expand Down Expand Up @@ -173,9 +169,7 @@ def post_detail_404(request, context):

@router.get(r"/posts_under_a_different_key$")
def posts_with_results_key(request, context):
return paginate_response(
request, generate_posts(), records_key="many-results"
)
return paginate_response(request, generate_posts(), records_key="many-results")

@router.get("/protected/posts/basic-auth")
def protected_basic_auth(request, context):
Expand Down Expand Up @@ -231,6 +225,4 @@ def refresh_token(request, context):
def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10):
assert len(pages) == total_pages
for i, page in enumerate(pages):
assert page == [
{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)
]
assert page == [{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)]
4 changes: 1 addition & 3 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def test_bearer_token_auth_success(self, rest_client: RESTClient):
def test_api_key_auth_success(self, rest_client: RESTClient):
response = rest_client.get(
"/protected/posts/api-key",
auth=APIKeyAuth(
name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key")
),
auth=APIKeyAuth(name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key")),
)
assert response.status_code == 200
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}
Expand Down
8 changes: 2 additions & 6 deletions tests/sources/helpers/rest_client/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@
},
{
"response": {
"_embedded": {
"items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]
},
"_embedded": {"items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]},
"_links": {
"first": {"href": "http://api.example.com/items?page=0&size=2"},
"self": {"href": "http://api.example.com/items?page=1&size=2"},
Expand Down Expand Up @@ -315,9 +313,7 @@ def test_find_records(test_case):
@pytest.mark.parametrize("test_case", TEST_RESPONSES)
def test_find_next_page_key(test_case):
response = test_case["response"]
expected = test_case.get("expected").get(
"next_path", None
) # Some cases may not have next_path
expected = test_case.get("expected").get("next_path", None) # Some cases may not have next_path
assert find_next_page_path(response) == expected


Expand Down
4 changes: 1 addition & 3 deletions tests/sources/helpers/rest_client/test_paginators.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def test_update_state(self):

def test_update_state_with_next(self):
paginator = SinglePagePaginator()
response = Mock(
Response, json=lambda: {"next": "http://example.com/next", "results": []}
)
response = Mock(Response, json=lambda: {"next": "http://example.com/next", "results": []})
response.links = {"next": {"url": "http://example.com/next"}}
paginator.update_state(response)
assert paginator.has_next_page is False
Expand Down
Loading