diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 37268c0d2f..029cd75399 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -628,7 +628,14 @@ def row_tuples_to_arrow( " extracting an SQL VIEW that selects with cast." ) json_str_array = pa.array( - [None if s is None else json.dumps(s) if not issubclass(type(s), set) else json.dumps(list(s)) for s in columnar_known_types[field.name]] + [ + ( + None + if s is None + else json.dumps(s) if not issubclass(type(s), set) else json.dumps(list(s)) + ) + for s in columnar_known_types[field.name] + ] ) columnar_known_types[field.name] = json_str_array diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index a619a05a00..ac0532b66e 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -100,6 +100,7 @@ def _create_request( path_or_url: str, method: HTTPMethod, params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: Optional[AuthBase] = None, hooks: Optional[Hooks] = None, @@ -110,10 +111,14 @@ def _create_request( else: url = join_url(self.base_url, path_or_url) + headers = headers or {} + if self.headers: + headers.update(self.headers) + return Request( method=method, url=url, - headers=self.headers, + headers=headers, params=params, json=json, auth=auth or self.auth, @@ -124,6 +129,7 @@ def _send_request(self, request: Request, **kwargs: Any) -> Response: logger.info( f"Making {request.method.upper()} request to {request.url}" f" with params={request.params}, json={request.json}" + f" with headers={request.headers}" ) prepared_request = self.session.prepare_request(request) @@ -143,6 +149,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> prepared_request = self._create_request( path_or_url=path, method=method, + headers=kwargs.pop("headers", None), params=kwargs.pop("params", None), json=kwargs.pop("json", None), auth=kwargs.pop("auth", None), @@ -161,6 +168,7 @@ def paginate( path: str = "", method: HTTPMethodBasic = "GET", params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, auth: Optional[AuthBase] = None, paginator: Optional[BasePaginator] = None, @@ -176,6 +184,7 @@ def paginate( be used instead of the base_url + path. method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'. params (Optional[Dict[str, Any]]): URL parameters for the request. + headers (Optional[Dict[str, Any]]): Headers for the request. json (Optional[Dict[str, Any]]): JSON payload for the request. auth (Optional[AuthBase): Authentication configuration for the request. paginator (Optional[BasePaginator]): Paginator instance for handling @@ -213,7 +222,13 @@ def paginate( hooks["response"] = [raise_for_status] request = self._create_request( - path_or_url=path, method=method, params=params, json=json, auth=auth, hooks=hooks + path_or_url=path, + headers=headers, + method=method, + params=params, + json=json, + auth=auth, + hooks=hooks, ) if paginator: diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index 966d9e8b6c..dfdae73d4d 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -291,6 +291,7 @@ def process( def paginate_resource( method: HTTPMethodBasic, path: str, + headers: Dict[str, Any], params: Dict[str, Any], json: Optional[Dict[str, Any]], paginator: Optional[BasePaginator], @@ -313,6 +314,7 @@ def paginate_resource( yield from client.paginate( method=method, + headers=headers, path=path, params=params, json=json, @@ -327,6 +329,7 @@ def paginate_resource( )( method=endpoint_config.get("method", "get"), path=endpoint_config.get("path"), + headers=endpoint_config.get("headers"), params=request_params, json=request_json, paginator=paginator, @@ -346,6 +349,7 @@ def paginate_dependent_resource( items: List[Dict[str, Any]], method: HTTPMethodBasic, path: str, + headers: Dict[str, Any], params: Dict[str, Any], paginator: Optional[BasePaginator], data_selector: Optional[jsonpath.TJsonPath], @@ -368,12 +372,13 @@ def paginate_dependent_resource( ) for item in items: - formatted_path, parent_record = process_parent_data_item( - path, item, resolved_params, include_from_parent + formatted_path, formatted_headers, parent_record = process_parent_data_item( + path, item, resolved_params, include_from_parent, headers ) for child_page in client.paginate( method=method, + headers=formatted_headers, path=formatted_path, params=params, paginator=paginator, @@ -392,6 +397,7 @@ def paginate_dependent_resource( )( method=endpoint_config.get("method", "get"), path=endpoint_config.get("path"), + headers=endpoint_config.get("headers"), params=base_params, paginator=paginator, data_selector=endpoint_config.get("data_selector"), diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index d03a4fd59b..9a63d93bdb 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -64,6 +64,7 @@ ResponseActionDict, Endpoint, EndpointResource, + ResolveParamLocation, ) @@ -330,6 +331,7 @@ def expand_and_index_resources( assert isinstance(endpoint_resource["endpoint"], dict) _setup_single_entity_endpoint(endpoint_resource["endpoint"]) _bind_path_params(endpoint_resource) + _bind_header_params(endpoint_resource) resource_name = endpoint_resource["name"] assert isinstance( @@ -375,50 +377,84 @@ def _make_endpoint_resource( return _merge_resource_endpoints(default_config, resource) -def _bind_path_params(resource: EndpointResource) -> None: - """Binds params declared in path to params available in `params`. Pops the - bound params but. Params of type `resolve` and `incremental` are skipped - and bound later. - """ - path_params: Dict[str, Any] = {} - assert isinstance(resource["endpoint"], dict) # type guard - resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])] - path = resource["endpoint"]["path"] - for format_ in string.Formatter().parse(path): +def _bind_string( + target: str, + resource: EndpointResource, + resolve_params: List[str], + location_params: Dict[str, Any], + location: str, +) -> None: + for format_ in string.Formatter().parse(target): name = format_[1] if name: - params = resource["endpoint"].get("params", {}) - if name not in params and name not in path_params: + params = resource["endpoint"].get("params", {}) # type: ignore[union-attr] + if name not in params and name not in location_params: raise ValueError( - f"The path {path} defined in resource {resource['name']} requires param with" - f" name {name} but it is not found in {params}" + f"The {location} {target} defined in resource {resource['name']} requires param" + f" with name {name} but it is not found in {params}" ) if name in resolve_params: resolve_params.remove(name) if name in params: if not isinstance(params[name], dict): # bind resolved param and pop it from endpoint - path_params[name] = params.pop(name) + location_params[name] = params.pop(name) else: param_type = params[name].get("type") if param_type != "resolve": raise ValueError( - f"The path {path} defined in resource {resource['name']} tries to bind" - f" param {name} with type {param_type}. Paths can only bind 'resolve'" - " type params." + f"The {location} {target} defined in resource {resource['name']} tries" + f" to bind param {name} with type {param_type}. {location} can only" + " bind 'resolve' type params." ) - # resolved params are bound later - path_params[name] = "{" + name + "}" + location_params[name] = "{" + name + "}" + + +def _bind_path_params(resource: EndpointResource) -> None: + """Binds params declared in path to params available in `params`. Pops the + bound params but. Params of type `resolve` and `incremental` are skipped + and bound later. + """ + path_params: Dict[str, Any] = {} + assert isinstance(resource["endpoint"], dict) # type guard + resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"], "path")] + path = resource["endpoint"]["path"] + _bind_string(str(path), resource, resolve_params, path_params, "path") if len(resolve_params) > 0: raise NotImplementedError( f"Resource {resource['name']} defines resolve params {resolve_params} that are not" - f" bound in path {path}. Resolve query params not supported yet." + f" bound in path {path}." ) resource["endpoint"]["path"] = path.format(**path_params) +def _bind_header_params(resource: EndpointResource) -> None: + """Binds params declared in headers to params available in `params`. Pops the + bound params but skips params of type `resolve` and `incremental`, which are bound later. + """ + header_params: Dict[str, Any] = {} + assert isinstance(resource["endpoint"], dict) # type guard + resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"], "header")] + headers = resource["endpoint"].get("headers", {}) + formatted_headers = {} + for header_name, header_value in headers.items(): + _bind_string(str(header_name), resource, resolve_params, header_params, "header") + _bind_string(str(header_value), resource, resolve_params, header_params, "header") + formatted_headers[header_name.format(**header_params)] = header_value.format( + **header_params + ) + + if len(resolve_params) > 0: + raise NotImplementedError( + f"Resource {resource['name']} defines resolve params {resolve_params} that are not" + " bound in headers." + ) + + resource["endpoint"]["headers"] = formatted_headers + + def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: """Tries to guess if the endpoint refers to a single entity and when detected: * if `data_selector` was not specified (or is None), "$" is selected @@ -435,18 +471,26 @@ def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: return endpoint -def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]: +def _find_resolved_params( + endpoint_config: Endpoint, location: Optional[ResolveParamLocation] = None +) -> List[ResolvedParam]: """ Find all resolved params in the endpoint configuration and return a list of ResolvedParam objects. + Param: + location: Optional[ResolveParamLocation] = None - filter resolved params by location if provided. + Resolved params are of type ResolveParamConfig (bound param with a key "type" set to "resolve".) """ - return [ + resolved_params = [ ResolvedParam(key, value) # type: ignore[arg-type] for key, value in endpoint_config.get("params", {}).items() - if (isinstance(value, dict) and value.get("type") == "resolve") + if isinstance(value, dict) and value.get("type") == "resolve" ] + if location is None: + return resolved_params + return list(filter(lambda rp: rp.resolve_config.get("location") == location, resolved_params)) def _action_type_unless_custom_hook( @@ -574,10 +618,14 @@ def process_parent_data_item( item: Dict[str, Any], resolved_params: List[ResolvedParam], include_from_parent: List[str], -) -> Tuple[str, Dict[str, Any]]: + headers: Optional[Dict[str, Any]] = None, +) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: parent_resource_name = resolved_params[0].resolve_config["resource"] - param_values = {} + param_values: Dict[str, Dict[str, str]] = { + "path": {}, + "header": {}, + } for resolved_param in resolved_params: field_values = jsonpath.find_values(resolved_param.field_path, item) @@ -591,9 +639,15 @@ def process_parent_data_item( f" {', '.join(item.keys())}" ) - param_values[resolved_param.param_name] = field_values[0] + location = resolved_param.resolve_config.get("location") + if location == "path": + param_values["path"][resolved_param.param_name] = field_values[0] + elif location == "header": + param_values["header"][resolved_param.param_name] = field_values[0] + else: + param_values["path"][resolved_param.param_name] = field_values[0] - bound_path = path.format(**param_values) + bound_path = path.format(**param_values["path"]) parent_record: Dict[str, Any] = {} if include_from_parent: @@ -607,7 +661,15 @@ def process_parent_data_item( ) parent_record[child_key] = item[parent_key] - return bound_path, parent_record + if headers is not None: + formatted_headers = { + k.format(**param_values["header"]) if isinstance(k, str) else str(k): ( + v.format(**param_values["header"]) if isinstance(v, str) else str(v) + ) + for k, v in headers.items() + } + return bound_path, formatted_headers, parent_record + return bound_path, {}, parent_record def _merge_resource_endpoints( diff --git a/dlt/sources/rest_api/typing.py b/dlt/sources/rest_api/typing.py index ccef828b1a..e40b5aa0be 100644 --- a/dlt/sources/rest_api/typing.py +++ b/dlt/sources/rest_api/typing.py @@ -10,6 +10,7 @@ Optional, Union, ) +from enum import Enum from dlt.common import jsonpath from dlt.common.schema.typing import ( @@ -224,9 +225,13 @@ class ParamBindConfig(TypedDict): type: ParamBindType # noqa -class ResolveParamConfig(ParamBindConfig): +ResolveParamLocation = Literal["path", "header"] + + +class ResolveParamConfig(ParamBindConfig, total=False): resource: str field: str + location: Optional[ResolveParamLocation] class IncrementalParamConfig(ParamBindConfig, IncrementalRESTArgs): @@ -243,6 +248,7 @@ class ResolvedParam: def __post_init__(self) -> None: self.field_path = jsonpath.compile_path(self.resolve_config["field"]) + self.resolve_config["location"] = self.resolve_config.get("location", "path") class ResponseActionDict(TypedDict, total=False): @@ -259,6 +265,7 @@ class Endpoint(TypedDict, total=False): method: Optional[HTTPMethodBasic] params: Optional[Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]] json: Optional[Dict[str, Any]] + headers: Optional[Dict[str, Any]] paginator: Optional[PaginatorConfig] data_selector: Optional[jsonpath.TJsonPath] response_actions: Optional[List[ResponseAction]] diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index d3d9308df1..a8f2305a8e 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -1,6 +1,5 @@ import re from copy import deepcopy - import pytest from graphlib import CycleError # type: ignore @@ -10,6 +9,7 @@ ) from dlt.sources.rest_api.config_setup import ( _bind_path_params, + _bind_header_params, process_parent_data_item, ) from dlt.sources.rest_api.typing import ( @@ -80,29 +80,34 @@ def test_bind_path_param() -> None: _bind_path_params(tp_4) assert tp_4 == tp_5 - # resolved param will remain unbounded and + # resolved param will remain unbounded and raise an error tp_6 = deepcopy(three_params) tp_6["endpoint"]["path"] = "{org}/{repo}/issues/1234/comments" # type: ignore[index] - with pytest.raises(NotImplementedError): + with pytest.raises(NotImplementedError) as val_ex: # type: ignore[assignment] _bind_path_params(tp_6) + assert ( + "Resource comments defines resolve params ['id'] that are not bound in path" + " {org}/{repo}/issues/1234/comments." + in str(val_ex.value) + ) def test_process_parent_data_item() -> None: resolve_params = [ ResolvedParam("id", {"field": "obj_id", "resource": "issues", "type": "resolve"}) ] - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, None ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" assert parent_record == {} - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, ["obj_id"] ) assert parent_record == {"_issues_obj_id": 12345} - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345, "obj_node": "node_1"}, resolve_params, @@ -117,21 +122,21 @@ def test_process_parent_data_item() -> None: ) ] item = {"some_results": {"obj_id": 12345}} - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", item, resolve_param_nested, None ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" # param path not found with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"_id": 12345}, resolve_params, None ) assert "Transformer expects a field 'obj_id'" in str(val_ex.value) # included path not found with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345, "obj_node": "node_1"}, resolve_params, @@ -145,7 +150,7 @@ def test_process_parent_data_item() -> None: ResolvedParam("id", {"field": "id", "resource": "comments", "type": "resolve"}), ] - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{issue_id}/comments/{id}", {"issue": 12345, "id": 56789}, multi_resolve_params, @@ -156,7 +161,7 @@ def test_process_parent_data_item() -> None: # param path not found with multiple parameters with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( + bound_path, _, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{issue_id}/comments/{id}", {"_issue": 12345, "id": 56789}, multi_resolve_params, @@ -165,6 +170,78 @@ def test_process_parent_data_item() -> None: assert "Transformer expects a field 'issue'" in str(val_ex.value) +def test_process_parent_data_item_headers() -> None: + resolve_params = [ + ResolvedParam( + "token", + {"field": "token", "resource": "authenticate", "type": "resolve", "location": "header"}, + ) + ] + _, resolved_headers, parent_record = process_parent_data_item( + "chicken", + {"token": 12345}, + resolve_params, + None, + {"Authorization": "{token}"}, + ) + assert resolved_headers == {"Authorization": "12345"} + + # multiple params + resolve_params = [ + ResolvedParam( + "token", + {"field": "token", "resource": "authenticate", "type": "resolve", "location": "header"}, + ), + ResolvedParam( + "num", + {"field": "num", "resource": "authenticate", "type": "resolve", "location": "header"}, + ), + ] + _, resolved_headers, parent_record = process_parent_data_item( + "chicken", + {"token": 12345, "num": 2}, + resolve_params, + None, + {"Authorization": "{token}", "num": "{num}"}, + ) + assert resolved_headers == {"Authorization": "12345", "num": "2"} + + resolve_params = [ + ResolvedParam( + "token", + { + "field": "auth.token", + "resource": "authenticate", + "type": "resolve", + "location": "header", + }, + ) + ] + # param path not found + with pytest.raises(ValueError) as val_ex: + _, _, parent_record = process_parent_data_item( + "chicken", + {"_token": 12345}, + resolve_params, + None, + {"Authorization": "{token}"}, + ) + assert "Transformer expects a field 'auth.token'" in str(val_ex.value) + + resolve_params = [ + ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}) + ] + # param not provided + with pytest.raises(KeyError): + _, _, parent_record = process_parent_data_item( + "chicken", + {"token": 12345}, + resolve_params, + None, + {"Authorization": "{token}"}, + ) + + def test_two_resources_can_depend_on_one_parent_resource() -> None: user_id = { "user_id": { @@ -351,3 +428,19 @@ def test_circular_resource_bindingis_invalid() -> None: with pytest.raises(CycleError) as e: rest_api_resources(config) assert e.match(re.escape("'nodes are in a cycle', ['chicken', 'egg', 'chicken']")) + + +def test_bind_header_params() -> None: + resource_with_headers: EndpointResource = { + "name": "test_resource", + "endpoint": { + "path": "test/path", + "headers": {"Authorization": "Bearer {token}"}, + "params": { + "token": "test_token", + }, + }, + } + _bind_header_params(resource_with_headers) + assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "Bearer test_token" # type: ignore[index] + assert len(resource_with_headers["endpoint"]["params"]) == 0 # type: ignore[index] diff --git a/tests/sources/rest_api/test_process_parent_data_item.py b/tests/sources/rest_api/test_process_parent_data_item.py new file mode 100644 index 0000000000..5267648425 --- /dev/null +++ b/tests/sources/rest_api/test_process_parent_data_item.py @@ -0,0 +1,24 @@ +from dlt.sources.rest_api import process_parent_data_item +from dlt.sources.rest_api.typing import ResolvedParam + + +def test_process_parent_data_item(): + path = "{token}" + item = {"token": "1234"} + resolved_params = [ + ResolvedParam( + "token", + {"field": "token", "resource": "authenticate", "type": "resolve", "location": "header"}, + ), + ResolvedParam("token", {"field": "token", "resource": "authenticate", "type": "resolve"}), + ] + include_from_parent = ["token"] + headers = {"{token}": "{token}"} + + formatted_path, formatted_headers, parent_record = process_parent_data_item( + path, item, resolved_params, include_from_parent, headers + ) + + assert formatted_path == "1234" + assert formatted_headers == {"1234": "1234"} + assert parent_record == {"_authenticate_token": "1234"} diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index 904bcaf159..8ea76d2eac 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -1,5 +1,7 @@ import dlt import pytest +from unittest.mock import patch, MagicMock +from requests import Response, Request from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContainer @@ -149,3 +151,54 @@ def test_dependent_resource(destination_name: str, invocation_type: str) -> None } assert table_counts["pokemon"] == 2 + + +@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) +@pytest.mark.parametrize("invocation_type", ("deco", "factory")) +@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request") +def test_request_headers(mock: MagicMock, destination_name: str, invocation_type: str) -> None: + mock_resp = Response() + mock_resp.status_code = 200 + mock_resp.json = lambda: {"success": "ok"} # type: ignore + mock.return_value = mock_resp + + @dlt.resource + def authenticate(): + yield [{"token": 1}] + + base_url = "https://api.example.com" + config: RESTAPIConfig = { + "client": {"base_url": base_url, "headers": {"foo": "bar"}}, + "resources": [ + { + "name": "chicken", + "endpoint": { + "path": "chicken", + "headers": {"token": "Bearer {token}", "num": "2"}, + "params": { + "token": { + "type": "resolve", + "field": "token", + "resource": "authenticate", + "location": "header", + }, + }, + }, + }, + authenticate(), + ], + } + + if invocation_type == "deco": + data = rest_api(**config) + else: + data = rest_api_source(config) + pipeline = _make_pipeline(destination_name) + pipeline.run(data) + + mock.assert_called() + args, kwargs = mock.call_args + request_param: Request = args[0] + + assert request_param.url == f"{base_url}/chicken" + assert request_param.headers == {"foo": "bar", "token": "Bearer 1", "num": "2"}