Skip to content

Commit

Permalink
fix: handle single file in cases of multiple file uploads (#2950)
Browse files Browse the repository at this point in the history
* fix: correctly parse files when uploading list of files

When a model is specified for uploading files and a specific field in that model is annotated with
`list[UploadFile]` and a single file is uploaded, the validation fails. This is because the extracted
file is not passed onto `msgspec` as a list. This fixes that by ensuring that the value is converted
into a list if needed before giving it to `msgspec` for validation.

* fix: handle case where name may not be in the form data

If there was no field name with the expected name (as per the given model), we don't want to raise a KeyError there.
Instead the missing field name should be raised as a validation error by the validation library.

* refactor: use API provided by FieldDefinition
  • Loading branch information
guacs authored Jan 5, 2024
1 parent e62cd3a commit 9e99ca5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
13 changes: 11 additions & 2 deletions litestar/_kwargs/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from litestar.exceptions import ValidationException
from litestar.params import BodyKwarg
from litestar.types import Empty
from litestar.utils.predicates import is_non_string_sequence
from litestar.utils.scope.state import ScopeState

if TYPE_CHECKING:
Expand Down Expand Up @@ -355,7 +356,7 @@ async def extract_multipart(

if field_definition.is_non_string_sequence:
values = list(form_values.values())
if field_definition.inner_types[0].annotation is UploadFile and isinstance(values[0], list):
if field_definition.has_inner_subclass_of(UploadFile) and isinstance(values[0], list):
return values[0]

return values
Expand All @@ -366,7 +367,15 @@ async def extract_multipart(
if not form_values and is_data_optional:
return None

return data_dto(connection).decode_builtins(form_values) if data_dto else form_values
if data_dto:
return data_dto(connection).decode_builtins(form_values)

for name, tp in field_definition.get_type_hints().items():
value = form_values.get(name)
if value is not None and is_non_string_sequence(tp) and not isinstance(value, list):
form_values[name] = [value]

return form_values

return cast("Callable[[ASGIConnection[Any, Any, Any, Any]], Coroutine[Any, Any, Any]]", extract_multipart)

Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_kwargs/test_multipart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,27 @@ async def handler(data: List[UploadFile] = Body(media_type=RequestEncodingType.M
assert response.status_code == HTTP_201_CREATED


@dataclass
class Files:
file_list: List[UploadFile]


@pytest.mark.parametrize("file_count", (1, 2))
def test_upload_multiple_files_in_model(file_count: int) -> None:
@post("/")
async def handler(data: Files = Body(media_type=RequestEncodingType.MULTI_PART)) -> None:
assert len(data.file_list) == file_count

for file in data.file_list:
assert await file.read() == b"1"

with create_test_client([handler]) as client:
files_to_upload = [("file_list", b"1") for _ in range(file_count)]
response = client.post("/", files=files_to_upload)

assert response.status_code == HTTP_201_CREATED


def test_optional_formdata() -> None:
@post("/", signature_types=[UploadFile])
async def hello_world(data: Optional[UploadFile] = Body(media_type=RequestEncodingType.MULTI_PART)) -> None:
Expand Down

0 comments on commit 9e99ca5

Please sign in to comment.