From 7ac2bef5c74ca587a46dc8a6ef78369945fd2cec Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Fri, 19 Apr 2024 13:00:20 +1000 Subject: [PATCH] refactor: deprecate CORSMiddleware from public interface. (#3404) * refactor: deprecate CORSMiddleware from public interface. It is used internally, its use is hardcoded and it doesn't work similar to other middleware when added to layers. Should be implementation detail IMO. Closes #3403 * fix: remove refs to CORSMiddleware in docstrings. * refactor: remove __all__ --- litestar/app.py | 4 +- litestar/middleware/_internal.py | 79 +++++++++++++++ litestar/middleware/cors.py | 97 ++++--------------- litestar/testing/helpers.py | 4 +- tests/unit/test_deprecations.py | 5 + .../test_middleware/test_cors_middleware.py | 2 +- 6 files changed, 106 insertions(+), 85 deletions(-) create mode 100644 litestar/middleware/_internal.py diff --git a/litestar/app.py b/litestar/app.py index b4826a9077..9877d48210 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -32,7 +32,7 @@ NoRouteMatchFoundException, ) from litestar.logging.config import LoggingConfig, get_logger_placeholder -from litestar.middleware.cors import CORSMiddleware +from litestar.middleware._internal import CORSMiddleware from litestar.openapi.config import OpenAPIConfig from litestar.plugins import ( CLIPluginProtocol, @@ -245,7 +245,7 @@ def __init__( this app. Can be overridden by route handlers. compression_config: Configures compression behaviour of the application, this enabled a builtin or user defined Compression middleware. - cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + cors_config: If set, configures CORS handling for the application. csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. debug: If ``True``, app errors rendered as HTML with a stack trace. dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. diff --git a/litestar/middleware/_internal.py b/litestar/middleware/_internal.py new file mode 100644 index 0000000000..393e5355d1 --- /dev/null +++ b/litestar/middleware/_internal.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar.datastructures import Headers, MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.middleware.base import AbstractMiddleware + +if TYPE_CHECKING: + from litestar.config.cors import CORSConfig + from litestar.types import ASGIApp, Message, Receive, Scope, Send + + +class CORSMiddleware(AbstractMiddleware): + """CORS Middleware.""" + + def __init__(self, app: ASGIApp, config: CORSConfig) -> None: + """Middleware that adds CORS validation to the application. + + Args: + app: The ``next`` ASGI app to call. + config: An instance of :class:`CORSConfig ` + """ + super().__init__(app=app, scopes={ScopeType.HTTP}) + self.config = config + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ASGI callable. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + headers = Headers.from_scope(scope=scope) + if origin := headers.get("origin"): + await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers)) + else: + await self.app(scope, receive, send) + + def send_wrapper(self, send: Send, origin: str, has_cookie: bool) -> Send: + """Wrap ``send`` to ensure that state is not disconnected. + + Args: + has_cookie: Boolean flag dictating if the connection has a cookie set. + origin: The value of the ``Origin`` header. + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + + async def wrapped_send(message: Message) -> None: + if message["type"] == "http.response.start": + message.setdefault("headers", []) + headers = MutableScopeHeaders.from_message(message=message) + headers.update(self.config.simple_headers) + + if (self.config.is_allow_all_origins and has_cookie) or ( + not self.config.is_allow_all_origins and self.config.is_origin_allowed(origin=origin) + ): + headers["Access-Control-Allow-Origin"] = origin + headers["Vary"] = "Origin" + + # We don't want to overwrite this for preflight requests. + allow_headers = headers.get("Access-Control-Allow-Headers") + if not allow_headers and self.config.allow_headers: + headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers))) + + allow_methods = headers.get("Access-Control-Allow-Methods") + if not allow_methods and self.config.allow_methods: + headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods))) + + await send(message) + + return wrapped_send diff --git a/litestar/middleware/cors.py b/litestar/middleware/cors.py index 010576aa6a..c323a4de6f 100644 --- a/litestar/middleware/cors.py +++ b/litestar/middleware/cors.py @@ -1,82 +1,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from litestar.datastructures import Headers, MutableScopeHeaders -from litestar.enums import ScopeType -from litestar.middleware.base import AbstractMiddleware - -__all__ = ("CORSMiddleware",) - - -if TYPE_CHECKING: - from litestar.config.cors import CORSConfig - from litestar.types import ASGIApp, Message, Receive, Scope, Send - - -class CORSMiddleware(AbstractMiddleware): - """CORS Middleware.""" - - def __init__(self, app: ASGIApp, config: CORSConfig) -> None: - """Middleware that adds CORS validation to the application. - - Args: - app: The ``next`` ASGI app to call. - config: An instance of :class:`CORSConfig ` - """ - super().__init__(app=app, scopes={ScopeType.HTTP}) - self.config = config - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """ASGI callable. - - Args: - scope: The ASGI connection scope. - receive: The ASGI receive function. - send: The ASGI send function. - - Returns: - None - """ - headers = Headers.from_scope(scope=scope) - if origin := headers.get("origin"): - await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers)) - else: - await self.app(scope, receive, send) - - def send_wrapper(self, send: Send, origin: str, has_cookie: bool) -> Send: - """Wrap ``send`` to ensure that state is not disconnected. - - Args: - has_cookie: Boolean flag dictating if the connection has a cookie set. - origin: The value of the ``Origin`` header. - send: The ASGI send function. - - Returns: - An ASGI send function. - """ - - async def wrapped_send(message: Message) -> None: - if message["type"] == "http.response.start": - message.setdefault("headers", []) - headers = MutableScopeHeaders.from_message(message=message) - headers.update(self.config.simple_headers) - - if (self.config.is_allow_all_origins and has_cookie) or ( - not self.config.is_allow_all_origins and self.config.is_origin_allowed(origin=origin) - ): - headers["Access-Control-Allow-Origin"] = origin - headers["Vary"] = "Origin" - - # We don't want to overwrite this for preflight requests. - allow_headers = headers.get("Access-Control-Allow-Headers") - if not allow_headers and self.config.allow_headers: - headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers))) - - allow_methods = headers.get("Access-Control-Allow-Methods") - if not allow_methods and self.config.allow_methods: - headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods))) - - await send(message) - - return wrapped_send +from typing import Any + +from litestar.middleware import _internal +from litestar.utils.deprecation import warn_deprecation + + +def __getattr__(name: str) -> Any: + if name == "CORSMiddleware": + warn_deprecation( + version="2.9", + deprecated_name=name, + kind="class", + removal_in="3.0", + info="CORS middleware has been removed from the public API.", + ) + return _internal.CORSMiddleware + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/litestar/testing/helpers.py b/litestar/testing/helpers.py index ab15609ffe..0635ef6531 100644 --- a/litestar/testing/helpers.py +++ b/litestar/testing/helpers.py @@ -169,7 +169,7 @@ def test_my_handler() -> None: this app. Can be overridden by route handlers. compression_config: Configures compression behaviour of the application, this enabled a builtin or user defined Compression middleware. - cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + cors_config: If set, configures CORS handling for the application. csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. debug: If ``True``, app errors rendered as HTML with a stack trace. dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. @@ -430,7 +430,7 @@ async def test_my_handler() -> None: this app. Can be overridden by route handlers. compression_config: Configures compression behaviour of the application, this enabled a builtin or user defined Compression middleware. - cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + cors_config: If set, configures CORS handling for the application. csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. debug: If ``True``, app errors rendered as HTML with a stack trace. dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index acc1beaa75..f644394ad1 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -168,3 +168,8 @@ def test_openapi_config_enabled_endpoints_deprecation() -> None: with pytest.warns(DeprecationWarning): OpenAPIConfig(title="API", version="1.0", enabled_endpoints={"redoc"}) + + +def test_cors_middleware_public_interface_deprecation() -> None: + with pytest.warns(DeprecationWarning): + from litestar.middleware.cors import CORSMiddleware # noqa: F401 diff --git a/tests/unit/test_middleware/test_cors_middleware.py b/tests/unit/test_middleware/test_cors_middleware.py index 4b9b49deec..58dfee120d 100644 --- a/tests/unit/test_middleware/test_cors_middleware.py +++ b/tests/unit/test_middleware/test_cors_middleware.py @@ -4,7 +4,7 @@ from litestar import get from litestar.config.cors import CORSConfig -from litestar.middleware.cors import CORSMiddleware +from litestar.middleware._internal import CORSMiddleware from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND from litestar.testing import create_test_client from litestar.types.asgi_types import Method