diff --git a/litestar/connection/request.py b/litestar/connection/request.py index 254c31561f..23c60f0b3c 100644 --- a/litestar/connection/request.py +++ b/litestar/connection/request.py @@ -74,7 +74,7 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = super().__init__(scope, receive, send) self.is_connected: bool = True self._body: bytes | EmptyType = Empty - self._form: dict[str, str | list[str]] | EmptyType = Empty + self._form: FormMultiDict | EmptyType = Empty self._json: Any = Empty self._msgpack: Any = Empty self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty @@ -205,26 +205,36 @@ async def form(self) -> FormMultiDict: A FormMultiDict instance """ if self._form is Empty: - if (form := self._connection_state.form) is not Empty: - self._form = form - else: + if (form_data := self._connection_state.form) is Empty: content_type, options = self.content_type if content_type == RequestEncodingType.MULTI_PART: - self._form = parse_multipart_form( + form_data = parse_multipart_form( body=await self.body(), boundary=options.get("boundary", "").encode(), multipart_form_part_limit=self.app.multipart_form_part_limit, ) elif content_type == RequestEncodingType.URL_ENCODED: - self._form = parse_url_encoded_form_data( + form_data = parse_url_encoded_form_data( await self.body(), ) else: - self._form = {} - - self._connection_state.form = self._form + form_data = {} + + self._connection_state.form = form_data + + # form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a + # list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so + # multi-keys can be accessed properly + items = [] + for k, v in form_data.items(): + if isinstance(v, list): + for sv in v: + items.append((k, sv)) + else: + items.append((k, v)) + self._form = FormMultiDict(items) - return FormMultiDict(self._form) + return self._form async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None: """Send a push promise. diff --git a/tests/unit/test_connection/test_request.py b/tests/unit/test_connection/test_request.py index 690e4c5afa..7393211647 100644 --- a/tests/unit/test_connection/test_request.py +++ b/tests/unit/test_connection/test_request.py @@ -11,7 +11,7 @@ import pytest -from litestar import MediaType, Request, asgi, get +from litestar import MediaType, Request, asgi, get, post from litestar.connection.base import empty_send from litestar.datastructures import Address, Cookie from litestar.exceptions import ( @@ -282,6 +282,24 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"form": {"abc": "123 @"}} +def test_request_form_urlencoded_multi_keys() -> None: + @post("/") + async def handler(request: Request) -> Any: + return (await request.form()).getall("foo") + + with create_test_client(handler) as client: + assert client.post("/", data={"foo": ["1", "2"]}).json() == ["1", "2"] + + +def test_request_form_multipart_multi_keys() -> None: + @post("/") + async def handler(request: Request) -> int: + return len((await request.form()).getall("foo")) + + with create_test_client(handler) as client: + assert client.post("/", data={"foo": "1"}, files={"foo": b"a"}).json() == 2 + + def test_request_body_then_stream() -> None: async def app(scope: Any, receive: Receive, send: Send) -> None: request = Request[Any, Any, Any](scope, receive) diff --git a/tests/unit/test_kwargs/test_multipart_data.py b/tests/unit/test_kwargs/test_multipart_data.py index 1f0e15f3ff..fe67da01c5 100644 --- a/tests/unit/test_kwargs/test_multipart_data.py +++ b/tests/unit/test_kwargs/test_multipart_data.py @@ -51,18 +51,17 @@ async def form_multi_item_handler(request: Request) -> DefaultDict[str, list]: data = await request.form() output = defaultdict(list) for key, value in data.multi_items(): - for v in value: - if isinstance(v, UploadFile): - content = await v.read() - output[key].append( - { - "filename": v.filename, - "content": content.decode(), - "content_type": v.content_type, - } - ) - else: - output[key].append(v) + if isinstance(value, UploadFile): + content = await value.read() + output[key].append( + { + "filename": value.filename, + "content": content.decode(), + "content_type": value.content_type, + } + ) + else: + output[key].append(value) return output