From 35a9837aa70d13b447d1d14e0c74f86f1c11a4d9 Mon Sep 17 00:00:00 2001 From: "Stanislav Lyu." Date: Fri, 29 Nov 2024 18:11:12 +0300 Subject: [PATCH 1/2] fix: Enum OAS generation (#3518) (#3525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: svlyubovsk Co-authored-by: Alc-Alc <45509143+Alc-Alc@users.noreply.github.com> Co-authored-by: Janek Nouvertné Co-authored-by: Alc-Alc --- .../openapi/plugins/swagger_ui_config.py | 2 +- litestar/_openapi/schema_generation/schema.py | 64 ++++++---- litestar/_openapi/schema_generation/utils.py | 49 +------- litestar/openapi/controller.py | 2 +- litestar/openapi/plugins.py | 2 +- litestar/typing.py | 5 + .../test_piccolo_orm/test_piccolo_orm_dto.py | 23 ++-- tests/unit/test_openapi/conftest.py | 5 +- tests/unit/test_openapi/test_endpoints.py | 2 +- tests/unit/test_openapi/test_parameters.py | 41 ++++--- tests/unit/test_openapi/test_schema.py | 113 +++++++++++++----- .../unit/test_openapi/test_spec_generation.py | 12 +- .../test_converter.py | 1 + tests/unit/test_openapi/utils.py | 5 + .../test_attrs/test_schema_plugin.py | 2 +- .../test_attrs/test_schema_spec_generation.py | 4 +- .../test_pydantic/test_openapi.py | 4 +- .../test_pydantic/test_schema_plugin.py | 2 +- 18 files changed, 192 insertions(+), 146 deletions(-) diff --git a/docs/examples/openapi/plugins/swagger_ui_config.py b/docs/examples/openapi/plugins/swagger_ui_config.py index 94abdf049a..7280f24e0a 100644 --- a/docs/examples/openapi/plugins/swagger_ui_config.py +++ b/docs/examples/openapi/plugins/swagger_ui_config.py @@ -1,3 +1,3 @@ from litestar.openapi.plugins import SwaggerRenderPlugin -swagger_plugin = SwaggerRenderPlugin(version="5.1.3", path="/swagger") +swagger_plugin = SwaggerRenderPlugin(version="5.18.2", path="/swagger") diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index 743f30afcd..9464599108 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -4,7 +4,7 @@ from copy import copy from datetime import date, datetime, time, timedelta from decimal import Decimal -from enum import Enum, EnumMeta +from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path from typing import ( @@ -40,9 +40,7 @@ create_string_constrained_field_schema, ) from litestar._openapi.schema_generation.utils import ( - _should_create_enum_schema, _should_create_literal_schema, - _type_or_first_not_none_inner_type, get_json_schema_formatted_examples, ) from litestar.datastructures import SecretBytes, SecretString, UploadFile @@ -181,22 +179,6 @@ def _get_type_schema_name(field_definition: FieldDefinition) -> str: return name -def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Schema: - """Create a schema instance for an enum. - - Args: - annotation: An enum. - include_null: Whether to include null as a possible value. - - Returns: - A schema instance. - """ - enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated] - if include_null and None not in enum_values: - enum_values.append(None) - return Schema(type=_types_in_list(enum_values), enum=enum_values) - - def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]: """Iterate over the flattened arguments of a Literal. @@ -331,18 +313,20 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re result = self.for_type_alias_type(field_definition) elif plugin_for_annotation := self.get_plugin_for(field_definition): result = self.for_plugin(field_definition, plugin_for_annotation) - elif _should_create_enum_schema(field_definition): - annotation = _type_or_first_not_none_inner_type(field_definition) - result = create_enum_schema(annotation, include_null=field_definition.is_optional) elif _should_create_literal_schema(field_definition): annotation = ( make_non_optional_union(field_definition.annotation) if field_definition.is_optional else field_definition.annotation ) - result = create_literal_schema(annotation, include_null=field_definition.is_optional) + result = create_literal_schema( + annotation, + include_null=field_definition.is_optional, + ) elif field_definition.is_optional: result = self.for_optional_field(field_definition) + elif field_definition.is_enum: + result = self.for_enum_field(field_definition) elif field_definition.is_union: result = self.for_union_field(field_definition) elif field_definition.is_type_var: @@ -445,7 +429,7 @@ def for_optional_field(self, field_definition: FieldDefinition) -> Schema: else: result = [schema_or_reference] - return Schema(one_of=[Schema(type=OpenAPIType.NULL), *result]) + return Schema(one_of=[*result, Schema(type=OpenAPIType.NULL)]) def for_union_field(self, field_definition: FieldDefinition) -> Schema: """Create a Schema for a union FieldDefinition. @@ -569,6 +553,38 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) -> # INFO: Removed because it was only for pydantic constrained collections return schema + def for_enum_field( + self, + field_definition: FieldDefinition, + ) -> Schema | Reference: + """Create a schema instance for an enum. + + Args: + field_definition: A signature field instance. + + Returns: + A schema or reference instance. + """ + enum_type: None | OpenAPIType | list[OpenAPIType] = None + if issubclass(field_definition.annotation, Enum): # pragma: no branch + # This method is only called for enums, so this branch is always executed + if issubclass(field_definition.annotation, str): # StrEnum + enum_type = OpenAPIType.STRING + elif issubclass(field_definition.annotation, int): # IntEnum + enum_type = OpenAPIType.INTEGER + + enum_values: list[Any] = [v.value for v in field_definition.annotation] + if enum_type is None: + enum_type = _types_in_list(enum_values) + + schema = self.schema_registry.get_schema_for_field_definition(field_definition) + schema.type = enum_type + schema.enum = enum_values + schema.title = get_name(field_definition.annotation) + schema.description = field_definition.annotation.__doc__ + + return self.schema_registry.get_reference_for_field_definition(field_definition) or schema + def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference: if field.kwarg_definition and field.is_const and field.has_default and schema.const is None: schema.const = field.default diff --git a/litestar/_openapi/schema_generation/utils.py b/litestar/_openapi/schema_generation/utils.py index cfcb976c82..175a519aea 100644 --- a/litestar/_openapi/schema_generation/utils.py +++ b/litestar/_openapi/schema_generation/utils.py @@ -1,6 +1,5 @@ from __future__ import annotations -from enum import Enum from typing import TYPE_CHECKING, Any, Mapping from litestar.utils.helpers import get_name @@ -11,53 +10,7 @@ from litestar.openapi.spec import Example from litestar.typing import FieldDefinition -__all__ = ( - "_should_create_enum_schema", - "_should_create_literal_schema", - "_type_or_first_not_none_inner_type", -) - - -def _type_or_first_not_none_inner_type(field_definition: FieldDefinition) -> Any: - """Get the first inner type that is not None. - - This is a narrow focussed utility to be used when we know that a field definition either represents - a single type, or a single type in a union with `None`, and we want the single type. - - Args: - field_definition: A field definition instance. - - Returns: - A field definition instance. - """ - if not field_definition.is_optional: - return field_definition.annotation - inner = next((t for t in field_definition.inner_types if not t.is_none_type), None) - if inner is None: - raise ValueError("Field definition has no inner type that is not None") - return inner.annotation - - -def _should_create_enum_schema(field_definition: FieldDefinition) -> bool: - """Predicate to determine if we should create an enum schema for the field def, or not. - - This returns true if the field definition is an enum, or if the field definition is a union - of an enum and ``None``. - - When an annotation is ``SomeEnum | None`` we should create a schema for the enum that includes ``null`` - in the enum values. - - Args: - field_definition: A field definition instance. - - Returns: - A boolean - """ - return field_definition.is_subclass_of(Enum) or ( - field_definition.is_optional - and len(field_definition.args) == 2 - and field_definition.has_inner_subclass_of(Enum) - ) +__all__ = ("_should_create_literal_schema",) def _should_create_literal_schema(field_definition: FieldDefinition) -> bool: diff --git a/litestar/openapi/controller.py b/litestar/openapi/controller.py index 61f1148b1d..ca5c7e56ed 100644 --- a/litestar/openapi/controller.py +++ b/litestar/openapi/controller.py @@ -36,7 +36,7 @@ class OpenAPIController(Controller): """Base styling of the html body.""" redoc_version: str = "next" """Redoc version to download from the CDN.""" - swagger_ui_version: str = "5.1.3" + swagger_ui_version: str = "5.18.2" """SwaggerUI version to download from the CDN.""" stoplight_elements_version: str = "7.7.18" """StopLight Elements version to download from the CDN.""" diff --git a/litestar/openapi/plugins.py b/litestar/openapi/plugins.py index d3503c877e..a006381737 100644 --- a/litestar/openapi/plugins.py +++ b/litestar/openapi/plugins.py @@ -499,7 +499,7 @@ class SwaggerRenderPlugin(OpenAPIRenderPlugin): def __init__( self, - version: str = "5.1.3", + version: str = "5.18.2", js_url: str | None = None, css_url: str | None = None, standalone_preset_js_url: str | None = None, diff --git a/litestar/typing.py b/litestar/typing.py index 191c76b5a9..37dec75825 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -5,6 +5,7 @@ from collections import abc from copy import deepcopy from dataclasses import dataclass, is_dataclass, replace +from enum import Enum from inspect import Parameter, Signature from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, TypeVar, cast @@ -339,6 +340,10 @@ def is_typeddict_type(self) -> bool: return is_typeddict(self.origin or self.annotation) + @property + def is_enum(self) -> bool: + return self.is_subclass_of(Enum) + @property def type_(self) -> Any: """The type of the annotation with all the wrappers removed, including the generic types.""" diff --git a/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py b/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py index c382ae1ba4..7363573673 100644 --- a/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py +++ b/tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py @@ -139,9 +139,14 @@ def test_piccolo_dto_openapi_spec_generation() -> None: assert concert_schema assert concert_schema.to_schema() == { "properties": { - "band_1": {"oneOf": [{"type": "null"}, {"type": "integer"}]}, - "band_2": {"oneOf": [{"type": "null"}, {"type": "integer"}]}, - "venue": {"oneOf": [{"type": "null"}, {"type": "integer"}]}, + "band_1": {"oneOf": [{"type": "integer"}, {"type": "null"}]}, + "band_2": { + "oneOf": [ + {"type": "integer"}, + {"type": "null"}, + ] + }, + "venue": {"oneOf": [{"type": "integer"}, {"type": "null"}]}, }, "required": [], "title": "CreateConcertConcertRequestBody", @@ -152,10 +157,10 @@ def test_piccolo_dto_openapi_spec_generation() -> None: assert record_studio_schema assert record_studio_schema.to_schema() == { "properties": { - "facilities": {"oneOf": [{"type": "null"}, {"type": "string"}]}, - "facilities_b": {"oneOf": [{"type": "null"}, {"type": "string"}]}, - "microphones": {"oneOf": [{"type": "null"}, {"items": {"type": "string"}, "type": "array"}]}, - "id": {"oneOf": [{"type": "null"}, {"type": "integer"}]}, + "facilities": {"oneOf": [{"type": "string"}, {"type": "null"}]}, + "facilities_b": {"oneOf": [{"type": "string"}, {"type": "null"}]}, + "microphones": {"oneOf": [{"items": {"type": "string"}, "type": "array"}, {"type": "null"}]}, + "id": {"oneOf": [{"type": "integer"}, {"type": "null"}]}, }, "required": [], "title": "RetrieveStudioRecordingStudioResponseBody", @@ -166,8 +171,8 @@ def test_piccolo_dto_openapi_spec_generation() -> None: assert venue_schema assert venue_schema.to_schema() == { "properties": { - "id": {"oneOf": [{"type": "null"}, {"type": "integer"}]}, - "name": {"oneOf": [{"type": "null"}, {"type": "string"}]}, + "id": {"oneOf": [{"type": "integer"}, {"type": "null"}]}, + "name": {"oneOf": [{"type": "string"}, {"type": "null"}]}, }, "required": [], "title": "RetrieveVenuesVenueResponseBody", diff --git a/tests/unit/test_openapi/conftest.py b/tests/unit/test_openapi/conftest.py index 20dfeb6c7a..6ec46f8d86 100644 --- a/tests/unit/test_openapi/conftest.py +++ b/tests/unit/test_openapi/conftest.py @@ -10,7 +10,7 @@ from litestar.openapi.spec.example import Example from litestar.params import Parameter from tests.models import DataclassPerson, DataclassPersonFactory, DataclassPet -from tests.unit.test_openapi.utils import Gender, PetException +from tests.unit.test_openapi.utils import Gender, LuckyNumber, PetException class PartialDataclassPersonDTO(DataclassDTO[DataclassPerson]): @@ -45,8 +45,9 @@ def get_persons( from_date: Optional[Union[int, datetime, date]] = None, to_date: Optional[Union[int, datetime, date]] = None, gender: Optional[Union[Gender, List[Gender]]] = Parameter( - examples=[Example(value="M"), Example(value=["M", "O"])] + examples=[Example(value=Gender.MALE), Example(value=[Gender.MALE, Gender.OTHER])] ), + lucky_number: Optional[LuckyNumber] = Parameter(examples=[Example(value=LuckyNumber.SEVEN)]), # header parameter secret_header: str = Parameter(header="secret"), # cookie parameter diff --git a/tests/unit/test_openapi/test_endpoints.py b/tests/unit/test_openapi/test_endpoints.py index 7ad694fb70..6e0230a7e3 100644 --- a/tests/unit/test_openapi/test_endpoints.py +++ b/tests/unit/test_openapi/test_endpoints.py @@ -39,7 +39,7 @@ def test_default_redoc_cdn_urls( def test_default_swagger_ui_cdn_urls( person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig ) -> None: - default_swagger_ui_version = "5.1.3" + default_swagger_ui_version = "5.18.2" default_swagger_bundles = [ f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui.css", f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui-bundle.js", diff --git a/tests/unit/test_openapi/test_parameters.py b/tests/unit/test_openapi/test_parameters.py index 08505d0672..8ef236e12b 100644 --- a/tests/unit/test_openapi/test_parameters.py +++ b/tests/unit/test_openapi/test_parameters.py @@ -15,12 +15,13 @@ from litestar.exceptions import ImproperlyConfiguredException from litestar.handlers import HTTPRouteHandler from litestar.openapi import OpenAPIConfig -from litestar.openapi.spec import Example, OpenAPI, Schema +from litestar.openapi.spec import Example, OpenAPI, Reference, Schema from litestar.openapi.spec.enums import OpenAPIType from litestar.params import Dependency, Parameter from litestar.routes import BaseRoute from litestar.testing import create_test_client from litestar.utils import find_index +from tests.unit.test_openapi.utils import Gender, LuckyNumber if TYPE_CHECKING: from litestar.openapi.spec.parameter import Parameter as OpenAPIParameter @@ -49,8 +50,10 @@ def test_create_parameters(person_controller: Type[Controller]) -> None: ExampleFactory.seed_random(10) parameters = _create_parameters(app=Litestar(route_handlers=[person_controller]), path="/{service_id}/person") - assert len(parameters) == 9 - page, name, service_id, page_size, from_date, to_date, gender, secret_header, cookie_value = tuple(parameters) + assert len(parameters) == 10 + page, name, service_id, page_size, from_date, to_date, gender, lucky_number, secret_header, cookie_value = tuple( + parameters + ) assert service_id.name == "service_id" assert service_id.param_in == ParamType.PATH @@ -104,23 +107,15 @@ def test_create_parameters(person_controller: Type[Controller]) -> None: assert is_schema_value(gender.schema) assert gender.schema == Schema( one_of=[ - Schema(type=OpenAPIType.NULL), - Schema( - type=OpenAPIType.STRING, - enum=["M", "F", "O", "A"], - examples=["M"], - ), + Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"), Schema( type=OpenAPIType.ARRAY, - items=Schema( - type=OpenAPIType.STRING, - enum=["M", "F", "O", "A"], - examples=["F"], - ), - examples=[["A"]], + items=Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"), + examples=[[Gender.MALE]], ), + Schema(type=OpenAPIType.NULL), ], - examples=["M", ["M", "O"]], + examples=[Gender.MALE, [Gender.MALE, Gender.OTHER]], ) assert not gender.required @@ -136,6 +131,18 @@ def test_create_parameters(person_controller: Type[Controller]) -> None: assert cookie_value.required assert cookie_value.schema.examples + assert lucky_number.param_in == ParamType.QUERY + assert lucky_number.name == "lucky_number" + assert is_schema_value(lucky_number.schema) + assert lucky_number.schema == Schema( + one_of=[ + Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_LuckyNumber"), + Schema(type=OpenAPIType.NULL), + ], + examples=[LuckyNumber.SEVEN], + ) + assert not lucky_number.required + def test_deduplication_for_param_where_key_and_type_are_equal() -> None: class BaseDep: @@ -397,8 +404,8 @@ async def handler( app = Litestar([handler]) assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr] assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr] - Schema(type=OpenAPIType.NULL), Schema(type=OpenAPIType.STRING), + Schema(type=OpenAPIType.NULL), ] assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr] assert ( diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index 3e15b9f51d..4606687027 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from datetime import date, datetime, timezone from enum import Enum, auto -from typing import ( # type: ignore[attr-defined] +from typing import ( TYPE_CHECKING, Any, Dict, @@ -13,8 +13,7 @@ Tuple, TypedDict, TypeVar, - Union, - _GenericAlias, # pyright: ignore + Union, # pyright: ignore ) import annotated_types @@ -29,7 +28,6 @@ KWARG_DEFINITION_ATTRIBUTE_TO_OPENAPI_PROPERTY_MAP, SchemaCreator, ) -from litestar._openapi.schema_generation.utils import _type_or_first_not_none_inner_type from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar from litestar.di import Provide from litestar.enums import ParamType @@ -41,7 +39,6 @@ from litestar.pagination import ClassicPagination, CursorPagination, OffsetPagination from litestar.params import KwargDefinition, Parameter, ParameterKwarg from litestar.testing import create_test_client -from litestar.types.builtin_types import NoneType from litestar.typing import FieldDefinition from litestar.utils.helpers import get_name from tests.helpers import get_schema_for_field_definition @@ -392,7 +389,7 @@ class TypedDictGeneric(TypedDict, Generic[T]): @pytest.mark.parametrize("cls", annotations) def test_schema_generation_with_generic_classes(cls: Any) -> None: expected_foo_schema = Schema(type=OpenAPIType.INTEGER) - expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.NULL), Schema(type=OpenAPIType.INTEGER)]) + expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.INTEGER), Schema(type=OpenAPIType.NULL)]) properties = get_schema_for_field_definition( FieldDefinition.from_kwarg(name=get_name(cls), annotation=cls) @@ -446,7 +443,7 @@ def test_schema_generation_with_generic_classes_constrained() -> None: ) def test_schema_generation_with_pagination(annotation: Any) -> None: expected_foo_schema = Schema(type=OpenAPIType.INTEGER) - expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.NULL), Schema(type=OpenAPIType.INTEGER)]) + expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.INTEGER), Schema(type=OpenAPIType.NULL)]) properties = get_schema_for_field_definition(FieldDefinition.from_annotation(annotation).inner_types[-1]).properties @@ -473,13 +470,87 @@ def test_schema_tuple_with_union() -> None: def test_optional_enum() -> None: class Foo(Enum): + A = 1 + B = "b" + + creator = SchemaCreator(plugins=openapi_schema_plugins) + schema = creator.for_field_definition(FieldDefinition.from_annotation(Optional[Foo])) + assert isinstance(schema, Schema) + assert schema.type is None + assert schema.one_of is not None + null_schema = schema.one_of[1] + assert isinstance(null_schema, Schema) + assert null_schema.type is not None + assert null_schema.type is OpenAPIType.NULL + enum_ref = schema.one_of[0] + assert isinstance(enum_ref, Reference) + assert enum_ref.ref == "#/components/schemas/tests_unit_test_openapi_test_schema_test_optional_enum.Foo" + enum_schema = creator.schema_registry.from_reference(enum_ref).schema + assert enum_schema.type + assert set(enum_schema.type) == {OpenAPIType.INTEGER, OpenAPIType.STRING} + assert enum_schema.enum + assert enum_schema.enum[0] == 1 + assert enum_schema.enum[1] == "b" + + +def test_optional_str_specified_enum() -> None: + class StringEnum(str, Enum): + A = "a" + B = "b" + + creator = SchemaCreator(plugins=openapi_schema_plugins) + schema = creator.for_field_definition(FieldDefinition.from_annotation(Optional[StringEnum])) + assert isinstance(schema, Schema) + assert schema.type is None + assert schema.one_of is not None + + enum_ref = schema.one_of[0] + assert isinstance(enum_ref, Reference) + assert ( + enum_ref.ref + == "#/components/schemas/tests_unit_test_openapi_test_schema_test_optional_str_specified_enum.StringEnum" + ) + enum_schema = creator.schema_registry.from_reference(enum_ref).schema + assert enum_schema.type + assert enum_schema.type == OpenAPIType.STRING + assert enum_schema.enum + assert enum_schema.enum[0] == "a" + assert enum_schema.enum[1] == "b" + + null_schema = schema.one_of[1] + assert isinstance(null_schema, Schema) + assert null_schema.type is not None + assert null_schema.type is OpenAPIType.NULL + + +def test_optional_int_specified_enum() -> None: + class IntEnum(int, Enum): A = 1 B = 2 - schema = get_schema_for_field_definition(FieldDefinition.from_annotation(Optional[Foo])) - assert schema.type is not None - assert set(schema.type) == {OpenAPIType.INTEGER, OpenAPIType.NULL} - assert schema.enum == [1, 2, None] + creator = SchemaCreator(plugins=openapi_schema_plugins) + schema = creator.for_field_definition(FieldDefinition.from_annotation(Optional[IntEnum])) + assert isinstance(schema, Schema) + assert schema.type is None + assert schema.one_of is not None + + enum_ref = schema.one_of[0] + assert isinstance(enum_ref, Reference) + assert ( + enum_ref.ref + == "#/components/schemas/tests_unit_test_openapi_test_schema_test_optional_int_specified_enum.IntEnum" + ) + enum_schema = creator.schema_registry.from_reference(enum_ref).schema + assert enum_schema.type + assert enum_schema.type == OpenAPIType.INTEGER + assert enum_schema.enum + assert enum_schema.enum[0] == 1 + assert enum_schema.enum[1] == 2 + + null_schema = schema.one_of[1] + assert isinstance(null_schema, Schema) + assert null_schema.type is not None + assert null_schema.type is OpenAPIType.NULL def test_optional_literal() -> None: @@ -489,24 +560,6 @@ def test_optional_literal() -> None: assert schema.enum == [1, None] -@pytest.mark.parametrize( - ("in_type", "out_type"), - [ - (FieldDefinition.from_annotation(Optional[int]), int), - (FieldDefinition.from_annotation(Union[None, int]), int), - (FieldDefinition.from_annotation(int), int), - # hack to create a union of NoneType, NoneType to hit a branch for coverage - (FieldDefinition.from_annotation(_GenericAlias(Union, (NoneType, NoneType))), ValueError), - ], -) -def test_type_or_first_not_none_inner_type_utility(in_type: Any, out_type: Any) -> None: - if out_type is ValueError: - with pytest.raises(out_type): - _type_or_first_not_none_inner_type(in_type) - else: - assert _type_or_first_not_none_inner_type(in_type) == out_type - - def test_not_generating_examples_property() -> None: with_examples = SchemaCreator(generate_examples=True) without_examples = with_examples.not_generating_examples @@ -576,9 +629,9 @@ class ModelB(base_type): # type: ignore[no-redef, misc] FieldDefinition.from_kwarg(name="Lookup", annotation=Union[ModelA, ModelB, None]) ) assert schema.one_of == [ - Schema(type=OpenAPIType.NULL), Reference(ref="#/components/schemas/tests_unit_test_openapi_test_schema_test_type_union_with_none.ModelA"), Reference("#/components/schemas/tests_unit_test_openapi_test_schema_test_type_union_with_none.ModelB"), + Schema(type=OpenAPIType.NULL), ] diff --git a/tests/unit/test_openapi/test_spec_generation.py b/tests/unit/test_openapi/test_spec_generation.py index f64bd3f569..9601e7a93a 100644 --- a/tests/unit/test_openapi/test_spec_generation.py +++ b/tests/unit/test_openapi/test_spec_generation.py @@ -27,7 +27,7 @@ def handler(data: cls) -> cls: "first_name": {"type": "string"}, "last_name": {"type": "string"}, "id": {"type": "string"}, - "optional": {"oneOf": [{"type": "null"}, {"type": "string"}]}, + "optional": {"oneOf": [{"type": "string"}, {"type": "null"}]}, "complex": { "type": "object", "additionalProperties": { @@ -37,11 +37,11 @@ def handler(data: cls) -> cls: }, "pets": { "oneOf": [ - {"type": "null"}, { "items": {"$ref": "#/components/schemas/DataclassPet"}, "type": "array", }, + {"type": "null"}, ] }, }, @@ -189,8 +189,8 @@ def test_recursive_schema_generation( "properties": { "a": {"$ref": "#/components/schemas/A"}, "b": {"$ref": "#/components/schemas/B"}, - "opt_a": {"oneOf": [{"type": "null"}, {"$ref": "#/components/schemas/A"}]}, - "opt_b": {"oneOf": [{"type": "null"}, {"$ref": "#/components/schemas/B"}]}, + "opt_a": {"oneOf": [{"$ref": "#/components/schemas/A"}, {"type": "null"}]}, + "opt_b": {"oneOf": [{"$ref": "#/components/schemas/B"}, {"type": "null"}]}, "list_a": {"items": {"$ref": "#/components/schemas/A"}, "type": "array"}, "list_b": {"items": {"$ref": "#/components/schemas/B"}, "type": "array"}, }, @@ -202,8 +202,8 @@ def test_recursive_schema_generation( "properties": { "a": {"$ref": "#/components/schemas/A"}, "b": {"$ref": "#/components/schemas/B"}, - "opt_a": {"oneOf": [{"type": "null"}, {"$ref": "#/components/schemas/A"}]}, - "opt_b": {"oneOf": [{"type": "null"}, {"$ref": "#/components/schemas/B"}]}, + "opt_a": {"oneOf": [{"$ref": "#/components/schemas/A"}, {"type": "null"}]}, + "opt_b": {"oneOf": [{"$ref": "#/components/schemas/B"}, {"type": "null"}]}, "list_a": {"items": {"$ref": "#/components/schemas/A"}, "type": "array"}, "list_b": {"items": {"$ref": "#/components/schemas/B"}, "type": "array"}, }, diff --git a/tests/unit/test_openapi/test_typescript_converter/test_converter.py b/tests/unit/test_openapi/test_typescript_converter/test_converter.py index 0241f7152d..eb4d42054d 100644 --- a/tests/unit/test_openapi/test_typescript_converter/test_converter.py +++ b/tests/unit/test_openapi/test_typescript_converter/test_converter.py @@ -334,6 +334,7 @@ def test_openapi_to_typescript_converter(person_controller: Type[Controller], pe export interface QueryParameters { from_date?: null | number | string | string; gender?: "A" | "F" | "M" | "O" | ("A" | "F" | "M" | "O")[] | null; + lucky_number?: 2 | 7 | null; name?: null | string | string[]; page: number; pageSize: number; diff --git a/tests/unit/test_openapi/utils.py b/tests/unit/test_openapi/utils.py index 5190870795..a368c4f0fd 100644 --- a/tests/unit/test_openapi/utils.py +++ b/tests/unit/test_openapi/utils.py @@ -12,3 +12,8 @@ class Gender(str, Enum): FEMALE = "F" OTHER = "O" ANY = "A" + + +class LuckyNumber(int, Enum): + TWO = 2 + SEVEN = 7 diff --git a/tests/unit/test_plugins/test_attrs/test_schema_plugin.py b/tests/unit/test_plugins/test_attrs/test_schema_plugin.py index 2d48aa128a..d73aa030c5 100644 --- a/tests/unit/test_plugins/test_attrs/test_schema_plugin.py +++ b/tests/unit/test_plugins/test_attrs/test_schema_plugin.py @@ -23,7 +23,7 @@ class AttrsGeneric(Generic[T]): def test_schema_generation_with_generic_classes() -> None: cls = AttrsGeneric[int] expected_foo_schema = Schema(type=OpenAPIType.INTEGER) - expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.NULL), Schema(type=OpenAPIType.INTEGER)]) + expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.INTEGER), Schema(type=OpenAPIType.NULL)]) field_definition = FieldDefinition.from_kwarg(name=get_name(cls), annotation=cls) properties = get_schema_for_field_definition(field_definition, plugins=[AttrsSchemaPlugin()]).properties diff --git a/tests/unit/test_plugins/test_attrs/test_schema_spec_generation.py b/tests/unit/test_plugins/test_attrs/test_schema_spec_generation.py index fb7a229ae1..341f2e8fac 100644 --- a/tests/unit/test_plugins/test_attrs/test_schema_spec_generation.py +++ b/tests/unit/test_plugins/test_attrs/test_schema_spec_generation.py @@ -29,7 +29,7 @@ def handler(data: Person) -> Person: "first_name": {"type": "string"}, "last_name": {"type": "string"}, "id": {"type": "string"}, - "optional": {"oneOf": [{"type": "null"}, {"type": "string"}]}, + "optional": {"oneOf": [{"type": "string"}, {"type": "null"}]}, "complex": { "type": "object", "additionalProperties": { @@ -39,11 +39,11 @@ def handler(data: Person) -> Person: }, "pets": { "oneOf": [ - {"type": "null"}, { "items": {"$ref": "#/components/schemas/DataclassPet"}, "type": "array", }, + {"type": "null"}, ] }, }, diff --git a/tests/unit/test_plugins/test_pydantic/test_openapi.py b/tests/unit/test_plugins/test_pydantic/test_openapi.py index a7e89ef201..01e36fb21b 100644 --- a/tests/unit/test_plugins/test_pydantic/test_openapi.py +++ b/tests/unit/test_plugins/test_pydantic/test_openapi.py @@ -465,7 +465,7 @@ def handler(data: cls) -> cls: "first_name": {"type": "string"}, "last_name": {"type": "string"}, "id": {"type": "string"}, - "optional": {"oneOf": [{"type": "null"}, {"type": "string"}]}, + "optional": {"oneOf": [{"type": "string"}, {"type": "null"}]}, "complex": { "type": "object", "additionalProperties": { @@ -476,11 +476,11 @@ def handler(data: cls) -> cls: "union": {"oneOf": [{"type": "integer"}, {"items": {"type": "string"}, "type": "array"}]}, "pets": { "oneOf": [ - {"type": "null"}, { "items": {"$ref": "#/components/schemas/DataclassPet"}, "type": "array", }, + {"type": "null"}, ] }, }, diff --git a/tests/unit/test_plugins/test_pydantic/test_schema_plugin.py b/tests/unit/test_plugins/test_pydantic/test_schema_plugin.py index bcbf53b7f3..aee70c84a4 100644 --- a/tests/unit/test_plugins/test_pydantic/test_schema_plugin.py +++ b/tests/unit/test_plugins/test_pydantic/test_schema_plugin.py @@ -38,7 +38,7 @@ def test_schema_generation_with_generic_classes(model: Type[Union[PydanticV1Gene field_definition = FieldDefinition.from_kwarg(name=get_name(cls), annotation=cls) properties = get_schema_for_field_definition(field_definition, plugins=[PydanticSchemaPlugin()]).properties expected_foo_schema = Schema(type=OpenAPIType.INTEGER) - expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.NULL), Schema(type=OpenAPIType.INTEGER)]) + expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.INTEGER), Schema(type=OpenAPIType.NULL)]) assert properties assert properties["foo"] == expected_foo_schema From 22d8f18795807fda51da73acfb3441824f1b05c0 Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Fri, 29 Nov 2024 12:52:53 -0600 Subject: [PATCH 2/2] docs: fix advanced alchemy references (#3881) * docs: fix advanced alchemy references * fix: correct additional AA references * fix: references * fix: docs building * fix: linting --- docs/conf.py | 30 +++++-------------- .../plugins/tutorial/full_app_with_plugin.py | 3 +- .../01-modeling-and-features.rst | 1 + .../sqlalchemy/4-final-touches-and-recap.rst | 2 +- .../plugins/sqlalchemy_init_plugin.rst | 28 ++++++++--------- .../sqlalchemy/plugins/sqlalchemy_plugin.rst | 10 +++---- docs/usage/plugins/index.rst | 4 +-- 7 files changed, 32 insertions(+), 46 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index c47306c634..be6d8017ac 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,9 +29,9 @@ extensions = [ "sphinx.ext.intersphinx", - "sphinx.ext.autosectionlabel", "sphinx.ext.autodoc", "sphinx.ext.napoleon", + "sphinx.ext.autosectionlabel", "sphinx_design", "auto_pytabs.sphinx_ext", "tools.sphinx_ext", @@ -48,6 +48,7 @@ "msgspec": ("https://jcristharif.com/msgspec/", None), "anyio": ("https://anyio.readthedocs.io/en/stable/", None), "multidict": ("https://multidict.aio-libs.org/en/stable/", None), + "cryptography": ("https://cryptography.io/en/latest/", None), "sqlalchemy": ("https://docs.sqlalchemy.org/en/20/", None), "alembic": ("https://alembic.sqlalchemy.org/en/latest/", None), "click": ("https://click.palletsprojects.com/en/8.1.x/", None), @@ -60,6 +61,8 @@ "advanced-alchemy": ("https://docs.advanced-alchemy.litestar.dev/latest/", None), "jinja2": ("https://jinja.palletsprojects.com/en/latest/", None), "trio": ("https://trio.readthedocs.io/en/stable/", None), + "pydantic": ("https://docs.pydantic.dev/latest/", None), + "typing_extensions": ("https://typing-extensions.readthedocs.io/en/stable/", None), } napoleon_google_docstring = True @@ -74,6 +77,7 @@ autodoc_default_options = {"special-members": "__init__", "show-inheritance": True, "members": True} autodoc_member_order = "bysource" autodoc_typehints_format = "short" +autodoc_mock_imports = [] nitpicky = True nitpick_ignore = [ @@ -164,27 +168,7 @@ (PY_METH, "litestar.dto.factory.DTOData.create_instance"), (PY_METH, "litestar.dto.interface.DTOInterface.data_to_encodable_type"), (PY_CLASS, "MetaData"), - (PY_CLASS, "advanced_alchemy.repository.typing.ModelT"), - (PY_OBJ, "advanced_alchemy.config.common.SessionMakerT"), - (PY_OBJ, "advanced_alchemy.config.common.ConnectionT"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.plugins._slots_base.SlotsBase"), - (PY_CLASS, "advanced_alchemy.config.EngineConfig"), - (PY_CLASS, "advanced_alchemy.config.common.GenericAlembicConfig"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.SQLAlchemyDTO"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.dto.SQLAlchemyDTO"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.plugins.SQLAlchemyPlugin"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.plugins.SQLAlchemySerializationPlugin"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.plugins.SQLAlchemyInitPlugin"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.config.SQLAlchemySyncConfig"), - (PY_CLASS, "advanced_alchemy.extensions.litestar.config.SQLAlchemyAsyncConfig"), - (PY_METH, "advanced_alchemy.extensions.litestar.plugins.SQLAlchemySerializationPlugin.create_dto_for_type"), - (PY_CLASS, "advanced_alchemy.base.BasicAttributes"), - (PY_CLASS, "advanced_alchemy.config.AsyncSessionConfig"), - (PY_CLASS, "advanced_alchemy.config.SyncSessionConfig"), - (PY_CLASS, "advanced_alchemy.types.JsonB"), - (PY_CLASS, "advanced_alchemy.types.BigIntIdentity"), (PY_FUNC, "sqlalchemy.get_engine"), - (PY_ATTR, "advanced_alchemy.repository.AbstractAsyncRepository.id_attribute"), (PY_OBJ, "litestar.template.base.T_co"), ("py:exc", "RepositoryError"), ("py:exc", "InternalServerError"), @@ -204,6 +188,9 @@ (PY_CLASS, "typing.Self"), (PY_CLASS, "attr.AttrsInstance"), (PY_CLASS, "typing_extensions.TypeGuard"), + (PY_CLASS, "advanced_alchemy.types.BigIntIdentity"), + (PY_CLASS, "advanced_alchemy.types.JsonB"), + (PY_CLASS, "advanced_alchemy.repository.SQLAlchemyAsyncRepository"), ] nitpick_ignore_regex = [ @@ -247,7 +234,6 @@ "litestar.template": {"litestar.template.base.T_co"}, "litestar.openapi.OpenAPIController.security": {"SecurityRequirement"}, "litestar.response.file.async_file_iterator": {"FileSystemAdapter"}, - "advanced_alchemy._listeners.touch_updated_timestamp": {"Session"}, re.compile("litestar.response.redirect.*"): {"RedirectStatusType"}, re.compile(r"litestar\.plugins.*"): re.compile(".*ModelT"), re.compile(r"litestar\.(contrib|repository)\.*"): re.compile(".*T"), diff --git a/docs/examples/contrib/sqlalchemy/plugins/tutorial/full_app_with_plugin.py b/docs/examples/contrib/sqlalchemy/plugins/tutorial/full_app_with_plugin.py index fd55548c1f..aec569996d 100644 --- a/docs/examples/contrib/sqlalchemy/plugins/tutorial/full_app_with_plugin.py +++ b/docs/examples/contrib/sqlalchemy/plugins/tutorial/full_app_with_plugin.py @@ -1,6 +1,5 @@ from typing import AsyncGenerator, List, Optional -from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import autocommit_before_send_handler from sqlalchemy import select from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.ext.asyncio import AsyncSession @@ -74,7 +73,7 @@ async def update_item(item_title: str, data: TodoItem, transaction: AsyncSession connection_string="sqlite+aiosqlite:///todo.sqlite", metadata=Base.metadata, create_all=True, - before_send_handler=autocommit_before_send_handler, + before_send_handler="autocommit", ) app = Litestar( diff --git a/docs/tutorials/repository-tutorial/01-modeling-and-features.rst b/docs/tutorials/repository-tutorial/01-modeling-and-features.rst index ce66ea2979..849d7cc4f9 100644 --- a/docs/tutorials/repository-tutorial/01-modeling-and-features.rst +++ b/docs/tutorials/repository-tutorial/01-modeling-and-features.rst @@ -74,6 +74,7 @@ Additional features provided by the built-in base models include: reverts to an ``Integer`` for unsupported variants. - A custom :class:`JsonB ` type that uses native ``JSONB`` where possible and ``Binary`` or ``Blob`` as an alternative. +- A custom :class:`EncryptedString ` encrypted string that supports multiple cryptography backends. Let's build on this as we look at the repository classes. diff --git a/docs/tutorials/sqlalchemy/4-final-touches-and-recap.rst b/docs/tutorials/sqlalchemy/4-final-touches-and-recap.rst index 9b70b8f776..a619dff114 100644 --- a/docs/tutorials/sqlalchemy/4-final-touches-and-recap.rst +++ b/docs/tutorials/sqlalchemy/4-final-touches-and-recap.rst @@ -59,7 +59,7 @@ engine and session lifecycle, and register our ``transaction`` dependency. .. literalinclude:: /examples/contrib/sqlalchemy/plugins/tutorial/full_app_with_plugin.py :language: python :linenos: - :lines: 80-84 + :lines: 80-83 .. seealso:: diff --git a/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst b/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst index 2ebd069f6a..581e39fe53 100644 --- a/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst +++ b/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst @@ -1,7 +1,7 @@ SQLAlchemy Init Plugin ---------------------- -The :class:`SQLAlchemyInitPlugin ` adds functionality to the +The :class:`SQLAlchemyInitPlugin ` adds functionality to the application that supports using Litestar with `SQLAlchemy `_. The plugin: @@ -39,8 +39,8 @@ Renaming the dependencies ######################### You can change the name that the engine and session are bound to by setting the -:attr:`engine_dependency_key ` -and :attr:`session_dependency_key ` +:attr:`engine_dependency_key ` +and :attr:`session_dependency_key ` attributes on the plugin configuration. Configuring the before send handler @@ -50,7 +50,7 @@ The plugin configures a ``before_send`` handler that is called before sending a session and removes it from the connection scope. You can change the handler by setting the -:attr:`before_send_handler ` +:attr:`before_send_handler ` attribute on the configuration object. For example, an alternate handler is available that will also commit the session on success and rollback upon failure. @@ -73,21 +73,21 @@ on success and rollback upon failure. Configuring the plugins ####################### -Both the :class:`SQLAlchemyAsyncConfig ` and the -:class:`SQLAlchemySyncConfig ` have an ``engine_config`` +Both the :class:`SQLAlchemyAsyncConfig ` and the +:class:`SQLAlchemySyncConfig ` have an ``engine_config`` attribute that is used to configure the engine. The ``engine_config`` attribute is an instance of -:class:`EngineConfig ` and exposes all of the configuration options +:class:`EngineConfig ` and exposes all of the configuration options available to the SQLAlchemy engine. -The :class:`SQLAlchemyAsyncConfig ` class and the -:class:`SQLAlchemySyncConfig ` class also have a +The :class:`SQLAlchemyAsyncConfig ` class and the +:class:`SQLAlchemySyncConfig ` class also have a ``session_config`` attribute that is used to configure the session. This is either an instance of -:class:`AsyncSessionConfig ` or -:class:`SyncSessionConfig ` depending on the type of config +:class:`AsyncSessionConfig ` or +:class:`SyncSessionConfig ` depending on the type of config object. These classes expose all of the configuration options available to the SQLAlchemy session. -Finally, the :class:`SQLAlchemyAsyncConfig ` class and the -:class:`SQLAlchemySyncConfig ` class expose configuration +Finally, the :class:`SQLAlchemyAsyncConfig ` class and the +:class:`SQLAlchemySyncConfig ` class expose configuration options to control their behavior. Consult the reference documentation for more information. @@ -98,7 +98,7 @@ Example The below example is a complete demonstration of use of the init plugin. Readers who are familiar with the prior section may note the additional complexity involved in managing the conversion to and from SQLAlchemy objects within the handlers. Read on to see how this increased complexity is efficiently handled by the -:class:`SQLAlchemySerializationPlugin `. +:class:`SQLAlchemySerializationPlugin `. .. tab-set:: diff --git a/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_plugin.rst b/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_plugin.rst index 8b0a702a85..c28b5e654b 100644 --- a/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_plugin.rst +++ b/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_plugin.rst @@ -1,18 +1,18 @@ SQLAlchemy Plugin ----------------- -The :class:`SQLAlchemyPlugin ` provides complete support for +The :class:`SQLAlchemyPlugin ` provides complete support for working with `SQLAlchemy `_ in Litestar applications. .. note:: This plugin is only compatible with SQLAlchemy 2.0+. -The :class:`SQLAlchemyPlugin ` combines the functionality of -:class:`SQLAlchemyInitPlugin ` and -:class:`SQLAlchemySerializationPlugin `, each of +The :class:`SQLAlchemyPlugin ` combines the functionality of +:class:`SQLAlchemyInitPlugin ` and +:class:`SQLAlchemySerializationPlugin `, each of which are examined in detail in the following sections. As such, this section describes a complete example of using the -:class:`SQLAlchemyPlugin ` with a Litestar application and a +:class:`SQLAlchemyPlugin ` with a Litestar application and a SQLite database. Or, skip ahead to :doc:`/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin` or diff --git a/docs/usage/plugins/index.rst b/docs/usage/plugins/index.rst index 13076750de..5eb27c9d32 100644 --- a/docs/usage/plugins/index.rst +++ b/docs/usage/plugins/index.rst @@ -91,11 +91,11 @@ The following example shows the actual implementation of the ``SerializationPlug :language: python :caption: ``SerializationPluginProtocol`` implementation example -:meth:`supports_type(self, field_definition: FieldDefinition) -> bool: ` +:meth:`supports_type(self, field_definition: FieldDefinition) -> bool: ` returns a :class:`bool` indicating whether the plugin supports serialization for the given type. Specifically, we return ``True`` if the parsed type is either a collection of SQLAlchemy models or a single SQLAlchemy model. -:meth:`create_dto_for_type(self, field_definition: FieldDefinition) -> type[AbstractDTO]: ` +:meth:`create_dto_for_type(self, field_definition: FieldDefinition) -> type[AbstractDTO]: ` takes a :class:`FieldDefinition ` instance as an argument and returns a :class:`SQLAlchemyDTO ` subclass and includes some logic that may be interesting to potential serialization plugin authors.