diff --git a/GNUmakefile b/GNUmakefile index 8e743feec..74622648b 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -72,7 +72,7 @@ unit-tests: venv/.deps-dev rm -fr runtime/* .PHONY: integration-tests -unit-tests: export PYTEST_ARGS ?= +integration-tests: export PYTEST_ARGS ?= integration-tests: venv/.deps-dev rm -fr runtime/* $(PYTHON) -m pytest -s -vvv $(PYTEST_ARGS) tests/integration/ diff --git a/src/karapace/forward_client.py b/src/karapace/forward_client.py index 9d791303c..239de41c4 100644 --- a/src/karapace/forward_client.py +++ b/src/karapace/forward_client.py @@ -24,10 +24,10 @@ class ForwardClient: USER_AGENT = f"Karapace/{__version__}" def __init__(self) -> None: - self._forward_client: aiohttp.ClientSession | None = None + self._forward_client: aiohttp.ClientSession = aiohttp.ClientSession(headers={"User-Agent": self.USER_AGENT}) - def _get_forward_client(self) -> aiohttp.ClientSession: - return aiohttp.ClientSession(headers={"User-Agent": ForwardClient.USER_AGENT}) + async def close(self) -> None: + await self._forward_client.close() def _acceptable_response_content_type(self, *, content_type: str) -> bool: return ( @@ -42,11 +42,7 @@ async def _forward_request_remote( ) -> bytes: LOG.info("Forwarding %s request to remote url: %r since we're not the master", request.method, request.url) timeout = 60.0 - headers = request.headers.mutablecopy() - func = getattr(self._get_forward_client(), request.method.lower()) - # auth_header = request.headers.get("Authorization") - # if auth_header is not None: - # headers["Authorization"] = auth_header + func = getattr(self._forward_client, request.method.lower()) forward_url = f"{primary_url}{request.url.path}" if request.url.query: @@ -55,7 +51,7 @@ async def _forward_request_remote( async with async_timeout.timeout(timeout): body_data = await request.body() - async with func(forward_url, headers=headers, data=body_data) as response: + async with func(forward_url, headers=request.headers.mutablecopy(), data=body_data) as response: if self._acceptable_response_content_type(content_type=response.headers.get("Content-Type")): return await response.text() LOG.error("Unknown response for forwarded request: %s", response) diff --git a/src/schema_registry/factory.py b/src/schema_registry/factory.py index 12a80775d..4ef678841 100644 --- a/src/schema_registry/factory.py +++ b/src/schema_registry/factory.py @@ -9,6 +9,7 @@ from karapace import version as karapace_version from karapace.auth import AuthenticatorAndAuthorizer from karapace.config import Config +from karapace.forward_client import ForwardClient from karapace.logging_setup import configure_logging, log_config_without_secrets from karapace.schema_registry import KarapaceSchemaRegistry from karapace.statsd import StatsClient @@ -26,6 +27,7 @@ @inject async def karapace_schema_registry_lifespan( _: FastAPI, + forward_client: ForwardClient = Depends(Provide[SchemaRegistryContainer.karapace_container.forward_client]), stastd: StatsClient = Depends(Provide[SchemaRegistryContainer.karapace_container.statsd]), schema_registry: KarapaceSchemaRegistry = Depends(Provide[SchemaRegistryContainer.karapace_container.schema_registry]), authorizer: AuthenticatorAndAuthorizer = Depends(Provide[SchemaRegistryContainer.karapace_container.authorizer]), @@ -37,19 +39,17 @@ async def karapace_schema_registry_lifespan( yield finally: - if schema_registry: - await schema_registry.close() - if authorizer: - await authorizer.close() - if stastd: - stastd.close() + await schema_registry.close() + await authorizer.close() + await forward_client.close() + stastd.close() def create_karapace_application( *, config: Config, lifespan: Callable[ - [FastAPI, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None] + [FastAPI, ForwardClient, StatsClient, KarapaceSchemaRegistry, AuthenticatorAndAuthorizer], AsyncContextManager[None] ], ) -> FastAPI: configure_logging(config=config) diff --git a/tests/unit/test_forwarding_client.py b/tests/unit/test_forwarding_client.py index 744b1e9da..6dc9700c9 100644 --- a/tests/unit/test_forwarding_client.py +++ b/tests/unit/test_forwarding_client.py @@ -1,7 +1,7 @@ """ -karapace - schema registry authentication and authorization tests +karapace - master forwarding tests -Copyright (c) 2023 Aiven Ltd +Copyright (c) 2024 Aiven Ltd See LICENSE for details """ from __future__ import annotations @@ -15,6 +15,7 @@ from tests.base_testcase import BaseTestCase from unittest.mock import AsyncMock, Mock, patch +import aiohttp import pytest @@ -28,6 +29,20 @@ class ContentTypeTestCase(BaseTestCase): content_type: str +@pytest.fixture(name="forward_client") +def fixture_forward_client() -> ForwardClient: + with patch("karapace.forward_client.aiohttp") as mocked_aiohttp: + mocked_aiohttp.ClientSession.return_value = Mock( + spec=aiohttp.ClientSession, headers={"User-Agent": ForwardClient.USER_AGENT} + ) + return ForwardClient() + + +async def test_forward_client_close(forward_client: ForwardClient) -> None: + await forward_client.close() + forward_client._forward_client.close.assert_awaited_once() # pylint: disable=protected-access + + @pytest.mark.parametrize( "testcase", [ @@ -42,112 +57,100 @@ class ContentTypeTestCase(BaseTestCase): ], ids=str, ) -async def test_forward_request_with_basemodel_response(testcase: ContentTypeTestCase) -> None: - forward_client = ForwardClient() - with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client: - mock_request = Mock(spec=Request) - mock_request.method = "GET" - mock_request.headers = Headers() - - mock_aiohttp_session = Mock() - mock_get_forward_client.return_value = mock_aiohttp_session - mock_get_func = Mock() - mock_response_context = AsyncMock - mock_response = AsyncMock() - mock_response_context.call_function = lambda _: mock_response - mock_response.text.return_value = '{"number":10,"string":"important"}' - headers = MutableHeaders() - headers["Content-Type"] = testcase.content_type - mock_response.headers = headers - - async def mock_aenter(_) -> Mock: - return mock_response - - async def mock_aexit(_, __, ___, ____) -> None: - return - - mock_get_func.__aenter__ = mock_aenter - mock_get_func.__aexit__ = mock_aexit - mock_aiohttp_session.get.return_value = mock_get_func - - response = await forward_client.forward_request_remote( - request=mock_request, - primary_url="test-url", - response_type=TestResponse, - ) - - assert response == TestResponse(number=10, string="important") - - -async def test_forward_request_with_integer_list_response() -> None: - forward_client = ForwardClient() - with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client: - mock_request = Mock(spec=Request) - mock_request.method = "GET" - mock_request.headers = Headers() - - mock_aiohttp_session = Mock() - mock_get_forward_client.return_value = mock_aiohttp_session - mock_get_func = Mock() - mock_response_context = AsyncMock - mock_response = AsyncMock() - mock_response_context.call_function = lambda _: mock_response - mock_response.text.return_value = "[1, 2, 3, 10]" - headers = MutableHeaders() - headers["Content-Type"] = "application/json" - mock_response.headers = headers - - async def mock_aenter(_) -> Mock: - return mock_response - - async def mock_aexit(_, __, ___, ____) -> None: - return - - mock_get_func.__aenter__ = mock_aenter - mock_get_func.__aexit__ = mock_aexit - mock_aiohttp_session.get.return_value = mock_get_func - - response = await forward_client.forward_request_remote( - request=mock_request, - primary_url="test-url", - response_type=list[int], - ) - - assert response == [1, 2, 3, 10] - - -async def test_forward_request_with_integer_response() -> None: - forward_client = ForwardClient() - with patch.object(forward_client, "_get_forward_client", autospec=True) as mock_get_forward_client: - mock_request = Mock(spec=Request) - mock_request.method = "GET" - mock_request.headers = Headers() - - mock_aiohttp_session = Mock() - mock_get_forward_client.return_value = mock_aiohttp_session - mock_get_func = Mock() - mock_response_context = AsyncMock - mock_response = AsyncMock() - mock_response_context.call_function = lambda _: mock_response - mock_response.text.return_value = "12" - headers = MutableHeaders() - headers["Content-Type"] = "application/json" - mock_response.headers = headers - - async def mock_aenter(_) -> Mock: - return mock_response - - async def mock_aexit(_, __, ___, ____) -> None: - return - - mock_get_func.__aenter__ = mock_aenter - mock_get_func.__aexit__ = mock_aexit - mock_aiohttp_session.get.return_value = mock_get_func - - response = await forward_client.forward_request_remote( - request=mock_request, - primary_url="test-url", - response_type=int, - ) - - assert response == 12 +async def test_forward_request_with_basemodel_response(forward_client: ForwardClient, testcase: ContentTypeTestCase) -> None: + mock_request = Mock(spec=Request) + mock_request.method = "GET" + mock_request.headers = Headers() + + mock_get_func = Mock() + mock_response_context = AsyncMock + mock_response = AsyncMock() + mock_response_context.call_function = lambda _: mock_response + mock_response.text.return_value = '{"number":10,"string":"important"}' + headers = MutableHeaders() + headers["Content-Type"] = testcase.content_type + mock_response.headers = headers + + async def mock_aenter(_) -> Mock: + return mock_response + + async def mock_aexit(_, __, ___, ____) -> None: + return + + mock_get_func.__aenter__ = mock_aenter + mock_get_func.__aexit__ = mock_aexit + forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access + + response = await forward_client.forward_request_remote( + request=mock_request, + primary_url="test-url", + response_type=TestResponse, + ) + + assert response == TestResponse(number=10, string="important") + + +async def test_forward_request_with_integer_list_response(forward_client: ForwardClient) -> None: + mock_request = Mock(spec=Request) + mock_request.method = "GET" + mock_request.headers = Headers() + + mock_get_func = Mock() + mock_response_context = AsyncMock + mock_response = AsyncMock() + mock_response_context.call_function = lambda _: mock_response + mock_response.text.return_value = "[1, 2, 3, 10]" + headers = MutableHeaders() + headers["Content-Type"] = "application/json" + mock_response.headers = headers + + async def mock_aenter(_) -> Mock: + return mock_response + + async def mock_aexit(_, __, ___, ____) -> None: + return + + mock_get_func.__aenter__ = mock_aenter + mock_get_func.__aexit__ = mock_aexit + forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access + + response = await forward_client.forward_request_remote( + request=mock_request, + primary_url="test-url", + response_type=list[int], + ) + + assert response == [1, 2, 3, 10] + + +async def test_forward_request_with_integer_response(forward_client: ForwardClient) -> None: + mock_request = Mock(spec=Request) + mock_request.method = "GET" + mock_request.headers = Headers() + + mock_get_func = Mock() + mock_response_context = AsyncMock + mock_response = AsyncMock() + mock_response_context.call_function = lambda _: mock_response + mock_response.text.return_value = "12" + headers = MutableHeaders() + headers["Content-Type"] = "application/json" + mock_response.headers = headers + + async def mock_aenter(_) -> Mock: + return mock_response + + async def mock_aexit(_, __, ___, ____) -> None: + return + + mock_get_func.__aenter__ = mock_aenter + mock_get_func.__aexit__ = mock_aexit + forward_client._forward_client.get.return_value = mock_get_func # pylint: disable=protected-access + + response = await forward_client.forward_request_remote( + request=mock_request, + primary_url="test-url", + response_type=int, + ) + + assert response == 12