From 7bf0a51f10f2c8254ecb43b0ba5da0685fc2f550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Thu, 28 Nov 2024 18:57:39 +0100 Subject: [PATCH] feat(core): Streaming multipart parser (#3872) --- litestar/_kwargs/extractors.py | 4 +- litestar/_multipart.py | 169 +++++++----------- litestar/connection/request.py | 6 +- litestar/datastructures/upload_file.py | 2 +- pyproject.toml | 1 + tests/unit/test_kwargs/test_multipart_data.py | 27 ++- uv.lock | 11 ++ 7 files changed, 106 insertions(+), 114 deletions(-) diff --git a/litestar/_kwargs/extractors.py b/litestar/_kwargs/extractors.py index e95add74ae..657e3a2c4b 100644 --- a/litestar/_kwargs/extractors.py +++ b/litestar/_kwargs/extractors.py @@ -340,8 +340,8 @@ async def _extract_multipart( connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key] connection.scope["_form"] # type: ignore[typeddict-item] if "_form" in connection.scope - else parse_multipart_form( - body=await connection.body(), + else await parse_multipart_form( + stream=connection.stream(), boundary=connection.content_type[-1].get("boundary", "").encode(), multipart_form_part_limit=multipart_form_part_limit, type_decoders=connection.route_handler.resolve_type_decoders(), diff --git a/litestar/_multipart.py b/litestar/_multipart.py index 55b36208aa..430fa13043 100644 --- a/litestar/_multipart.py +++ b/litestar/_multipart.py @@ -1,41 +1,22 @@ -"""The contents of this file were adapted from sanic. - -MIT License - -Copyright (c) 2016-present Sanic Community - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" - from __future__ import annotations import re from collections import defaultdict -from email.utils import decode_rfc2231 -from typing import TYPE_CHECKING, Any -from urllib.parse import unquote +from typing import TYPE_CHECKING, Any, AsyncGenerator + +from multipart import ( # type: ignore[import-untyped] + MultipartSegment, + ParserError, + ParserLimitReached, + PushMultipartParser, +) from litestar.datastructures.upload_file import UploadFile -from litestar.exceptions import ValidationException +from litestar.exceptions import ClientException -__all__ = ("parse_body", "parse_content_header", "parse_multipart_form") +__all__ = ("parse_content_header", "parse_multipart_form") +from litestar.utils.compat import async_next if TYPE_CHECKING: from litestar.types import TypeDecodersSequence @@ -67,34 +48,8 @@ def parse_content_header(value: str) -> tuple[str, dict[str, str]]: return value.strip().lower(), options -def parse_body(body: bytes, boundary: bytes, multipart_form_part_limit: int) -> list[bytes]: - """Split the body using the boundary - and validate the number of form parts is within the allowed limit. - - Args: - body: The form body. - boundary: The boundary used to separate form components. - multipart_form_part_limit: The limit of allowed form components - - Returns: - A list of form components. - """ - if not (body and boundary): - return [] - - form_parts = body.split(boundary, multipart_form_part_limit + 3)[1:-1] - - if len(form_parts) > multipart_form_part_limit: - raise ValidationException( - f"number of multipart components exceeds the allowed limit of {multipart_form_part_limit}, " - f"this potentially indicates a DoS attack" - ) - - return form_parts - - -def parse_multipart_form( - body: bytes, +async def parse_multipart_form( # noqa: C901 + stream: AsyncGenerator[bytes, None], boundary: bytes, multipart_form_part_limit: int = 1000, type_decoders: TypeDecodersSequence | None = None, @@ -102,7 +57,7 @@ def parse_multipart_form( """Parse multipart form data. Args: - body: Body of the request. + stream: Body of the request. boundary: Boundary of the multipart message. multipart_form_part_limit: Limit of the number of parts allowed. type_decoders: A sequence of type decoders to use. @@ -113,51 +68,55 @@ def parse_multipart_form( fields: defaultdict[str, list[Any]] = defaultdict(list) - for form_part in parse_body(body=body, boundary=boundary, multipart_form_part_limit=multipart_form_part_limit): - file_name = None - content_type = "text/plain" - content_charset = "utf-8" - field_name = None - line_index = 2 - line_end_index = 0 - headers: list[tuple[str, str]] = [] - - while line_end_index != -1: - line_end_index = form_part.find(b"\r\n", line_index) - form_line = form_part[line_index:line_end_index].decode("utf-8") - - if not form_line: - break - - line_index = line_end_index + 2 - colon_index = form_line.index(":") - current_idx = colon_index + 2 - form_header_field = form_line[:colon_index].lower() - form_header_value, form_parameters = parse_content_header(form_line[current_idx:]) - - if form_header_field == "content-disposition": - field_name = form_parameters.get("name") - file_name = form_parameters.get("filename") - - if file_name is None and (filename_with_asterisk := form_parameters.get("filename*")): - encoding, _, value = decode_rfc2231(filename_with_asterisk) - file_name = unquote(value, encoding=encoding or content_charset) - - elif form_header_field == "content-type": - content_type = form_header_value - content_charset = form_parameters.get("charset", "utf-8") - headers.append((form_header_field, form_header_value)) - - if field_name: - post_data = form_part[line_index:-4].lstrip(b"\r\n") - if file_name: - form_file = UploadFile( - content_type=content_type, filename=file_name, file_data=post_data, headers=dict(headers) - ) - fields[field_name].append(form_file) - elif post_data: - fields[field_name].append(post_data.decode(content_charset)) - else: - fields[field_name].append(None) + chunk = await async_next(stream, b"") + if not chunk: + return fields + + try: + with PushMultipartParser(boundary, max_segment_count=multipart_form_part_limit) as parser: + segment: MultipartSegment | None = None + data: UploadFile | bytearray = bytearray() + while not parser.closed: + for form_part in parser.parse(chunk): + if isinstance(form_part, MultipartSegment): + segment = form_part + if segment.filename: + data = UploadFile( + content_type=segment.content_type or "text/plain", + filename=segment.filename, + headers=dict(segment.headerlist), + ) + elif form_part: + if isinstance(data, UploadFile): + await data.write(form_part) + else: + data.extend(form_part) + else: + # end of part + if segment is None: + # we have reached the end of a segment before we have + # received a complete header segment + raise ClientException("Unexpected eof in multipart/form-data") + + if isinstance(data, UploadFile): + await data.seek(0) + fields[segment.name].append(data) + elif data: + fields[segment.name].append(data.decode(segment.charset or "utf-8")) + else: + fields[segment.name].append(None) + + # reset for next part + data = bytearray() + segment = None + + chunk = await async_next(stream, b"") + + except ParserError as exc: + raise ClientException("Invalid multipart/form-data") from exc + except ParserLimitReached: + # FIXME (3.0): This should raise a '413 - Request Entity Too Large', but for + # backwards compatibility, we keep it as a 400 for now + raise ClientException("Request Entity Too Large") from None return {k: v if len(v) > 1 else v[0] for k, v in fields.items()} diff --git a/litestar/connection/request.py b/litestar/connection/request.py index d874586b7f..db16c4cdc7 100644 --- a/litestar/connection/request.py +++ b/litestar/connection/request.py @@ -77,7 +77,7 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = """ super().__init__(scope, receive, send) self.is_connected: bool = True - self._body: bytes | EmptyType = Empty + self._body: bytes | EmptyType = self._connection_state.body self._form: FormMultiDict | EmptyType = Empty self._json: Any = Empty self._msgpack: Any = Empty @@ -264,8 +264,8 @@ async def form(self) -> FormMultiDict: if (form_data := self._connection_state.form) is Empty: content_type, options = self.content_type if content_type == RequestEncodingType.MULTI_PART: - form_data = parse_multipart_form( - body=await self.body(), + form_data = await parse_multipart_form( + stream=self.stream(), boundary=options.get("boundary", "").encode(), multipart_form_part_limit=self.app.multipart_form_part_limit, ) diff --git a/litestar/datastructures/upload_file.py b/litestar/datastructures/upload_file.py index 82b1da1ade..a1e5c7d6b4 100644 --- a/litestar/datastructures/upload_file.py +++ b/litestar/datastructures/upload_file.py @@ -49,7 +49,7 @@ def rolled_to_disk(self) -> bool: """ return getattr(self.file, "_rolled", False) - async def write(self, data: bytes) -> int: + async def write(self, data: bytes | bytearray) -> int: """Proxy for data writing. Args: diff --git a/pyproject.toml b/pyproject.toml index 2cf3bf128a..04194ab21e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "click", "rich>=13.0.0", "rich-click", + "multipart>=1.2.0", # default litestar plugins "litestar-htmx>=0.3.0" ] diff --git a/tests/unit/test_kwargs/test_multipart_data.py b/tests/unit/test_kwargs/test_multipart_data.py index fe67da01c5..9b2d295f6a 100644 --- a/tests/unit/test_kwargs/test_multipart_data.py +++ b/tests/unit/test_kwargs/test_multipart_data.py @@ -186,7 +186,11 @@ def test_multipart_request_multiple_files_with_headers(tmpdir: Any) -> None: "filename": "test2.txt", "content": "", "content_type": "text/plain", - "headers": [["content-disposition", "form-data"], ["x-custom", "f2"], ["content-type", "text/plain"]], + "headers": [ + ["content-disposition", 'form-data; name="test2"; filename="test2.txt"'], + ["x-custom", "f2"], + ["content-type", "text/plain"], + ], }, } @@ -292,6 +296,7 @@ def test_multipart_request_without_charset_for_filename() -> None: } +@pytest.mark.xfail(reason="filename* is deprecated and should not be used according to RFC-7578") def test_multipart_request_with_asterisks_filename() -> None: with create_test_client(form_handler) as client: response = client.post( @@ -456,13 +461,14 @@ async def hello_world(data: Optional[UploadFile] = Body(media_type=RequestEncodi @pytest.mark.parametrize("limit", (1000, 100, 10)) def test_multipart_form_part_limit(limit: int) -> None: @post("/", signature_types=[UploadFile]) - async def hello_world(data: List[UploadFile] = Body(media_type=RequestEncodingType.MULTI_PART)) -> None: - assert len(data) == limit + async def hello_world(data: List[UploadFile] = Body(media_type=RequestEncodingType.MULTI_PART)) -> dict: + return {"limit": len(data)} with create_test_client(route_handlers=[hello_world], multipart_form_part_limit=limit) as client: data = {str(i): "a" for i in range(limit)} response = client.post("/", files=data) assert response.status_code == HTTP_201_CREATED + assert response.json() == {"limit": limit} data = {str(i): "a" for i in range(limit)} data[str(limit + 1)] = "b" @@ -577,3 +583,18 @@ async def form_(request: Request, data: Annotated[AddProductFormMsgspec, Body(me headers={"Content-Type": "multipart/form-data; boundary=1f35df74046888ceaa62d8a534a076dd"}, ) assert response.status_code == HTTP_201_CREATED + + +def test_invalid_multipart_raises_client_error() -> None: + with create_test_client(form_handler) as client: + response = client.post( + "/form", + content=( + b"--20b303e711c4ab8c443184ac833ab00f\r\n" + b"Content-Disposition: form-data; " + b'name="value"\r\n\r\n' + b"--20b303e711c4ab8c44318833ab00f--\r\n" + ), + headers={"Content-Type": "multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f"}, + ) + assert response.status_code == HTTP_400_BAD_REQUEST diff --git a/uv.lock b/uv.lock index 169d9dfbba..a5598faa54 100644 --- a/uv.lock +++ b/uv.lock @@ -1641,6 +1641,7 @@ dependencies = [ { name = "litestar-htmx" }, { name = "msgspec" }, { name = "multidict" }, + { name = "multipart" }, { name = "polyfactory" }, { name = "pyyaml" }, { name = "rich" }, @@ -1843,6 +1844,7 @@ requires-dist = [ { name = "minijinja", marker = "extra == 'minijinja'", specifier = ">=1.0.0" }, { name = "msgspec", specifier = ">=0.18.2" }, { name = "multidict", specifier = ">=6.0.2" }, + { name = "multipart", specifier = ">=1.2.0" }, { name = "opentelemetry-instrumentation-asgi", marker = "extra == 'opentelemetry'" }, { name = "piccolo", marker = "extra == 'piccolo'" }, { name = "picologging", marker = "extra == 'picologging'" }, @@ -2298,6 +2300,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/b7/b9e70fde2c0f0c9af4cc5277782a89b66d35948ea3369ec9f598358c3ac5/multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506", size = 10051 }, ] +[[package]] +name = "multipart" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/5a/66075b21cf7622958e724c482ab9e0c150fe057b20743ec215706bf5dbac/multipart-1.2.0.tar.gz", hash = "sha256:fc9ec7177b642e07c3c360126b9c845160376e65db6fcfd04df63d58135e1e8b", size = 35932 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/49/89bdcbefe6fdb59ef67e7d6d407c88acdce7ca6441b73945f949862081c6/multipart-1.2.0-py3-none-any.whl", hash = "sha256:79ecedd8ad13e1b4888224cedcc7caee6b784cb89337b50c83ad24a09c66be0f", size = 12915 }, +] + [[package]] name = "mypy" version = "1.13.0"