diff --git a/litestar/_signature/model.py b/litestar/_signature/model.py index 42c79947f5..33653ed548 100644 --- a/litestar/_signature/model.py +++ b/litestar/_signature/model.py @@ -38,7 +38,7 @@ from litestar.exceptions import InternalServerException, ValidationException from litestar.params import KwargDefinition, ParameterKwarg from litestar.typing import FieldDefinition # noqa -from litestar.utils import is_class_and_subclass +from litestar.utils import get_origin_or_inner_type, is_class_and_subclass from litestar.utils.dataclass import simple_asdict if TYPE_CHECKING: @@ -85,8 +85,15 @@ def _deserializer(target_type: Any, value: Any, default_deserializer: Callable[[ if isinstance(value, DTOData): return value - if isinstance(value, target_type): - return value + try: + if isinstance(value, target_type): + return value + except TypeError as exc: + if (origin := get_origin_or_inner_type(target_type)) is not None: + if isinstance(value, origin): + return value + else: + raise exc if decoder := getattr(target_type, "_decoder", None): return decoder(target_type, value) diff --git a/litestar/serialization/msgspec_hooks.py b/litestar/serialization/msgspec_hooks.py index ba98ee4a4f..06c6a65324 100644 --- a/litestar/serialization/msgspec_hooks.py +++ b/litestar/serialization/msgspec_hooks.py @@ -22,6 +22,7 @@ from litestar.datastructures.secret_values import SecretBytes, SecretString from litestar.exceptions import SerializationException from litestar.types import Empty, EmptyType, Serializer, TypeDecodersSequence +from litestar.utils.typing import get_origin_or_inner_type if TYPE_CHECKING: from litestar.types import TypeEncodersMap @@ -107,8 +108,19 @@ def default_deserializer( from litestar.datastructures.state import ImmutableState - if isinstance(value, target_type): - return value + try: + if isinstance(value, target_type): + return value + except TypeError as exc: + # we might get a TypeError here if target_type is a subscribed generic. For + # performance reasons, we let this happen and only unwrap this when we're + # certain this might be the case + if (origin := get_origin_or_inner_type(target_type)) is not None: + target_type = origin + if isinstance(value, target_type): + return value + else: + raise exc if type_decoders: for predicate, decoder in type_decoders: diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index db968c433d..bc39616723 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Generic, List, Optional, TypeVar import pytest from attr import define @@ -289,3 +289,17 @@ def fn(a: Annotated[int, Parameter(gt=5)], b: Annotated[int, Parameter(lt=5)]) - {"message": "Expected `int` >= 6", "key": "a", "source": ParamType.QUERY}, {"message": "Expected `int` <= 4", "key": "b", "source": ParamType.QUERY}, ] + + +def test_validate_subscribed_generics() -> None: + T = TypeVar("T") + + class Foo(Generic[T]): + pass + + @get("/") + async def something(foo: Foo[str] = Foo()) -> None: + return None + + with create_test_client([something]) as client: + assert client.get("/").status_code == 200