diff --git a/litestar/contrib/pydantic/pydantic_dto_factory.py b/litestar/contrib/pydantic/pydantic_dto_factory.py index 6942b0b5aa..4322aa4c32 100644 --- a/litestar/contrib/pydantic/pydantic_dto_factory.py +++ b/litestar/contrib/pydantic/pydantic_dto_factory.py @@ -1,12 +1,13 @@ from __future__ import annotations +import dataclasses from dataclasses import replace from typing import TYPE_CHECKING, Any, Collection, Generic, TypeVar from warnings import warn from typing_extensions import Annotated, TypeAlias, override -from litestar.contrib.pydantic.utils import is_pydantic_undefined, is_pydantic_v2 +from litestar.contrib.pydantic.utils import is_pydantic_2_model, is_pydantic_undefined, is_pydantic_v2 from litestar.dto.base_dto import AbstractDTO from litestar.dto.data_structures import DTOFieldDefinition from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field @@ -17,6 +18,8 @@ if TYPE_CHECKING: from typing import Generator + from litestar.dto import DTOConfig + try: import pydantic as _ # noqa: F401 except ImportError as e: @@ -160,3 +163,13 @@ def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: if pydantic_v2 is not Empty: # type: ignore[comparison-overlap] return field_definition.is_subclass_of((pydantic_v1.BaseModel, pydantic_v2.BaseModel)) return field_definition.is_subclass_of(pydantic_v1.BaseModel) # type: ignore[unreachable] + + @classmethod + def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) -> DTOConfig: + if is_pydantic_2_model(model_type) and (model_config := getattr(model_type, "model_config", None)): + if model_config.get("extra") == "forbid": + config = dataclasses.replace(config, forbid_unknown_fields=True) + elif issubclass(model_type, pydantic_v1.BaseModel) and (model_config := getattr(model_type, "Config", None)): # noqa: SIM102 + if getattr(model_config, "extra", None) == "forbid": + config = dataclasses.replace(config, forbid_unknown_fields=True) + return config diff --git a/litestar/dto/base_dto.py b/litestar/dto/base_dto.py index 83fc16ce9c..352987bb3b 100644 --- a/litestar/dto/base_dto.py +++ b/litestar/dto/base_dto.py @@ -84,6 +84,24 @@ def __class_getitem__(cls, annotation: Any) -> type[Self]: return type(f"{cls.__name__}[{annotation}]", (cls,), cls_dict) # pyright: ignore + def __init_subclass__(cls, **kwargs: Any) -> None: + if (config := getattr(cls, "config", None)) and (model_type := getattr(cls, "model_type", None)): + # it's a concrete class + cls.config = cls.get_config_for_model_type(config, model_type) + + @classmethod + def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) -> DTOConfig: + """Create a new configuration for this specific ``model_type``, during the + creation of the factory. + + The returned config object will be set as the ``config`` attribute on the newly + defined factory class. + + .. versionadded: 2.11 + + """ + return config + def decode_builtins(self, value: dict[str, Any]) -> Any: """Decode a dictionary of Python values into an the DTO's datatype.""" diff --git a/tests/unit/test_contrib/test_pydantic/test_dto.py b/tests/unit/test_contrib/test_pydantic/test_dto.py index b6baddfa0e..782d915902 100644 --- a/tests/unit/test_contrib/test_pydantic/test_dto.py +++ b/tests/unit/test_contrib/test_pydantic/test_dto.py @@ -2,8 +2,10 @@ from typing import TYPE_CHECKING, Optional, cast +import pydantic as pydantic_v2 import pytest from pydantic import v1 as pydantic_v1 +from typing_extensions import Annotated, Literal from litestar import Request, post from litestar.contrib.pydantic import PydanticDTO, _model_dump_json @@ -100,3 +102,71 @@ def get_user() -> User: component_schema = schema.components.schemas["GetUserUserResponseBody"] assert component_schema.properties is not None assert component_schema.properties["id"].description == "This is a test (id description)." + + +@pytest.mark.parametrize( + "model_config_option, forbid_unknown_fields_default, expected_dto_config_option", + [ + ("forbid", False, True), + ("forbid", True, True), + ("allow", False, False), + ("allow", True, True), + ("ignore", True, True), + ("ignore", False, False), + ], +) +def test_forbid_unknown_fields_if_forbid_extra_is_set_v1( + use_experimental_dto_backend: bool, + forbid_unknown_fields_default: bool, + model_config_option: Literal["forbid", "allow", "ignore"], + expected_dto_config_option: bool, +) -> None: + class Model(pydantic_v1.BaseModel): + class Config: + extra = model_config_option + + a: str + + dto_config = DTOConfig( + experimental_codegen_backend=use_experimental_dto_backend, + # config set on the pydantic model should take precedence + forbid_unknown_fields=forbid_unknown_fields_default, + ) + dto = PydanticDTO[Annotated[Model, dto_config]] + + assert dto.config.forbid_unknown_fields is expected_dto_config_option + # ensure the config is merged + assert dto.config.experimental_codegen_backend is use_experimental_dto_backend + + +@pytest.mark.parametrize( + "model_config_option, forbid_unknown_fields_default, expected_dto_config_option", + [ + ("forbid", False, True), + ("forbid", True, True), + ("allow", False, False), + ("allow", True, True), + ("ignore", True, True), + ("ignore", False, False), + ], +) +def test_forbid_unknown_fields_if_forbid_extra_is_set_v2( + use_experimental_dto_backend: bool, + forbid_unknown_fields_default: bool, + model_config_option: Literal["forbid", "allow", "ignore"], + expected_dto_config_option: bool, +) -> None: + class Model(pydantic_v2.BaseModel): + a: str + model_config = pydantic_v2.ConfigDict(extra=model_config_option) + + dto_config = DTOConfig( + experimental_codegen_backend=use_experimental_dto_backend, + # config set on the pydantic model should take precedence + forbid_unknown_fields=forbid_unknown_fields_default, + ) + dto = PydanticDTO[Annotated[Model, dto_config]] + + assert dto.config.forbid_unknown_fields is expected_dto_config_option + # ensure the config is merged + assert dto.config.experimental_codegen_backend is use_experimental_dto_backend diff --git a/tests/unit/test_dto/test_factory/test_base_dto.py b/tests/unit/test_dto/test_factory/test_base_dto.py index 73daa53991..647291bb18 100644 --- a/tests/unit/test_dto/test_factory/test_base_dto.py +++ b/tests/unit/test_dto/test_factory/test_base_dto.py @@ -1,8 +1,9 @@ # ruff: noqa: UP006 from __future__ import annotations +import dataclasses from dataclasses import dataclass -from typing import TYPE_CHECKING, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Tuple, TypeVar, Union import pytest from typing_extensions import Annotated @@ -160,3 +161,31 @@ class SubType(Model): assert ( dto_type._dto_backends["handler_id"]["data_backend"].parsed_field_definitions[-1].name == "c" # pyright: ignore ) + + +def test_get_config_for_model_type() -> None: + base_config = DTOConfig(rename_strategy="camel") + + class CustomDTO(DataclassDTO[T], Generic[T]): + @classmethod + def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) -> DTOConfig: + return dataclasses.replace(config, exclude={"foo"}) + + annotated_dto = CustomDTO[Model] + annotated_dto_with_config = CustomDTO[Annotated[Model, base_config]] + + class SubclassDTO(CustomDTO[Model]): + pass + + class SubclassDTOWithConfig(CustomDTO[Model]): + config = base_config + + assert annotated_dto.config.exclude == {"foo"} + assert SubclassDTO.config.exclude == {"foo"} + + # we expect existing configs to have been merged + assert annotated_dto_with_config.config.exclude == {"foo"} + assert annotated_dto_with_config.config.rename_strategy == "camel" + + assert SubclassDTOWithConfig.config.exclude == {"foo"} + assert SubclassDTOWithConfig.config.rename_strategy == "camel"