Skip to content

Commit

Permalink
refactor: deprecate CORSMiddleware from public interface. (litestar-o…
Browse files Browse the repository at this point in the history
…rg#3404)

* refactor: deprecate CORSMiddleware from public interface.

It is used internally, its use is hardcoded and it doesn't work similar to other middleware when added to layers.

Should be implementation detail IMO.

Closes litestar-org#3403

* fix: remove refs to CORSMiddleware in docstrings.

* refactor: remove __all__
  • Loading branch information
peterschutt authored Apr 19, 2024
1 parent fb5f744 commit 7ac2bef
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 85 deletions.
4 changes: 2 additions & 2 deletions litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
NoRouteMatchFoundException,
)
from litestar.logging.config import LoggingConfig, get_logger_placeholder
from litestar.middleware.cors import CORSMiddleware
from litestar.middleware._internal import CORSMiddleware
from litestar.openapi.config import OpenAPIConfig
from litestar.plugins import (
CLIPluginProtocol,
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(
this app. Can be overridden by route handlers.
compression_config: Configures compression behaviour of the application, this enabled a builtin or user
defined Compression middleware.
cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`.
cors_config: If set, configures CORS handling for the application.
csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`.
debug: If ``True``, app errors rendered as HTML with a stack trace.
dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`.
Expand Down
79 changes: 79 additions & 0 deletions litestar/middleware/_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from litestar.datastructures import Headers, MutableScopeHeaders
from litestar.enums import ScopeType
from litestar.middleware.base import AbstractMiddleware

if TYPE_CHECKING:
from litestar.config.cors import CORSConfig
from litestar.types import ASGIApp, Message, Receive, Scope, Send


class CORSMiddleware(AbstractMiddleware):
"""CORS Middleware."""

def __init__(self, app: ASGIApp, config: CORSConfig) -> None:
"""Middleware that adds CORS validation to the application.
Args:
app: The ``next`` ASGI app to call.
config: An instance of :class:`CORSConfig <litestar.config.cors.CORSConfig>`
"""
super().__init__(app=app, scopes={ScopeType.HTTP})
self.config = config

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI callable.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None
"""
headers = Headers.from_scope(scope=scope)
if origin := headers.get("origin"):
await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers))
else:
await self.app(scope, receive, send)

def send_wrapper(self, send: Send, origin: str, has_cookie: bool) -> Send:
"""Wrap ``send`` to ensure that state is not disconnected.
Args:
has_cookie: Boolean flag dictating if the connection has a cookie set.
origin: The value of the ``Origin`` header.
send: The ASGI send function.
Returns:
An ASGI send function.
"""

async def wrapped_send(message: Message) -> None:
if message["type"] == "http.response.start":
message.setdefault("headers", [])
headers = MutableScopeHeaders.from_message(message=message)
headers.update(self.config.simple_headers)

if (self.config.is_allow_all_origins and has_cookie) or (
not self.config.is_allow_all_origins and self.config.is_origin_allowed(origin=origin)
):
headers["Access-Control-Allow-Origin"] = origin
headers["Vary"] = "Origin"

# We don't want to overwrite this for preflight requests.
allow_headers = headers.get("Access-Control-Allow-Headers")
if not allow_headers and self.config.allow_headers:
headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers)))

allow_methods = headers.get("Access-Control-Allow-Methods")
if not allow_methods and self.config.allow_methods:
headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods)))

await send(message)

return wrapped_send
97 changes: 17 additions & 80 deletions litestar/middleware/cors.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from litestar.datastructures import Headers, MutableScopeHeaders
from litestar.enums import ScopeType
from litestar.middleware.base import AbstractMiddleware

__all__ = ("CORSMiddleware",)


if TYPE_CHECKING:
from litestar.config.cors import CORSConfig
from litestar.types import ASGIApp, Message, Receive, Scope, Send


class CORSMiddleware(AbstractMiddleware):
"""CORS Middleware."""

def __init__(self, app: ASGIApp, config: CORSConfig) -> None:
"""Middleware that adds CORS validation to the application.
Args:
app: The ``next`` ASGI app to call.
config: An instance of :class:`CORSConfig <litestar.config.cors.CORSConfig>`
"""
super().__init__(app=app, scopes={ScopeType.HTTP})
self.config = config

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI callable.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None
"""
headers = Headers.from_scope(scope=scope)
if origin := headers.get("origin"):
await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers))
else:
await self.app(scope, receive, send)

def send_wrapper(self, send: Send, origin: str, has_cookie: bool) -> Send:
"""Wrap ``send`` to ensure that state is not disconnected.
Args:
has_cookie: Boolean flag dictating if the connection has a cookie set.
origin: The value of the ``Origin`` header.
send: The ASGI send function.
Returns:
An ASGI send function.
"""

async def wrapped_send(message: Message) -> None:
if message["type"] == "http.response.start":
message.setdefault("headers", [])
headers = MutableScopeHeaders.from_message(message=message)
headers.update(self.config.simple_headers)

if (self.config.is_allow_all_origins and has_cookie) or (
not self.config.is_allow_all_origins and self.config.is_origin_allowed(origin=origin)
):
headers["Access-Control-Allow-Origin"] = origin
headers["Vary"] = "Origin"

# We don't want to overwrite this for preflight requests.
allow_headers = headers.get("Access-Control-Allow-Headers")
if not allow_headers and self.config.allow_headers:
headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers)))

allow_methods = headers.get("Access-Control-Allow-Methods")
if not allow_methods and self.config.allow_methods:
headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods)))

await send(message)

return wrapped_send
from typing import Any

from litestar.middleware import _internal
from litestar.utils.deprecation import warn_deprecation


def __getattr__(name: str) -> Any:
if name == "CORSMiddleware":
warn_deprecation(
version="2.9",
deprecated_name=name,
kind="class",
removal_in="3.0",
info="CORS middleware has been removed from the public API.",
)
return _internal.CORSMiddleware
raise AttributeError(f"module {__name__} has no attribute {name}")
4 changes: 2 additions & 2 deletions litestar/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_my_handler() -> None:
this app. Can be overridden by route handlers.
compression_config: Configures compression behaviour of the application, this enabled a builtin or user
defined Compression middleware.
cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`.
cors_config: If set, configures CORS handling for the application.
csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`.
debug: If ``True``, app errors rendered as HTML with a stack trace.
dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`.
Expand Down Expand Up @@ -430,7 +430,7 @@ async def test_my_handler() -> None:
this app. Can be overridden by route handlers.
compression_config: Configures compression behaviour of the application, this enabled a builtin or user
defined Compression middleware.
cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`.
cors_config: If set, configures CORS handling for the application.
csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`.
debug: If ``True``, app errors rendered as HTML with a stack trace.
dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`.
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,8 @@ def test_openapi_config_enabled_endpoints_deprecation() -> None:

with pytest.warns(DeprecationWarning):
OpenAPIConfig(title="API", version="1.0", enabled_endpoints={"redoc"})


def test_cors_middleware_public_interface_deprecation() -> None:
with pytest.warns(DeprecationWarning):
from litestar.middleware.cors import CORSMiddleware # noqa: F401
2 changes: 1 addition & 1 deletion tests/unit/test_middleware/test_cors_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from litestar import get
from litestar.config.cors import CORSConfig
from litestar.middleware.cors import CORSMiddleware
from litestar.middleware._internal import CORSMiddleware
from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND
from litestar.testing import create_test_client
from litestar.types.asgi_types import Method
Expand Down

0 comments on commit 7ac2bef

Please sign in to comment.