diff --git a/litestar/_openapi/path_item.py b/litestar/_openapi/path_item.py index 74a04ce756..0081eead24 100644 --- a/litestar/_openapi/path_item.py +++ b/litestar/_openapi/path_item.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses from inspect import cleandoc from typing import TYPE_CHECKING @@ -8,6 +9,7 @@ from litestar._openapi.responses import create_responses_for_handler from litestar._openapi.utils import SEPARATORS_CLEANUP_PATTERN from litestar.enums import HttpMethod +from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.spec import Operation, PathItem from litestar.utils.helpers import unwrap_partial @@ -16,7 +18,7 @@ from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.routes import HTTPRoute -__all__ = ("create_path_item_for_route",) +__all__ = ("create_path_item_for_route", "merge_path_item_operations") class PathItemFactory: @@ -135,3 +137,32 @@ def create_path_item_for_route(openapi_context: OpenAPIContext, route: HTTPRoute """ path_item_factory = PathItemFactory(openapi_context, route) return path_item_factory.create_path_item() + + +def merge_path_item_operations(source: PathItem, other: PathItem, for_path: str) -> PathItem: + """Merge operations from path items, creating a new path item that includes + operations from both. + """ + attrs_to_merge = {"get", "put", "post", "delete", "options", "head", "patch", "trace"} + fields = {f.name for f in dataclasses.fields(PathItem)} - attrs_to_merge + if any(getattr(source, attr) and getattr(other, attr) for attr in attrs_to_merge): + raise ValueError("Cannot merge operation for PathItem if operation is set on both items") + + if differing_values := [ + (value_a, value_b) for attr in fields if (value_a := getattr(source, attr)) != (value_b := getattr(other, attr)) + ]: + raise ImproperlyConfiguredException( + f"Conflicting OpenAPI path configuration for {for_path!r}. " + f"{', '.join(f'{a} != {b}' for a, b in differing_values)}" + ) + + return dataclasses.replace( + source, + get=source.get or other.get, + post=source.post or other.post, + patch=source.patch or other.patch, + put=source.put or other.put, + delete=source.delete or other.delete, + options=source.options or other.options, + trace=source.trace or other.trace, + ) diff --git a/litestar/_openapi/plugin.py b/litestar/_openapi/plugin.py index 9bdbdecebd..4c5c32236a 100644 --- a/litestar/_openapi/plugin.py +++ b/litestar/_openapi/plugin.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from litestar._openapi.datastructures import OpenAPIContext -from litestar._openapi.path_item import create_path_item_for_route +from litestar._openapi.path_item import create_path_item_for_route, merge_path_item_operations from litestar.exceptions import ImproperlyConfiguredException from litestar.plugins import InitPluginProtocol from litestar.plugins.base import ReceiveRoutePlugin @@ -13,7 +13,7 @@ from litestar.app import Litestar from litestar.config.app import AppConfig from litestar.openapi.config import OpenAPIConfig - from litestar.openapi.spec import OpenAPI + from litestar.openapi.spec import OpenAPI, PathItem from litestar.routes import BaseRoute @@ -41,10 +41,15 @@ def _build_openapi_schema(self) -> OpenAPI: openapi = openapi_config.to_openapi_schema() context = OpenAPIContext(openapi_config=openapi_config, plugins=self.app.plugins.openapi) - openapi.paths = { - route.path_format or "/": create_path_item_for_route(context, route) - for route in self.included_routes.values() - } + path_items: dict[str, PathItem] = {} + for route in self.included_routes.values(): + path = route.path_format or "/" + path_item = create_path_item_for_route(context, route) + if existing_path_item := path_items.get(path): + path_item = merge_path_item_operations(existing_path_item, path_item, for_path=path) + path_items[path] = path_item + + openapi.paths = path_items openapi.components.schemas = context.schema_registry.generate_components_schemas() return openapi diff --git a/tests/unit/test_openapi/test_path_item.py b/tests/unit/test_openapi/test_path_item.py index caede875b4..1f0dadc5b3 100644 --- a/tests/unit/test_openapi/test_path_item.py +++ b/tests/unit/test_openapi/test_path_item.py @@ -1,19 +1,21 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, cast +from unittest.mock import MagicMock import pytest from typing_extensions import TypeAlias -from litestar import Controller, Litestar, Request, Router, delete, get +from litestar import Controller, HttpMethod, Litestar, Request, Router, delete, get from litestar._openapi.datastructures import OpenAPIContext -from litestar._openapi.path_item import PathItemFactory +from litestar._openapi.path_item import PathItemFactory, merge_path_item_operations from litestar._openapi.utils import default_operation_id_creator from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.openapi.config import OpenAPIConfig -from litestar.openapi.spec import Operation +from litestar.openapi.spec import Operation, PathItem from litestar.utils import find_index if TYPE_CHECKING: @@ -215,3 +217,30 @@ def handler_2() -> None: ... schema = factory.create_path_item() assert schema.get assert schema.delete is None + + +@pytest.mark.parametrize("method", HttpMethod) +def test_merge_path_item_operations_operation_set_on_both_raises(method: HttpMethod) -> None: + with pytest.raises(ValueError, match="Cannot merge operation"): + merge_path_item_operations( + PathItem(**{method.value.lower(): MagicMock()}), + PathItem(**{method.value.lower(): MagicMock()}), + for_path="/", + ) + + +@pytest.mark.parametrize( + "attr", + [ + f.name + for f in dataclasses.fields(PathItem) + if f.name.upper() + not in [ + *HttpMethod, + "TRACE", # remove once https://github.com/litestar-org/litestar/pull/3294 is merged + ] + ], +) +def test_merge_path_item_operation_differing_values_raises(attr: str) -> None: + with pytest.raises(ImproperlyConfiguredException, match="Conflicting OpenAPI path configuration for '/'"): + merge_path_item_operations(PathItem(), PathItem(**{attr: MagicMock()}), for_path="/") diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index a4aceee0cf..acae797c3c 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -23,14 +23,14 @@ from msgspec import Struct from typing_extensions import Annotated, TypeAlias -from litestar import Controller, MediaType, get +from litestar import Controller, MediaType, get, post from litestar._openapi.schema_generation.plugins import openapi_schema_plugins from litestar._openapi.schema_generation.schema import ( KWARG_DEFINITION_ATTRIBUTE_TO_OPENAPI_PROPERTY_MAP, SchemaCreator, ) from litestar._openapi.schema_generation.utils import _get_normalized_schema_key, _type_or_first_not_none_inner_type -from litestar.app import DEFAULT_OPENAPI_CONFIG +from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar from litestar.di import Provide from litestar.enums import ParamType from litestar.openapi.spec import ExternalDocumentation, OpenAPIType, Reference @@ -570,3 +570,19 @@ def test_default_not_provided_for_kwarg_but_for_field() -> None: schema = get_schema_for_field_definition(field_definition) assert schema.default == 10 + + +def test_routes_with_different_path_param_types_get_merged() -> None: + @get("/{param:int}") + async def get_handler(param: int) -> None: + pass + + @post("/{param:str}") + async def post_handler(param: str) -> None: + pass + + app = Litestar([get_handler, post_handler]) + assert app.openapi_schema.paths + paths = app.openapi_schema.paths["/{param}"] + assert paths.get is not None + assert paths.post is not None