Skip to content

Commit

Permalink
fix(connection): Fix creation of FormMultiDict in Request.form to pro…
Browse files Browse the repository at this point in the history
…perly handle multi-keys (#3639)

Fix creation of FormMultiDict in Request.form to properly handle multi-keys
  • Loading branch information
provinzkraut authored Jul 27, 2024
1 parent 592b77d commit 18d84d8
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 23 deletions.
30 changes: 20 additions & 10 deletions litestar/connection/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send =
super().__init__(scope, receive, send)
self.is_connected: bool = True
self._body: bytes | EmptyType = Empty
self._form: dict[str, str | list[str]] | EmptyType = Empty
self._form: FormMultiDict | EmptyType = Empty
self._json: Any = Empty
self._msgpack: Any = Empty
self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty
Expand Down Expand Up @@ -205,26 +205,36 @@ async def form(self) -> FormMultiDict:
A FormMultiDict instance
"""
if self._form is Empty:
if (form := self._connection_state.form) is not Empty:
self._form = form
else:
if (form_data := self._connection_state.form) is Empty:
content_type, options = self.content_type
if content_type == RequestEncodingType.MULTI_PART:
self._form = parse_multipart_form(
form_data = parse_multipart_form(
body=await self.body(),
boundary=options.get("boundary", "").encode(),
multipart_form_part_limit=self.app.multipart_form_part_limit,
)
elif content_type == RequestEncodingType.URL_ENCODED:
self._form = parse_url_encoded_form_data(
form_data = parse_url_encoded_form_data(
await self.body(),
)
else:
self._form = {}

self._connection_state.form = self._form
form_data = {}

self._connection_state.form = form_data

# form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a
# list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so
# multi-keys can be accessed properly
items = []
for k, v in form_data.items():
if isinstance(v, list):
for sv in v:
items.append((k, sv))
else:
items.append((k, v))
self._form = FormMultiDict(items)

return FormMultiDict(self._form)
return self._form

async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None:
"""Send a push promise.
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/test_connection/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pytest

from litestar import MediaType, Request, asgi, get
from litestar import MediaType, Request, asgi, get, post
from litestar.connection.base import empty_send
from litestar.datastructures import Address, Cookie
from litestar.exceptions import (
Expand Down Expand Up @@ -282,6 +282,24 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_form_urlencoded_multi_keys() -> None:
@post("/")
async def handler(request: Request) -> Any:
return (await request.form()).getall("foo")

with create_test_client(handler) as client:
assert client.post("/", data={"foo": ["1", "2"]}).json() == ["1", "2"]


def test_request_form_multipart_multi_keys() -> None:
@post("/")
async def handler(request: Request) -> int:
return len((await request.form()).getall("foo"))

with create_test_client(handler) as client:
assert client.post("/", data={"foo": "1"}, files={"foo": b"a"}).json() == 2


def test_request_body_then_stream() -> None:
async def app(scope: Any, receive: Receive, send: Send) -> None:
request = Request[Any, Any, Any](scope, receive)
Expand Down
23 changes: 11 additions & 12 deletions tests/unit/test_kwargs/test_multipart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,17 @@ async def form_multi_item_handler(request: Request) -> DefaultDict[str, list]:
data = await request.form()
output = defaultdict(list)
for key, value in data.multi_items():
for v in value:
if isinstance(v, UploadFile):
content = await v.read()
output[key].append(
{
"filename": v.filename,
"content": content.decode(),
"content_type": v.content_type,
}
)
else:
output[key].append(v)
if isinstance(value, UploadFile):
content = await value.read()
output[key].append(
{
"filename": value.filename,
"content": content.decode(),
"content_type": value.content_type,
}
)
else:
output[key].append(value)
return output


Expand Down

0 comments on commit 18d84d8

Please sign in to comment.