diff --git a/src/typesense/configuration.py b/src/typesense/configuration.py index 6bff96f..96ea7e1 100644 --- a/src/typesense/configuration.py +++ b/src/typesense/configuration.py @@ -74,6 +74,8 @@ class ConfigDict(typing.TypedDict): master_node (typing.Union[str, NodeConfigDict], deprecated): A dictionary or URL that represents the master node. + additional_headers (dict): Additional headers to include in the request. + read_replica_nodes (list[typing.Union[str, NodeConfigDict]], deprecated): A list of dictionaries or URLs that represent the read replica nodes. """ @@ -87,6 +89,7 @@ class ConfigDict(typing.TypedDict): verify: typing.NotRequired[bool] timeout_seconds: typing.NotRequired[int] # deprecated master_node: typing.NotRequired[typing.Union[str, NodeConfigDict]] # deprecated + additional_headers: typing.NotRequired[typing.Dict[str, str]] read_replica_nodes: typing.NotRequired[ typing.List[typing.Union[str, NodeConfigDict]] ] # deprecated @@ -213,6 +216,7 @@ def __init__( 60, ) self.verify = config_dict.get("verify", True) + self.additional_headers = config_dict.get("additional_headers", {}) def _handle_nearest_node( self, diff --git a/src/typesense/request_handler.py b/src/typesense/request_handler.py index 3ef16ca..b9f822a 100644 --- a/src/typesense/request_handler.py +++ b/src/typesense/request_handler.py @@ -204,7 +204,11 @@ def make_request( Raises: TypesenseClientError: If the API returns an error response. """ - headers = {self.api_key_header_name: self.config.api_key} + headers = { + self.api_key_header_name: self.config.api_key, + } + headers.update(self.config.additional_headers) + kwargs.setdefault("headers", {}).update(headers) kwargs.setdefault("timeout", self.config.connection_timeout_seconds) kwargs.setdefault("verify", self.config.verify) diff --git a/tests/api_call_test.py b/tests/api_call_test.py index 3ae8b34..caaa4a1 100644 --- a/tests/api_call_test.py +++ b/tests/api_call_test.py @@ -6,6 +6,7 @@ import sys import time +from isort import Config from pytest_mock import MockFixture if sys.version_info >= (3, 11): @@ -135,6 +136,43 @@ def test_normalize_params_with_no_booleans() -> None: assert parameter_dict == {"key1": "value", "key2": 123} +def test_additional_headers(fake_api_call: ApiCall) -> None: + """Test the `make_request` method with additional headers from the config.""" + session = requests.sessions.Session() + api_call = ApiCall( + Configuration( + { + "additional_headers": { + "AdditionalHeader1": "test", + "AdditionalHeader2": "test2", + }, + "api_key": "test-api", + "nodes": [ + "http://nearest:8108", + ], + }, + ), + ) + + with requests_mock.mock(session=session) as request_mocker: + request_mocker.get( + "http://nearest:8108/test", + json={"key": "value"}, + status_code=200, + ) + + api_call._execute_request( + session.get, + "/test", + as_json=True, + entity_type=typing.Dict[str, str], + ) + + request = request_mocker.request_history[-1] + assert request.headers["AdditionalHeader1"] == "test" + assert request.headers["AdditionalHeader2"] == "test2" + + def test_make_request_as_json(fake_api_call: ApiCall) -> None: """Test the `make_request` method with JSON response.""" session = requests.sessions.Session() @@ -172,6 +210,7 @@ def test_make_request_as_text(fake_api_call: ApiCall) -> None: as_json=False, entity_type=typing.Dict[str, str], ) + assert response == "response text" @@ -431,7 +470,8 @@ def test_get_node_no_healthy_nodes( assert "No healthy nodes were found. Returning the next node." in caplog.text assert ( - selected_node == fake_api_call.node_manager.nodes[fake_api_call.node_manager.node_index] + selected_node + == fake_api_call.node_manager.nodes[fake_api_call.node_manager.node_index] ) assert fake_api_call.node_manager.node_index == 0 diff --git a/tests/configuration_test.py b/tests/configuration_test.py index 120888f..da3166f 100644 --- a/tests/configuration_test.py +++ b/tests/configuration_test.py @@ -66,6 +66,7 @@ def test_configuration_explicit() -> None: "num_retries": 5, "retry_interval_seconds": 2.0, "verify": False, + "additional_headers": {"X-Test": "test", "X-Test2": "test2"}, } configuration = Configuration(config) @@ -82,6 +83,7 @@ def test_configuration_explicit() -> None: "num_retries": 5, "retry_interval_seconds": 2.0, "verify": False, + "additional_headers": {"X-Test": "test", "X-Test2": "test2"}, } assert_to_contain_object(configuration, expected)