Skip to content

Commit

Permalink
fix: clear session cookie if new session gt CHUNK_SIZE (#3446)
Browse files Browse the repository at this point in the history
* fix: clear session cookie if new session gt CHUNK_SIZE

Fix an issue where the connection session cookie is not cleared if the response session is stored across multiple cookies.

Closes #3441

* Update litestar/middleware/session/client_side.py

Co-authored-by: Jacob Coffee <[email protected]>

* refactor: use dataclass utils to iterate over Cookie fields

---------

Co-authored-by: Jacob Coffee <[email protected]>
  • Loading branch information
peterschutt and JacobCoffee authored Apr 28, 2024
1 parent f06e951 commit 0670551
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 36 deletions.
74 changes: 39 additions & 35 deletions litestar/middleware/session/client_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import re
import time
from base64 import b64decode, b64encode
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from os import urandom
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Final, Literal, Mapping

from litestar.datastructures import MutableScopeHeaders
from litestar.datastructures.cookie import Cookie
Expand Down Expand Up @@ -38,12 +38,19 @@
NONCE_SIZE = 12
CHUNK_SIZE = 4096 - 64
AAD = b"additional_authenticated_data="
SET_COOKIE_INCLUDE = {f.name for f in fields(Cookie) if f.name not in {"key", "secret"}}
CLEAR_COOKIE_INCLUDE = {f.name for f in fields(Cookie) if f.name not in {"key", "secret", "max_age"}}


class ClientSideSessionBackend(BaseSessionBackend["CookieBackendConfig"]):
"""Cookie backend for SessionMiddleware."""

__slots__ = ("aesgcm", "cookie_re")
__slots__ = (
"_clear_cookie_params",
"_set_cookie_params",
"aesgcm",
"cookie_re",
)

def __init__(self, config: CookieBackendConfig) -> None:
"""Initialize ``ClientSideSessionBackend``.
Expand All @@ -54,6 +61,12 @@ def __init__(self, config: CookieBackendConfig) -> None:
super().__init__(config)
self.aesgcm = AESGCM(config.secret)
self.cookie_re = re.compile(rf"{self.config.key}(?:-\d+)?")
self._set_cookie_params: Final[Mapping[str, Any]] = dict(
extract_dataclass_items(config, exclude_none=True, include=SET_COOKIE_INCLUDE)
)
self._clear_cookie_params: Final[Mapping[str, Any]] = dict(
extract_dataclass_items(config, exclude_none=True, include=CLEAR_COOKIE_INCLUDE)
)

def dump_data(self, data: Any, scope: Scope | None = None) -> list[bytes]:
"""Given serializable data, including pydantic models and numpy types, dump it into a bytes string, encrypt,
Expand Down Expand Up @@ -107,19 +120,25 @@ def get_cookie_keys(self, connection: ASGIConnection) -> list[str]:
"""
return sorted(key for key in connection.cookies if self.cookie_re.fullmatch(key))

def _create_session_cookies(self, data: list[bytes], cookie_params: dict[str, Any] | None = None) -> list[Cookie]:
def get_cookie_key_set(self, connection: ASGIConnection) -> set[str]:
"""Return a set of cookie-keys from the connection if they match the session-cookie pattern.
.. versionadded:: 2.8.3
Args:
connection: An ASGIConnection instance
Returns:
A set of session-cookie keys
"""
return {key for key in connection.cookies if self.cookie_re.fullmatch(key)}

def _create_session_cookies(self, data: list[bytes]) -> list[Cookie]:
"""Create a list of cookies containing the session data.
If the data is split into multiple cookies, the key will be of the format ``session-{segment number}``,
however if only one cookie is needed, the key will be ``session``.
"""
if cookie_params is None:
cookie_params = dict(
extract_dataclass_items(
self.config,
exclude_none=True,
include={f for f in Cookie.__dict__ if f not in ("key", "secret")},
)
)
cookie_params = self._set_cookie_params

if len(data) == 1:
return [
Expand Down Expand Up @@ -156,38 +175,23 @@ async def store_in_message(self, scope_session: ScopeSession, message: Message,

scope = connection.scope
headers = MutableScopeHeaders.from_message(message)
cookie_keys = self.get_cookie_keys(connection)
connection_cookies = self.get_cookie_key_set(connection)
response_cookies: set[str] = set()

if scope_session and scope_session is not Empty:
data = self.dump_data(scope_session, scope=scope)
cookie_params = dict(
extract_dataclass_items(
self.config,
exclude_none=True,
include={f for f in Cookie.__dict__ if f not in ("key", "secret")},
)
)
for cookie in self._create_session_cookies(data, cookie_params):
for cookie in self._create_session_cookies(data):
headers.add("Set-Cookie", cookie.to_header(header=""))
# Cookies with the same key overwrite the earlier cookie with that key. To expire earlier session
# cookies, first check how many session cookies will not be overwritten in this upcoming response.
# If leftover cookies are greater than or equal to 1, that means older session cookies have to be
# expired and their names are in cookie_keys.
cookies_to_clear = cookie_keys[len(data) :] if len(cookie_keys) - len(data) > 0 else []
response_cookies.add(cookie.key)

cookies_to_clear = connection_cookies - response_cookies
else:
cookies_to_clear = cookie_keys
cookies_to_clear = connection_cookies

for cookie_key in cookies_to_clear:
cookie_params = dict(
extract_dataclass_items(
self.config,
exclude_none=True,
include={f for f in Cookie.__dict__ if f not in ("key", "secret", "max_age")},
)
)
headers.add(
"Set-Cookie",
Cookie(value="null", key=cookie_key, expires=0, **cookie_params).to_header(header=""),
Cookie(value="null", key=cookie_key, expires=0, **self._clear_cookie_params).to_header(header=""),
)

async def load_from_connection(self, connection: ASGIConnection) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from cryptography.exceptions import InvalidTag

from litestar import Request, get, post
from litestar.datastructures.headers import MutableScopeHeaders
from litestar.exceptions import ImproperlyConfiguredException
from litestar.middleware.session import SessionMiddleware
from litestar.middleware.session.client_side import (
Expand All @@ -18,7 +19,8 @@
CookieBackendConfig,
)
from litestar.serialization import encode_json
from litestar.testing import create_test_client
from litestar.testing import RequestFactory, create_test_client
from litestar.types.asgi_types import HTTPResponseStartEvent
from tests.helpers import randbytes


Expand Down Expand Up @@ -220,3 +222,25 @@ def test_load_data_should_raise_invalid_tag_if_tampered_aad(cookie_session_backe

with pytest.raises(InvalidTag):
cookie_session_backend.load_data(encoded)


async def test_store_in_message_clears_cookies_when_session_grows_gt_chunk_size(
cookie_session_backend: ClientSideSessionBackend,
) -> None:
"""Should clear the cookies when the session grows larger than the chunk size."""
# we have a connection that already contains a cookie header with the "session" key in it
connection = RequestFactory().get("/", headers={"Cookie": "session=foo"})
# we want to persist a new session that is larger than the chunk size
# by the time the encrypted data, nonce and associated data are b64 encoded, the size of
# this session will be > 2x larger than the chunk size
session = create_session(size=CHUNK_SIZE)
message: HTTPResponseStartEvent = {"type": "http.response.start", "status": 200, "headers": []}
await cookie_session_backend.store_in_message(session, message, connection)
# due to the large session stored in multiple chunks, we now enumerate the name of the cookies
# e.g., session-0, session-1, session-2, etc. This means we need to have a cookie with the name
# "session" in the response headers that is set to null to clear the original cookie.
headers = MutableScopeHeaders.from_message(message)
assert len(headers.headers) > 1
header_name, header_content = headers.headers[-1]
assert header_name == b"set-cookie"
assert header_content.startswith(b"session=null;")

0 comments on commit 0670551

Please sign in to comment.