From 56eaeeb4fbe49962d0c45e8a3b2bb15014fec78d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sat, 21 Oct 2023 12:44:44 +0200 Subject: [PATCH] test: Remove Pydantic as default model from tests (#2458) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove Pydantic as default model from tests --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --- litestar/testing/request_factory.py | 5 +- tests/__init__.py | 105 ----------------- .../e2e/test_routing/test_path_resolution.py | 11 +- tests/models.py | 100 ++++++++++++++++ tests/unit/test_app.py | 16 +-- tests/unit/test_contrib/test_jwt/test_auth.py | 7 +- .../test_pydantic/test_integration.py | 13 +++ tests/unit/test_controller.py | 22 ++-- .../unit/test_dto/test_factory/test_utils.py | 14 +-- .../test_http_handlers/test_media_type.py | 4 +- .../test_http_handlers/test_validations.py | 4 +- .../test_reserved_kwargs_injection.py | 55 +++------ tests/unit/test_openapi/conftest.py | 50 ++++---- tests/unit/test_openapi/test_responses.py | 39 +++---- tests/unit/test_openapi/test_schema.py | 12 +- .../unit/test_openapi/test_spec_generation.py | 12 +- tests/unit/test_pagination.py | 70 +++++------ .../test_response_to_asgi_response.py | 12 +- .../unit/test_response/test_serialization.py | 110 +++++++++++------- tests/unit/test_security/test_session_auth.py | 6 +- .../unit/test_testing/test_request_factory.py | 35 +++--- tests/unit/test_utils/test_typing.py | 42 +++---- 22 files changed, 368 insertions(+), 376 deletions(-) create mode 100644 tests/models.py diff --git a/litestar/testing/request_factory.py b/litestar/testing/request_factory.py index 2b72b8b52a..6e294dd173 100644 --- a/litestar/testing/request_factory.py +++ b/litestar/testing/request_factory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from dataclasses import asdict from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlencode @@ -288,9 +289,7 @@ def _create_request_with_data( data = attrs_as_dict(data) # type: ignore[arg-type] elif is_pydantic_model_instance(data): - from litestar.contrib.pydantic import _model_dump - - data = _model_dump(data) + data = data.model_dump(mode="json") if hasattr(data, "model_dump") else json.loads(data.json()) if request_media_type == RequestEncodingType.JSON: encoding_headers, stream = httpx_encode_json(data) diff --git a/tests/__init__.py b/tests/__init__.py index a7f36af8f3..e69de29bb2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,105 +0,0 @@ -from dataclasses import dataclass as vanilla_dataclass -from enum import Enum -from typing import Dict, List, Optional -from uuid import UUID - -import attrs -import msgspec -from polyfactory.factories.pydantic_factory import ModelFactory -from pydantic import BaseModel, Field -from pydantic.dataclasses import dataclass as pydantic_dataclass -from typing_extensions import NotRequired, Required, TypedDict - -from litestar.contrib.pydantic import PydanticDTO -from litestar.dto import DTOConfig - - -class Species(str, Enum): - DOG = "Dog" - CAT = "Cat" - MONKEY = "Monkey" - PIG = "Pig" - - -class PydanticPet(BaseModel): - name: str - species: Species = Field(default=Species.MONKEY) - age: float - - -class PydanticPerson(BaseModel): - first_name: str - last_name: str - id: str - optional: Optional[str] - complex: Dict[str, List[Dict[str, str]]] - pets: Optional[List[PydanticPet]] = None - - -class PydanticPersonFactory(ModelFactory[PydanticPerson]): - __model__ = PydanticPerson - - -class PydanticPetFactory(ModelFactory[PydanticPet]): - __model__ = PydanticPet - - -class PartialPersonDTO(PydanticDTO[PydanticPerson]): - config = DTOConfig(partial=True) - - -@vanilla_dataclass -class VanillaDataClassPerson: - first_name: str - last_name: str - id: str - optional: Optional[str] - complex: Dict[str, List[Dict[str, str]]] - pets: Optional[List[PydanticPet]] = None - - -@pydantic_dataclass -class PydanticDataClassPerson: - first_name: str - last_name: str - id: str - optional: Optional[str] - complex: Dict[str, List[Dict[str, str]]] - pets: Optional[List[PydanticPet]] = None - - -class TypedDictPerson(TypedDict): - first_name: Required[str] - last_name: Required[str] - id: Required[str] - optional: NotRequired[Optional[str]] - complex: Required[Dict[str, List[Dict[str, str]]]] - pets: NotRequired[Optional[List[PydanticPet]]] - - -@attrs.define -class AttrsPerson: - first_name: str - last_name: str - id: str - optional: Optional[str] - complex: Dict[str, List[Dict[str, str]]] - pets: Optional[List[PydanticPet]] - - -class MsgSpecStructPerson(msgspec.Struct): - first_name: str - last_name: str - id: str - optional: Optional[str] - complex: Dict[str, List[Dict[str, str]]] - pets: Optional[List[PydanticPet]] - - -class User(BaseModel): - name: str - id: UUID - - -class UserFactory(ModelFactory[User]): - __model__ = User diff --git a/tests/e2e/test_routing/test_path_resolution.py b/tests/e2e/test_routing/test_path_resolution.py index 7b91374847..0d5afebc00 100644 --- a/tests/e2e/test_routing/test_path_resolution.py +++ b/tests/e2e/test_routing/test_path_resolution.py @@ -4,7 +4,6 @@ import pytest from litestar import Controller, MediaType, Router, delete, get, post -from litestar.contrib.pydantic import _model_dump from litestar.status_codes import ( HTTP_200_OK, HTTP_204_NO_CONTENT, @@ -12,7 +11,6 @@ HTTP_405_METHOD_NOT_ALLOWED, ) from litestar.testing import create_test_client -from tests import PydanticPerson, PydanticPersonFactory @delete(sync_to_thread=False) @@ -94,19 +92,16 @@ def mixed_params(path_param: int, value: int) -> str: def test_root_route_handler( decorator: Type[get], test_path: str, decorator_path: str, delete_handler: Optional[Callable] ) -> None: - person_instance = PydanticPersonFactory.build() - class MyController(Controller): path = test_path @decorator(path=decorator_path) - def test_method(self) -> PydanticPerson: - return person_instance + def test_method(self) -> str: + return "hello" with create_test_client([MyController, delete_handler] if delete_handler else MyController) as client: response = client.get(decorator_path or test_path) - assert response.status_code == HTTP_200_OK, response.json() - assert response.json() == _model_dump(person_instance) + assert response.status_code == HTTP_200_OK if delete_handler: delete_response = client.delete("/") assert delete_response.status_code == HTTP_204_NO_CONTENT diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 0000000000..fc8ec4402f --- /dev/null +++ b/tests/models.py @@ -0,0 +1,100 @@ +import dataclasses +from enum import Enum +from typing import Dict, List, Optional +from uuid import UUID + +import attrs +import msgspec +from polyfactory.factories import DataclassFactory +from pydantic import BaseModel +from pydantic.dataclasses import dataclass as pydantic_dataclass +from typing_extensions import NotRequired, Required, TypedDict + + +class Species(str, Enum): + DOG = "Dog" + CAT = "Cat" + MONKEY = "Monkey" + PIG = "Pig" + + +@dataclasses.dataclass +class DataclassPet: + name: str + age: float + species: Species = Species.MONKEY + + +@dataclasses.dataclass +class DataclassPerson: + first_name: str + last_name: str + id: str + optional: Optional[str] + complex: Dict[str, List[Dict[str, str]]] + pets: Optional[List[DataclassPet]] = None + + +@pydantic_dataclass +class PydanticDataclassPerson: + first_name: str + last_name: str + id: str + optional: Optional[str] + complex: Dict[str, List[Dict[str, str]]] + pets: Optional[List[DataclassPet]] = None + + +class TypedDictPerson(TypedDict): + first_name: Required[str] + last_name: Required[str] + id: Required[str] + optional: NotRequired[Optional[str]] + complex: Required[Dict[str, List[Dict[str, str]]]] + pets: NotRequired[Optional[List[DataclassPet]]] + + +class PydanticPerson(BaseModel): + first_name: str + last_name: str + id: str + optional: Optional[str] + complex: Dict[str, List[Dict[str, str]]] + pets: Optional[List[DataclassPet]] = None + + +@attrs.define +class AttrsPerson: + first_name: str + last_name: str + id: str + optional: Optional[str] + complex: Dict[str, List[Dict[str, str]]] + pets: Optional[List[DataclassPet]] + + +class MsgSpecStructPerson(msgspec.Struct): + first_name: str + last_name: str + id: str + optional: Optional[str] + complex: Dict[str, List[Dict[str, str]]] + pets: Optional[List[DataclassPet]] + + +@dataclasses.dataclass +class User: + name: str + id: UUID + + +class UserFactory(DataclassFactory[User]): + __model__ = User + + +class DataclassPersonFactory(DataclassFactory[DataclassPerson]): + __model__ = DataclassPerson + + +class DataclassPetFactory(DataclassFactory[DataclassPet]): + __model__ = DataclassPet diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 735d565ec7..daaa8e18de 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -9,10 +9,9 @@ import pytest from click import Group -from pydantic import VERSION from pytest import MonkeyPatch -from litestar import Litestar, MediaType, Request, Response, get, post +from litestar import Litestar, MediaType, Request, Response, get from litestar.config.app import AppConfig from litestar.config.response_cache import ResponseCacheConfig from litestar.contrib.sqlalchemy.plugins import SQLAlchemySerializationPlugin @@ -29,7 +28,6 @@ from litestar.router import Router from litestar.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing import TestClient, create_test_client -from tests import PydanticPerson if TYPE_CHECKING: from typing import Dict @@ -244,18 +242,6 @@ def on_startup(app: Litestar) -> None: assert response.headers.get("My Header") == "value injected during send" -def test_default_handling_of_pydantic_errors() -> None: - @post("/{param:int}") - def my_route_handler(param: int, data: PydanticPerson) -> None: - ... - - with create_test_client(my_route_handler) as client: - response = client.post("/123", json={"first_name": "moishe"}) - extra = response.json().get("extra") - assert extra is not None - assert 3 if len(extra) == VERSION.startswith("1") else 4 - - def test_using_custom_http_exception_handler() -> None: @get("/{param:int}") def my_route_handler(param: int) -> None: diff --git a/tests/unit/test_contrib/test_jwt/test_auth.py b/tests/unit/test_contrib/test_jwt/test_auth.py index c78ef8d3d0..f65ba3ff8b 100644 --- a/tests/unit/test_contrib/test_jwt/test_auth.py +++ b/tests/unit/test_contrib/test_jwt/test_auth.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from uuid import uuid4 +import msgspec import pytest from hypothesis import given, settings from hypothesis.strategies import dictionaries, integers, none, one_of, sampled_from, text, timedeltas @@ -14,7 +15,7 @@ from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_401_UNAUTHORIZED from litestar.stores.memory import MemoryStore from litestar.testing import create_test_client -from tests import User, UserFactory +from tests.models import User, UserFactory if TYPE_CHECKING: from litestar.connection import ASGIConnection @@ -75,7 +76,7 @@ async def retrieve_user_handler(token: Token, _: "ASGIConnection") -> Any: @get("/my-endpoint", middleware=[jwt_auth.middleware]) def my_handler(request: Request["User", Token, Any]) -> None: assert request.user - assert _model_dump(request.user) == _model_dump(user) + assert msgspec.to_builtins(request.user) == msgspec.to_builtins(user) assert request.auth.sub == str(user.id) @get("/login") @@ -183,7 +184,7 @@ async def retrieve_user_handler(token: Token, connection: Any) -> Any: @get("/my-endpoint", middleware=[jwt_auth.middleware]) def my_handler(request: Request["User", Token, Any]) -> None: assert request.user - assert _model_dump(request.user) == _model_dump(user) + assert msgspec.to_builtins(request.user) == msgspec.to_builtins(user) assert request.auth.sub == str(user.id) @get("/login") diff --git a/tests/unit/test_contrib/test_pydantic/test_integration.py b/tests/unit/test_contrib/test_pydantic/test_integration.py index c75b434500..fd11117c96 100644 --- a/tests/unit/test_contrib/test_pydantic/test_integration.py +++ b/tests/unit/test_contrib/test_pydantic/test_integration.py @@ -7,6 +7,7 @@ from litestar import post from litestar.contrib.pydantic.pydantic_dto_factory import PydanticDTO from litestar.testing import create_test_client +from tests.models import PydanticPerson def test_pydantic_validation_error_raises_400() -> None: @@ -54,3 +55,15 @@ def handler(data: Model) -> Model: extra[0].pop("url") assert extra == expected_errors + + +def test_default_handling_of_pydantic_errors() -> None: + @post("/{param:int}") + def my_route_handler(param: int, data: PydanticPerson) -> None: + ... + + with create_test_client(my_route_handler) as client: + response = client.post("/123", json={"first_name": "moishe"}) + extra = response.json().get("extra") + assert extra is not None + assert 3 if len(extra) == VERSION.startswith("1") else 4 diff --git a/tests/unit/test_controller.py b/tests/unit/test_controller.py index 638632cc01..a2d008683a 100644 --- a/tests/unit/test_controller.py +++ b/tests/unit/test_controller.py @@ -1,5 +1,6 @@ from typing import Any, Type, Union +import msgspec import pytest from pydantic import BaseModel @@ -16,11 +17,10 @@ websocket, ) from litestar.connection import WebSocket -from litestar.contrib.pydantic import _model_dump from litestar.exceptions import ImproperlyConfiguredException from litestar.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from litestar.testing import create_test_client -from tests import PydanticPerson, PydanticPersonFactory +from tests.models import DataclassPerson, DataclassPersonFactory @pytest.mark.parametrize( @@ -30,13 +30,13 @@ get, HttpMethod.GET, HTTP_200_OK, - Response(content=PydanticPersonFactory.build()), - Response[PydanticPerson], + Response(content=DataclassPersonFactory.build()), + Response[DataclassPerson], ), - (get, HttpMethod.GET, HTTP_200_OK, PydanticPersonFactory.build(), PydanticPerson), - (post, HttpMethod.POST, HTTP_201_CREATED, PydanticPersonFactory.build(), PydanticPerson), - (put, HttpMethod.PUT, HTTP_200_OK, PydanticPersonFactory.build(), PydanticPerson), - (patch, HttpMethod.PATCH, HTTP_200_OK, PydanticPersonFactory.build(), PydanticPerson), + (get, HttpMethod.GET, HTTP_200_OK, DataclassPersonFactory.build(), DataclassPerson), + (post, HttpMethod.POST, HTTP_201_CREATED, DataclassPersonFactory.build(), DataclassPerson), + (put, HttpMethod.PUT, HTTP_200_OK, DataclassPersonFactory.build(), DataclassPerson), + (patch, HttpMethod.PATCH, HTTP_200_OK, DataclassPersonFactory.build(), DataclassPerson), (delete, HttpMethod.DELETE, HTTP_204_NO_CONTENT, None, None), ], ) @@ -60,7 +60,7 @@ def test_method(self) -> return_annotation: response = client.request(http_method, test_path) assert response.status_code == expected_status_code if return_value and isinstance(return_value, BaseModel): - assert response.json() == _model_dump(return_value) + assert response.json() == msgspec.to_builtins(return_value) def test_controller_with_websocket_handler() -> None: @@ -70,8 +70,8 @@ class MyController(Controller): path = test_path @get() - def get_person(self) -> PydanticPerson: - return PydanticPersonFactory.build() + def get_person(self) -> DataclassPerson: + return DataclassPersonFactory.build() @websocket(path="/socket") async def ws(self, socket: WebSocket) -> None: diff --git a/tests/unit/test_dto/test_factory/test_utils.py b/tests/unit/test_dto/test_factory/test_utils.py index a7595d2a96..90dddc82d2 100644 --- a/tests/unit/test_dto/test_factory/test_utils.py +++ b/tests/unit/test_dto/test_factory/test_utils.py @@ -2,26 +2,24 @@ from litestar.dto import DataclassDTO from litestar.typing import FieldDefinition -from tests import VanillaDataClassPerson +from tests.models import DataclassPerson T = TypeVar("T") def test_resolve_model_type_optional() -> None: field_definition = FieldDefinition.from_annotation(Optional[int]) - assert DataclassDTO[VanillaDataClassPerson].resolve_model_type(field_definition) == FieldDefinition.from_annotation( - int - ) + assert DataclassDTO[DataclassPerson].resolve_model_type(field_definition) == FieldDefinition.from_annotation(int) def test_resolve_generic_wrapper_type_no_origin() -> None: field_definition = FieldDefinition.from_annotation(int) - assert DataclassDTO[VanillaDataClassPerson].resolve_generic_wrapper_type(field_definition) is None + assert DataclassDTO[DataclassPerson].resolve_generic_wrapper_type(field_definition) is None def test_resolve_generic_wrapper_type_origin_no_parameters() -> None: field_definition = FieldDefinition.from_annotation(List[int]) - assert DataclassDTO[VanillaDataClassPerson].resolve_generic_wrapper_type(field_definition) is None + assert DataclassDTO[DataclassPerson].resolve_generic_wrapper_type(field_definition) is None def test_resolve_generic_wrapper_type_model_type_not_subtype_of_specialized_type() -> None: @@ -30,7 +28,7 @@ class Wrapper(Generic[T]): field_definition = FieldDefinition.from_annotation(Wrapper[int]) - assert DataclassDTO[VanillaDataClassPerson].resolve_generic_wrapper_type(field_definition) is None + assert DataclassDTO[DataclassPerson].resolve_generic_wrapper_type(field_definition) is None def test_resolve_generic_wrapper_type_type_var_not_attribute() -> None: @@ -40,4 +38,4 @@ def returns_t(self) -> T: # type:ignore[empty-body] field_definition = FieldDefinition.from_annotation(Wrapper[int]) - assert DataclassDTO[VanillaDataClassPerson].resolve_generic_wrapper_type(field_definition) is None + assert DataclassDTO[DataclassPerson].resolve_generic_wrapper_type(field_definition) is None diff --git a/tests/unit/test_handlers/test_http_handlers/test_media_type.py b/tests/unit/test_handlers/test_http_handlers/test_media_type.py index 3b27df0845..52a4874128 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_media_type.py +++ b/tests/unit/test_handlers/test_http_handlers/test_media_type.py @@ -5,7 +5,7 @@ from pydantic.types import PaymentCardBrand from litestar import Litestar, MediaType, get -from tests import PydanticPerson +from tests.models import DataclassPerson class MyEnum(Enum): @@ -26,7 +26,7 @@ class MyBytes(bytes): (PaymentCardBrand, MediaType.TEXT), (MyEnum, MediaType.JSON), (dict, MediaType.JSON), - (PydanticPerson, MediaType.JSON), + (DataclassPerson, MediaType.JSON), ), ) def test_media_type_inference(annotation: Any, expected_media_type: MediaType) -> None: diff --git a/tests/unit/test_handlers/test_http_handlers/test_validations.py b/tests/unit/test_handlers/test_http_handlers/test_validations.py index 673a111534..5ee119d62d 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_validations.py +++ b/tests/unit/test_handlers/test_http_handlers/test_validations.py @@ -13,7 +13,7 @@ HTTP_304_NOT_MODIFIED, HTTP_307_TEMPORARY_REDIRECT, ) -from tests import PydanticPerson +from tests.models import DataclassPerson def test_route_handler_validation_http_method() -> None: @@ -104,7 +104,7 @@ def test_function_1(socket: WebSocket) -> None: with pytest.raises(ImproperlyConfiguredException): @get("/person") - def test_function_2(self, data: PydanticPerson) -> None: # type: ignore + def test_function_2(self, data: DataclassPerson) -> None: # type: ignore return None Litestar(route_handlers=[test_function_2]) diff --git a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py index f5a1442dea..7bba51cc3d 100644 --- a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py +++ b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py @@ -1,9 +1,7 @@ from typing import Any, List, Optional, Type, cast -import pydantic +import msgspec.json import pytest -from polyfactory.factories.pydantic_factory import ModelFactory -from pydantic import BaseModel, Field from litestar import ( Controller, @@ -17,7 +15,6 @@ post, put, ) -from litestar.contrib.pydantic import _model_dump from litestar.datastructures.state import ImmutableState, State from litestar.exceptions import ImproperlyConfiguredException from litestar.status_codes import ( @@ -28,7 +25,7 @@ ) from litestar.testing import create_test_client from litestar.types import Scope -from tests import PydanticPerson, PydanticPersonFactory +from tests.models import DataclassPerson, DataclassPersonFactory class CustomState(State): @@ -66,19 +63,7 @@ def route_handler(state: state_typing) -> str: # type: ignore assert client.app.state.called -class QueryParams(BaseModel): - first: str - second: List[str] = ( - Field(min_items=3) if pydantic.VERSION.startswith("1") else Field(min_length=1) # pyright: ignore - ) - third: Optional[int] - - -class QueryParamsFactory(ModelFactory): - __model__ = QueryParams - - -person_instance = PydanticPersonFactory.build() +person_instance = DataclassPersonFactory.build() @pytest.mark.parametrize( @@ -97,11 +82,11 @@ class MyController(Controller): path = test_path @decorator() - def test_method(self, data: PydanticPerson) -> None: + def test_method(self, data: DataclassPerson) -> None: assert data == person_instance with create_test_client(MyController) as client: - response = client.request(http_method, test_path, json=_model_dump(person_instance)) + response = client.request(http_method, test_path, json=msgspec.to_builtins(person_instance)) assert response.status_code == expected_status_code @@ -117,17 +102,17 @@ def test_method(self, data: PydanticPerson) -> None: def test_data_using_list_of_models(decorator: Any, http_method: Any, expected_status_code: Any) -> None: test_path = "/person" - people = PydanticPersonFactory.batch(size=5) + people = DataclassPersonFactory.batch(size=5) class MyController(Controller): path = test_path @decorator() - def test_method(self, data: List[PydanticPerson]) -> None: + def test_method(self, data: List[DataclassPerson]) -> None: assert data == people with create_test_client(MyController) as client: - response = client.request(http_method, test_path, json=[_model_dump(p) for p in people]) + response = client.request(http_method, test_path, json=msgspec.to_builtins(people)) assert response.status_code == expected_status_code @@ -178,21 +163,15 @@ def test_method(self, person_id: str) -> None: ], ) def test_query_params(decorator: Any, http_method: Any, expected_status_code: Any) -> None: - test_path = "/person" - - query_params_instance = QueryParamsFactory.build() - - class MyController(Controller): - path = test_path - - @decorator() - def test_method(self, first: str, second: List[str], third: Optional[int] = None) -> None: - assert first == query_params_instance.first - assert second == query_params_instance.second - assert third == query_params_instance.third - - with create_test_client(MyController) as client: - response = client.request(http_method, test_path, params=query_params_instance.dict(exclude_none=True)) + @decorator("/person") + def handler(first: str, second: List[str], third: int, fourth: Optional[str] = None) -> None: + assert first == "foo" + assert second == ["a", "b"] + assert third == 2 + assert fourth is None + + with create_test_client(handler) as client: + response = client.request(http_method, "/person", params={"first": "foo", "second": ["a", "b"], "third": "2"}) assert response.status_code == expected_status_code diff --git a/tests/unit/test_openapi/conftest.py b/tests/unit/test_openapi/conftest.py index 358fd2c5c6..b6d4e4c080 100644 --- a/tests/unit/test_openapi/conftest.py +++ b/tests/unit/test_openapi/conftest.py @@ -8,14 +8,18 @@ from litestar import Controller, MediaType, delete, get, patch, post, put from litestar.datastructures import ResponseHeader, State -from litestar.dto import DTOData +from litestar.dto import DataclassDTO, DTOConfig, DTOData from litestar.openapi.spec.example import Example from litestar.params import Parameter -from tests import PartialPersonDTO, PydanticPerson, PydanticPersonFactory, PydanticPet, VanillaDataClassPerson +from tests.models import DataclassPerson, DataclassPersonFactory, DataclassPet from .utils import Gender, PetException +class PartialDataclassPersonDTO(DataclassDTO[DataclassPerson]): + config = DTOConfig(partial=True) + + def create_person_controller() -> Type[Controller]: class PersonController(Controller): path = "/{service_id:int}/person" @@ -50,45 +54,45 @@ def get_persons( secret_header: str = Parameter(header="secret"), # cookie parameter cookie_value: int = Parameter(cookie="value"), - ) -> List[PydanticPerson]: + ) -> List[DataclassPerson]: return [] @post(media_type=MediaType.TEXT) def create_person( - self, data: PydanticPerson, secret_header: str = Parameter(header="secret") - ) -> PydanticPerson: + self, data: DataclassPerson, secret_header: str = Parameter(header="secret") + ) -> DataclassPerson: return data - @post(path="/bulk", dto=PartialPersonDTO) + @post(path="/bulk", dto=PartialDataclassPersonDTO) def bulk_create_person( - self, data: List[DTOData[PydanticPerson]], secret_header: str = Parameter(header="secret") - ) -> List[PydanticPerson]: + self, data: List[DTOData[DataclassPerson]], secret_header: str = Parameter(header="secret") + ) -> List[DataclassPerson]: return [] @put(path="/bulk") def bulk_update_person( - self, data: List[PydanticPerson], secret_header: str = Parameter(header="secret") - ) -> List[PydanticPerson]: + self, data: List[DataclassPerson], secret_header: str = Parameter(header="secret") + ) -> List[DataclassPerson]: return [] - @patch(path="/bulk", dto=PartialPersonDTO) + @patch(path="/bulk", dto=PartialDataclassPersonDTO) def bulk_partial_update_person( - self, data: List[DTOData[PydanticPerson]], secret_header: str = Parameter(header="secret") - ) -> List[PydanticPerson]: + self, data: List[DTOData[DataclassPerson]], secret_header: str = Parameter(header="secret") + ) -> List[DataclassPerson]: return [] @get(path="/{person_id:str}") - def get_person_by_id(self, person_id: str) -> PydanticPerson: + def get_person_by_id(self, person_id: str) -> DataclassPerson: """Description in docstring.""" - return PydanticPersonFactory.build(id=person_id) + return DataclassPersonFactory.build(id=person_id) - @patch(path="/{person_id:str}", description="Description in decorator", dto=PartialPersonDTO) - def partial_update_person(self, person_id: str, data: DTOData[PydanticPerson]) -> PydanticPerson: + @patch(path="/{person_id:str}", description="Description in decorator", dto=PartialDataclassPersonDTO) + def partial_update_person(self, person_id: str, data: DTOData[DataclassPerson]) -> DataclassPerson: """Description in docstring.""" - return PydanticPersonFactory.build(id=person_id) + return DataclassPersonFactory.build(id=person_id) @put(path="/{person_id:str}") - def update_person(self, person_id: str, data: PydanticPerson) -> PydanticPerson: + def update_person(self, person_id: str, data: DataclassPerson) -> DataclassPerson: """Multiline docstring example. Line 3. @@ -100,8 +104,8 @@ def delete_person(self, person_id: str) -> None: return None @get(path="/dataclass") - def get_person_dataclass(self) -> VanillaDataClassPerson: - return VanillaDataClassPerson( + def get_person_dataclass(self) -> DataclassPerson: + return DataclassPerson( first_name="Moishe", last_name="zuchmir", id="1", optional=None, complex={}, pets=None ) @@ -113,13 +117,13 @@ class PetController(Controller): path = "/pet" @get() - def pets(self) -> List[PydanticPet]: + def pets(self) -> List[DataclassPet]: return [] @get( path="/owner-or-pet", response_headers=[ResponseHeader(name="x-my-tag", value="123")], raises=[PetException] ) - def get_pets_or_owners(self) -> List[Union[PydanticPerson, PydanticPet]]: + def get_pets_or_owners(self) -> List[Union[DataclassPerson, DataclassPet]]: return [] return PetController diff --git a/tests/unit/test_openapi/test_responses.py b/tests/unit/test_openapi/test_responses.py index ef918193c6..ec80a3d5c6 100644 --- a/tests/unit/test_openapi/test_responses.py +++ b/tests/unit/test_openapi/test_responses.py @@ -19,7 +19,6 @@ create_success_response, ) from litestar._openapi.schema_generation import SchemaCreator -from litestar.contrib.pydantic import PydanticSchemaPlugin from litestar.datastructures import Cookie, ResponseHeader from litestar.dto import AbstractDTO from litestar.exceptions import ( @@ -41,7 +40,7 @@ HTTP_406_NOT_ACCEPTABLE, ) from litestar.typing import FieldDefinition -from tests import PydanticPerson, PydanticPersonFactory +from tests.models import DataclassPerson, DataclassPersonFactory from .utils import PetException @@ -70,7 +69,7 @@ def test_create_responses(person_controller: type[Controller], pet_controller: t responses = create_responses( handler, raises_validation_error=False, - schema_creator=SchemaCreator(generate_examples=True, plugins=[PydanticSchemaPlugin()]), + schema_creator=SchemaCreator(generate_examples=True), ) assert responses assert str(HTTP_400_BAD_REQUEST) not in responses @@ -216,14 +215,12 @@ def handler() -> list: def test_create_success_response_with_response_class() -> None: @get(path="/test", name="test") - def handler() -> Response[PydanticPerson]: - return Response(content=PydanticPersonFactory.build()) + def handler() -> Response[DataclassPerson]: + return Response(content=DataclassPersonFactory.build()) handler = get_registered_route_handler(handler, "test") schemas: dict[str, Schema] = {} - response = create_success_response( - handler, SchemaCreator(generate_examples=True, schemas=schemas, plugins=[PydanticSchemaPlugin()]) - ) + response = create_success_response(handler, SchemaCreator(generate_examples=True, schemas=schemas)) assert response.content reference = response.content["application/json"].schema @@ -231,7 +228,7 @@ def handler() -> Response[PydanticPerson]: assert isinstance(reference, Reference) key = reference.ref.split("/")[-1] assert isinstance(schemas[key], Schema) - assert key == PydanticPerson.__name__ + assert key == DataclassPerson.__name__ def test_create_success_response_with_stream() -> None: @@ -323,7 +320,7 @@ def test_create_additional_responses() -> None: class ServerError: message: str - class AuthenticationError(BaseModel): + class AuthenticationError(TypedDict): message: str class UnknownError(TypedDict): @@ -336,11 +333,11 @@ class UnknownError(TypedDict): 505: ResponseSpec(data_container=UnknownError), } ) - def handler() -> PydanticPerson: - return PydanticPersonFactory.build() + def handler() -> DataclassPerson: + return DataclassPersonFactory.build() schemas: dict[str, Schema] = {} - responses = create_additional_responses(handler, SchemaCreator(schemas=schemas, plugins=[PydanticSchemaPlugin()])) + responses = create_additional_responses(handler, SchemaCreator(schemas=schemas)) first_response = next(responses) assert first_response[0] == "401" @@ -380,8 +377,8 @@ class OkResponse(BaseModel): message: str @get(responses={200: ResponseSpec(data_container=OkResponse, description="Overwritten response")}, name="test") - def handler() -> PydanticPerson: - return PydanticPersonFactory.build() + def handler() -> DataclassPerson: + return DataclassPersonFactory.build() handler = get_registered_route_handler(handler, "test") responses = create_responses( @@ -402,7 +399,7 @@ class ErrorResponse(BaseModel): responses={400: ResponseSpec(data_container=ErrorResponse, description="Overwritten response")}, name="test", ) - def handler() -> PydanticPerson: + def handler() -> DataclassPerson: raise ValidationException() handler = get_registered_route_handler(handler, "test") @@ -421,21 +418,19 @@ class CustomResponse(Response[T]): pass @get(path="/test", name="test", signature_types=[CustomResponse]) - def handler() -> CustomResponse[PydanticPerson]: - return CustomResponse(content=PydanticPersonFactory.build()) + def handler() -> CustomResponse[DataclassPerson]: + return CustomResponse(content=DataclassPersonFactory.build()) handler = get_registered_route_handler(handler, "test") schemas: dict[str, Schema] = {} - response = create_success_response( - handler, SchemaCreator(generate_examples=True, schemas=schemas, plugins=[PydanticSchemaPlugin()]) - ) + response = create_success_response(handler, SchemaCreator(generate_examples=True, schemas=schemas)) assert response.content assert isinstance(response.content["application/json"], OpenAPIMediaType) reference = response.content["application/json"].schema assert isinstance(reference, Reference) schema = schemas[reference.value] - assert schema.title == "PydanticPerson" + assert schema.title == "DataclassPerson" def test_success_response_with_future_annotations(create_module: Callable[[str], ModuleType]) -> None: diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index 4c72ee8e09..29c5d40899 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -27,7 +27,7 @@ from litestar.params import BodyKwarg, Parameter, ParameterKwarg from litestar.testing import create_test_client from litestar.typing import FieldDefinition -from tests import PydanticPerson, PydanticPet +from tests.models import DataclassPerson, DataclassPet if TYPE_CHECKING: from types import ModuleType @@ -166,16 +166,16 @@ def test_title_validation() -> None: schemas: Dict[str, Schema] = {} schema_creator = SchemaCreator(schemas=schemas, plugins=[PydanticSchemaPlugin()]) - schema_creator.for_field_definition(FieldDefinition.from_kwarg(name="Person", annotation=PydanticPerson)) - assert schemas.get("PydanticPerson") + schema_creator.for_field_definition(FieldDefinition.from_kwarg(name="Person", annotation=DataclassPerson)) + assert schemas.get("DataclassPerson") - schema_creator.for_field_definition(FieldDefinition.from_kwarg(name="Pet", annotation=PydanticPet)) - assert schemas.get("PydanticPet") + schema_creator.for_field_definition(FieldDefinition.from_kwarg(name="Pet", annotation=DataclassPet)) + assert schemas.get("DataclassPet") with pytest.raises(ImproperlyConfiguredException): schema_creator.for_field_definition( FieldDefinition.from_kwarg( - name="PydanticPerson", annotation=PydanticPet, kwarg_definition=BodyKwarg(title="PydanticPerson") + name="DataclassPerson", annotation=DataclassPet, kwarg_definition=BodyKwarg(title="DataclassPerson") ) ) diff --git a/tests/unit/test_openapi/test_spec_generation.py b/tests/unit/test_openapi/test_spec_generation.py index 74858c8228..3471261c53 100644 --- a/tests/unit/test_openapi/test_spec_generation.py +++ b/tests/unit/test_openapi/test_spec_generation.py @@ -5,13 +5,13 @@ from litestar import post from litestar.testing import create_test_client -from tests import ( +from tests.models import ( AttrsPerson, + DataclassPerson, MsgSpecStructPerson, - PydanticDataClassPerson, + PydanticDataclassPerson, PydanticPerson, TypedDictPerson, - VanillaDataClassPerson, ) @@ -19,8 +19,8 @@ "cls", ( PydanticPerson, - VanillaDataClassPerson, - PydanticDataClassPerson, + DataclassPerson, + PydanticDataclassPerson, TypedDictPerson, MsgSpecStructPerson, AttrsPerson, @@ -51,7 +51,7 @@ def handler(data: cls) -> cls: "pets": { "oneOf": [ {"type": "null"}, - {"items": {"$ref": "#/components/schemas/PydanticPet"}, "type": "array"}, + {"items": {"$ref": "#/components/schemas/DataclassPet"}, "type": "array"}, ] }, }, diff --git a/tests/unit/test_pagination.py b/tests/unit/test_pagination.py index afa2d53afa..c545a91d9a 100644 --- a/tests/unit/test_pagination.py +++ b/tests/unit/test_pagination.py @@ -18,72 +18,72 @@ ) from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client -from tests import PydanticPerson, PydanticPersonFactory +from tests.models import DataclassPerson, DataclassPersonFactory -class TestSyncClassicPaginator(AbstractSyncClassicPaginator[PydanticPerson]): +class TestSyncClassicPaginator(AbstractSyncClassicPaginator[DataclassPerson]): __test__ = False - def __init__(self, data: List[PydanticPerson]): + def __init__(self, data: List[DataclassPerson]): self.data = data def get_total(self, page_size: int) -> int: return round(len(self.data) / page_size) - def get_items(self, page_size: int, current_page: int) -> List[PydanticPerson]: + def get_items(self, page_size: int, current_page: int) -> List[DataclassPerson]: return [self.data[i : i + page_size] for i in range(0, len(self.data), page_size)][current_page - 1] -class TestAsyncClassicPaginator(AbstractAsyncClassicPaginator[PydanticPerson]): +class TestAsyncClassicPaginator(AbstractAsyncClassicPaginator[DataclassPerson]): __test__ = False - def __init__(self, data: List[PydanticPerson]): + def __init__(self, data: List[DataclassPerson]): self.data = data async def get_total(self, page_size: int) -> int: return round(len(self.data) / page_size) - async def get_items(self, page_size: int, current_page: int) -> List[PydanticPerson]: + async def get_items(self, page_size: int, current_page: int) -> List[DataclassPerson]: return [self.data[i : i + page_size] for i in range(0, len(self.data), page_size)][current_page - 1] -class TestSyncOffsetPaginator(AbstractSyncOffsetPaginator[PydanticPerson]): +class TestSyncOffsetPaginator(AbstractSyncOffsetPaginator[DataclassPerson]): __test__ = False - def __init__(self, data: List[PydanticPerson]): + def __init__(self, data: List[DataclassPerson]): self.data = data def get_total(self) -> int: return len(self.data) - def get_items(self, limit: int, offset: int) -> List[PydanticPerson]: + def get_items(self, limit: int, offset: int) -> List[DataclassPerson]: return list(islice(islice(self.data, offset, None), limit)) -class TestAsyncOffsetPaginator(AbstractAsyncOffsetPaginator[PydanticPerson]): +class TestAsyncOffsetPaginator(AbstractAsyncOffsetPaginator[DataclassPerson]): __test__ = False - def __init__(self, data: List[PydanticPerson]): + def __init__(self, data: List[DataclassPerson]): self.data = data async def get_total(self) -> int: return len(self.data) - async def get_items(self, limit: int, offset: int) -> List[PydanticPerson]: + async def get_items(self, limit: int, offset: int) -> List[DataclassPerson]: return list(islice(islice(self.data, offset, None), limit)) -data = PydanticPersonFactory.batch(50) +data = DataclassPersonFactory.batch(50) @pytest.mark.parametrize("paginator", (TestSyncClassicPaginator(data=data), TestAsyncClassicPaginator(data=data))) def test_classic_pagination_data_shape(paginator: Any) -> None: @get("/async") - async def async_handler(page_size: int, current_page: int) -> ClassicPagination[PydanticPerson]: + async def async_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: return await paginator(page_size=page_size, current_page=current_page) # type: ignore @get("/sync") - def sync_handler(page_size: int, current_page: int) -> ClassicPagination[PydanticPerson]: + def sync_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: return paginator(page_size=page_size, current_page=current_page) # type: ignore with create_test_client([async_handler, sync_handler]) as client: @@ -103,11 +103,11 @@ def sync_handler(page_size: int, current_page: int) -> ClassicPagination[Pydanti @pytest.mark.parametrize("paginator", (TestSyncClassicPaginator(data=data), TestAsyncClassicPaginator(data=data))) def test_classic_pagination_openapi_schema(paginator: Any) -> None: @get("/async") - async def async_handler(page_size: int, current_page: int) -> ClassicPagination[PydanticPerson]: + async def async_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: return await paginator(page_size=page_size, current_page=current_page) # type: ignore @get("/sync") - def sync_handler(page_size: int, current_page: int) -> ClassicPagination[PydanticPerson]: + def sync_handler(page_size: int, current_page: int) -> ClassicPagination[DataclassPerson]: return paginator(page_size=page_size, current_page=current_page) # type: ignore with create_test_client([async_handler, sync_handler], openapi_config=DEFAULT_OPENAPI_CONFIG) as client: @@ -120,7 +120,7 @@ def sync_handler(page_size: int, current_page: int) -> ClassicPagination[Pydanti assert spec == { "schema": { "properties": { - "items": {"items": {"$ref": "#/components/schemas/PydanticPerson"}, "type": "array"}, + "items": {"items": {"$ref": "#/components/schemas/DataclassPerson"}, "type": "array"}, "page_size": {"type": "integer", "description": "Number of items per page."}, "current_page": {"type": "integer", "description": "Current page number."}, "total_pages": {"type": "integer", "description": "Total number of pages."}, @@ -133,11 +133,11 @@ def sync_handler(page_size: int, current_page: int) -> ClassicPagination[Pydanti @pytest.mark.parametrize("paginator", (TestSyncOffsetPaginator(data=data), TestAsyncOffsetPaginator(data=data))) def test_limit_offset_pagination_data_shape(paginator: Any) -> None: @get("/async") - async def async_handler(limit: int, offset: int) -> OffsetPagination[PydanticPerson]: + async def async_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: return await paginator(limit=limit, offset=offset) # type: ignore @get("/sync") - def sync_handler(limit: int, offset: int) -> OffsetPagination[PydanticPerson]: + def sync_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: return paginator(limit=limit, offset=offset) # type: ignore with create_test_client([async_handler, sync_handler]) as client: @@ -157,11 +157,11 @@ def sync_handler(limit: int, offset: int) -> OffsetPagination[PydanticPerson]: @pytest.mark.parametrize("paginator", (TestSyncOffsetPaginator(data=data), TestAsyncOffsetPaginator(data=data))) def test_limit_offset_pagination_openapi_schema(paginator: Any) -> None: @get("/async") - async def async_handler(limit: int, offset: int) -> OffsetPagination[PydanticPerson]: + async def async_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: return await paginator(limit=limit, offset=offset) # type: ignore @get("/sync") - def sync_handler(limit: int, offset: int) -> OffsetPagination[PydanticPerson]: + def sync_handler(limit: int, offset: int) -> OffsetPagination[DataclassPerson]: return paginator(limit=limit, offset=offset) # type: ignore with create_test_client([async_handler, sync_handler], openapi_config=DEFAULT_OPENAPI_CONFIG) as client: @@ -174,7 +174,7 @@ def sync_handler(limit: int, offset: int) -> OffsetPagination[PydanticPerson]: assert spec == { "schema": { "properties": { - "items": {"items": {"$ref": "#/components/schemas/PydanticPerson"}, "type": "array"}, + "items": {"items": {"$ref": "#/components/schemas/DataclassPerson"}, "type": "array"}, "limit": {"type": "integer", "description": "Maximal number of items to send."}, "offset": {"type": "integer", "description": "Offset from the beginning of the query."}, "total": {"type": "integer", "description": "Total number of items."}, @@ -184,26 +184,26 @@ def sync_handler(limit: int, offset: int) -> OffsetPagination[PydanticPerson]: } -class TestSyncCursorPagination(AbstractSyncCursorPaginator[str, PydanticPerson]): +class TestSyncCursorPagination(AbstractSyncCursorPaginator[str, DataclassPerson]): __test__ = False - def __init__(self, data: List[PydanticPerson]): + def __init__(self, data: List[DataclassPerson]): self.data = data - def get_items(self, cursor: Optional[str], results_per_page: int) -> "Tuple[List[PydanticPerson], Optional[str]]": + def get_items(self, cursor: Optional[str], results_per_page: int) -> "Tuple[List[DataclassPerson], Optional[str]]": results = self.data[:results_per_page] return results, results[-1].id -class TestAsyncCursorPagination(AbstractAsyncCursorPaginator[str, PydanticPerson]): +class TestAsyncCursorPagination(AbstractAsyncCursorPaginator[str, DataclassPerson]): __test__ = False - def __init__(self, data: List[PydanticPerson]): + def __init__(self, data: List[DataclassPerson]): self.data = data async def get_items( self, cursor: Optional[str], results_per_page: int - ) -> "Tuple[List[PydanticPerson], Optional[str]]": + ) -> "Tuple[List[DataclassPerson], Optional[str]]": results = self.data[:results_per_page] return results, results[-1].id @@ -211,11 +211,11 @@ async def get_items( @pytest.mark.parametrize("paginator", (TestSyncCursorPagination(data=data), TestAsyncCursorPagination(data=data))) def test_cursor_pagination_data_shape(paginator: Any) -> None: @get("/async") - async def async_handler(cursor: Optional[str] = None) -> CursorPagination[str, PydanticPerson]: + async def async_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: return await paginator(cursor=cursor, results_per_page=5) # type: ignore @get("/sync") - def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, PydanticPerson]: + def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: return paginator(cursor=cursor, results_per_page=5) # type: ignore with create_test_client([async_handler, sync_handler]) as client: @@ -234,11 +234,11 @@ def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, Pydantic @pytest.mark.parametrize("paginator", (TestSyncCursorPagination(data=data), TestAsyncCursorPagination(data=data))) def test_cursor_pagination_openapi_schema(paginator: Any) -> None: @get("/async") - async def async_handler(cursor: Optional[str] = None) -> CursorPagination[str, PydanticPerson]: + async def async_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: return await paginator(cursor=cursor, results_per_page=5) # type: ignore @get("/sync") - def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, PydanticPerson]: + def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, DataclassPerson]: return paginator(cursor=cursor, results_per_page=5) # type: ignore with create_test_client([async_handler, sync_handler], openapi_config=DEFAULT_OPENAPI_CONFIG) as client: @@ -251,7 +251,7 @@ def sync_handler(cursor: Optional[str] = None) -> CursorPagination[str, Pydantic assert spec == { "schema": { "properties": { - "items": {"items": {"$ref": "#/components/schemas/PydanticPerson"}, "type": "array"}, + "items": {"items": {"$ref": "#/components/schemas/DataclassPerson"}, "type": "array"}, "cursor": { "type": "string", "description": "Unique ID, designating the last identifier in the given data set. This value can be used to request the 'next' batch of records.", diff --git a/tests/unit/test_response/test_response_to_asgi_response.py b/tests/unit/test_response/test_response_to_asgi_response.py index f398e61801..15ac3bdfa2 100644 --- a/tests/unit/test_response/test_response_to_asgi_response.py +++ b/tests/unit/test_response/test_response_to_asgi_response.py @@ -6,6 +6,7 @@ from time import sleep from typing import TYPE_CHECKING, Any, Generator, Iterator, cast +import msgspec import pytest from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse from starlette.responses import Response as StarletteResponse @@ -14,7 +15,6 @@ from litestar._signature import SignatureModel from litestar.background_tasks import BackgroundTask from litestar.contrib.jinja import JinjaTemplateEngine -from litestar.contrib.pydantic import _model_dump from litestar.datastructures import Cookie, ResponseHeader from litestar.response import ServerSentEvent from litestar.response.base import ASGIResponse @@ -28,7 +28,7 @@ from litestar.types import StreamType from litestar.utils import AsyncIteratorWrapper from litestar.utils.signature import ParsedSignature -from tests import PydanticPerson, PydanticPersonFactory +from tests.models import DataclassPerson, DataclassPersonFactory if TYPE_CHECKING: from typing import AsyncGenerator @@ -74,11 +74,11 @@ def __init__(self) -> None: async def test_to_response_async_await(anyio_backend: str) -> None: @route(http_method=HttpMethod.POST, path="/person") - async def handler(data: PydanticPerson) -> PydanticPerson: - assert isinstance(data, PydanticPerson) + async def handler(data: DataclassPerson) -> DataclassPerson: + assert isinstance(data, DataclassPerson) return data - person_instance = PydanticPersonFactory.build() + person_instance = DataclassPersonFactory.build() handler._signature_model = SignatureModel.create( dependency_name_set=set(), fn=handler.fn.value, @@ -92,7 +92,7 @@ async def handler(data: PydanticPerson) -> PydanticPerson: app=Litestar(route_handlers=[handler]), request=RequestFactory().get(route_handler=handler), ) - assert loads(response.body) == _model_dump(person_instance) # type: ignore + assert loads(response.body) == msgspec.to_builtins(person_instance) # type: ignore[attr-defined] async def test_to_response_returning_litestar_response() -> None: diff --git a/tests/unit/test_response/test_serialization.py b/tests/unit/test_response/test_serialization.py index 6eb8f4a4fd..c60cee7954 100644 --- a/tests/unit/test_response/test_serialization.py +++ b/tests/unit/test_response/test_serialization.py @@ -1,28 +1,24 @@ import enum -from json import loads from pathlib import Path, PurePath -from typing import Any, Dict, List +from typing import Any, Callable, cast import msgspec import pytest from pydantic import SecretStr +from pytest import FixtureRequest from litestar import MediaType, Response -from litestar.contrib.pydantic import PydanticInitPlugin, _model_dump +from litestar.contrib.pydantic import PydanticInitPlugin from litestar.exceptions import ImproperlyConfiguredException from litestar.serialization import get_serializer -from tests import ( +from tests.models import ( + DataclassPersonFactory, MsgSpecStructPerson, - PydanticDataClassPerson, + PydanticDataclassPerson, PydanticPerson, - PydanticPersonFactory, - VanillaDataClassPerson, ) -person = PydanticPersonFactory.build() -secret = SecretStr("secret_text") -pure_path = PurePath("/path/to/file") -path = Path("/path/to/file") +person = DataclassPersonFactory.build() class _TestEnum(enum.Enum): @@ -30,42 +26,68 @@ class _TestEnum(enum.Enum): B = "beta" -@pytest.mark.parametrize("media_type", [MediaType.JSON, MediaType.MESSAGEPACK]) -@pytest.mark.parametrize( - "content, response_type", - [ - [person, PydanticPerson], - [{"key": 123}, Dict[str, int]], - [[{"key": 123}], List[Dict[str, int]]], - [VanillaDataClassPerson(**_model_dump(person)), VanillaDataClassPerson], - [PydanticDataClassPerson(**_model_dump(person)), PydanticDataClassPerson], - [MsgSpecStructPerson(**_model_dump(person)), MsgSpecStructPerson], - [{"enum": _TestEnum.A}, Dict[str, _TestEnum]], - [{"secret": secret}, Dict[str, SecretStr]], - [{"pure_path": pure_path}, Dict[str, PurePath]], - [{"path": path}, Dict[str, PurePath]], - ], -) -def test_response_serialization_structured_types(content: Any, response_type: Any, media_type: MediaType) -> None: +@pytest.fixture(params=[MediaType.JSON, MediaType.MESSAGEPACK]) +def media_type(request: FixtureRequest) -> MediaType: + return cast(MediaType, request.param) + + +DecodeMediaType = Callable[[Any], Any] + + +@pytest.fixture() +def decode_media_type(media_type: MediaType) -> DecodeMediaType: + if media_type == MediaType.JSON: + return msgspec.json.decode + return msgspec.msgpack.decode + + +def test_pydantic(media_type: MediaType, decode_media_type: DecodeMediaType) -> None: + content = PydanticPerson.parse_obj(msgspec.to_builtins(person)) encoded = Response(None).render( content, media_type=media_type, enc_hook=get_serializer(type_encoders=PydanticInitPlugin.encoders()) ) - if media_type == media_type.JSON: - value = loads(encoded) - else: - value = msgspec.msgpack.decode(encoded) - if isinstance(value, dict) and "enum" in value: - assert content.__class__(**value)["enum"] == content["enum"].value - elif isinstance(value, dict) and "secret" in value: - assert content.__class__(**value)["secret"] == str(content["secret"]) - elif isinstance(value, dict) and "pure_path" in value: - assert content.__class__(**value)["pure_path"] == str(content["pure_path"]) - elif isinstance(value, dict) and "path" in value: - assert content.__class__(**value)["path"] == str(content["path"]) - elif isinstance(value, dict): - assert content.__class__(**value) == content - else: - assert [content[0].__class__(**value[0])] == content + assert PydanticPerson.parse_obj(decode_media_type(encoded)) == content + + +def test_dataclass(media_type: MediaType, decode_media_type: DecodeMediaType) -> None: + encoded = Response(None).render(person, media_type=media_type) + assert decode_media_type(encoded) == msgspec.to_builtins(person) + + +def test_pydantic_dataclass(media_type: MediaType, decode_media_type: DecodeMediaType) -> None: + encoded = Response(None).render(PydanticDataclassPerson(**msgspec.to_builtins(person)), media_type=media_type) + assert decode_media_type(encoded) == msgspec.to_builtins(person) + + +def test_struct(media_type: MediaType, decode_media_type: DecodeMediaType) -> None: + encoded = Response(None).render(MsgSpecStructPerson(**msgspec.to_builtins(person)), media_type=media_type) + assert decode_media_type(encoded) == msgspec.to_builtins(person) + + +@pytest.mark.parametrize("content", [{"value": 1}, [{"value": 1}]]) +def test_dict(media_type: MediaType, decode_media_type: DecodeMediaType, content: Any) -> None: + encoded = Response(None).render(content, media_type=media_type) + assert decode_media_type(encoded) == content + + +def test_enum(media_type: MediaType, decode_media_type: DecodeMediaType) -> None: + encoded = Response(None).render({"value": _TestEnum.A}, media_type=media_type) + assert decode_media_type(encoded) == {"value": _TestEnum.A.value} + + +def test_pydantic_secret(media_type: MediaType, decode_media_type: DecodeMediaType) -> None: + encoded = Response(None).render( + {"value": SecretStr("secret_text")}, + media_type=media_type, + enc_hook=get_serializer(type_encoders=PydanticInitPlugin.encoders()), + ) + assert decode_media_type(encoded) == {"value": "**********"} + + +@pytest.mark.parametrize("path", [PurePath("/path/to/file"), Path("/path/to/file")]) +def test_path(media_type: MediaType, decode_media_type: DecodeMediaType, path: Path) -> None: + encoded = Response(None).render({"value": path}, media_type=media_type) + assert decode_media_type(encoded) == {"value": "/path/to/file"} @pytest.mark.parametrize( diff --git a/tests/unit/test_security/test_session_auth.py b/tests/unit/test_security/test_session_auth.py index fd8d8cbfa5..e66f65892e 100644 --- a/tests/unit/test_security/test_session_auth.py +++ b/tests/unit/test_security/test_session_auth.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from uuid import uuid4 +import msgspec from starlette.status import ( HTTP_200_OK, HTTP_201_CREATED, @@ -9,14 +10,13 @@ ) from litestar import Litestar, Request, delete, get, post -from litestar.contrib.pydantic import _model_dump from litestar.middleware.session.server_side import ( ServerSideSessionBackend, ServerSideSessionConfig, ) from litestar.security.session_auth import SessionAuth from litestar.testing import create_test_client -from tests import User, UserFactory +from tests.models import User, UserFactory if TYPE_CHECKING: from litestar.connection import ASGIConnection @@ -39,7 +39,7 @@ def test_authentication(session_backend_config_memory: ServerSideSessionConfig) @post("/login") def login_handler(request: "Request[Any, Any, Any]", data: User) -> None: - request.set_session(_model_dump(data)) + request.set_session(msgspec.to_builtins(data)) @delete("/user/{user_id:str}") def delete_user_handler(request: "Request[User, Any, Any]") -> None: diff --git a/tests/unit/test_testing/test_request_factory.py b/tests/unit/test_testing/test_request_factory.py index 198bfd07fa..b4023eba3a 100644 --- a/tests/unit/test_testing/test_request_factory.py +++ b/tests/unit/test_testing/test_request_factory.py @@ -1,26 +1,27 @@ import json -from typing import Callable, Dict +from typing import Any, Callable, Dict +import msgspec import pytest from pydantic import BaseModel from litestar import HttpMethod, Litestar, get -from litestar.contrib.pydantic import _model_dump from litestar.datastructures import Cookie, MultiDict from litestar.enums import ParamType, RequestEncodingType +from litestar.serialization import encode_json from litestar.testing import RequestFactory from litestar.types import DataContainerType -from tests import ( +from tests.models import ( AttrsPerson, + DataclassPerson, + DataclassPersonFactory, + DataclassPetFactory, MsgSpecStructPerson, PydanticPerson, - PydanticPersonFactory, - PydanticPetFactory, - VanillaDataClassPerson, ) _DEFAULT_REQUEST_FACTORY_URL = "http://test.org:3000/" -pet = PydanticPetFactory.build() +pet = DataclassPetFactory.build() async def test_request_factory_empty_body() -> None: @@ -64,26 +65,30 @@ def test_request_factory_build_headers() -> None: assert headers[decoded_key] == decoded_value -@pytest.mark.parametrize("data_cls", [PydanticPerson, VanillaDataClassPerson, AttrsPerson, MsgSpecStructPerson]) +def _json_roundtrip(obj: Any) -> Any: + return + + +@pytest.mark.parametrize("data_cls", [PydanticPerson, DataclassPerson, AttrsPerson, MsgSpecStructPerson]) async def test_request_factory_create_with_data(data_cls: DataContainerType) -> None: - person = _model_dump(PydanticPersonFactory.build()) + person_data = msgspec.json.decode(encode_json(DataclassPersonFactory.build())) request = RequestFactory()._create_request_with_data( HttpMethod.POST, "/", - data=data_cls(**person), # type: ignore + data=data_cls(**person_data), # type: ignore ) body = await request.body() - assert json.loads(body.decode()) == person + assert json.loads(body) == person_data @pytest.mark.parametrize( "request_media_type, verify_data", [ - [RequestEncodingType.JSON, lambda data: json.loads(data) == _model_dump(pet)], + [RequestEncodingType.JSON, lambda data: json.loads(data) == msgspec.to_builtins(pet)], [RequestEncodingType.MULTI_PART, lambda data: "Content-Disposition" in data], [ RequestEncodingType.URL_ENCODED, - lambda data: data == f"name={pet.name}&species={pet.species.value}&age={pet.age}", + lambda data: data == f"name={pet.name}&age={pet.age}&species={pet.species.value}", ], ], ) @@ -93,7 +98,7 @@ async def test_request_factory_create_with_content_type( request = RequestFactory()._create_request_with_data( HttpMethod.POST, "/", - data=_model_dump(pet), + data=msgspec.to_builtins(pet), request_media_type=request_media_type, ) assert request.headers["Content-Type"].startswith(request_media_type.value) @@ -197,4 +202,4 @@ async def test_request_factory_post_put_patch(factory: Callable, method: HttpMet assert len(request.headers.keys()) == 3 assert request.headers.get("header1") == "value1" body = await request.body() - assert json.loads(body) == _model_dump(pet) + assert json.loads(body) == msgspec.to_builtins(pet) diff --git a/tests/unit/test_utils/test_typing.py b/tests/unit/test_utils/test_typing.py index db3c897e24..b5e2e41f23 100644 --- a/tests/unit/test_utils/test_typing.py +++ b/tests/unit/test_utils/test_typing.py @@ -7,7 +7,7 @@ from typing_extensions import Annotated from litestar.utils.typing import annotation_is_iterable_of_type, get_origin_or_inner_type, make_non_optional_union -from tests import PydanticPerson, PydanticPet +from tests.models import DataclassPerson, DataclassPet if version_info >= (3, 10): from collections import deque # noqa: F401 @@ -15,12 +15,12 @@ py_310_plus_annotation = [ (eval(tp), exp) for tp, exp in [ - ("tuple[PydanticPerson, ...]", True), - ("list[PydanticPerson]", True), - ("deque[PydanticPerson]", True), - ("tuple[PydanticPet, ...]", False), - ("list[PydanticPet]", False), - ("deque[PydanticPet]", False), + ("tuple[DataclassPerson, ...]", True), + ("list[DataclassPerson]", True), + ("deque[DataclassPerson]", True), + ("tuple[DataclassPet, ...]", False), + ("list[DataclassPet]", False), + ("deque[DataclassPet]", False), ] ] else: @@ -30,16 +30,16 @@ @pytest.mark.parametrize( "annotation, expected", ( - (List[PydanticPerson], True), - (Sequence[PydanticPerson], True), - (Iterable[PydanticPerson], True), - (Tuple[PydanticPerson, ...], True), - (Deque[PydanticPerson], True), - (List[PydanticPet], False), - (Sequence[PydanticPet], False), - (Iterable[PydanticPet], False), - (Tuple[PydanticPet, ...], False), - (Deque[PydanticPet], False), + (List[DataclassPerson], True), + (Sequence[DataclassPerson], True), + (Iterable[DataclassPerson], True), + (Tuple[DataclassPerson, ...], True), + (Deque[DataclassPerson], True), + (List[DataclassPet], False), + (Sequence[DataclassPet], False), + (Iterable[DataclassPet], False), + (Tuple[DataclassPet, ...], False), + (Deque[DataclassPet], False), *py_310_plus_annotation, (int, False), (str, False), @@ -47,7 +47,7 @@ ), ) def test_annotation_is_iterable_of_type(annotation: Any, expected: bool) -> None: - assert annotation_is_iterable_of_type(annotation=annotation, type_value=PydanticPerson) is expected + assert annotation_is_iterable_of_type(annotation=annotation, type_value=DataclassPerson) is expected @pytest.mark.parametrize( @@ -58,6 +58,6 @@ def test_make_non_optional_union(annotation: Any, expected: Any) -> None: def test_get_origin_or_inner_type() -> None: - assert get_origin_or_inner_type(List[PydanticPerson]) == list - assert get_origin_or_inner_type(Annotated[List[PydanticPerson], "foo"]) == list - assert get_origin_or_inner_type(Annotated[Dict[str, List[PydanticPerson]], "foo"]) == dict + assert get_origin_or_inner_type(List[DataclassPerson]) == list + assert get_origin_or_inner_type(Annotated[List[DataclassPerson], "foo"]) == list + assert get_origin_or_inner_type(Annotated[Dict[str, List[DataclassPerson]], "foo"]) == dict