Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: parameterised headers rest_api_source #1

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _create_request(
path: str,
method: HTTPMethod,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -109,10 +110,14 @@ def _create_request(
else:
url = join_url(self.base_url, path)

headers = headers or {}
if self.headers:
headers.update(self.headers)

return Request(
method=method,
url=url,
headers=self.headers,
headers=headers,
params=params,
json=json,
auth=auth or self.auth,
Expand Down Expand Up @@ -142,6 +147,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) ->
prepared_request = self._create_request(
path=path,
method=method,
headers=kwargs.pop("headers", None),
params=kwargs.pop("params", None),
json=kwargs.pop("json", None),
auth=kwargs.pop("auth", None),
Expand All @@ -160,6 +166,7 @@ def paginate(
path: str = "",
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
Expand All @@ -173,6 +180,7 @@ def paginate(
path (str): Endpoint path for the request, relative to `base_url`.
method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'.
params (Optional[Dict[str, Any]]): URL parameters for the request.
headers (Optional[Dict[str, Any]]): Headers for the request.
json (Optional[Dict[str, Any]]): JSON payload for the request.
auth (Optional[AuthBase): Authentication configuration for the request.
paginator (Optional[BasePaginator]): Paginator instance for handling
Expand Down Expand Up @@ -210,7 +218,7 @@ def paginate(
hooks["response"] = [raise_for_status]

request = self._create_request(
path=path, method=method, params=params, json=json, auth=auth, hooks=hooks
path=path, headers=headers, method=method, params=params, json=json, auth=auth, hooks=hooks
)

if paginator:
Expand Down
7 changes: 5 additions & 2 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
headers: Dict[str, Any],
params: Dict[str, Any],
paginator: Optional[BasePaginator],
data_selector: Optional[jsonpath.TJsonPath],
Expand All @@ -368,12 +369,13 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent
formatted_path, formatted_headers, parent_record = process_parent_data_item(
path, headers, item, resolved_params, include_from_parent
)

for child_page in client.paginate(
method=method,
headers=formatted_headers,
path=formatted_path,
params=params,
paginator=paginator,
Expand All @@ -392,6 +394,7 @@ def paginate_dependent_resource(
)(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
headers=endpoint_config.get("headers"),
params=base_params,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
Expand Down
55 changes: 47 additions & 8 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def expand_and_index_resources(
assert isinstance(endpoint_resource["endpoint"], dict)
_setup_single_entity_endpoint(endpoint_resource["endpoint"])
_bind_path_params(endpoint_resource)
_bind_header_params(endpoint_resource)

resource_name = endpoint_resource["name"]
assert isinstance(
Expand Down Expand Up @@ -410,15 +411,46 @@ def _bind_path_params(resource: EndpointResource) -> None:
# resolved params are bound later
path_params[name] = "{" + name + "}"

if len(resolve_params) > 0:
raise NotImplementedError(
f"Resource {resource['name']} defines resolve params {resolve_params} that are not"
f" bound in path {path}. Resolve query params not supported yet."
)

resource["endpoint"]["path"] = path.format(**path_params)


def _bind_header_params(resource: EndpointResource) -> None:
"""Binds params declared in headers to params available in `params`. Pops the
bound params but skips params of type `resolve` and `incremental`, which are bound later.
"""
header_params: Dict[str, Any] = {}
assert isinstance(resource["endpoint"], dict) # type guard
headers = resource["endpoint"].get("headers", {})
params = resource["endpoint"].get("params", {})
resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])]

for header_key, header_value in headers.items():
if isinstance(header_value, str) and header_value.startswith("{") and header_value.endswith("}"):
param_name = header_value.strip("{}")
if param_name not in params:
raise ValueError(
f"The header '{header_key}' in resource '{resource['name']}' requires a param with "
f"name '{param_name}' but it is not found in {params}."
)
if param_name in resolve_params:
resolve_params.remove(param_name)
if param_name in params:
if not isinstance(params[param_name], dict):
# Bind the header param and pop it from params
header_params[header_key] = params.pop(param_name)
else:
param_type = params[param_name].get("type")
if param_type != "resolve":
raise ValueError(
f"The header '{header_key}' in resource '{resource['name']}' tries to bind param "
f"'{param_name}' with type '{param_type}'. Headers can only bind 'resolve' type params."
)
# Resolved params are bound later
header_params[header_key] = "{" + param_name + "}"

resource["endpoint"]["headers"] = {**headers, **header_params}


def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint:
"""Tries to guess if the endpoint refers to a single entity and when detected:
* if `data_selector` was not specified (or is None), "$" is selected
Expand Down Expand Up @@ -569,12 +601,14 @@ def remove_field(response: Response, *args, **kwargs) -> Response:
return None



def process_parent_data_item(
path: str,
headers: Dict[str, Any],
item: Dict[str, Any],
resolved_params: List[ResolvedParam],
include_from_parent: List[str],
) -> Tuple[str, Dict[str, Any]]:
) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
parent_resource_name = resolved_params[0].resolve_config["resource"]

param_values = {}
Expand All @@ -594,6 +628,11 @@ def process_parent_data_item(
param_values[resolved_param.param_name] = field_values[0]

bound_path = path.format(**param_values)
formatted_headers = {}
for k, v in headers.items():
key = k if not isinstance(k, str) else k.format(**param_values)
val = v if not isinstance(v, str) else v.format(**param_values)
formatted_headers[key] = val

parent_record: Dict[str, Any] = {}
if include_from_parent:
Expand All @@ -607,7 +646,7 @@ def process_parent_data_item(
)
parent_record[child_key] = item[parent_key]

return bound_path, parent_record
return bound_path, formatted_headers, parent_record


def _merge_resource_endpoints(
Expand Down
1 change: 1 addition & 0 deletions dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class Endpoint(TypedDict, total=False):
method: Optional[HTTPMethodBasic]
params: Optional[Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]]
json: Optional[Dict[str, Any]]
headers: Optional[Dict[str, Any]]
paginator: Optional[PaginatorConfig]
data_selector: Optional[jsonpath.TJsonPath]
response_actions: Optional[List[ResponseAction]]
Expand Down
18 changes: 17 additions & 1 deletion tests/sources/rest_api/configurations/test_resolve_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import re
from copy import deepcopy

import pytest
from graphlib import CycleError # type: ignore

import dtl
from dlt.sources.rest_api import (
rest_api_resources,
rest_api_source,
)
from dlt.sources.rest_api.config_setup import (
_bind_path_params,
_bind_header_params,
process_parent_data_item,
)
from dlt.sources.rest_api.typing import (
Expand Down Expand Up @@ -351,3 +352,18 @@ def test_circular_resource_bindingis_invalid() -> None:
with pytest.raises(CycleError) as e:
rest_api_resources(config)
assert e.match(re.escape("'nodes are in a cycle', ['chicken', 'egg', 'chicken']"))


def test_bind_header_params() -> None:
resource_with_headers: EndpointResource = {
"name": "test_resource",
"endpoint": {
"path": "test/path",
"headers": {"Authorization": "{token}"},
"params": {
"token": "test_token",
},
},
}
_bind_header_params(resource_with_headers)
assert resource_with_headers["endpoint"]["headers"]["Authorization"] == "test_token"
108 changes: 108 additions & 0 deletions tests/sources/rest_api/test_rest_api_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dlt
import pytest
from unittest.mock import patch, MagicMock
from requests import Response, Request

from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContainer

Expand Down Expand Up @@ -149,3 +151,109 @@ def test_dependent_resource(destination_name: str, invocation_type: str) -> None
}

assert table_counts["pokemon"] == 2


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
@pytest.mark.parametrize("invocation_type", ("deco", "factory"))
@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request")
def test_request_headers(mock: MagicMock, destination_name: str, invocation_type: str) -> None:
mock_resp = Response()
mock_resp.status_code = 200
mock_resp.json = lambda: {"success": "ok"}
mock.return_value = mock_resp

@dlt.resource
def authenticate():
yield [{"token": 1}]

base_url = "https://api.example.com"
config: RESTAPIConfig = {
"client": {
"base_url": base_url,
"headers": {"foo": "bar"}
},
"resources": [
{
"name": "chicken",
"endpoint": {
"path": "chicken",
"headers": {"token": "{token}", "num": "2"},
"params": {
"token": {
"type": "resolve",
"field": "token",
"resource": "authenticate",
},
},
},
},
authenticate(),
],
}

if invocation_type == "deco":
data = rest_api(**config)
else:
data = rest_api_source(config)
pipeline = _make_pipeline(destination_name)
pipeline.run(data)

mock.assert_called()
args, kwargs = mock.call_args
request_param: Request = args[0]

assert request_param.url == f"{base_url}/chicken"
assert request_param.headers == {"foo": "bar", "token": "1", "num": "2"}


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
@pytest.mark.parametrize("invocation_type", ("deco", "factory"))
@patch("dlt.sources.helpers.rest_client.client.RESTClient._send_request")
def test_request_headers_dynamic_key(mock: MagicMock, destination_name: str, invocation_type: str) -> None:
mock_resp = Response()
mock_resp.status_code = 200
mock_resp.json = lambda: {"success": "ok"}
mock.return_value = mock_resp

@dlt.resource
def authenticate():
yield [{"token": 1}]

base_url = "https://api.example.com"
config: RESTAPIConfig = {
"client": {
"base_url": base_url,
"headers": {"foo": "bar"}
},
"resources": [
{
"name": "chicken",
"endpoint": {
"path": "chicken",
"headers": {"{token}": "{token}", "num": "2"},
"params": {
"token": {
"type": "resolve",
"field": "token",
"resource": "authenticate",
},
},
},
},
authenticate(),
],
}

if invocation_type == "deco":
data = rest_api(**config)
else:
data = rest_api_source(config)
pipeline = _make_pipeline(destination_name)
pipeline.run(data)

mock.assert_called()
args, kwargs = mock.call_args
request_param: Request = args[0]

assert request_param.url == f"{base_url}/chicken"
assert request_param.headers == {"foo": "bar", "1": "1", "num": "2"}