diff --git a/litestar/_openapi/datastructures.py b/litestar/_openapi/datastructures.py index d97c8db405..bc55be74f5 100644 --- a/litestar/_openapi/datastructures.py +++ b/litestar/_openapi/datastructures.py @@ -1,14 +1,67 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, Iterator, Sequence +from typing import TYPE_CHECKING, Iterator, Sequence, _GenericAlias # type: ignore[attr-defined] from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.spec import Reference, Schema +from litestar.params import KwargDefinition if TYPE_CHECKING: from litestar.openapi import OpenAPIConfig from litestar.plugins import OpenAPISchemaPluginProtocol + from litestar.typing import FieldDefinition + + +def _longest_common_prefix(tuples_: list[tuple[str, ...]]) -> tuple[str, ...]: + """Find the longest common prefix of a list of tuples. + + Args: + tuples_: A list of tuples to find the longest common prefix of. + + Returns: + The longest common prefix of the tuples. + """ + prefix_ = tuples_[0] + for t in tuples_: + # Compare the current prefix with each tuple and shorten it + prefix_ = prefix_[: min(len(prefix_), len(t))] + for i in range(len(prefix_)): + if prefix_[i] != t[i]: + prefix_ = prefix_[:i] + break + return prefix_ + + +def _get_component_key_override(field: FieldDefinition) -> str | None: + if ( + (kwarg_definition := field.kwarg_definition) + and isinstance(kwarg_definition, KwargDefinition) + and (schema_key := kwarg_definition.schema_component_key) + ): + return schema_key + return None + + +def _get_normalized_schema_key(field_definition: FieldDefinition) -> tuple[str, ...]: + """Create a key for a type annotation. + + The key should be a tuple such as ``("path", "to", "type", "TypeName")``. + + Args: + field_definition: Field definition + + Returns: + A tuple of strings. + """ + if override := _get_component_key_override(field_definition): + return (override,) + + annotation = field_definition.annotation + module = getattr(annotation, "__module__", "") + name = str(annotation)[len(module) + 1 :] if isinstance(annotation, _GenericAlias) else annotation.__qualname__ + name = name.replace("..", ".") + return *module.split("."), name class RegisteredSchema: @@ -43,32 +96,63 @@ def __init__(self) -> None: self._schema_key_map: dict[tuple[str, ...], RegisteredSchema] = {} self._schema_reference_map: dict[int, RegisteredSchema] = {} self._model_name_groups: defaultdict[str, list[RegisteredSchema]] = defaultdict(list) + self._component_type_map: dict[tuple[str, ...], FieldDefinition] = {} - def get_schema_for_key(self, key: tuple[str, ...]) -> Schema: + def get_schema_for_field_definition(self, field: FieldDefinition) -> Schema: """Get a registered schema by its key. Args: - key: The key to the schema to get. + field: The field definition to get the schema for Returns: A RegisteredSchema object. """ + key = _get_normalized_schema_key(field) if key not in self._schema_key_map: self._schema_key_map[key] = registered_schema = RegisteredSchema(key, Schema(), []) self._model_name_groups[key[-1]].append(registered_schema) + self._component_type_map[key] = field + else: + if (existing_type := self._component_type_map[key]) != field: + raise ImproperlyConfiguredException( + f"Schema component keys must be unique. Cannot override existing key {'_'.join(key)!r} for type " + f"{existing_type.raw!r} with new type {field.raw!r}" + ) return self._schema_key_map[key].schema - def get_reference_for_key(self, key: tuple[str, ...]) -> Reference | None: + def get_reference_for_field_definition(self, field: FieldDefinition) -> Reference | None: """Get a reference to a registered schema by its key. Args: - key: The key to the schema to get. + field: The field definition to get the reference for Returns: A Reference object. """ + key = _get_normalized_schema_key(field) if key not in self._schema_key_map: return None + + if (existing_type := self._component_type_map[key]) != field: + # TODO: This should check for strict equality, e.g. changes in type metadata + # However, this is currently not possible to do without breaking things, as + # we allow to define metadata on a type annotation in one place to be used + # for the same type in a different place, where that same type is *not* + # annotated with this metadata. The proper fix for this would be to e.g. + # inline DTO definitions when they are created at the handler level, as + # they won't be reused (they already generate a unique key), and create a + # more strict lookup policy for component schemas + msg = ( + f"Schema component keys must be unique. While obtaining a reference for the type '{field.raw!r}', the " + f"generated key {'_'.join(key)!r} was already associated with a different type '{existing_type.raw!r}'. " + ) + if key_override := _get_component_key_override(field): # pragma: no branch + # Currently, this can never not be true, however, in the future we might + # decide to do a stricter equality check as lined out above, in which + # case there can be other cases than overrides that cause this error + msg += f"Hint: Both types are defining a 'schema_component_key' with the value of {key_override!r}" + raise ImproperlyConfiguredException(msg) + registered_schema = self._schema_key_map[key] reference = Reference(f"#/components/schemas/{'_'.join(key)}") registered_schema.references.append(reference) @@ -107,26 +191,7 @@ def remove_common_prefix(tuples: list[tuple[str, ...]]) -> list[tuple[str, ...]] A list of tuples with the common prefix removed. """ - def longest_common_prefix(tuples_: list[tuple[str, ...]]) -> tuple[str, ...]: - """Find the longest common prefix of a list of tuples. - - Args: - tuples_: A list of tuples to find the longest common prefix of. - - Returns: - The longest common prefix of the tuples. - """ - prefix_ = tuples_[0] - for t in tuples_: - # Compare the current prefix with each tuple and shorten it - prefix_ = prefix_[: min(len(prefix_), len(t))] - for i in range(len(prefix_)): - if prefix_[i] != t[i]: - prefix_ = prefix_[:i] - break - return prefix_ - - prefix = longest_common_prefix(tuples) + prefix = _longest_common_prefix(tuples) prefix_length = len(prefix) return [t[prefix_length:] for t in tuples] diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index 029a3f7bcf..defa0e0717 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -40,7 +40,6 @@ create_string_constrained_field_schema, ) from litestar._openapi.schema_generation.utils import ( - _get_normalized_schema_key, _should_create_enum_schema, _should_create_literal_schema, _type_or_first_not_none_inner_type, @@ -508,8 +507,7 @@ def for_plugin(self, field_definition: FieldDefinition, plugin: OpenAPISchemaPlu Returns: A schema instance. """ - key = _get_normalized_schema_key(field_definition.annotation) - if (ref := self.schema_registry.get_reference_for_key(key)) is not None: + if (ref := self.schema_registry.get_reference_for_field_definition(field_definition)) is not None: return ref schema = plugin.to_openapi_schema(field_definition=field_definition, schema_creator=self) @@ -612,8 +610,7 @@ def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schem schema.examples = get_json_schema_formatted_examples(create_examples_for_field(field)) if schema.title and schema.type == OpenAPIType.OBJECT: - key = _get_normalized_schema_key(field.annotation) - return self.schema_registry.get_reference_for_key(key) or schema + return self.schema_registry.get_reference_for_field_definition(field) or schema return schema def create_component_schema( @@ -644,7 +641,7 @@ def create_component_schema( Returns: A schema instance. """ - schema = self.schema_registry.get_schema_for_key(_get_normalized_schema_key(type_.annotation)) + schema = self.schema_registry.get_schema_for_field_definition(type_) schema.title = title or _get_type_schema_name(type_) schema.required = required schema.type = openapi_type diff --git a/litestar/_openapi/schema_generation/utils.py b/litestar/_openapi/schema_generation/utils.py index 7ce27ca945..832952cb79 100644 --- a/litestar/_openapi/schema_generation/utils.py +++ b/litestar/_openapi/schema_generation/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Any, Mapping, _GenericAlias # type: ignore[attr-defined] +from typing import TYPE_CHECKING, Any, Mapping from litestar.utils.helpers import get_name @@ -15,7 +15,6 @@ "_type_or_first_not_none_inner_type", "_should_create_enum_schema", "_should_create_literal_schema", - "_get_normalized_schema_key", ) @@ -83,23 +82,6 @@ def _should_create_literal_schema(field_definition: FieldDefinition) -> bool: ) -def _get_normalized_schema_key(annotation: Any) -> tuple[str, ...]: - """Create a key for a type annotation. - - The key should be a tuple such as ``("path", "to", "type", "TypeName")``. - - Args: - annotation: a type annotation - - Returns: - A tuple of strings. - """ - module = getattr(annotation, "__module__", "") - name = str(annotation)[len(module) + 1 :] if isinstance(annotation, _GenericAlias) else annotation.__qualname__ - name = name.replace("..", ".") - return *module.split("."), name - - def get_formatted_examples(field_definition: FieldDefinition, examples: Sequence[Example]) -> Mapping[str, Example]: """Format the examples into the OpenAPI schema format.""" diff --git a/litestar/params.py b/litestar/params.py index c52389e0f4..b1ef2361ce 100644 --- a/litestar/params.py +++ b/litestar/params.py @@ -119,6 +119,11 @@ class KwargDefinition: .. versionadded:: 2.8.0 """ + schema_component_key: str | None = None + """ + Use as the key for the reference when creating a component for this type + .. versionadded:: 2.12.0 + """ @property def is_constrained(self) -> bool: @@ -195,6 +200,7 @@ def Parameter( required: bool | None = None, title: str | None = None, schema_extra: dict[str, Any] | None = None, + schema_component_key: str | None = None, ) -> Any: """Create an extended parameter kwarg definition. @@ -239,6 +245,8 @@ def Parameter( schema. .. versionadded:: 2.8.0 + schema_component_key: Use this as the key for the reference when creating a component for this type + .. versionadded:: 2.12.0 """ return ParameterKwarg( annotation=annotation, @@ -264,6 +272,7 @@ def Parameter( max_length=max_length, pattern=pattern, schema_extra=schema_extra, + schema_component_key=schema_component_key, ) @@ -308,6 +317,7 @@ def Body( pattern: str | None = None, title: str | None = None, schema_extra: dict[str, Any] | None = None, + schema_component_key: str | None = None, ) -> Any: """Create an extended request body kwarg definition. @@ -349,6 +359,8 @@ def Body( schema. .. versionadded:: 2.8.0 + schema_component_key: Use this as the key for the reference when creating a component for this type + .. versionadded:: 2.12.0 """ return BodyKwarg( media_type=media_type, @@ -371,6 +383,7 @@ def Body( pattern=pattern, multipart_form_part_limit=multipart_form_part_limit, schema_extra=schema_extra, + schema_component_key=schema_component_key, ) diff --git a/tests/unit/test_openapi/test_datastructures.py b/tests/unit/test_openapi/test_datastructures.py index 6e5e6a92c5..abafa68553 100644 --- a/tests/unit/test_openapi/test_datastructures.py +++ b/tests/unit/test_openapi/test_datastructures.py @@ -1,9 +1,16 @@ from __future__ import annotations +from typing import Dict, Generic, List, TypeVar + +import msgspec import pytest -from litestar._openapi.datastructures import SchemaRegistry +from litestar._openapi.datastructures import SchemaRegistry, _get_normalized_schema_key +from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.spec import Reference, Schema +from litestar.params import KwargDefinition +from litestar.typing import FieldDefinition +from tests.models import DataclassPerson @pytest.fixture() @@ -11,28 +18,125 @@ def schema_registry() -> SchemaRegistry: return SchemaRegistry() -def test_get_schema_for_key(schema_registry: SchemaRegistry) -> None: +def test_get_schema_for_field_definition(schema_registry: SchemaRegistry) -> None: assert not schema_registry._schema_key_map assert not schema_registry._schema_reference_map assert not schema_registry._model_name_groups - key = ("a", "b", "c") - schema = schema_registry.get_schema_for_key(key) + field = FieldDefinition.from_annotation(str) + schema = schema_registry.get_schema_for_field_definition(field) + key = _get_normalized_schema_key(field) assert isinstance(schema, Schema) assert key in schema_registry._schema_key_map assert not schema_registry._schema_reference_map - assert len(schema_registry._model_name_groups["c"]) == 1 - assert schema_registry._model_name_groups["c"][0].schema is schema - assert schema_registry.get_schema_for_key(key) is schema + assert len(schema_registry._model_name_groups[key[-1]]) == 1 + assert schema_registry._model_name_groups[key[-1]][0].schema is schema + assert schema_registry.get_schema_for_field_definition(field) is schema -def test_get_reference_for_key(schema_registry: SchemaRegistry) -> None: +def test_get_reference_for_field_definition(schema_registry: SchemaRegistry) -> None: assert not schema_registry._schema_key_map assert not schema_registry._schema_reference_map assert not schema_registry._model_name_groups - key = ("a", "b", "c") - assert schema_registry.get_reference_for_key(key) is None - schema_registry.get_schema_for_key(key) - reference = schema_registry.get_reference_for_key(key) + field = FieldDefinition.from_annotation(str) + key = _get_normalized_schema_key(field) + + assert schema_registry.get_reference_for_field_definition(field) is None + schema_registry.get_schema_for_field_definition(field) + reference = schema_registry.get_reference_for_field_definition(field) assert isinstance(reference, Reference) assert id(reference) in schema_registry._schema_reference_map assert reference in schema_registry._schema_key_map[key].references + + +def test_get_normalized_schema_key() -> None: + class LocalClass(msgspec.Struct): + id: str + + T = TypeVar("T") + + # replace each of the long strings with underscores with a tuple of strings split at each underscore + assert _get_normalized_schema_key(FieldDefinition.from_annotation(LocalClass)) == ( + "tests", + "unit", + "test_openapi", + "test_datastructures", + "test_get_normalized_schema_key.LocalClass", + ) + + assert _get_normalized_schema_key(FieldDefinition.from_annotation(DataclassPerson)) == ( + "tests", + "models", + "DataclassPerson", + ) + + builtin_dict = Dict[str, List[int]] + assert _get_normalized_schema_key(FieldDefinition.from_annotation(builtin_dict)) == ( + "typing", + "Dict[str, typing.List[int]]", + ) + + builtin_with_custom = Dict[str, DataclassPerson] + assert _get_normalized_schema_key(FieldDefinition.from_annotation(builtin_with_custom)) == ( + "typing", + "Dict[str, tests.models.DataclassPerson]", + ) + + class LocalGeneric(Generic[T]): + pass + + assert _get_normalized_schema_key(FieldDefinition.from_annotation(LocalGeneric)) == ( + "tests", + "unit", + "test_openapi", + "test_datastructures", + "test_get_normalized_schema_key.LocalGeneric", + ) + + generic_int = LocalGeneric[int] + generic_str = LocalGeneric[str] + + assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_int)) == ( + "tests", + "unit", + "test_openapi", + "test_datastructures", + "test_get_normalized_schema_key.LocalGeneric[int]", + ) + + assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_str)) == ( + "tests", + "unit", + "test_openapi", + "test_datastructures", + "test_get_normalized_schema_key.LocalGeneric[str]", + ) + + assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_int)) != _get_normalized_schema_key( + FieldDefinition.from_annotation(generic_str) + ) + + +def test_raise_on_override_for_same_field_definition() -> None: + registry = SchemaRegistry() + schema = registry.get_schema_for_field_definition( + FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="foo")) + ) + # registering the same thing again with the same name should work + assert ( + registry.get_schema_for_field_definition( + FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="foo")) + ) + is schema + ) + # registering the same *type* with a different name should result in a different schema + assert ( + registry.get_schema_for_field_definition( + FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="bar")) + ) + is not schema + ) + # registering a different type with a previously used name should raise an exception + with pytest.raises(ImproperlyConfiguredException): + registry.get_schema_for_field_definition( + FieldDefinition.from_annotation(int, kwarg_definition=KwargDefinition(schema_component_key="foo")) + ) diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index 21623ec0b2..daa57ad4c0 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -29,10 +29,11 @@ KWARG_DEFINITION_ATTRIBUTE_TO_OPENAPI_PROPERTY_MAP, SchemaCreator, ) -from litestar._openapi.schema_generation.utils import _get_normalized_schema_key, _type_or_first_not_none_inner_type +from litestar._openapi.schema_generation.utils import _type_or_first_not_none_inner_type from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar from litestar.di import Provide from litestar.enums import ParamType +from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.spec import ExternalDocumentation, OpenAPIType, Reference from litestar.openapi.spec.example import Example from litestar.openapi.spec.parameter import Parameter as OpenAPIParameter @@ -88,58 +89,68 @@ def test_process_schema_result() -> None: assert getattr(schema, schema_key) == getattr(kwarg_definition, signature_key) -def test_get_normalized_schema_key() -> None: - class LocalClass(msgspec.Struct): - id: str +def test_override_schema_component_key() -> None: + @dataclass + class Data: + pass - # replace each of the long strings with underscores with a tuple of strings split at each underscore - assert ( - "tests", - "unit", - "test_openapi", - "test_schema", - "test_get_normalized_schema_key.LocalClass", - ) == _get_normalized_schema_key(LocalClass) + @post("/") + def handler( + data: Data, + ) -> Annotated[Data, Parameter(schema_component_key="not_data")]: + return Data() - assert ("tests", "models", "DataclassPerson") == _get_normalized_schema_key(DataclassPerson) + @get("/") + def handler_2() -> Annotated[Data, Parameter(schema_component_key="not_data")]: + return Data() + + app = Litestar([handler, handler_2]) + schema = app.openapi_schema.to_schema() + # we expect the annotated / non-annotated to generate independent components + assert schema["paths"]["/"]["post"]["requestBody"]["content"]["application/json"] == { + "schema": {"$ref": "#/components/schemas/test_override_schema_component_key.Data"} + } + assert schema["paths"]["/"]["post"]["responses"]["201"]["content"] == { + "application/json": {"schema": {"$ref": "#/components/schemas/not_data"}} + } + # a response with the same type and the same name should reference the same component + assert schema["paths"]["/"]["get"]["responses"]["200"]["content"] == { + "application/json": {"schema": {"$ref": "#/components/schemas/not_data"}} + } + assert app.openapi_schema.to_schema()["components"] == { + "schemas": { + "not_data": {"properties": {}, "type": "object", "required": [], "title": "Data"}, + "test_override_schema_component_key.Data": { + "properties": {}, + "type": "object", + "required": [], + "title": "Data", + }, + } + } - builtin_dict = Dict[str, List[int]] - assert ("typing", "Dict[str, typing.List[int]]") == _get_normalized_schema_key(builtin_dict) - builtin_with_custom = Dict[str, DataclassPerson] - assert ("typing", "Dict[str, tests.models.DataclassPerson]") == _get_normalized_schema_key(builtin_with_custom) +def test_override_schema_component_key_raise_if_keys_are_not_unique() -> None: + @dataclass + class Data: + pass - class LocalGeneric(Generic[T]): + @dataclass + class Data2: pass - assert ( - "tests", - "unit", - "test_openapi", - "test_schema", - "test_get_normalized_schema_key.LocalGeneric", - ) == _get_normalized_schema_key(LocalGeneric) - - generic_int = LocalGeneric[int] - generic_str = LocalGeneric[str] - - assert ( - "tests", - "unit", - "test_openapi", - "test_schema", - "test_get_normalized_schema_key.LocalGeneric[int]", - ) == _get_normalized_schema_key(generic_int) - - assert ( - "tests", - "unit", - "test_openapi", - "test_schema", - "test_get_normalized_schema_key.LocalGeneric[str]", - ) == _get_normalized_schema_key(generic_str) - - assert _get_normalized_schema_key(generic_int) != _get_normalized_schema_key(generic_str) + @post("/") + def handler( + data: Data, + ) -> Annotated[Data, Parameter(schema_component_key="not_data")]: + return Data() + + @get("/") + def handler_2() -> Annotated[Data2, Parameter(schema_component_key="not_data")]: + return Data2() + + with pytest.raises(ImproperlyConfiguredException, match="Schema component keys must be unique"): + Litestar([handler, handler_2]).openapi_schema.to_schema() def test_dependency_schema_generation() -> None: