diff --git a/litestar/response/redirect.py b/litestar/response/redirect.py index 6a070769e9..b768ed8ca9 100644 --- a/litestar/response/redirect.py +++ b/litestar/response/redirect.py @@ -1,9 +1,11 @@ from __future__ import annotations import itertools -from typing import TYPE_CHECKING, Any, Iterable, Literal +from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence +from urllib.parse import urlencode from litestar.constants import REDIRECT_ALLOWED_MEDIA_TYPES, REDIRECT_STATUS_CODES +from litestar.datastructures import MultiDict from litestar.enums import MediaType from litestar.exceptions import ImproperlyConfiguredException from litestar.response.base import ASGIResponse, Response @@ -13,6 +15,8 @@ from litestar.utils.helpers import get_enum_string_value if TYPE_CHECKING: + from collections.abc import Mapping + from litestar.app import Litestar from litestar.background_tasks import BackgroundTask, BackgroundTasks from litestar.connection import Request @@ -94,6 +98,7 @@ def __init__( media_type: str | MediaType | None = None, status_code: RedirectStatusType | None = None, type_encoders: TypeEncodersMap | None = None, + query_params: Mapping[str, str | Sequence[str]] | MultiDict | None = None, ) -> None: """Initialize the response. @@ -108,12 +113,21 @@ def __init__( status_code: An HTTP status code. The status code should be one of 301, 302, 303, 307 or 308, otherwise an exception will be raised. type_encoders: A mapping of types to callables that transform them into types supported for serialization. + query_params: A dictionary of values from which the request's query will be generated. Raises: ImproperlyConfiguredException: Either if status code is not a redirect status code or media type is not supported. """ - self.url = path + if query_params is None: + self.url = path + elif isinstance(query_params, MultiDict): + # We can't use MultiDictMixin.dict() because it's not deterministic + query_params_dict = {k: query_params.getall(k) for k in query_params} + self.url = f"{path}?{urlencode(query_params_dict, doseq=True)}" + else: + self.url = f"{path}?{urlencode(query_params, doseq=True)}" + if status_code is None: status_code = HTTP_302_FOUND super().__init__( diff --git a/tests/unit/test_response/test_redirect_response.py b/tests/unit/test_response/test_redirect_response.py index fc1919697a..4058a85904 100644 --- a/tests/unit/test_response/test_redirect_response.py +++ b/tests/unit/test_response/test_redirect_response.py @@ -4,11 +4,14 @@ their API. """ -from typing import TYPE_CHECKING, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING import pytest from litestar import get +from litestar.datastructures import MultiDict from litestar.exceptions import ImproperlyConfiguredException from litestar.response.base import ASGIResponse from litestar.response.redirect import ASGIRedirectResponse, Redirect @@ -20,7 +23,7 @@ def test_redirect_response() -> None: - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = ASGIResponse(body=b"hello, world", media_type="text/plain") else: @@ -34,7 +37,7 @@ async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: def test_quoting_redirect_response() -> None: - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/test/": response = ASGIResponse(body=b"hello, world", media_type="text/plain") else: @@ -48,7 +51,7 @@ async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: def test_redirect_response_content_length_header() -> None: - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = ASGIResponse(body=b"hello", media_type="text/plain") else: @@ -67,7 +70,7 @@ def test_redirect_response_status_validation() -> None: def test_redirect_response_html_media_type() -> None: - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = ASGIResponse(body=b"hello") else: @@ -95,7 +98,7 @@ def test_redirect_response_media_type_validation() -> None: (308, 308), ], ) -def test_redirect_dynamic_status_code(status_code: Optional[int], expected_status_code: int) -> None: +def test_redirect_dynamic_status_code(status_code: int | None, expected_status_code: int) -> None: @get("/") def handler() -> Redirect: return Redirect(path="/something-else", status_code=status_code) # type: ignore[arg-type] @@ -107,8 +110,22 @@ def handler() -> Redirect: assert res.status_code == expected_status_code +@pytest.mark.parametrize( + "query_params", [{"single": "a", "list": ["b", "c"]}, MultiDict([("single", "a"), ("list", "b"), ("list", "c")])] +) +def test_redirect_with_query_params(query_params: dict[str, str | list[str]] | MultiDict) -> None: + @get("/") + def handler() -> Redirect: + return Redirect(path="/something-else", query_params=query_params) + + with create_test_client([handler]) as client: + location_header = client.get("/", follow_redirects=False).headers["location"] + expected = "/something-else?single=a&list=b&list=c" + assert location_header == expected + + @pytest.mark.parametrize("handler_status_code", [301, 307, None]) -def test_redirect(handler_status_code: Optional[int]) -> None: +def test_redirect(handler_status_code: int | None) -> None: @get("/", status_code=handler_status_code) def handler() -> Redirect: return Redirect(path="/something-else", status_code=handler_status_code) # type: ignore[arg-type]