From 9e99ca569140861084af258a14ede58ebd54010a Mon Sep 17 00:00:00 2001 From: guacs <126393040+guacs@users.noreply.github.com> Date: Fri, 5 Jan 2024 20:15:21 +0530 Subject: [PATCH] fix: handle single file in cases of multiple file uploads (#2950) * fix: correctly parse files when uploading list of files When a model is specified for uploading files and a specific field in that model is annotated with `list[UploadFile]` and a single file is uploaded, the validation fails. This is because the extracted file is not passed onto `msgspec` as a list. This fixes that by ensuring that the value is converted into a list if needed before giving it to `msgspec` for validation. * fix: handle case where name may not be in the form data If there was no field name with the expected name (as per the given model), we don't want to raise a KeyError there. Instead the missing field name should be raised as a validation error by the validation library. * refactor: use API provided by FieldDefinition --- litestar/_kwargs/extractors.py | 13 ++++++++++-- tests/unit/test_kwargs/test_multipart_data.py | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/litestar/_kwargs/extractors.py b/litestar/_kwargs/extractors.py index 4c2b8d19ff..9fd1ccf25a 100644 --- a/litestar/_kwargs/extractors.py +++ b/litestar/_kwargs/extractors.py @@ -15,6 +15,7 @@ from litestar.exceptions import ValidationException from litestar.params import BodyKwarg from litestar.types import Empty +from litestar.utils.predicates import is_non_string_sequence from litestar.utils.scope.state import ScopeState if TYPE_CHECKING: @@ -355,7 +356,7 @@ async def extract_multipart( if field_definition.is_non_string_sequence: values = list(form_values.values()) - if field_definition.inner_types[0].annotation is UploadFile and isinstance(values[0], list): + if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list): return values[0] return values @@ -366,7 +367,15 @@ async def extract_multipart( if not form_values and is_data_optional: return None - return data_dto(connection).decode_builtins(form_values) if data_dto else form_values + if data_dto: + return data_dto(connection).decode_builtins(form_values) + + for name, tp in field_definition.get_type_hints().items(): + value = form_values.get(name) + if value is not None and is_non_string_sequence(tp) and not isinstance(value, list): + form_values[name] = [value] + + return form_values return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart) diff --git a/tests/unit/test_kwargs/test_multipart_data.py b/tests/unit/test_kwargs/test_multipart_data.py index 2cfbc5962e..94b707db03 100644 --- a/tests/unit/test_kwargs/test_multipart_data.py +++ b/tests/unit/test_kwargs/test_multipart_data.py @@ -410,6 +410,27 @@ async def handler(data: List[UploadFile] = Body(media_type=RequestEncodingType.M assert response.status_code == HTTP_201_CREATED +@dataclass +class Files: + file_list: List[UploadFile] + + +@pytest.mark.parametrize("file_count", (1, 2)) +def test_upload_multiple_files_in_model(file_count: int) -> None: + @post("/") + async def handler(data: Files = Body(media_type=RequestEncodingType.MULTI_PART)) -> None: + assert len(data.file_list) == file_count + + for file in data.file_list: + assert await file.read() == b"1" + + with create_test_client([handler]) as client: + files_to_upload = [("file_list", b"1") for _ in range(file_count)] + response = client.post("/", files=files_to_upload) + + assert response.status_code == HTTP_201_CREATED + + def test_optional_formdata() -> None: @post("/", signature_types=[UploadFile]) async def hello_world(data: Optional[UploadFile] = Body(media_type=RequestEncodingType.MULTI_PART)) -> None: