From aba12c79a3ffff053cf475fbd92a6b45c7254fe0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Thu, 12 Sep 2024 09:02:47 +0200 Subject: [PATCH] feat: Support strings in `media_type` for `ResponseSpec` (#3729) --- litestar/_openapi/responses.py | 7 ++++++- litestar/openapi/datastructures.py | 2 +- tests/unit/test_openapi/test_responses.py | 8 ++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/litestar/_openapi/responses.py b/litestar/_openapi/responses.py index 6b0f312d3c..1be71838f2 100644 --- a/litestar/_openapi/responses.py +++ b/litestar/_openapi/responses.py @@ -252,7 +252,12 @@ def create_additional_responses(self) -> Iterator[tuple[str, OpenAPIResponse]]: content: dict[str, OpenAPIMediaType] | None if additional_response.data_container is not None: schema = schema_creator.for_field_definition(field_def) - content = {additional_response.media_type: OpenAPIMediaType(schema=schema, examples=examples)} + media_type = additional_response.media_type + content = { + get_enum_string_value(media_type) + if not isinstance(media_type, str) + else media_type: OpenAPIMediaType(schema=schema, examples=examples) + } else: content = None diff --git a/litestar/openapi/datastructures.py b/litestar/openapi/datastructures.py index 5796a48d4c..cc0981c6c3 100644 --- a/litestar/openapi/datastructures.py +++ b/litestar/openapi/datastructures.py @@ -23,7 +23,7 @@ class ResponseSpec: """Generate examples for the response content.""" description: str = field(default="Additional response") """A description of the response.""" - media_type: MediaType = field(default=MediaType.JSON) + media_type: MediaType | str = field(default=MediaType.JSON) """Response media type.""" examples: list[Example] | None = field(default=None) """A list of Example models.""" diff --git a/tests/unit/test_openapi/test_responses.py b/tests/unit/test_openapi/test_responses.py index 148fe28a71..b8e0baea48 100644 --- a/tests/unit/test_openapi/test_responses.py +++ b/tests/unit/test_openapi/test_responses.py @@ -361,6 +361,7 @@ class UnknownError(TypedDict): 401: ResponseSpec(data_container=AuthenticationError, description="Authentication error"), 500: ResponseSpec(data_container=ServerError, generate_examples=False, media_type=MediaType.TEXT), 505: ResponseSpec(data_container=UnknownError), + 900: ResponseSpec(data_container=UnknownError, media_type="application/vnd.custom"), } ) def handler() -> DataclassPerson: @@ -398,6 +399,13 @@ def handler() -> DataclassPerson: assert third_response[0] == "505" assert third_response[1].description == "Additional response" + fourth_response = next(responses) + assert fourth_response[0] == "900" + assert fourth_response[1].description == "Additional response" + custom_media_type_content = fourth_response[1].content.get("application/vnd.custom") # type: ignore[union-attr] + assert custom_media_type_content + assert isinstance(custom_media_type_content, OpenAPIMediaType) + with pytest.raises(StopIteration): next(responses)