From 0670551d739751d3d8daa876729ad831811e5d3e Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Sun, 28 Apr 2024 18:48:03 +1000 Subject: [PATCH] fix: clear session cookie if new session gt CHUNK_SIZE (#3446) * fix: clear session cookie if new session gt CHUNK_SIZE Fix an issue where the connection session cookie is not cleared if the response session is stored across multiple cookies. Closes #3441 * Update litestar/middleware/session/client_side.py Co-authored-by: Jacob Coffee * refactor: use dataclass utils to iterate over Cookie fields --------- Co-authored-by: Jacob Coffee --- litestar/middleware/session/client_side.py | 74 ++++++++++--------- .../test_session/test_client_side_backend.py | 26 ++++++- 2 files changed, 64 insertions(+), 36 deletions(-) diff --git a/litestar/middleware/session/client_side.py b/litestar/middleware/session/client_side.py index f709410478..ddde1f288e 100644 --- a/litestar/middleware/session/client_side.py +++ b/litestar/middleware/session/client_side.py @@ -5,9 +5,9 @@ import re import time from base64 import b64decode, b64encode -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from os import urandom -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Final, Literal, Mapping from litestar.datastructures import MutableScopeHeaders from litestar.datastructures.cookie import Cookie @@ -38,12 +38,19 @@ NONCE_SIZE = 12 CHUNK_SIZE = 4096 - 64 AAD = b"additional_authenticated_data=" +SET_COOKIE_INCLUDE = {f.name for f in fields(Cookie) if f.name not in {"key", "secret"}} +CLEAR_COOKIE_INCLUDE = {f.name for f in fields(Cookie) if f.name not in {"key", "secret", "max_age"}} class ClientSideSessionBackend(BaseSessionBackend["CookieBackendConfig"]): """Cookie backend for SessionMiddleware.""" - __slots__ = ("aesgcm", "cookie_re") + __slots__ = ( + "_clear_cookie_params", + "_set_cookie_params", + "aesgcm", + "cookie_re", + ) def __init__(self, config: CookieBackendConfig) -> None: """Initialize ``ClientSideSessionBackend``. @@ -54,6 +61,12 @@ def __init__(self, config: CookieBackendConfig) -> None: super().__init__(config) self.aesgcm = AESGCM(config.secret) self.cookie_re = re.compile(rf"{self.config.key}(?:-\d+)?") + self._set_cookie_params: Final[Mapping[str, Any]] = dict( + extract_dataclass_items(config, exclude_none=True, include=SET_COOKIE_INCLUDE) + ) + self._clear_cookie_params: Final[Mapping[str, Any]] = dict( + extract_dataclass_items(config, exclude_none=True, include=CLEAR_COOKIE_INCLUDE) + ) def dump_data(self, data: Any, scope: Scope | None = None) -> list[bytes]: """Given serializable data, including pydantic models and numpy types, dump it into a bytes string, encrypt, @@ -107,19 +120,25 @@ def get_cookie_keys(self, connection: ASGIConnection) -> list[str]: """ return sorted(key for key in connection.cookies if self.cookie_re.fullmatch(key)) - def _create_session_cookies(self, data: list[bytes], cookie_params: dict[str, Any] | None = None) -> list[Cookie]: + def get_cookie_key_set(self, connection: ASGIConnection) -> set[str]: + """Return a set of cookie-keys from the connection if they match the session-cookie pattern. + + .. versionadded:: 2.8.3 + + Args: + connection: An ASGIConnection instance + + Returns: + A set of session-cookie keys + """ + return {key for key in connection.cookies if self.cookie_re.fullmatch(key)} + + def _create_session_cookies(self, data: list[bytes]) -> list[Cookie]: """Create a list of cookies containing the session data. If the data is split into multiple cookies, the key will be of the format ``session-{segment number}``, however if only one cookie is needed, the key will be ``session``. """ - if cookie_params is None: - cookie_params = dict( - extract_dataclass_items( - self.config, - exclude_none=True, - include={f for f in Cookie.__dict__ if f not in ("key", "secret")}, - ) - ) + cookie_params = self._set_cookie_params if len(data) == 1: return [ @@ -156,38 +175,23 @@ async def store_in_message(self, scope_session: ScopeSession, message: Message, scope = connection.scope headers = MutableScopeHeaders.from_message(message) - cookie_keys = self.get_cookie_keys(connection) + connection_cookies = self.get_cookie_key_set(connection) + response_cookies: set[str] = set() if scope_session and scope_session is not Empty: data = self.dump_data(scope_session, scope=scope) - cookie_params = dict( - extract_dataclass_items( - self.config, - exclude_none=True, - include={f for f in Cookie.__dict__ if f not in ("key", "secret")}, - ) - ) - for cookie in self._create_session_cookies(data, cookie_params): + for cookie in self._create_session_cookies(data): headers.add("Set-Cookie", cookie.to_header(header="")) - # Cookies with the same key overwrite the earlier cookie with that key. To expire earlier session - # cookies, first check how many session cookies will not be overwritten in this upcoming response. - # If leftover cookies are greater than or equal to 1, that means older session cookies have to be - # expired and their names are in cookie_keys. - cookies_to_clear = cookie_keys[len(data) :] if len(cookie_keys) - len(data) > 0 else [] + response_cookies.add(cookie.key) + + cookies_to_clear = connection_cookies - response_cookies else: - cookies_to_clear = cookie_keys + cookies_to_clear = connection_cookies for cookie_key in cookies_to_clear: - cookie_params = dict( - extract_dataclass_items( - self.config, - exclude_none=True, - include={f for f in Cookie.__dict__ if f not in ("key", "secret", "max_age")}, - ) - ) headers.add( "Set-Cookie", - Cookie(value="null", key=cookie_key, expires=0, **cookie_params).to_header(header=""), + Cookie(value="null", key=cookie_key, expires=0, **self._clear_cookie_params).to_header(header=""), ) async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]: diff --git a/tests/unit/test_middleware/test_session/test_client_side_backend.py b/tests/unit/test_middleware/test_session/test_client_side_backend.py index 782fa9f561..8d1a51aaa0 100644 --- a/tests/unit/test_middleware/test_session/test_client_side_backend.py +++ b/tests/unit/test_middleware/test_session/test_client_side_backend.py @@ -9,6 +9,7 @@ from cryptography.exceptions import InvalidTag from litestar import Request, get, post +from litestar.datastructures.headers import MutableScopeHeaders from litestar.exceptions import ImproperlyConfiguredException from litestar.middleware.session import SessionMiddleware from litestar.middleware.session.client_side import ( @@ -18,7 +19,8 @@ CookieBackendConfig, ) from litestar.serialization import encode_json -from litestar.testing import create_test_client +from litestar.testing import RequestFactory, create_test_client +from litestar.types.asgi_types import HTTPResponseStartEvent from tests.helpers import randbytes @@ -220,3 +222,25 @@ def test_load_data_should_raise_invalid_tag_if_tampered_aad(cookie_session_backe with pytest.raises(InvalidTag): cookie_session_backend.load_data(encoded) + + +async def test_store_in_message_clears_cookies_when_session_grows_gt_chunk_size( + cookie_session_backend: ClientSideSessionBackend, +) -> None: + """Should clear the cookies when the session grows larger than the chunk size.""" + # we have a connection that already contains a cookie header with the "session" key in it + connection = RequestFactory().get("/", headers={"Cookie": "session=foo"}) + # we want to persist a new session that is larger than the chunk size + # by the time the encrypted data, nonce and associated data are b64 encoded, the size of + # this session will be > 2x larger than the chunk size + session = create_session(size=CHUNK_SIZE) + message: HTTPResponseStartEvent = {"type": "http.response.start", "status": 200, "headers": []} + await cookie_session_backend.store_in_message(session, message, connection) + # due to the large session stored in multiple chunks, we now enumerate the name of the cookies + # e.g., session-0, session-1, session-2, etc. This means we need to have a cookie with the name + # "session" in the response headers that is set to null to clear the original cookie. + headers = MutableScopeHeaders.from_message(message) + assert len(headers.headers) > 1 + header_name, header_content = headers.headers[-1] + assert header_name == b"set-cookie" + assert header_content.startswith(b"session=null;")