Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add query params to redirect response #3901

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions litestar/response/redirect.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any, Iterable, Literal
from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence
from urllib.parse import urlencode

from litestar.constants import REDIRECT_ALLOWED_MEDIA_TYPES, REDIRECT_STATUS_CODES
from litestar.datastructures import MultiDict
from litestar.enums import MediaType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.base import ASGIResponse, Response
Expand All @@ -13,6 +15,8 @@
from litestar.utils.helpers import get_enum_string_value

if TYPE_CHECKING:
from collections.abc import Mapping

from litestar.app import Litestar
from litestar.background_tasks import BackgroundTask, BackgroundTasks
from litestar.connection import Request
Expand Down Expand Up @@ -94,6 +98,7 @@ def __init__(
media_type: str | MediaType | None = None,
status_code: RedirectStatusType | None = None,
type_encoders: TypeEncodersMap | None = None,
query_params: Mapping[str, str | Sequence[str]] | MultiDict | None = None,
) -> None:
"""Initialize the response.

Expand All @@ -108,12 +113,21 @@ def __init__(
status_code: An HTTP status code. The status code should be one of 301, 302, 303, 307 or 308,
otherwise an exception will be raised.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
query_params: A dictionary of values from which the request's query will be generated.

Raises:
ImproperlyConfiguredException: Either if status code is not a redirect status code or media type is not
supported.
"""
self.url = path
if query_params is None:
self.url = path
elif isinstance(query_params, MultiDict):
# We can't use MultiDictMixin.dict() because it's not deterministic
query_params_dict = {k: query_params.getall(k) for k in query_params}
self.url = f"{path}?{urlencode(query_params_dict, doseq=True)}"
else:
self.url = f"{path}?{urlencode(query_params, doseq=True)}"

if status_code is None:
status_code = HTTP_302_FOUND
super().__init__(
Expand Down
31 changes: 24 additions & 7 deletions tests/unit/test_response/test_redirect_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
their API.
"""

from typing import TYPE_CHECKING, Optional
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from litestar import get
from litestar.datastructures import MultiDict
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.base import ASGIResponse
from litestar.response.redirect import ASGIRedirectResponse, Redirect
Expand All @@ -20,7 +23,7 @@


def test_redirect_response() -> None:
async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = ASGIResponse(body=b"hello, world", media_type="text/plain")
else:
Expand All @@ -34,7 +37,7 @@ async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:


def test_quoting_redirect_response() -> None:
async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/test/":
response = ASGIResponse(body=b"hello, world", media_type="text/plain")
else:
Expand All @@ -48,7 +51,7 @@ async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:


def test_redirect_response_content_length_header() -> None:
async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = ASGIResponse(body=b"hello", media_type="text/plain")
else:
Expand All @@ -67,7 +70,7 @@ def test_redirect_response_status_validation() -> None:


def test_redirect_response_html_media_type() -> None:
async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = ASGIResponse(body=b"hello")
else:
Expand Down Expand Up @@ -95,7 +98,7 @@ def test_redirect_response_media_type_validation() -> None:
(308, 308),
],
)
def test_redirect_dynamic_status_code(status_code: Optional[int], expected_status_code: int) -> None:
def test_redirect_dynamic_status_code(status_code: int | None, expected_status_code: int) -> None:
@get("/")
def handler() -> Redirect:
return Redirect(path="/something-else", status_code=status_code) # type: ignore[arg-type]
Expand All @@ -107,8 +110,22 @@ def handler() -> Redirect:
assert res.status_code == expected_status_code


@pytest.mark.parametrize(
"query_params", [{"single": "a", "list": ["b", "c"]}, MultiDict([("single", "a"), ("list", "b"), ("list", "c")])]
)
def test_redirect_with_query_params(query_params: dict[str, str | list[str]] | MultiDict) -> None:
@get("/")
def handler() -> Redirect:
return Redirect(path="/something-else", query_params=query_params)

with create_test_client([handler]) as client:
location_header = client.get("/", follow_redirects=False).headers["location"]
expected = "/something-else?single=a&list=b&list=c"
assert location_header == expected


@pytest.mark.parametrize("handler_status_code", [301, 307, None])
def test_redirect(handler_status_code: Optional[int]) -> None:
def test_redirect(handler_status_code: int | None) -> None:
@get("/", status_code=handler_status_code)
def handler() -> Redirect:
return Redirect(path="/something-else", status_code=handler_status_code) # type: ignore[arg-type]
Expand Down
Loading