diff --git a/litestar/app.py b/litestar/app.py index e50e808fa7..c16a9eb4d2 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -341,7 +341,7 @@ def __init__( opt=dict(opt or {}), parameters=parameters or {}, pdb_on_exception=pdb_on_exception, - plugins=[*(plugins or []), *self._get_default_plugins()], + plugins=self._get_default_plugins(list(plugins or [])), request_class=request_class, response_cache_config=response_cache_config or ResponseCacheConfig(), response_class=response_class, @@ -466,18 +466,28 @@ def serialization_plugins(self) -> list[SerializationPluginProtocol]: return list(self.plugins.serialization) @staticmethod - def _get_default_plugins() -> list[PluginProtocol]: - default_plugins: list[PluginProtocol] = [] + def _get_default_plugins(plugins: list[PluginProtocol] | None = None) -> list[PluginProtocol]: + if plugins is None: + plugins = [] with suppress(MissingDependencyException): - from litestar.contrib.pydantic import PydanticInitPlugin, PydanticSchemaPlugin - - default_plugins.extend((PydanticInitPlugin(), PydanticSchemaPlugin())) - + from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin + + pydantic_plugin_found = any(isinstance(plugin, PydanticPlugin) for plugin in plugins) + pydantic_init_plugin_found = any(isinstance(plugin, PydanticInitPlugin) for plugin in plugins) + pydantic_schema_plugin_found = any(isinstance(plugin, PydanticSchemaPlugin) for plugin in plugins) + if not pydantic_plugin_found and not pydantic_init_plugin_found and not pydantic_schema_plugin_found: + plugins.append(PydanticPlugin()) + elif not pydantic_plugin_found and pydantic_init_plugin_found and not pydantic_schema_plugin_found: + plugins.append(PydanticSchemaPlugin()) + elif not pydantic_plugin_found and not pydantic_init_plugin_found and pydantic_schema_plugin_found: + plugins.append(PydanticInitPlugin()) with suppress(MissingDependencyException): from litestar.contrib.attrs import AttrsSchemaPlugin - default_plugins.append(AttrsSchemaPlugin()) - return default_plugins + pre_configured = any(isinstance(plugin, AttrsSchemaPlugin) for plugin in plugins) + if not pre_configured: + plugins.append(AttrsSchemaPlugin()) + return plugins @property def debug(self) -> bool: diff --git a/litestar/contrib/pydantic/__init__.py b/litestar/contrib/pydantic/__init__.py index de217e56f7..3c3d0dcf9e 100644 --- a/litestar/contrib/pydantic/__init__.py +++ b/litestar/contrib/pydantic/__init__.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any +from litestar.plugins import InitPluginProtocol + from .pydantic_dto_factory import PydanticDTO from .pydantic_init_plugin import PydanticInitPlugin from .pydantic_schema_plugin import PydanticSchemaPlugin @@ -9,7 +11,9 @@ if TYPE_CHECKING: import pydantic -__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin") + from litestar.config.app import AppConfig + +__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin", "PydanticPlugin") def _model_dump(model: pydantic.BaseModel, *, by_alias: bool = False) -> dict[str, Any]: @@ -20,5 +24,32 @@ def _model_dump(model: pydantic.BaseModel, *, by_alias: bool = False) -> dict[st ) -def _model_dump_json(model: pydantic.BaseModel) -> str: - return model.model_dump_json() if hasattr(model, "model_dump_json") else model.json() +def _model_dump_json(model: pydantic.BaseModel, by_alias: bool = False) -> str: + return ( + model.model_dump_json(by_alias=by_alias) if hasattr(model, "model_dump_json") else model.json(by_alias=by_alias) + ) + + +class PydanticPlugin(InitPluginProtocol): + """A plugin that provides Pydantic integration.""" + + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + """Initialize ``PydanticPlugin``. + + Args: + prefer_alias: OpenAPI and ``type_encoders`` will export by alias + """ + self.prefer_alias = prefer_alias + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Configure application for use with Pydantic. + + Args: + app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. + """ + app_config.plugins.extend( + [PydanticInitPlugin(prefer_alias=self.prefer_alias), PydanticSchemaPlugin(prefer_alias=self.prefer_alias)] + ) + return app_config diff --git a/litestar/contrib/pydantic/config.py b/litestar/contrib/pydantic/config.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/litestar/contrib/pydantic/pydantic_init_plugin.py b/litestar/contrib/pydantic/pydantic_init_plugin.py index 4a57bf5aba..eb7b851f93 100644 --- a/litestar/contrib/pydantic/pydantic_init_plugin.py +++ b/litestar/contrib/pydantic/pydantic_init_plugin.py @@ -71,11 +71,16 @@ def _is_pydantic_uuid(value: Any) -> bool: # pragma: no cover class PydanticInitPlugin(InitPluginProtocol): + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + self.prefer_alias = prefer_alias + @classmethod - def encoders(cls) -> dict[Any, Callable[[Any], Any]]: + def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: if pydantic.VERSION.startswith("1"): # pragma: no cover - return {**_base_encoders, **cls._create_pydantic_v1_encoders()} - return {**_base_encoders, **cls._create_pydantic_v2_encoders()} + return {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)} + return {**_base_encoders, **cls._create_pydantic_v2_encoders(prefer_alias)} @classmethod def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]: @@ -89,10 +94,10 @@ def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any] return decoders @staticmethod - def _create_pydantic_v1_encoders() -> dict[Any, Callable[[Any], Any]]: # pragma: no cover + def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover return { pydantic.BaseModel: lambda model: { - k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict().items() + k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items() }, pydantic.SecretField: str, pydantic.StrictBool: int, @@ -102,9 +107,9 @@ def _create_pydantic_v1_encoders() -> dict[Any, Callable[[Any], Any]]: # pragma } @staticmethod - def _create_pydantic_v2_encoders() -> dict[Any, Callable[[Any], Any]]: + def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: encoders: dict[Any, Callable[[Any], Any]] = { - pydantic.BaseModel: lambda model: model.model_dump(mode="json"), + pydantic.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias), pydantic.types.SecretStr: lambda val: "**********" if val else "", pydantic.types.SecretBytes: lambda val: "**********" if val else "", } @@ -117,6 +122,6 @@ def _create_pydantic_v2_encoders() -> dict[Any, Callable[[Any], Any]]: return encoders def on_app_init(self, app_config: AppConfig) -> AppConfig: - app_config.type_encoders = {**self.encoders(), **(app_config.type_encoders or {})} + app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})} app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])] return app_config diff --git a/litestar/contrib/pydantic/pydantic_schema_plugin.py b/litestar/contrib/pydantic/pydantic_schema_plugin.py index b644fcb570..eaf04131e2 100644 --- a/litestar/contrib/pydantic/pydantic_schema_plugin.py +++ b/litestar/contrib/pydantic/pydantic_schema_plugin.py @@ -132,6 +132,11 @@ class PydanticSchemaPlugin(OpenAPISchemaPluginProtocol): + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + self.prefer_alias = prefer_alias + @staticmethod def is_plugin_supported_type(value: Any) -> bool: return isinstance(value, _supported_types) or is_class_and_subclass(value, _supported_types) # type: ignore @@ -146,6 +151,8 @@ def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: S Returns: An :class:`OpenAPI ` instance. """ + if schema_creator.prefer_alias != self.prefer_alias: + schema_creator.prefer_alias = True if is_pydantic_model_class(field_definition.annotation): return self.for_pydantic_model(annotation=field_definition.annotation, schema_creator=schema_creator) return PYDANTIC_TYPE_MAP[field_definition.annotation] # pragma: no cover diff --git a/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py b/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py index d730d33f47..ba2c52d4b2 100644 --- a/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py +++ b/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py @@ -175,9 +175,13 @@ def test_serialization_of_model_instance(model: BaseModel) -> None: assert serializer(model) == _model_dump(model) -def test_pydantic_json_compatibility(model: BaseModel) -> None: - raw = _model_dump_json(model) - encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders())) +@pytest.mark.parametrize( + "prefer_alias", + [(False), (True)], +) +def test_pydantic_json_compatibility(model: BaseModel, prefer_alias: bool) -> None: + raw = _model_dump_json(model, by_alias=prefer_alias) + encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=prefer_alias))) raw_result = json.loads(raw) encoded_result = json.loads(encoded_json) @@ -203,17 +207,25 @@ def test_decode_json_raises_serialization_exception(model: BaseModel, decoder: A decoder(b"str") -def test_decode_json_typed(model: BaseModel) -> None: - dumped_model = _model_dump_json(model) +@pytest.mark.parametrize( + "prefer_alias", + [(False), (True)], +) +def test_decode_json_typed(model: BaseModel, prefer_alias: bool) -> None: + dumped_model = _model_dump_json(model, by_alias=prefer_alias) decoded_model = decode_json(value=dumped_model, target_type=Model, type_decoders=PydanticInitPlugin.decoders()) - assert _model_dump_json(decoded_model) == dumped_model + assert _model_dump_json(decoded_model, by_alias=prefer_alias) == dumped_model -def test_decode_msgpack_typed(model: BaseModel) -> None: - model_json = _model_dump_json(model) +@pytest.mark.parametrize( + "prefer_alias", + [(False), (True)], +) +def test_decode_msgpack_typed(model: BaseModel, prefer_alias: bool) -> None: + model_json = _model_dump_json(model, by_alias=prefer_alias) assert ( decode_msgpack( - encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders())), + encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=prefer_alias))), Model, type_decoders=PydanticInitPlugin.decoders(), ).json() diff --git a/tests/unit/test_openapi/test_config.py b/tests/unit/test_openapi/test_config.py index 24c9e7e8a8..83fb90c594 100644 --- a/tests/unit/test_openapi/test_config.py +++ b/tests/unit/test_openapi/test_config.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from litestar import Litestar, get, post +from litestar.contrib.pydantic import PydanticPlugin from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.config import OpenAPIConfig from litestar.openapi.spec import Components, Example, OpenAPIHeader, OpenAPIType, Schema @@ -83,6 +84,45 @@ def handler(data: RequestWithAlias) -> ResponseWithAlias: assert response.json() == {response_key: "foo"} +def test_pydantic_plugin_override_by_alias() -> None: + class RequestWithAlias(BaseModel): + first: str = Field(alias="second") + + class ResponseWithAlias(BaseModel): + first: str = Field(alias="second") + + @post("/") + def handler(data: RequestWithAlias) -> ResponseWithAlias: + return ResponseWithAlias(second=data.first) + + app = Litestar( + route_handlers=[handler], + openapi_config=OpenAPIConfig(title="my title", version="1.0.0"), + plugins=[PydanticPlugin(prefer_alias=True)], + ) + + assert app.openapi_schema + schemas = app.openapi_schema.to_schema()["components"]["schemas"] + request_key = "second" + assert schemas["RequestWithAlias"] == { + "properties": {request_key: {"type": "string"}}, + "type": "object", + "required": [request_key], + "title": "RequestWithAlias", + } + response_key = "second" + assert schemas["ResponseWithAlias"] == { + "properties": {response_key: {"type": "string"}}, + "type": "object", + "required": [response_key], + "title": "ResponseWithAlias", + } + + with TestClient(app) as client: + response = client.post("/", json={request_key: "foo"}) + assert response.json() == {response_key: "foo"} + + def test_allows_customization_of_operation_id_creator() -> None: def operation_id_creator(handler: "HTTPRouteHandler", _: Any, __: Any) -> str: return handler.name or ""