diff --git a/src/web/test/unit/application/test_create_app.py b/src/web/test/unit/application/test_create_app.py index 9916d68c..5b700ebd 100644 --- a/src/web/test/unit/application/test_create_app.py +++ b/src/web/test/unit/application/test_create_app.py @@ -18,7 +18,7 @@ from pydantic import BaseModel from pytest_mock import MockerFixture -from ..create_app import CreateApp, FlaskClientConfigurable, TFlaskClient +from ..create_app import ClientConfigurable, CreateApp, TFlaskClient class TestCreateApp(CreateApp): @@ -274,7 +274,7 @@ def test__configure_openapi__creates_flask_app_using_config( connexion_mock.assert_called_with(app_name, specification_dir=spec_path) def test__create_app__requires_flask_config( - self, flask_client_configurable: FlaskClientConfigurable[TFlaskClient] + self, flask_client_configurable: ClientConfigurable[TFlaskClient] ): with pytest.raises( Exception, diff --git a/src/web/test/unit/create_app.py b/src/web/test/unit/create_app.py index 7ca97b12..a9f0ce85 100644 --- a/src/web/test/unit/create_app.py +++ b/src/web/test/unit/create_app.py @@ -5,7 +5,16 @@ from contextlib import ExitStack from dataclasses import dataclass from functools import lru_cache -from typing import Any, Callable, Generator, Generic, Protocol, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Generic, + Protocol, + TypeVar, + cast, +) import json_logging import pytest @@ -21,7 +30,6 @@ ) from BL_Python.web.encryption import encrypt_flask_cookie from connexion import FlaskApp -from connexion.apps.abstract import TestClient as _TestClient from flask import Flask, Request, Response from flask.ctx import RequestContext from flask.sessions import SecureCookieSession @@ -30,10 +38,15 @@ from mock import MagicMock from pytest_mock import MockerFixture - -# In BL_Python.web `app` is always a Flask application -class TestClient(_TestClient): - app: Flask +# fmt: off +if TYPE_CHECKING: + from connexion.apps.abstract import TestClient as _TestClient # isort: skip + # In BL_Python.web `app` is always a Flask application + class TestClient(_TestClient): + app: Flask +else: + from connexion.apps.abstract import TestClient +# fmt: on TFlaskClient = FlaskClient | TestClient @@ -49,10 +62,9 @@ class FlaskClientInjector(Generic[T_flask_client]): client: T_flask_client injector: FlaskInjector - # connexion_app: FlaskApp | None = None -class FlaskAppGetter(Protocol, Generic[T_flask_app]): +class AppGetter(Protocol, Generic[T_flask_app]): """ A callable that instantiates a Flask application and returns the application with its IoC container. @@ -62,7 +74,7 @@ def __call__(self) -> FlaskAppInjector[T_flask_app]: ... -class FlaskClientConfigurable(Protocol, Generic[T_flask_client]): +class ClientConfigurable(Protocol, Generic[T_flask_client]): """ Get a Flask test client using the specified application configuration. @@ -72,14 +84,37 @@ class FlaskClientConfigurable(Protocol, Generic[T_flask_client]): Returns ------ - `FlaskClientInjector` + `FlaskClientInjector[T_flask_client]` """ def __call__(self, config: Config) -> FlaskClientInjector[T_flask_client]: ... -class FlaskRequestConfigurable(Protocol): +class OpenAPIClientConfigurable(ClientConfigurable[TestClient], Protocol): + """ + Get a Flask test client using the specified application configuration. + + Args + ------ + config: `Config` The custom application configuration used to instantiate the Flask app. + app_init_hook: `Callable[[FlaskAppInjector[FlaskApp]], None] | None = None` A method that is + called after the application is created, but before it is started. + + Returns + ------ + `FlaskClientInjector[TestClient]` + """ + + def __call__( # pyright: ignore[reportImplicitOverride] + self, + config: Config, + app_init_hook: Callable[[FlaskAppInjector[FlaskApp]], None] | None = None, + ) -> FlaskClientInjector[TestClient]: + ... + + +class RequestConfigurable(Protocol): """ Get a Flask request context, creating a Flask test client that uses the specified application configuration and, @@ -210,7 +245,7 @@ def _get_openapi_app( return next(self.__get_openapi_app(openapi_config, mocker)) def _flask_client( - self, flask_app_getter: FlaskAppGetter[Flask] + self, flask_app_getter: AppGetter[Flask] ) -> Generator[FlaskClientInjector[FlaskClient], Any, None]: with ExitStack() as stack: result = flask_app_getter() @@ -221,7 +256,7 @@ def _flask_client( client, FlaskClient ): # pyright: ignore[reportUnnecessaryIsInstance] raise Exception( - f"""This fixture created a `{type(client)}` test client, but is only meant for `{type(FlaskClient)}`. + f"""This fixture created a `{type(client)}` test client, but is only meant for `{FlaskClient}`. Ensure either that [openapi] is not set in the [flask] config, or use the `openapi_client` fixture.""" ) @@ -247,18 +282,16 @@ def _flask_client( yield FlaskClientInjector(client, result.injector) # , connexion_app) def _openapi_client( - self, flask_app_getter: FlaskAppGetter[FlaskApp] + self, flask_app_getter: AppGetter[FlaskApp] ) -> Generator[FlaskClientInjector[TestClient], Any, None]: with ExitStack() as stack: result = flask_app_getter() app = result.app client = stack.enter_context(app.test_client()) - if not isinstance( - client, TestClient - ): # pyright: ignore[reportUnnecessaryIsInstance] + if not isinstance(client, TestClient): raise Exception( - f"""This fixture created a `{type(client)}` test client, but is only meant for `{type(TestClient)}`. + f"""This fixture created a `{type(client)}` test client, but is only meant for `{TestClient}`. Ensure either that [openapi] is set in the [flask] config, or use the `flask_client` fixture.""" ) # client.cookies.set( @@ -284,10 +317,16 @@ def flask_client( ) -> FlaskClientInjector[FlaskClient]: return next(self._flask_client(lambda: _get_basic_flask_app)) + @pytest.fixture() + def openapi_client( + self, _get_basic_flask_app: FlaskAppInjector[FlaskApp] + ) -> FlaskClientInjector[TestClient]: + return next(self._openapi_client(lambda: _get_basic_flask_app)) + @pytest.fixture() def flask_client_configurable( self, mocker: MockerFixture - ) -> FlaskClientConfigurable[FlaskClient]: + ) -> ClientConfigurable[FlaskClient]: def _flask_client_getter(config: Config): return next( self._flask_client( @@ -297,6 +336,21 @@ def _flask_client_getter(config: Config): return _flask_client_getter + @pytest.fixture() + def openapi_client_configurable( + self, mocker: MockerFixture + ) -> OpenAPIClientConfigurable: + def _openapi_client_getter( + config: Config, + app_init_hook: Callable[[FlaskAppInjector[FlaskApp]], None] | None = None, + ): + application_result = next(self.__get_openapi_app(config, mocker)) + if app_init_hook is not None: + app_init_hook(application_result) + return next(self._openapi_client(lambda: application_result)) + + return _openapi_client_getter + def _flask_request( self, flask_client: FlaskClient, @@ -317,8 +371,8 @@ def flask_request( @pytest.fixture() def flask_request_configurable( self, - flask_client_configurable: FlaskClientConfigurable[FlaskClient], - ) -> FlaskRequestConfigurable: + flask_client_configurable: ClientConfigurable[FlaskClient], + ) -> RequestConfigurable: def _flask_request_getter( config: Config, request_context_args: dict[Any, Any] | None = None ): diff --git a/src/web/test/unit/middleware/test_api_response_handlers.py b/src/web/test/unit/middleware/test_api_response_handlers.py index 4511b4c6..fd1bf215 100644 --- a/src/web/test/unit/middleware/test_api_response_handlers.py +++ b/src/web/test/unit/middleware/test_api_response_handlers.py @@ -17,7 +17,7 @@ from pytest import LogCaptureFixture from pytest_mock import MockerFixture -from ..create_app import CreateApp, FlaskClientConfigurable, FlaskClientInjector +from ..create_app import ClientConfigurable, CreateApp, FlaskClientInjector class TestApiResponseHandlers(CreateApp): @@ -34,7 +34,7 @@ def test__register_api_response_handlers__binds_flask_before_request( def test__wrap_all_api_responses__sets_CSP_header( self, - flask_client_configurable: FlaskClientConfigurable[FlaskClient], + flask_client_configurable: ClientConfigurable[FlaskClient], basic_config: Config, ): csp_value = "default-src 'self' cdn.example.com;" @@ -71,7 +71,7 @@ def test__wrap_all_api_responses__sets_CORS_headers( header: str, value: str, config_attribute_name: str, - flask_client_configurable: FlaskClientConfigurable[FlaskClient], + flask_client_configurable: ClientConfigurable[FlaskClient], basic_config: Config, ): setattr(basic_config.web.security.cors, config_attribute_name, value) diff --git a/src/web/test/unit/middleware/test_middleware.py b/src/web/test/unit/middleware/test_middleware.py index bc6d75d0..2edd7ccf 100644 --- a/src/web/test/unit/middleware/test_middleware.py +++ b/src/web/test/unit/middleware/test_middleware.py @@ -5,6 +5,7 @@ import pytest import toml +from BL_Python.web.application import FlaskAppInjector from BL_Python.web.config import Config, FlaskConfig, FlaskOpenApiConfig from BL_Python.web.middleware import ( _get_correlation_id, # pyright: ignore[reportPrivateUsage] @@ -14,6 +15,7 @@ bind_errorhandler, bind_requesthandler, ) +from connexion import FlaskApp from flask import Flask, Response, abort from flask.testing import FlaskClient from mock import MagicMock @@ -21,9 +23,11 @@ from werkzeug.exceptions import BadRequest, HTTPException, Unauthorized from ..create_app import ( + ClientConfigurable, CreateApp, - FlaskClientConfigurable, - FlaskRequestConfigurable, + FlaskClientInjector, + OpenAPIClientConfigurable, + RequestConfigurable, TestClient, TFlaskClient, ) @@ -41,7 +45,7 @@ def test___register_api_response_handlers__sets_correlation_id_response_header_w self, config_type: str, format: Literal["plaintext", "JSON"], - flask_client_configurable: FlaskClientConfigurable[TFlaskClient], + flask_client_configurable: ClientConfigurable[TFlaskClient], basic_config: Config, ): basic_config.logging.format = format @@ -66,7 +70,7 @@ def test___register_api_response_handlers__sets_correlation_id_response_header_w self, config_type: str, format: Literal["plaintext", "JSON"], - flask_client_configurable: FlaskClientConfigurable[TFlaskClient], + flask_client_configurable: ClientConfigurable[TFlaskClient], basic_config: Config, ): basic_config.logging.format = format @@ -93,7 +97,7 @@ def test___get_correlation_id__validates_correlation_id_when_set_in_request_head self, config_type: str, format: Literal["plaintext", "JSON"], - flask_request_configurable: FlaskRequestConfigurable, + flask_request_configurable: RequestConfigurable, basic_config: Config, ): basic_config.logging.format = format @@ -121,7 +125,7 @@ def test___get_correlation_id__uses_existing_correlation_id_when_set_in_request_ self, config_type: str, format: Literal["plaintext", "JSON"], - flask_request_configurable: FlaskRequestConfigurable, + flask_request_configurable: RequestConfigurable, basic_config: Config, ): basic_config.logging.format = format @@ -147,7 +151,7 @@ def test___get_correlation_id__sets_correlation_id( self, config_type: str, format: Literal["plaintext", "JSON"], - flask_request_configurable: FlaskRequestConfigurable, + flask_request_configurable: RequestConfigurable, basic_config: Config, ): basic_config.logging.format = format @@ -165,7 +169,7 @@ def test___get_correlation_id__sets_correlation_id( def test__bind_requesthandler__returns_decorated_flask_request_hook( self, config_type: str, - flask_client_configurable: FlaskClientConfigurable[TFlaskClient], + flask_client_configurable: ClientConfigurable[TFlaskClient], basic_config: Config, ): flask_request_hook_mock = MagicMock() @@ -192,7 +196,7 @@ def test__bind_requesthandler__returns_decorated_flask_request_hook( def test__bind_requesthandler__calls_decorated_function_when_app_is_run( self, config_type: str, - flask_client_configurable: FlaskClientConfigurable[FlaskClient], + flask_client_configurable: ClientConfigurable[FlaskClient], basic_config: Config, ): if config_type == "openapi": @@ -226,7 +230,7 @@ def test__bind_errorhandler__binds_flask_errorhandler( self, code_or_exception: type[Exception] | int, config_type: str, - flask_client_configurable: FlaskClientConfigurable[TFlaskClient], + flask_client_configurable: ClientConfigurable[TFlaskClient], basic_config: Config, mocker: MockerFixture, ): @@ -261,7 +265,7 @@ def test__bind_errorhandler__from_Flask_calls_decorated_function_with_correct_er expected_exception_type: type[Exception], failure_lambda: Callable[[], Response], basic_config: Config, - flask_client_configurable: FlaskClientConfigurable[FlaskClient], + flask_client_configurable: ClientConfigurable[FlaskClient], ): flask_client = flask_client_configurable(basic_config) @@ -297,7 +301,7 @@ def test__bind_errorhandler__from_Connexion_calls_decorated_function_with_correc expected_exception_type: type[Exception], failure_lambda: Callable[[], Response], openapi_config: Config, - flask_client_configurable: FlaskClientConfigurable[TestClient], + openapi_client_configurable: OpenAPIClientConfigurable, mocker: MockerFixture, ): # fake_config_dict: AnyDict = { @@ -307,11 +311,13 @@ def test__bind_errorhandler__from_Connexion_calls_decorated_function_with_correc # The resolver in Connexion uses importlib to find operations # in the OpenAPI spec. Instead, just replace `import_module` # with this method as the return value. Connexion also - # requires that the `root` attribute exists. + # requires that the `root` attribute exists because that is + # the name of the OperationID in the fake OpenAPI spec. def fake_operation_method(): return "Hello" fake_operation_method.root = "/" + _ = mocker.patch( "connexion.utils.importlib", spec=importlib, @@ -323,16 +329,17 @@ def fake_operation_method(): return_value=toml.dumps(openapi_config.model_dump()), ) - # FIXME need a fixture to create a Connexion test client? - flask_client = flask_client_configurable(openapi_config) - application_errorhandler_mock = MagicMock() - _ = bind_errorhandler(flask_client.client.app, code_or_exception_type)( - application_errorhandler_mock - ) + + def app_init_hook(app: FlaskAppInjector[FlaskApp]): + _ = bind_errorhandler(app.app, code_or_exception_type)( + application_errorhandler_mock + ) + # _ = app.app.route("/")(failure_lambda) + + flask_client = openapi_client_configurable(openapi_config, app_init_hook) # this probably doesn't need to be done w/ connexion - _ = flask_client.route("/")(failure_lambda) _ = flask_client.client.get("/")