diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 5d7a2f7eb2..99421e2c60 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -30,9 +30,7 @@ else: PrivateKeyTypes = Any -TApiKeyLocation = Literal[ - "header", "cookie", "query", "param" -] # Alias for scheme "in" field +TApiKeyLocation = Literal["header", "cookie", "query", "param"] # Alias for scheme "in" field class AuthConfigBase(AuthBase, CredentialsConfiguration): @@ -102,7 +100,8 @@ def parse_native_representation(self, value: Any) -> None: raise NativeValueError( type(self), value, - f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}", + "HttpBasicAuth username and password must be a tuple of two strings, got" + f" {type(value)}", ) def __call__(self, request: PreparedRequest) -> PreparedRequest: @@ -147,9 +146,7 @@ class OAuthJWTAuth(BearerTokenAuth): default_token_expiration: int = 3600 def __post_init__(self) -> None: - self.scopes = ( - self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) - ) + self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) self.token = None self.token_expiry: Optional[pendulum.DateTime] = None @@ -171,9 +168,7 @@ def obtain_token(self) -> None: payload = self.create_jwt_payload() data = { "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", - "assertion": jwt.encode( - payload, self.load_private_key(), algorithm="RS256" - ), + "assertion": jwt.encode(payload, self.load_private_key(), algorithm="RS256"), } logger.debug(f"Obtaining token from {self.auth_endpoint}") @@ -208,8 +203,8 @@ def load_private_key(self) -> "PrivateKeyTypes": private_key_bytes = self.private_key.encode("utf-8") return serialization.load_pem_private_key( private_key_bytes, - password=self.private_key_passphrase.encode("utf-8") - if self.private_key_passphrase - else None, + password=( + self.private_key_passphrase.encode("utf-8") if self.private_key_passphrase else None + ), backend=default_backend(), ) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 4b5625eebe..027afc7cbb 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -135,9 +135,7 @@ def _send_request(self, request: Request) -> Response: return self.session.send(prepared_request) - def request( - self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any - ) -> Response: + def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> Response: prepared_request = self._create_request( path=path, method=method, @@ -145,14 +143,10 @@ def request( ) return self._send_request(prepared_request) - def get( - self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Response: + def get(self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response: return self.request(path, method="GET", params=params, **kwargs) - def post( - self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Response: + def post(self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response: return self.request(path, method="POST", json=json, **kwargs) def paginate( @@ -224,16 +218,12 @@ def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: paginator.update_request(request) # yield data with context - yield PageData( - data, request=request, response=response, paginator=paginator, auth=auth - ) + yield PageData(data, request=request, response=response, paginator=paginator, auth=auth) if not paginator.has_next_page: break - def extract_response( - self, response: Response, data_selector: jsonpath.TJsonPath - ) -> List[Any]: + def extract_response(self, response: Response, data_selector: jsonpath.TJsonPath) -> List[Any]: if data_selector: # we should compile data_selector data: Any = jsonpath.find_values(data_selector, response.json()) @@ -257,8 +247,6 @@ def detect_paginator(self, response: Response) -> BasePaginator: """ paginator = self.pagination_factory.create_paginator(response) if paginator is None: - raise ValueError( - f"No suitable paginator found for the response at {response.url}" - ) + raise ValueError(f"No suitable paginator found for the response at {response.url}") logger.info(f"Detected paginator: {paginator.__class__.__name__}") return paginator diff --git a/dlt/sources/helpers/rest_client/detector.py b/dlt/sources/helpers/rest_client/detector.py index f3af31bb4d..547162358c 100644 --- a/dlt/sources/helpers/rest_client/detector.py +++ b/dlt/sources/helpers/rest_client/detector.py @@ -80,8 +80,7 @@ def find_records( return next( list_info[2] for list_info in lists - if list_info[1] in RECORD_KEY_PATTERNS - and list_info[1] not in NON_RECORD_KEY_PATTERNS + if list_info[1] in RECORD_KEY_PATTERNS and list_info[1] not in NON_RECORD_KEY_PATTERNS ) except StopIteration: # return the least nested element @@ -142,9 +141,7 @@ def single_page_detector(response: Response) -> Optional[SinglePagePaginator]: class PaginatorFactory: - def __init__( - self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None - ): + def __init__(self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None): if detectors is None: detectors = [ header_links_detector, diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index ce414322a0..11a28c22ea 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -84,9 +84,7 @@ def update_state(self, response: Response) -> None: total = values[0] if values else None if total is None: - raise ValueError( - f"Total count not found in response for {self.__class__.__name__}" - ) + raise ValueError(f"Total count not found in response for {self.__class__.__name__}") try: total = int(total) diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index cffce7cb07..ef63c4526d 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -32,9 +32,7 @@ def __init__(self, base_url: str): self.routes: List[Route] = [] self.base_url = base_url - def _add_route( - self, method: str, pattern: str, func: RequestCallback - ) -> RequestCallback: + def _add_route(self, method: str, pattern: str, func: RequestCallback) -> RequestCallback: compiled_pattern = re.compile(f"{self.base_url}{pattern}") def serialize_response(request, context): @@ -116,9 +114,7 @@ def get_page_number(qs, key="page", default=1): return int(qs.get(key, [default])[0]) -def paginate_response( - request, records, page_size=10, records_key="data", use_absolute_url=True -): +def paginate_response(request, records, page_size=10, records_key="data", use_absolute_url=True): page_number = get_page_number(request.qs) total_records = len(records) total_pages = (total_records + page_size - 1) // page_size @@ -173,9 +169,7 @@ def post_detail_404(request, context): @router.get(r"/posts_under_a_different_key$") def posts_with_results_key(request, context): - return paginate_response( - request, generate_posts(), records_key="many-results" - ) + return paginate_response(request, generate_posts(), records_key="many-results") @router.get("/protected/posts/basic-auth") def protected_basic_auth(request, context): @@ -231,6 +225,4 @@ def refresh_token(request, context): def assert_pagination(pages, expected_start=0, page_size=10, total_pages=10): assert len(pages) == total_pages for i, page in enumerate(pages): - assert page == [ - {"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10) - ] + assert page == [{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)] diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 88653efefe..4311026e2e 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -154,9 +154,7 @@ def test_bearer_token_auth_success(self, rest_client: RESTClient): def test_api_key_auth_success(self, rest_client: RESTClient): response = rest_client.get( "/protected/posts/api-key", - auth=APIKeyAuth( - name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key") - ), + auth=APIKeyAuth(name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key")), ) assert response.status_code == 200 assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} diff --git a/tests/sources/helpers/rest_client/test_detector.py b/tests/sources/helpers/rest_client/test_detector.py index a9af1d36a4..933c9be9cc 100644 --- a/tests/sources/helpers/rest_client/test_detector.py +++ b/tests/sources/helpers/rest_client/test_detector.py @@ -101,9 +101,7 @@ }, { "response": { - "_embedded": { - "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}] - }, + "_embedded": {"items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]}, "_links": { "first": {"href": "http://api.example.com/items?page=0&size=2"}, "self": {"href": "http://api.example.com/items?page=1&size=2"}, @@ -315,9 +313,7 @@ def test_find_records(test_case): @pytest.mark.parametrize("test_case", TEST_RESPONSES) def test_find_next_page_key(test_case): response = test_case["response"] - expected = test_case.get("expected").get( - "next_path", None - ) # Some cases may not have next_path + expected = test_case.get("expected").get("next_path", None) # Some cases may not have next_path assert find_next_page_path(response) == expected diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 4d086f1486..03afb17ca6 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -166,9 +166,7 @@ def test_update_state(self): def test_update_state_with_next(self): paginator = SinglePagePaginator() - response = Mock( - Response, json=lambda: {"next": "http://example.com/next", "results": []} - ) + response = Mock(Response, json=lambda: {"next": "http://example.com/next", "results": []}) response.links = {"next": {"url": "http://example.com/next"}} paginator.update_state(response) assert paginator.has_next_page is False