From 659c054b0c032c0df9cedb4feb85c4bc7dccc51a Mon Sep 17 00:00:00 2001 From: Arne De Peuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:06:13 +0100 Subject: [PATCH 01/13] FEAT: parameterised headers rest_api_source --- dlt/sources/helpers/rest_client/client.py | 12 +- dlt/sources/rest_api/__init__.py | 7 +- dlt/sources/rest_api/config_setup.py | 55 +++++++-- dlt/sources/rest_api/typing.py | 1 + .../configurations/test_resolve_config.py | 18 ++- .../sources/rest_api/test_rest_api_source.py | 108 ++++++++++++++++++ 6 files changed, 188 insertions(+), 13 deletions(-) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 6d04373d8d..c1481a145f 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -99,6 +99,7 @@ def _create_request( path: str, method: HTTPMethod, params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: Optional[AuthBase] = None, hooks: Optional[Hooks] = None, @@ -109,10 +110,14 @@ def _create_request( else: url = join_url(self.base_url, path) + headers = headers or {} + if self.headers: + headers.update(self.headers) + return Request( method=method, url=url, - headers=self.headers, + headers=headers, params=params, json=json, auth=auth or self.auth, @@ -142,6 +147,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> prepared_request = self._create_request( path=path, method=method, + headers=kwargs.pop("headers", None), params=kwargs.pop("params", None), json=kwargs.pop("json", None), auth=kwargs.pop("auth", None), @@ -160,6 +166,7 @@ def paginate( path: str = "", method: HTTPMethodBasic = "GET", params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: Optional[AuthBase] = None, paginator: Optional[BasePaginator] = None, @@ -173,6 +180,7 @@ def paginate( path (str): Endpoint path for the request, relative to `base_url`. method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'. params (Optional[Dict[str, Any]]): URL parameters for the request. + headers (Optional[Dict[str, Any]]): Headers for the request. json (Optional[Dict[str, Any]]): JSON payload for the request. auth (Optional[AuthBase): Authentication configuration for the request. paginator (Optional[BasePaginator]): Paginator instance for handling @@ -210,7 +218,7 @@ def paginate( hooks["response"] = [raise_for_status] request = self._create_request( - path=path, method=method, params=params, json=json, auth=auth, hooks=hooks + path=path, headers=headers, method=method, params=params, json=json, auth=auth, hooks=hooks ) if paginator: diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index ed55f71e10..4b7e735626 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -346,6 +346,7 @@ def paginate_dependent_resource( items: List[Dict[str, Any]], method: HTTPMethodBasic, path: str, + headers: Dict[str, Any], params: Dict[str, Any], paginator: Optional[BasePaginator], data_selector: Optional[jsonpath.TJsonPath], @@ -368,12 +369,13 @@ def paginate_dependent_resource( ) for item in items: - formatted_path, parent_record = process_parent_data_item( - path, item, resolved_params, include_from_parent + formatted_path, formatted_headers, parent_record = process_parent_data_item( + path, headers, item, resolved_params, include_from_parent ) for child_page in client.paginate( method=method, + headers=formatted_headers, path=formatted_path, params=params, paginator=paginator, @@ -392,6 +394,7 @@ def paginate_dependent_resource( )( method=endpoint_config.get("method", "get"), path=endpoint_config.get("path"), + headers=endpoint_config.get("headers"), params=base_params, paginator=paginator, data_selector=endpoint_config.get("data_selector"), diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index d03a4fd59b..d7057aff45 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -330,6 +330,7 @@ def expand_and_index_resources( assert isinstance(endpoint_resource["endpoint"], dict) _setup_single_entity_endpoint(endpoint_resource["endpoint"]) _bind_path_params(endpoint_resource) + _bind_header_params(endpoint_resource) resource_name = endpoint_resource["name"] assert isinstance( @@ -410,15 +411,46 @@ def _bind_path_params(resource: EndpointResource) -> None: # resolved params are bound later path_params[name] = "{" + name + "}" - if len(resolve_params) > 0: - raise NotImplementedError( - f"Resource {resource['name']} defines resolve params {resolve_params} that are not" - f" bound in path {path}. Resolve query params not supported yet." - ) - resource["endpoint"]["path"] = path.format(**path_params) +def _bind_header_params(resource: EndpointResource) -> None: + """Binds params declared in headers to params available in `params`. Pops the + bound params but skips params of type `resolve` and `incremental`, which are bound later. + """ + header_params: Dict[str, Any] = {} + assert isinstance(resource["endpoint"], dict) # type guard + headers = resource["endpoint"].get("headers", {}) + params = resource["endpoint"].get("params", {}) + resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])] + + for header_key, header_value in headers.items(): + if isinstance(header_value, str) and header_value.startswith("{") and header_value.endswith("}"): + param_name = header_value.strip("{}") + if param_name not in params: + raise ValueError( + f"The header '{header_key}' in resource '{resource['name']}' requires a param with " + f"name '{param_name}' but it is not found in {params}." + ) + if param_name in resolve_params: + resolve_params.remove(param_name) + if param_name in params: + if not isinstance(params[param_name], dict): + # Bind the header param and pop it from params + header_params[header_key] = params.pop(param_name) + else: + param_type = params[param_name].get("type") + if param_type != "resolve": + raise ValueError( + f"The header '{header_key}' in resource '{resource['name']}' tries to bind param " + f"'{param_name}' with type '{param_type}'. Headers can only bind 'resolve' type params." + ) + # Resolved params are bound later + header_params[header_key] = "{" + param_name + "}" + + resource["endpoint"]["headers"] = {**headers, **header_params} + + def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: """Tries to guess if the endpoint refers to a single entity and when detected: * if `data_selector` was not specified (or is None), "$" is selected @@ -569,12 +601,14 @@ def remove_field(response: Response, *args, **kwargs) -> Response: return None + def process_parent_data_item( path: str, + headers: Dict[str, Any], item: Dict[str, Any], resolved_params: List[ResolvedParam], include_from_parent: List[str], -) -> Tuple[str, Dict[str, Any]]: +) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: parent_resource_name = resolved_params[0].resolve_config["resource"] param_values = {} @@ -594,6 +628,11 @@ def process_parent_data_item( param_values[resolved_param.param_name] = field_values[0] bound_path = path.format(**param_values) + formatted_headers = {} + for k, v in headers.items(): + key = k if not isinstance(k, str) else k.format(**param_values) + val = v if not isinstance(v, str) else v.format(**param_values) + formatted_headers[key] = val parent_record: Dict[str, Any] = {} if include_from_parent: @@ -607,7 +646,7 @@ def process_parent_data_item( ) parent_record[child_key] = item[parent_key] - return bound_path, parent_record + return bound_path, formatted_headers, parent_record def _merge_resource_endpoints( diff --git a/dlt/sources/rest_api/typing.py b/dlt/sources/rest_api/typing.py index d4cea892a3..455d012b10 100644 --- a/dlt/sources/rest_api/typing.py +++ b/dlt/sources/rest_api/typing.py @@ -259,6 +259,7 @@ class Endpoint(TypedDict, total=False): method: Optional[HTTPMethodBasic] params: Optional[Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]] json: Optional[Dict[str, Any]] + headers: Optional[Dict[str, Any]] paginator: Optional[PaginatorConfig] data_selector: Optional[jsonpath.TJsonPath] response_actions: Optional[List[ResponseAction]] diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index d3d9308df1..ef9fbe6af6 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -1,15 +1,16 @@ import re from copy import deepcopy - import pytest from graphlib import CycleError # type: ignore +import dtl from dlt.sources.rest_api import ( rest_api_resources, rest_api_source, ) from dlt.sources.rest_api.config_setup import ( _bind_path_params, + _bind_header_params, process_parent_data_item, ) from dlt.sources.rest_api.typing import ( @@ -351,3 +352,18 @@ def test_circular_resource_bindingis_invalid() -> None: with pytest.raises(CycleError) as e: rest_api_resources(config) assert e.match(re.escape("'nodes are in a cycle', ['chicken', 'egg', 'chicken']")) + + +def test_bind_header_params() -> None: + resource_with_headers: EndpointResource = { + "name": "test_resource", + "endpoint": { + "path": "test/path", + "headers": {"Authorization": "{token}"}, + "params": { + "token": "test_token", + }, + }, + } + _bind_header_params(resource_with_headers) + assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "test_token" diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index 904bcaf159..4e38413cc4 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -1,5 +1,7 @@ import dlt import pytest +from unittest.mock import patch, MagicMock +from requests import Response, Request from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContainer @@ -149,3 +151,109 @@ def test_dependent_resource(destination_name: str, invocation_type: str) -> None } assert table_counts["pokemon"] == 2 + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("invocation_type", ("deco", "factory")) +@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") +def test_request_headers(mock: MagicMock, destination_name: str, invocation_type: str) -> None: + mock_resp = Response() + mock_resp.status_code = 200 + mock_resp.json = lambda: {"success": "ok"} + mock.return_value = mock_resp + + @dlt.resource + def authenticate(): + yield [{"token": 1}] + + base_url = "https://api.example.com" + config: RESTAPIConfig = { + "client": { + "base_url": base_url, + "headers": {"foo": "bar"} + }, + "resources": [ + { + "name": "chicken", + "endpoint": { + "path": "chicken", + "headers": {"token": "{token}", "num": "2"}, + "params": { + "token": { + "type": "resolve", + "field": "token", + "resource": "authenticate", + }, + }, + }, + }, + authenticate(), + ], + } + + if invocation_type == "deco": + data = rest_api(**config) + else: + data = rest_api_source(config) + pipeline = _make_pipeline(destination_name) + pipeline.run(data) + + mock.assert_called() + args, kwargs = mock.call_args + request_param: Request = args[0] + + assert request_param.url == f"{base_url}/chicken" + assert request_param.headers == {"foo": "bar", "token": "1", "num": "2"} + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("invocation_type", ("deco", "factory")) +@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") +def test_request_headers_dynamic_key(mock: MagicMock, destination_name: str, invocation_type: str) -> None: + mock_resp = Response() + mock_resp.status_code = 200 + mock_resp.json = lambda: {"success": "ok"} + mock.return_value = mock_resp + + @dlt.resource + def authenticate(): + yield [{"token": 1}] + + base_url = "https://api.example.com" + config: RESTAPIConfig = { + "client": { + "base_url": base_url, + "headers": {"foo": "bar"} + }, + "resources": [ + { + "name": "chicken", + "endpoint": { + "path": "chicken", + "headers": {"{token}": "{token}", "num": "2"}, + "params": { + "token": { + "type": "resolve", + "field": "token", + "resource": "authenticate", + }, + }, + }, + }, + authenticate(), + ], + } + + if invocation_type == "deco": + data = rest_api(**config) + else: + data = rest_api_source(config) + pipeline = _make_pipeline(destination_name) + pipeline.run(data) + + mock.assert_called() + args, kwargs = mock.call_args + request_param: Request = args[0] + + assert request_param.url == f"{base_url}/chicken" + assert request_param.headers == {"foo": "bar", "1": "1", "num": "2"} From 8ed6812490a22ecd1b8763c51adf93cf8d1544a4 Mon Sep 17 00:00:00 2001 From: Arne De Peuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:16:54 +0100 Subject: [PATCH 02/13] FEAT: parameterised headers rest_api_source - nested header support --- dlt/sources/rest_api/config_setup.py | 14 +++-- .../sources/rest_api/test_rest_api_source.py | 53 +++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index d7057aff45..e4a68f1cfa 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -600,6 +600,14 @@ def remove_field(response: Response, *args, **kwargs) -> Response: return {"response": hooks + fallback_hooks} return None +def generic_format(input: Union[dict, str, list], param_values) -> Union[dict, str, list]: + if isinstance(input, dict): + return {generic_format(key, param_values): generic_format(val, param_values) for key, val in input.items()} + if isinstance(input, list): + return [generic_format(item, param_values) for item in input] + if isinstance(input, str): + return input.format(**param_values) + raise NotImplementedError(f"Param resolution formatting not supported for type: {type(input)}") def process_parent_data_item( @@ -628,11 +636,7 @@ def process_parent_data_item( param_values[resolved_param.param_name] = field_values[0] bound_path = path.format(**param_values) - formatted_headers = {} - for k, v in headers.items(): - key = k if not isinstance(k, str) else k.format(**param_values) - val = v if not isinstance(v, str) else v.format(**param_values) - formatted_headers[key] = val + formatted_headers = generic_format(headers, param_values) parent_record: Dict[str, Any] = {} if include_from_parent: diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index 4e38413cc4..c4721f303c 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -257,3 +257,56 @@ def authenticate(): assert request_param.url == f"{base_url}/chicken" assert request_param.headers == {"foo": "bar", "1": "1", "num": "2"} + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("invocation_type", ("deco", "factory")) +@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") +def test_request_headers_nested(mock: MagicMock, destination_name: str, invocation_type: str) -> None: + mock_resp = Response() + mock_resp.status_code = 200 + mock_resp.json = lambda: {"success": "ok"} + mock.return_value = mock_resp + + @dlt.resource + def authenticate(): + yield [{"token": 1}] + + base_url = "https://api.example.com" + config: RESTAPIConfig = { + "client": { + "base_url": base_url, + "headers": {"foo": "bar"} + }, + "resources": [ + { + "name": "chicken", + "endpoint": { + "path": "chicken", + "headers": {"{token}": "{token}", "num": "2", "nested": {"nested": "{token}", "{token}": "other"}}, + "params": { + "token": { + "type": "resolve", + "field": "token", + "resource": "authenticate", + }, + }, + }, + }, + authenticate(), + ], + } + + if invocation_type == "deco": + data = rest_api(**config) + else: + data = rest_api_source(config) + pipeline = _make_pipeline(destination_name) + pipeline.run(data) + + mock.assert_called() + args, kwargs = mock.call_args + request_param: Request = args[0] + + assert request_param.url == f"{base_url}/chicken" + assert request_param.headers == {"foo": "bar", "1": "1", "num": "2", "nested": {"nested": "1", "1": "other"}} From 64b3f79422d1797f9cc23e9ab3956fd8281a9952 Mon Sep 17 00:00:00 2001 From: Arne De Peuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:28:06 +0100 Subject: [PATCH 03/13] FEAT: parameterised headers rest_api_source - nested header support --- dlt/sources/rest_api/config_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index e4a68f1cfa..988382910b 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -607,7 +607,7 @@ def generic_format(input: Union[dict, str, list], param_values) -> Union[dict, s return [generic_format(item, param_values) for item in input] if isinstance(input, str): return input.format(**param_values) - raise NotImplementedError(f"Param resolution formatting not supported for type: {type(input)}") + return str(input) def process_parent_data_item( From b033f5cd1ab3ba10e9320d0a88512e929af7787b Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:03:15 +0100 Subject: [PATCH 04/13] FIX: tests + linting fixed --- dlt/sources/helpers/rest_client/client.py | 8 +++- dlt/sources/rest_api/__init__.py | 2 +- dlt/sources/rest_api/config_setup.py | 45 ++++++++++++------- .../deploy-with-modal-snippets.py | 6 +-- .../configurations/test_resolve_config.py | 22 +++++---- .../sources/rest_api/test_rest_api_source.py | 42 +++++++++-------- 6 files changed, 71 insertions(+), 54 deletions(-) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index c1481a145f..44599d8d63 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -218,7 +218,13 @@ def paginate( hooks["response"] = [raise_for_status] request = self._create_request( - path=path, headers=headers, method=method, params=params, json=json, auth=auth, hooks=hooks + path=path, + headers=headers, + method=method, + params=params, + json=json, + auth=auth, + hooks=hooks, ) if paginator: diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index 4b7e735626..53803dcabe 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -370,7 +370,7 @@ def paginate_dependent_resource( for item in items: formatted_path, formatted_headers, parent_record = process_parent_data_item( - path, headers, item, resolved_params, include_from_parent + path, item, resolved_params, include_from_parent, headers ) for child_page in client.paginate( diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 988382910b..d3ce9e01b9 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -425,12 +425,16 @@ def _bind_header_params(resource: EndpointResource) -> None: resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])] for header_key, header_value in headers.items(): - if isinstance(header_value, str) and header_value.startswith("{") and header_value.endswith("}"): + if ( + isinstance(header_value, str) + and header_value.startswith("{") + and header_value.endswith("}") + ): param_name = header_value.strip("{}") if param_name not in params: raise ValueError( - f"The header '{header_key}' in resource '{resource['name']}' requires a param with " - f"name '{param_name}' but it is not found in {params}." + f"The header '{header_key}' in resource '{resource['name']}' requires a param" + f" with name '{param_name}' but it is not found in {params}." ) if param_name in resolve_params: resolve_params.remove(param_name) @@ -442,8 +446,9 @@ def _bind_header_params(resource: EndpointResource) -> None: param_type = params[param_name].get("type") if param_type != "resolve": raise ValueError( - f"The header '{header_key}' in resource '{resource['name']}' tries to bind param " - f"'{param_name}' with type '{param_type}'. Headers can only bind 'resolve' type params." + f"The header '{header_key}' in resource '{resource['name']}' tries to" + f" bind param '{param_name}' with type '{param_type}'. Headers can only" + " bind 'resolve' type params." ) # Resolved params are bound later header_params[header_key] = "{" + param_name + "}" @@ -600,22 +605,28 @@ def remove_field(response: Response, *args, **kwargs) -> Response: return {"response": hooks + fallback_hooks} return None -def generic_format(input: Union[dict, str, list], param_values) -> Union[dict, str, list]: - if isinstance(input, dict): - return {generic_format(key, param_values): generic_format(val, param_values) for key, val in input.items()} - if isinstance(input, list): - return [generic_format(item, param_values) for item in input] - if isinstance(input, str): - return input.format(**param_values) - return str(input) + +def generic_format( + to_format: Union[Dict[str, Any], str, List[Any]], param_values: Dict[str, Any] +) -> Union[Dict[str, Any], str, List[Any]]: + if isinstance(to_format, dict): + return { + key.format(**param_values): generic_format(val, param_values) + for key, val in to_format.items() + } + if isinstance(to_format, list): + return [generic_format(item, param_values) for item in to_format] + if isinstance(to_format, str): + return to_format.format(**param_values) + return str(to_format) def process_parent_data_item( path: str, - headers: Dict[str, Any], item: Dict[str, Any], resolved_params: List[ResolvedParam], include_from_parent: List[str], + headers: Optional[Dict[str, Any]] = None, ) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: parent_resource_name = resolved_params[0].resolve_config["resource"] @@ -636,7 +647,6 @@ def process_parent_data_item( param_values[resolved_param.param_name] = field_values[0] bound_path = path.format(**param_values) - formatted_headers = generic_format(headers, param_values) parent_record: Dict[str, Any] = {} if include_from_parent: @@ -650,7 +660,10 @@ def process_parent_data_item( ) parent_record[child_key] = item[parent_key] - return bound_path, formatted_headers, parent_record + if headers is not None: + formatted_headers = generic_format(headers, param_values) + return bound_path, formatted_headers, parent_record # type: ignore[return-value] + return bound_path, {}, parent_record def _merge_resource_endpoints( diff --git a/docs/website/docs/walkthroughs/deploy-a-pipeline/deploy_snippets/deploy-with-modal-snippets.py b/docs/website/docs/walkthroughs/deploy-a-pipeline/deploy_snippets/deploy-with-modal-snippets.py index 5c50f06a04..8a488159c1 100644 --- a/docs/website/docs/walkthroughs/deploy-a-pipeline/deploy_snippets/deploy-with-modal-snippets.py +++ b/docs/website/docs/walkthroughs/deploy-a-pipeline/deploy_snippets/deploy-with-modal-snippets.py @@ -21,11 +21,7 @@ def test_modal_snippet() -> None: # @@@DLT_SNIPPET_END modal_image # @@@DLT_SNIPPET_START modal_function - @app.function( - volumes={"/data/": vol}, - schedule=modal.Period(days=1), - serialized=True - ) + @app.function(volumes={"/data/": vol}, schedule=modal.Period(days=1), serialized=True) def load_tables() -> None: import dlt import os diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index ef9fbe6af6..e9003874e5 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -3,7 +3,6 @@ import pytest from graphlib import CycleError # type: ignore -import dtl from dlt.sources.rest_api import ( rest_api_resources, rest_api_source, @@ -84,26 +83,25 @@ def test_bind_path_param() -> None: # resolved param will remain unbounded and tp_6 = deepcopy(three_params) tp_6["endpoint"]["path"] = "{org}/{repo}/issues/1234/comments" # type: ignore[index] - with pytest.raises(NotImplementedError): - _bind_path_params(tp_6) + _bind_path_params(tp_6) # Does not raise because headers are now supported... and so are query params because they reside in the URL def test_process_parent_data_item() -> None: resolve_params = [ ResolvedParam("id", {"field": "obj_id", "resource": "issues", "type": "resolve"}) ] - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, None ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" assert parent_record == {} - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, ["obj_id"] ) assert parent_record == {"_issues_obj_id": 12345} - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345, "obj_node": "node_1"}, resolve_params, @@ -118,21 +116,21 @@ def test_process_parent_data_item() -> None: ) ] item = {"some_results": {"obj_id": 12345}} - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", item, resolve_param_nested, None ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" # param path not found with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"_id": 12345}, resolve_params, None ) assert "Transformer expects a field 'obj_id'" in str(val_ex.value) # included path not found with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345, "obj_node": "node_1"}, resolve_params, @@ -146,7 +144,7 @@ def test_process_parent_data_item() -> None: ResolvedParam("id", {"field": "id", "resource": "comments", "type": "resolve"}), ] - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{issue_id}/comments/{id}", {"issue": 12345, "id": 56789}, multi_resolve_params, @@ -157,7 +155,7 @@ def test_process_parent_data_item() -> None: # param path not found with multiple parameters with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{issue_id}/comments/{id}", {"_issue": 12345, "id": 56789}, multi_resolve_params, @@ -366,4 +364,4 @@ def test_bind_header_params() -> None: }, } _bind_header_params(resource_with_headers) - assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "test_token" + assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "test_token" # type: ignore[index] diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index c4721f303c..c0bbf75dd9 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -159,7 +159,7 @@ def test_dependent_resource(destination_name: str, invocation_type: str) -> None def test_request_headers(mock: MagicMock, destination_name: str, invocation_type: str) -> None: mock_resp = Response() mock_resp.status_code = 200 - mock_resp.json = lambda: {"success": "ok"} + mock_resp.json = lambda: {"success": "ok"} # type: ignore mock.return_value = mock_resp @dlt.resource @@ -168,10 +168,7 @@ def authenticate(): base_url = "https://api.example.com" config: RESTAPIConfig = { - "client": { - "base_url": base_url, - "headers": {"foo": "bar"} - }, + "client": {"base_url": base_url, "headers": {"foo": "bar"}}, "resources": [ { "name": "chicken", @@ -209,10 +206,12 @@ def authenticate(): @pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) @pytest.mark.parametrize("invocation_type", ("deco", "factory")) @patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") -def test_request_headers_dynamic_key(mock: MagicMock, destination_name: str, invocation_type: str) -> None: +def test_request_headers_dynamic_key( + mock: MagicMock, destination_name: str, invocation_type: str +) -> None: mock_resp = Response() mock_resp.status_code = 200 - mock_resp.json = lambda: {"success": "ok"} + mock_resp.json = lambda: {"success": "ok"} # type: ignore mock.return_value = mock_resp @dlt.resource @@ -221,10 +220,7 @@ def authenticate(): base_url = "https://api.example.com" config: RESTAPIConfig = { - "client": { - "base_url": base_url, - "headers": {"foo": "bar"} - }, + "client": {"base_url": base_url, "headers": {"foo": "bar"}}, "resources": [ { "name": "chicken", @@ -262,10 +258,12 @@ def authenticate(): @pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) @pytest.mark.parametrize("invocation_type", ("deco", "factory")) @patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") -def test_request_headers_nested(mock: MagicMock, destination_name: str, invocation_type: str) -> None: +def test_request_headers_nested( + mock: MagicMock, destination_name: str, invocation_type: str +) -> None: mock_resp = Response() mock_resp.status_code = 200 - mock_resp.json = lambda: {"success": "ok"} + mock_resp.json = lambda: {"success": "ok"} # type: ignore mock.return_value = mock_resp @dlt.resource @@ -274,16 +272,17 @@ def authenticate(): base_url = "https://api.example.com" config: RESTAPIConfig = { - "client": { - "base_url": base_url, - "headers": {"foo": "bar"} - }, + "client": {"base_url": base_url, "headers": {"foo": "bar"}}, "resources": [ { "name": "chicken", "endpoint": { "path": "chicken", - "headers": {"{token}": "{token}", "num": "2", "nested": {"nested": "{token}", "{token}": "other"}}, + "headers": { + "{token}": "{token}", + "num": "2", + "nested": {"nested": "{token}", "{token}": "other"}, + }, "params": { "token": { "type": "resolve", @@ -309,4 +308,9 @@ def authenticate(): request_param: Request = args[0] assert request_param.url == f"{base_url}/chicken" - assert request_param.headers == {"foo": "bar", "1": "1", "num": "2", "nested": {"nested": "1", "1": "other"}} + assert request_param.headers == { + "foo": "bar", + "1": "1", + "num": "2", + "nested": {"nested": "1", "1": "other"}, + } From cf6740586ebd01a96a26ecaab02aad6c91ad4e49 Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:08:17 +0100 Subject: [PATCH 05/13] FIX: tests + linting fixed --- tests/sources/rest_api/configurations/test_resolve_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index e9003874e5..b75a86bdf0 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -83,7 +83,9 @@ def test_bind_path_param() -> None: # resolved param will remain unbounded and tp_6 = deepcopy(three_params) tp_6["endpoint"]["path"] = "{org}/{repo}/issues/1234/comments" # type: ignore[index] - _bind_path_params(tp_6) # Does not raise because headers are now supported... and so are query params because they reside in the URL + _bind_path_params( + tp_6 + ) # Does not raise because headers are now supported... and so are query params because they reside in the URL def test_process_parent_data_item() -> None: From c82b423993ea99f1de7410084acdfbcdce7b859a Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:58:44 +0100 Subject: [PATCH 06/13] FEAT: test_process_parent_data_item_headers --- .../configurations/test_resolve_config.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index b75a86bdf0..493e324bd3 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -166,6 +166,85 @@ def test_process_parent_data_item() -> None: assert "Transformer expects a field 'issue'" in str(val_ex.value) +def test_process_parent_data_item_headers() -> None: + resolve_params = [ + ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}) + ] + _, resolved_headers, parent_record = process_parent_data_item( + "chicken", + {"token": 12345}, + resolve_params, + None, + {"Authorization": "{token}"}, + ) + assert resolved_headers == {"Authorization": "12345"} + + # multiple params + resolve_params = [ + ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}), + ResolvedParam("num", {"field": "num", "resource": "authenticate", "type": "resolve"}), + ] + _, resolved_headers, parent_record = process_parent_data_item( + "chicken", + {"token": 12345, "num": 2}, + resolve_params, + None, + {"Authorization": "{token}", "num": "{num}"}, + ) + assert resolved_headers == {"Authorization": "12345", "num": "2"} + + # nested params + resolve_params = [ + ResolvedParam("token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"}), + ResolvedParam("num", {"field": "auth.num", "resource": "authenticate", "type": "resolve"}), + ] + _, resolved_headers, parent_record = process_parent_data_item( + "chicken", + {"auth": {"token": 12345, "num": 2}}, + resolve_params, + None, + {"Authorization": "{token}", "num": "{num}"}, + ) + assert resolved_headers == {"Authorization": "12345", "num": "2"} + + # nested header dict + resolve_params = [ + ResolvedParam("token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"}) + ] + _, resolved_headers, parent_record = process_parent_data_item( + "chicken", + {"auth": {"token": 12345}}, + resolve_params, + None, + {"Authorization": {"Bearer": "{token}"}}, + ) + assert resolved_headers == {"Authorization": {"Bearer": "12345"}} + + # nested header list + resolve_params = [ + ResolvedParam("token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"}) + ] + _, resolved_headers, parent_record = process_parent_data_item( + "chicken", + {"auth": {"token": 12345}}, + resolve_params, + None, + {"Authorization": ["Bearer", "{token}"]}, + ) + assert resolved_headers == {"Authorization": ["Bearer", "12345"]} + + # param path not found + with pytest.raises(ValueError) as val_ex: + _, _, parent_record = process_parent_data_item( + "chicken", + {"_token": 12345}, + resolve_params, + None, + {"Authorization": "{token}"}, + ) + assert "Transformer expects a field 'auth.token'" in str(val_ex.value) + + def test_two_resources_can_depend_on_one_parent_resource() -> None: user_id = { "user_id": { From ee0861fe15c65cf0015d8d85dde046bcb4571cbd Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 21 Nov 2024 21:02:27 +0100 Subject: [PATCH 07/13] CHORE: reformat test_process_parent_data_item_headers --- .../rest_api/configurations/test_resolve_config.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index 493e324bd3..51584ca8d1 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -195,7 +195,9 @@ def test_process_parent_data_item_headers() -> None: # nested params resolve_params = [ - ResolvedParam("token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"}), + ResolvedParam( + "token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"} + ), ResolvedParam("num", {"field": "auth.num", "resource": "authenticate", "type": "resolve"}), ] _, resolved_headers, parent_record = process_parent_data_item( @@ -209,7 +211,9 @@ def test_process_parent_data_item_headers() -> None: # nested header dict resolve_params = [ - ResolvedParam("token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"}) + ResolvedParam( + "token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"} + ) ] _, resolved_headers, parent_record = process_parent_data_item( "chicken", @@ -222,7 +226,9 @@ def test_process_parent_data_item_headers() -> None: # nested header list resolve_params = [ - ResolvedParam("token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"}) + ResolvedParam( + "token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"} + ) ] _, resolved_headers, parent_record = process_parent_data_item( "chicken", From 261c89dae63c539761b2e437701082a5c20623f0 Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Fri, 22 Nov 2024 02:08:58 +0100 Subject: [PATCH 08/13] FEAT: _bind_header_params supports: nested header structures, key replacement --- dlt/sources/rest_api/config_setup.py | 46 +++++-------------- .../configurations/test_resolve_config.py | 20 ++++++++ 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index d3ce9e01b9..fc36cddeed 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -418,42 +418,20 @@ def _bind_header_params(resource: EndpointResource) -> None: """Binds params declared in headers to params available in `params`. Pops the bound params but skips params of type `resolve` and `incremental`, which are bound later. """ - header_params: Dict[str, Any] = {} assert isinstance(resource["endpoint"], dict) # type guard - headers = resource["endpoint"].get("headers", {}) params = resource["endpoint"].get("params", {}) - resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])] - - for header_key, header_value in headers.items(): - if ( - isinstance(header_value, str) - and header_value.startswith("{") - and header_value.endswith("}") - ): - param_name = header_value.strip("{}") - if param_name not in params: - raise ValueError( - f"The header '{header_key}' in resource '{resource['name']}' requires a param" - f" with name '{param_name}' but it is not found in {params}." - ) - if param_name in resolve_params: - resolve_params.remove(param_name) - if param_name in params: - if not isinstance(params[param_name], dict): - # Bind the header param and pop it from params - header_params[header_key] = params.pop(param_name) - else: - param_type = params[param_name].get("type") - if param_type != "resolve": - raise ValueError( - f"The header '{header_key}' in resource '{resource['name']}' tries to" - f" bind param '{param_name}' with type '{param_type}'. Headers can only" - " bind 'resolve' type params." - ) - # Resolved params are bound later - header_params[header_key] = "{" + param_name + "}" + bind_params = {} + # copy must be made because size of dict changes during iteration + params_iter = params.copy() + for k, v in params_iter.items(): + if not isinstance(v, dict): + bind_params[k] = params.pop(k) + else: + # resolved params are bound later + bind_params[k] = "{" + k + "}" - resource["endpoint"]["headers"] = {**headers, **header_params} + headers = resource["endpoint"].get("headers", {}) + resource["endpoint"]["headers"] = generic_format(headers, bind_params) # type: ignore[typeddict-item] def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: @@ -611,7 +589,7 @@ def generic_format( ) -> Union[Dict[str, Any], str, List[Any]]: if isinstance(to_format, dict): return { - key.format(**param_values): generic_format(val, param_values) + generic_format(key, param_values): generic_format(val, param_values) # type: ignore[misc] for key, val in to_format.items() } if isinstance(to_format, list): diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index 51584ca8d1..354d98ee23 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -452,3 +452,23 @@ def test_bind_header_params() -> None: } _bind_header_params(resource_with_headers) assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "test_token" # type: ignore[index] + assert len(resource_with_headers["endpoint"]["params"]) == 0 # type: ignore[index] + + +def test_bind_header_params_nested() -> None: + resource_with_headers: EndpointResource = { + "name": "test_resource", + "endpoint": { + "path": "test/path", + "headers": {"{token}": "{token}", "deeper": {"{token}": ["{token}"]}}, + "params": { + "token": "test_token", + }, + }, + } + _bind_header_params(resource_with_headers) + assert resource_with_headers["endpoint"]["headers"] == { # type: ignore[index] + "test_token": "test_token", + "deeper": {"test_token": ["test_token"]}, + } + assert len(resource_with_headers["endpoint"]["params"]) == 0 # type: ignore[index] From ce4034253076fcdc7547278d6214095a96d163b9 Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:46:38 +0100 Subject: [PATCH 09/13] FIX: also pass headers if no resolved params + log headers --- dlt/sources/helpers/rest_client/client.py | 1 + dlt/sources/rest_api/__init__.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 44599d8d63..3065728a81 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -128,6 +128,7 @@ def _send_request(self, request: Request, **kwargs: Any) -> Response: logger.info( f"Making {request.method.upper()} request to {request.url}" f" with params={request.params}, json={request.json}" + f" with headers={request.headers}" ) prepared_request = self.session.prepare_request(request) diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index 53803dcabe..78c16dba06 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -291,6 +291,7 @@ def process( def paginate_resource( method: HTTPMethodBasic, path: str, + headers: Dict[str, Any], params: Dict[str, Any], json: Optional[Dict[str, Any]], paginator: Optional[BasePaginator], @@ -313,6 +314,7 @@ def paginate_resource( yield from client.paginate( method=method, + headers=headers, path=path, params=params, json=json, @@ -327,6 +329,7 @@ def paginate_resource( )( method=endpoint_config.get("method", "get"), path=endpoint_config.get("path"), + headers=endpoint_config.get("headers"), params=request_params, json=request_json, paginator=paginator, From a070b23f11221b874d7dc6d85a485eeed51c3f4a Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Mon, 25 Nov 2024 21:22:56 +0100 Subject: [PATCH 10/13] FIX: rm nested headers functionality --- dlt/sources/rest_api/config_setup.py | 32 +++-- .../configurations/test_resolve_config.py | 64 +--------- .../sources/rest_api/test_rest_api_source.py | 117 +----------------- 3 files changed, 18 insertions(+), 195 deletions(-) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index fc36cddeed..551712b2fc 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -431,7 +431,13 @@ def _bind_header_params(resource: EndpointResource) -> None: bind_params[k] = "{" + k + "}" headers = resource["endpoint"].get("headers", {}) - resource["endpoint"]["headers"] = generic_format(headers, bind_params) # type: ignore[typeddict-item] + formatted_headers = { + k.format(**bind_params) if isinstance(k, str) else str(k): ( + v.format(**bind_params) if isinstance(v, str) else str(v) + ) + for k, v in headers.items() + } + resource["endpoint"]["headers"] = formatted_headers def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: @@ -584,21 +590,6 @@ def remove_field(response: Response, *args, **kwargs) -> Response: return None -def generic_format( - to_format: Union[Dict[str, Any], str, List[Any]], param_values: Dict[str, Any] -) -> Union[Dict[str, Any], str, List[Any]]: - if isinstance(to_format, dict): - return { - generic_format(key, param_values): generic_format(val, param_values) # type: ignore[misc] - for key, val in to_format.items() - } - if isinstance(to_format, list): - return [generic_format(item, param_values) for item in to_format] - if isinstance(to_format, str): - return to_format.format(**param_values) - return str(to_format) - - def process_parent_data_item( path: str, item: Dict[str, Any], @@ -639,8 +630,13 @@ def process_parent_data_item( parent_record[child_key] = item[parent_key] if headers is not None: - formatted_headers = generic_format(headers, param_values) - return bound_path, formatted_headers, parent_record # type: ignore[return-value] + formatted_headers = { + k.format(**param_values) if isinstance(k, str) else str(k): ( + v.format(**param_values) if isinstance(v, str) else str(v) + ) + for k, v in headers.items() + } + return bound_path, formatted_headers, parent_record return bound_path, {}, parent_record diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index 354d98ee23..f498ee36f6 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -193,52 +193,11 @@ def test_process_parent_data_item_headers() -> None: ) assert resolved_headers == {"Authorization": "12345", "num": "2"} - # nested params - resolve_params = [ - ResolvedParam( - "token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"} - ), - ResolvedParam("num", {"field": "auth.num", "resource": "authenticate", "type": "resolve"}), - ] - _, resolved_headers, parent_record = process_parent_data_item( - "chicken", - {"auth": {"token": 12345, "num": 2}}, - resolve_params, - None, - {"Authorization": "{token}", "num": "{num}"}, - ) - assert resolved_headers == {"Authorization": "12345", "num": "2"} - - # nested header dict resolve_params = [ ResolvedParam( "token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"} ) ] - _, resolved_headers, parent_record = process_parent_data_item( - "chicken", - {"auth": {"token": 12345}}, - resolve_params, - None, - {"Authorization": {"Bearer": "{token}"}}, - ) - assert resolved_headers == {"Authorization": {"Bearer": "12345"}} - - # nested header list - resolve_params = [ - ResolvedParam( - "token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"} - ) - ] - _, resolved_headers, parent_record = process_parent_data_item( - "chicken", - {"auth": {"token": 12345}}, - resolve_params, - None, - {"Authorization": ["Bearer", "{token}"]}, - ) - assert resolved_headers == {"Authorization": ["Bearer", "12345"]} - # param path not found with pytest.raises(ValueError) as val_ex: _, _, parent_record = process_parent_data_item( @@ -444,31 +403,12 @@ def test_bind_header_params() -> None: "name": "test_resource", "endpoint": { "path": "test/path", - "headers": {"Authorization": "{token}"}, + "headers": {"Authorization": "Bearer {token}"}, "params": { "token": "test_token", }, }, } _bind_header_params(resource_with_headers) - assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "test_token" # type: ignore[index] - assert len(resource_with_headers["endpoint"]["params"]) == 0 # type: ignore[index] - - -def test_bind_header_params_nested() -> None: - resource_with_headers: EndpointResource = { - "name": "test_resource", - "endpoint": { - "path": "test/path", - "headers": {"{token}": "{token}", "deeper": {"{token}": ["{token}"]}}, - "params": { - "token": "test_token", - }, - }, - } - _bind_header_params(resource_with_headers) - assert resource_with_headers["endpoint"]["headers"] == { # type: ignore[index] - "test_token": "test_token", - "deeper": {"test_token": ["test_token"]}, - } + assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "Bearer test_token" # type: ignore[index] assert len(resource_with_headers["endpoint"]["params"]) == 0 # type: ignore[index] diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index c0bbf75dd9..f376b197f5 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -174,7 +174,7 @@ def authenticate(): "name": "chicken", "endpoint": { "path": "chicken", - "headers": {"token": "{token}", "num": "2"}, + "headers": {"token": "Bearer {token}", "num": "2"}, "params": { "token": { "type": "resolve", @@ -200,117 +200,4 @@ def authenticate(): request_param: Request = args[0] assert request_param.url == f"{base_url}/chicken" - assert request_param.headers == {"foo": "bar", "token": "1", "num": "2"} - - -@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) -@pytest.mark.parametrize("invocation_type", ("deco", "factory")) -@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") -def test_request_headers_dynamic_key( - mock: MagicMock, destination_name: str, invocation_type: str -) -> None: - mock_resp = Response() - mock_resp.status_code = 200 - mock_resp.json = lambda: {"success": "ok"} # type: ignore - mock.return_value = mock_resp - - @dlt.resource - def authenticate(): - yield [{"token": 1}] - - base_url = "https://api.example.com" - config: RESTAPIConfig = { - "client": {"base_url": base_url, "headers": {"foo": "bar"}}, - "resources": [ - { - "name": "chicken", - "endpoint": { - "path": "chicken", - "headers": {"{token}": "{token}", "num": "2"}, - "params": { - "token": { - "type": "resolve", - "field": "token", - "resource": "authenticate", - }, - }, - }, - }, - authenticate(), - ], - } - - if invocation_type == "deco": - data = rest_api(**config) - else: - data = rest_api_source(config) - pipeline = _make_pipeline(destination_name) - pipeline.run(data) - - mock.assert_called() - args, kwargs = mock.call_args - request_param: Request = args[0] - - assert request_param.url == f"{base_url}/chicken" - assert request_param.headers == {"foo": "bar", "1": "1", "num": "2"} - - -@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) -@pytest.mark.parametrize("invocation_type", ("deco", "factory")) -@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") -def test_request_headers_nested( - mock: MagicMock, destination_name: str, invocation_type: str -) -> None: - mock_resp = Response() - mock_resp.status_code = 200 - mock_resp.json = lambda: {"success": "ok"} # type: ignore - mock.return_value = mock_resp - - @dlt.resource - def authenticate(): - yield [{"token": 1}] - - base_url = "https://api.example.com" - config: RESTAPIConfig = { - "client": {"base_url": base_url, "headers": {"foo": "bar"}}, - "resources": [ - { - "name": "chicken", - "endpoint": { - "path": "chicken", - "headers": { - "{token}": "{token}", - "num": "2", - "nested": {"nested": "{token}", "{token}": "other"}, - }, - "params": { - "token": { - "type": "resolve", - "field": "token", - "resource": "authenticate", - }, - }, - }, - }, - authenticate(), - ], - } - - if invocation_type == "deco": - data = rest_api(**config) - else: - data = rest_api_source(config) - pipeline = _make_pipeline(destination_name) - pipeline.run(data) - - mock.assert_called() - args, kwargs = mock.call_args - request_param: Request = args[0] - - assert request_param.url == f"{base_url}/chicken" - assert request_param.headers == { - "foo": "bar", - "1": "1", - "num": "2", - "nested": {"nested": "1", "1": "other"}, - } + assert request_param.headers == {"foo": "bar", "token": "Bearer 1", "num": "2"} From 5396496b134c34b4fb2bd4bc48ae2b2b7feb1377 Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Wed, 27 Nov 2024 22:07:50 +0100 Subject: [PATCH 11/13] FEAT: location flag for improved clarity --- dlt/sources/rest_api/config_setup.py | 121 +++++++++++------- dlt/sources/rest_api/typing.py | 8 +- .../configurations/test_resolve_config.py | 48 +++++-- .../rest_api/test_process_parent_data_item.py | 24 ++++ .../sources/rest_api/test_rest_api_source.py | 1 + 5 files changed, 150 insertions(+), 52 deletions(-) create mode 100644 tests/sources/rest_api/test_process_parent_data_item.py diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 551712b2fc..fd10847d3e 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -64,6 +64,7 @@ ResponseActionDict, Endpoint, EndpointResource, + ResolveParamLocation, ) @@ -376,40 +377,55 @@ def _make_endpoint_resource( return _merge_resource_endpoints(default_config, resource) -def _bind_path_params(resource: EndpointResource) -> None: - """Binds params declared in path to params available in `params`. Pops the - bound params but. Params of type `resolve` and `incremental` are skipped - and bound later. - """ - path_params: Dict[str, Any] = {} - assert isinstance(resource["endpoint"], dict) # type guard - resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])] - path = resource["endpoint"]["path"] - for format_ in string.Formatter().parse(path): +def _bind_string( + target: str, + resource: EndpointResource, + resolve_params: List[str], + location_params: Dict[str, Any], + location: str, +) -> None: + for format_ in string.Formatter().parse(target): name = format_[1] if name: - params = resource["endpoint"].get("params", {}) - if name not in params and name not in path_params: + params = resource["endpoint"].get("params", {}) # type: ignore[union-attr] + if name not in params and name not in location_params: raise ValueError( - f"The path {path} defined in resource {resource['name']} requires param with" - f" name {name} but it is not found in {params}" + f"The {location} {target} defined in resource {resource['name']} requires param" + f" with name {name} but it is not found in {params}" ) if name in resolve_params: resolve_params.remove(name) if name in params: if not isinstance(params[name], dict): # bind resolved param and pop it from endpoint - path_params[name] = params.pop(name) + location_params[name] = params.pop(name) else: param_type = params[name].get("type") if param_type != "resolve": raise ValueError( - f"The path {path} defined in resource {resource['name']} tries to bind" - f" param {name} with type {param_type}. Paths can only bind 'resolve'" - " type params." + f"The {location} {target} defined in resource {resource['name']} tries" + f" to bind param {name} with type {param_type}. {location} can only" + " bind 'resolve' type params." ) - # resolved params are bound later - path_params[name] = "{" + name + "}" + location_params[name] = "{" + name + "}" + + +def _bind_path_params(resource: EndpointResource) -> None: + """Binds params declared in path to params available in `params`. Pops the + bound params but. Params of type `resolve` and `incremental` are skipped + and bound later. + """ + path_params: Dict[str, Any] = {} + assert isinstance(resource["endpoint"], dict) # type guard + resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"], "path")] + path = resource["endpoint"]["path"] + _bind_string(str(path), resource, resolve_params, path_params, "path") + + if len(resolve_params) > 0: + raise NotImplementedError( + f"Resource {resource['name']} defines resolve params {resolve_params} that are not" + f" bound in path {path}." + ) resource["endpoint"]["path"] = path.format(**path_params) @@ -418,25 +434,24 @@ def _bind_header_params(resource: EndpointResource) -> None: """Binds params declared in headers to params available in `params`. Pops the bound params but skips params of type `resolve` and `incremental`, which are bound later. """ + header_params: Dict[str, Any] = {} assert isinstance(resource["endpoint"], dict) # type guard - params = resource["endpoint"].get("params", {}) - bind_params = {} - # copy must be made because size of dict changes during iteration - params_iter = params.copy() - for k, v in params_iter.items(): - if not isinstance(v, dict): - bind_params[k] = params.pop(k) - else: - # resolved params are bound later - bind_params[k] = "{" + k + "}" - + resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"], "header")] headers = resource["endpoint"].get("headers", {}) - formatted_headers = { - k.format(**bind_params) if isinstance(k, str) else str(k): ( - v.format(**bind_params) if isinstance(v, str) else str(v) + formatted_headers = {} + for header_name, header_value in headers.items(): + _bind_string(str(header_name), resource, resolve_params, header_params, "header") + _bind_string(str(header_value), resource, resolve_params, header_params, "header") + formatted_headers[header_name.format(**header_params)] = header_value.format( + **header_params ) - for k, v in headers.items() - } + + if len(resolve_params) > 0: + raise NotImplementedError( + f"Resource {resource['name']} defines resolve params {resolve_params} that are not" + " bound in headers." + ) + resource["endpoint"]["headers"] = formatted_headers @@ -456,7 +471,9 @@ def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: return endpoint -def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]: +def _find_resolved_params( + endpoint_config: Endpoint, location: Optional[ResolveParamLocation] = None +) -> List[ResolvedParam]: """ Find all resolved params in the endpoint configuration and return a list of ResolvedParam objects. @@ -466,7 +483,16 @@ def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]: return [ ResolvedParam(key, value) # type: ignore[arg-type] for key, value in endpoint_config.get("params", {}).items() - if (isinstance(value, dict) and value.get("type") == "resolve") + if ( + isinstance(value, dict) + and value.get("type") == "resolve" + and ( + value.get("location") == location + or location is None + or value.get("location") is None + and location == "path" + ) + ) ] @@ -599,7 +625,10 @@ def process_parent_data_item( ) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: parent_resource_name = resolved_params[0].resolve_config["resource"] - param_values = {} + param_values: Dict[str, Dict[str, str]] = { + "path": {}, + "header": {}, + } for resolved_param in resolved_params: field_values = jsonpath.find_values(resolved_param.field_path, item) @@ -613,9 +642,15 @@ def process_parent_data_item( f" {', '.join(item.keys())}" ) - param_values[resolved_param.param_name] = field_values[0] + location = resolved_param.resolve_config.get("location") + if location == "path": + param_values["path"][resolved_param.param_name] = field_values[0] + elif location == "header": + param_values["header"][resolved_param.param_name] = field_values[0] + else: + param_values["path"][resolved_param.param_name] = field_values[0] - bound_path = path.format(**param_values) + bound_path = path.format(**param_values["path"]) parent_record: Dict[str, Any] = {} if include_from_parent: @@ -631,8 +666,8 @@ def process_parent_data_item( if headers is not None: formatted_headers = { - k.format(**param_values) if isinstance(k, str) else str(k): ( - v.format(**param_values) if isinstance(v, str) else str(v) + k.format(**param_values["header"]) if isinstance(k, str) else str(k): ( + v.format(**param_values["header"]) if isinstance(v, str) else str(v) ) for k, v in headers.items() } diff --git a/dlt/sources/rest_api/typing.py b/dlt/sources/rest_api/typing.py index 455d012b10..54f3247c72 100644 --- a/dlt/sources/rest_api/typing.py +++ b/dlt/sources/rest_api/typing.py @@ -10,6 +10,7 @@ Optional, Union, ) +from enum import Enum from dlt.common import jsonpath from dlt.common.schema.typing import ( @@ -224,9 +225,13 @@ class ParamBindConfig(TypedDict): type: ParamBindType # noqa -class ResolveParamConfig(ParamBindConfig): +ResolveParamLocation = Literal["path", "header"] + + +class ResolveParamConfig(ParamBindConfig, total=False): resource: str field: str + location: Optional[ResolveParamLocation] class IncrementalParamConfig(ParamBindConfig, IncrementalRESTArgs): @@ -243,6 +248,7 @@ class ResolvedParam: def __post_init__(self) -> None: self.field_path = jsonpath.compile_path(self.resolve_config["field"]) + self.resolve_config["location"] = self.resolve_config.get("location", "path") class ResponseActionDict(TypedDict, total=False): diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index f498ee36f6..a8f2305a8e 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -80,12 +80,16 @@ def test_bind_path_param() -> None: _bind_path_params(tp_4) assert tp_4 == tp_5 - # resolved param will remain unbounded and + # resolved param will remain unbounded and raise an error tp_6 = deepcopy(three_params) tp_6["endpoint"]["path"] = "{org}/{repo}/issues/1234/comments" # type: ignore[index] - _bind_path_params( - tp_6 - ) # Does not raise because headers are now supported... and so are query params because they reside in the URL + with pytest.raises(NotImplementedError) as val_ex: # type: ignore[assignment] + _bind_path_params(tp_6) + assert ( + "Resource comments defines resolve params ['id'] that are not bound in path" + " {org}/{repo}/issues/1234/comments." + in str(val_ex.value) + ) def test_process_parent_data_item() -> None: @@ -168,7 +172,10 @@ def test_process_parent_data_item() -> None: def test_process_parent_data_item_headers() -> None: resolve_params = [ - ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}) + ResolvedParam( + "token", + {"field": "token", "resource": "authenticate", "type": "resolve", "location": "header"}, + ) ] _, resolved_headers, parent_record = process_parent_data_item( "chicken", @@ -181,8 +188,14 @@ def test_process_parent_data_item_headers() -> None: # multiple params resolve_params = [ - ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}), - ResolvedParam("num", {"field": "num", "resource": "authenticate", "type": "resolve"}), + ResolvedParam( + "token", + {"field": "token", "resource": "authenticate", "type": "resolve", "location": "header"}, + ), + ResolvedParam( + "num", + {"field": "num", "resource": "authenticate", "type": "resolve", "location": "header"}, + ), ] _, resolved_headers, parent_record = process_parent_data_item( "chicken", @@ -195,7 +208,13 @@ def test_process_parent_data_item_headers() -> None: resolve_params = [ ResolvedParam( - "token", {"field": "auth.token", "resource": "authenticate", "type": "resolve"} + "token", + { + "field": "auth.token", + "resource": "authenticate", + "type": "resolve", + "location": "header", + }, ) ] # param path not found @@ -209,6 +228,19 @@ def test_process_parent_data_item_headers() -> None: ) assert "Transformer expects a field 'auth.token'" in str(val_ex.value) + resolve_params = [ + ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}) + ] + # param not provided + with pytest.raises(KeyError): + _, _, parent_record = process_parent_data_item( + "chicken", + {"token": 12345}, + resolve_params, + None, + {"Authorization": "{token}"}, + ) + def test_two_resources_can_depend_on_one_parent_resource() -> None: user_id = { diff --git a/tests/sources/rest_api/test_process_parent_data_item.py b/tests/sources/rest_api/test_process_parent_data_item.py new file mode 100644 index 0000000000..5267648425 --- /dev/null +++ b/tests/sources/rest_api/test_process_parent_data_item.py @@ -0,0 +1,24 @@ +from dlt.sources.rest_api import process_parent_data_item +from dlt.sources.rest_api.typing import ResolvedParam + + +def test_process_parent_data_item(): + path = "{token}" + item = {"token": "1234"} + resolved_params = [ + ResolvedParam( + "token", + {"field": "token", "resource": "authenticate", "type": "resolve", "location": "header"}, + ), + ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}), + ] + include_from_parent = ["token"] + headers = {"{token}": "{token}"} + + formatted_path, formatted_headers, parent_record = process_parent_data_item( + path, item, resolved_params, include_from_parent, headers + ) + + assert formatted_path == "1234" + assert formatted_headers == {"1234": "1234"} + assert parent_record == {"_authenticate_token": "1234"} diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index f376b197f5..8ea76d2eac 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -180,6 +180,7 @@ def authenticate(): "type": "resolve", "field": "token", "resource": "authenticate", + "location": "header", }, }, }, From 944378291c5bda700c51d507a500092c91a80441 Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Thu, 28 Nov 2024 19:39:22 +0100 Subject: [PATCH 12/13] CHORE: improve code quality --- dlt/sources/rest_api/config_setup.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index fd10847d3e..9a63d93bdb 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -478,22 +478,19 @@ def _find_resolved_params( Find all resolved params in the endpoint configuration and return a list of ResolvedParam objects. + Param: + location: Optional[ResolveParamLocation] = None - filter resolved params by location if provided. + Resolved params are of type ResolveParamConfig (bound param with a key "type" set to "resolve".) """ - return [ + resolved_params = [ ResolvedParam(key, value) # type: ignore[arg-type] for key, value in endpoint_config.get("params", {}).items() - if ( - isinstance(value, dict) - and value.get("type") == "resolve" - and ( - value.get("location") == location - or location is None - or value.get("location") is None - and location == "path" - ) - ) + if isinstance(value, dict) and value.get("type") == "resolve" ] + if location is None: + return resolved_params + return list(filter(lambda rp: rp.resolve_config.get("location") == location, resolved_params)) def _action_type_unless_custom_hook( From b3d53e88072247209c7639e55f1d98b977ca3e0a Mon Sep 17 00:00:00 2001 From: ArneDePeuter <107651037+ArneDePeuter@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:06:12 +0100 Subject: [PATCH 13/13] CHORE: reformat --- dlt/common/libs/pyarrow.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 37268c0d2f..029cd75399 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -628,7 +628,14 @@ def row_tuples_to_arrow( " extracting an SQL VIEW that selects with cast." ) json_str_array = pa.array( - [None if s is None else json.dumps(s) if not issubclass(type(s), set) else json.dumps(list(s)) for s in columnar_known_types[field.name]] + [ + ( + None + if s is None + else json.dumps(s) if not issubclass(type(s), set) else json.dumps(list(s)) + ) + for s in columnar_known_types[field.name] + ] ) columnar_known_types[field.name] = json_str_array