Skip to content

Commit

Permalink
fix: Enum OAS generation (#3518) (#3525)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: svlyubovsk <[email protected]>
Co-authored-by: Alc-Alc <[email protected]>
Co-authored-by: Janek Nouvertné <[email protected]>
Co-authored-by: Alc-Alc <alc@localhost>
  • Loading branch information
5 people authored Nov 29, 2024
1 parent 3510cab commit 35a9837
Show file tree
Hide file tree
Showing 18 changed files with 192 additions and 146 deletions.
2 changes: 1 addition & 1 deletion docs/examples/openapi/plugins/swagger_ui_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from litestar.openapi.plugins import SwaggerRenderPlugin

swagger_plugin = SwaggerRenderPlugin(version="5.1.3", path="/swagger")
swagger_plugin = SwaggerRenderPlugin(version="5.18.2", path="/swagger")
64 changes: 40 additions & 24 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum, EnumMeta
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -40,9 +40,7 @@
create_string_constrained_field_schema,
)
from litestar._openapi.schema_generation.utils import (
_should_create_enum_schema,
_should_create_literal_schema,
_type_or_first_not_none_inner_type,
get_json_schema_formatted_examples,
)
from litestar.datastructures import SecretBytes, SecretString, UploadFile
Expand Down Expand Up @@ -181,22 +179,6 @@ def _get_type_schema_name(field_definition: FieldDefinition) -> str:
return name


def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Schema:
"""Create a schema instance for an enum.
Args:
annotation: An enum.
include_null: Whether to include null as a possible value.
Returns:
A schema instance.
"""
enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated]
if include_null and None not in enum_values:
enum_values.append(None)
return Schema(type=_types_in_list(enum_values), enum=enum_values)


def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]:
"""Iterate over the flattened arguments of a Literal.
Expand Down Expand Up @@ -331,18 +313,20 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re
result = self.for_type_alias_type(field_definition)
elif plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
annotation = _type_or_first_not_none_inner_type(field_definition)
result = create_enum_schema(annotation, include_null=field_definition.is_optional)
elif _should_create_literal_schema(field_definition):
annotation = (
make_non_optional_union(field_definition.annotation)
if field_definition.is_optional
else field_definition.annotation
)
result = create_literal_schema(annotation, include_null=field_definition.is_optional)
result = create_literal_schema(
annotation,
include_null=field_definition.is_optional,
)
elif field_definition.is_optional:
result = self.for_optional_field(field_definition)
elif field_definition.is_enum:
result = self.for_enum_field(field_definition)
elif field_definition.is_union:
result = self.for_union_field(field_definition)
elif field_definition.is_type_var:
Expand Down Expand Up @@ -445,7 +429,7 @@ def for_optional_field(self, field_definition: FieldDefinition) -> Schema:
else:
result = [schema_or_reference]

return Schema(one_of=[Schema(type=OpenAPIType.NULL), *result])
return Schema(one_of=[*result, Schema(type=OpenAPIType.NULL)])

def for_union_field(self, field_definition: FieldDefinition) -> Schema:
"""Create a Schema for a union FieldDefinition.
Expand Down Expand Up @@ -569,6 +553,38 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) ->
# INFO: Removed because it was only for pydantic constrained collections
return schema

def for_enum_field(
self,
field_definition: FieldDefinition,
) -> Schema | Reference:
"""Create a schema instance for an enum.
Args:
field_definition: A signature field instance.
Returns:
A schema or reference instance.
"""
enum_type: None | OpenAPIType | list[OpenAPIType] = None
if issubclass(field_definition.annotation, Enum): # pragma: no branch
# This method is only called for enums, so this branch is always executed
if issubclass(field_definition.annotation, str): # StrEnum
enum_type = OpenAPIType.STRING
elif issubclass(field_definition.annotation, int): # IntEnum
enum_type = OpenAPIType.INTEGER

enum_values: list[Any] = [v.value for v in field_definition.annotation]
if enum_type is None:
enum_type = _types_in_list(enum_values)

schema = self.schema_registry.get_schema_for_field_definition(field_definition)
schema.type = enum_type
schema.enum = enum_values
schema.title = get_name(field_definition.annotation)
schema.description = field_definition.annotation.__doc__

return self.schema_registry.get_reference_for_field_definition(field_definition) or schema

def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference:
if field.kwarg_definition and field.is_const and field.has_default and schema.const is None:
schema.const = field.default
Expand Down
49 changes: 1 addition & 48 deletions litestar/_openapi/schema_generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping

from litestar.utils.helpers import get_name
Expand All @@ -11,53 +10,7 @@
from litestar.openapi.spec import Example
from litestar.typing import FieldDefinition

__all__ = (
"_should_create_enum_schema",
"_should_create_literal_schema",
"_type_or_first_not_none_inner_type",
)


def _type_or_first_not_none_inner_type(field_definition: FieldDefinition) -> Any:
"""Get the first inner type that is not None.
This is a narrow focussed utility to be used when we know that a field definition either represents
a single type, or a single type in a union with `None`, and we want the single type.
Args:
field_definition: A field definition instance.
Returns:
A field definition instance.
"""
if not field_definition.is_optional:
return field_definition.annotation
inner = next((t for t in field_definition.inner_types if not t.is_none_type), None)
if inner is None:
raise ValueError("Field definition has no inner type that is not None")
return inner.annotation


def _should_create_enum_schema(field_definition: FieldDefinition) -> bool:
"""Predicate to determine if we should create an enum schema for the field def, or not.
This returns true if the field definition is an enum, or if the field definition is a union
of an enum and ``None``.
When an annotation is ``SomeEnum | None`` we should create a schema for the enum that includes ``null``
in the enum values.
Args:
field_definition: A field definition instance.
Returns:
A boolean
"""
return field_definition.is_subclass_of(Enum) or (
field_definition.is_optional
and len(field_definition.args) == 2
and field_definition.has_inner_subclass_of(Enum)
)
__all__ = ("_should_create_literal_schema",)


def _should_create_literal_schema(field_definition: FieldDefinition) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion litestar/openapi/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OpenAPIController(Controller):
"""Base styling of the html body."""
redoc_version: str = "next"
"""Redoc version to download from the CDN."""
swagger_ui_version: str = "5.1.3"
swagger_ui_version: str = "5.18.2"
"""SwaggerUI version to download from the CDN."""
stoplight_elements_version: str = "7.7.18"
"""StopLight Elements version to download from the CDN."""
Expand Down
2 changes: 1 addition & 1 deletion litestar/openapi/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ class SwaggerRenderPlugin(OpenAPIRenderPlugin):

def __init__(
self,
version: str = "5.1.3",
version: str = "5.18.2",
js_url: str | None = None,
css_url: str | None = None,
standalone_preset_js_url: str | None = None,
Expand Down
5 changes: 5 additions & 0 deletions litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import abc
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from enum import Enum
from inspect import Parameter, Signature
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, TypeVar, cast

Expand Down Expand Up @@ -339,6 +340,10 @@ def is_typeddict_type(self) -> bool:

return is_typeddict(self.origin or self.annotation)

@property
def is_enum(self) -> bool:
return self.is_subclass_of(Enum)

@property
def type_(self) -> Any:
"""The type of the annotation with all the wrappers removed, including the generic types."""
Expand Down
23 changes: 14 additions & 9 deletions tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
assert concert_schema
assert concert_schema.to_schema() == {
"properties": {
"band_1": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"band_2": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"venue": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"band_1": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
"band_2": {
"oneOf": [
{"type": "integer"},
{"type": "null"},
]
},
"venue": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
},
"required": [],
"title": "CreateConcertConcertRequestBody",
Expand All @@ -152,10 +157,10 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
assert record_studio_schema
assert record_studio_schema.to_schema() == {
"properties": {
"facilities": {"oneOf": [{"type": "null"}, {"type": "string"}]},
"facilities_b": {"oneOf": [{"type": "null"}, {"type": "string"}]},
"microphones": {"oneOf": [{"type": "null"}, {"items": {"type": "string"}, "type": "array"}]},
"id": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"facilities": {"oneOf": [{"type": "string"}, {"type": "null"}]},
"facilities_b": {"oneOf": [{"type": "string"}, {"type": "null"}]},
"microphones": {"oneOf": [{"items": {"type": "string"}, "type": "array"}, {"type": "null"}]},
"id": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
},
"required": [],
"title": "RetrieveStudioRecordingStudioResponseBody",
Expand All @@ -166,8 +171,8 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
assert venue_schema
assert venue_schema.to_schema() == {
"properties": {
"id": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"name": {"oneOf": [{"type": "null"}, {"type": "string"}]},
"id": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
"name": {"oneOf": [{"type": "string"}, {"type": "null"}]},
},
"required": [],
"title": "RetrieveVenuesVenueResponseBody",
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_openapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from litestar.openapi.spec.example import Example
from litestar.params import Parameter
from tests.models import DataclassPerson, DataclassPersonFactory, DataclassPet
from tests.unit.test_openapi.utils import Gender, PetException
from tests.unit.test_openapi.utils import Gender, LuckyNumber, PetException


class PartialDataclassPersonDTO(DataclassDTO[DataclassPerson]):
Expand Down Expand Up @@ -45,8 +45,9 @@ def get_persons(
from_date: Optional[Union[int, datetime, date]] = None,
to_date: Optional[Union[int, datetime, date]] = None,
gender: Optional[Union[Gender, List[Gender]]] = Parameter(
examples=[Example(value="M"), Example(value=["M", "O"])]
examples=[Example(value=Gender.MALE), Example(value=[Gender.MALE, Gender.OTHER])]
),
lucky_number: Optional[LuckyNumber] = Parameter(examples=[Example(value=LuckyNumber.SEVEN)]),
# header parameter
secret_header: str = Parameter(header="secret"),
# cookie parameter
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_openapi/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_default_redoc_cdn_urls(
def test_default_swagger_ui_cdn_urls(
person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig
) -> None:
default_swagger_ui_version = "5.1.3"
default_swagger_ui_version = "5.18.2"
default_swagger_bundles = [
f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui.css",
f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui-bundle.js",
Expand Down
41 changes: 24 additions & 17 deletions tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from litestar.exceptions import ImproperlyConfiguredException
from litestar.handlers import HTTPRouteHandler
from litestar.openapi import OpenAPIConfig
from litestar.openapi.spec import Example, OpenAPI, Schema
from litestar.openapi.spec import Example, OpenAPI, Reference, Schema
from litestar.openapi.spec.enums import OpenAPIType
from litestar.params import Dependency, Parameter
from litestar.routes import BaseRoute
from litestar.testing import create_test_client
from litestar.utils import find_index
from tests.unit.test_openapi.utils import Gender, LuckyNumber

if TYPE_CHECKING:
from litestar.openapi.spec.parameter import Parameter as OpenAPIParameter
Expand Down Expand Up @@ -49,8 +50,10 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
ExampleFactory.seed_random(10)

parameters = _create_parameters(app=Litestar(route_handlers=[person_controller]), path="/{service_id}/person")
assert len(parameters) == 9
page, name, service_id, page_size, from_date, to_date, gender, secret_header, cookie_value = tuple(parameters)
assert len(parameters) == 10
page, name, service_id, page_size, from_date, to_date, gender, lucky_number, secret_header, cookie_value = tuple(
parameters
)

assert service_id.name == "service_id"
assert service_id.param_in == ParamType.PATH
Expand Down Expand Up @@ -104,23 +107,15 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
assert is_schema_value(gender.schema)
assert gender.schema == Schema(
one_of=[
Schema(type=OpenAPIType.NULL),
Schema(
type=OpenAPIType.STRING,
enum=["M", "F", "O", "A"],
examples=["M"],
),
Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
Schema(
type=OpenAPIType.ARRAY,
items=Schema(
type=OpenAPIType.STRING,
enum=["M", "F", "O", "A"],
examples=["F"],
),
examples=[["A"]],
items=Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
examples=[[Gender.MALE]],
),
Schema(type=OpenAPIType.NULL),
],
examples=["M", ["M", "O"]],
examples=[Gender.MALE, [Gender.MALE, Gender.OTHER]],
)
assert not gender.required

Expand All @@ -136,6 +131,18 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
assert cookie_value.required
assert cookie_value.schema.examples

assert lucky_number.param_in == ParamType.QUERY
assert lucky_number.name == "lucky_number"
assert is_schema_value(lucky_number.schema)
assert lucky_number.schema == Schema(
one_of=[
Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_LuckyNumber"),
Schema(type=OpenAPIType.NULL),
],
examples=[LuckyNumber.SEVEN],
)
assert not lucky_number.required


def test_deduplication_for_param_where_key_and_type_are_equal() -> None:
class BaseDep:
Expand Down Expand Up @@ -397,8 +404,8 @@ async def handler(
app = Litestar([handler])
assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr]
Schema(type=OpenAPIType.NULL),
Schema(type=OpenAPIType.STRING),
Schema(type=OpenAPIType.NULL),
]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert (
Expand Down
Loading

0 comments on commit 35a9837

Please sign in to comment.