Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow customizing schema component keys #3738

Merged
merged 6 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 90 additions & 25 deletions litestar/_openapi/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,67 @@
from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING, Iterator, Sequence
from typing import TYPE_CHECKING, Iterator, Sequence, _GenericAlias # type: ignore[attr-defined]

from litestar.exceptions import ImproperlyConfiguredException
from litestar.openapi.spec import Reference, Schema
from litestar.params import KwargDefinition

if TYPE_CHECKING:
from litestar.openapi import OpenAPIConfig
from litestar.plugins import OpenAPISchemaPluginProtocol
from litestar.typing import FieldDefinition


def _longest_common_prefix(tuples_: list[tuple[str, ...]]) -> tuple[str, ...]:
provinzkraut marked this conversation as resolved.
Show resolved Hide resolved
"""Find the longest common prefix of a list of tuples.

Args:
tuples_: A list of tuples to find the longest common prefix of.

Returns:
The longest common prefix of the tuples.
"""
prefix_ = tuples_[0]
for t in tuples_:
# Compare the current prefix with each tuple and shorten it
prefix_ = prefix_[: min(len(prefix_), len(t))]
for i in range(len(prefix_)):
if prefix_[i] != t[i]:
prefix_ = prefix_[:i]
break
return prefix_


def _get_component_key_override(field: FieldDefinition) -> str | None:
if (
(kwarg_definition := field.kwarg_definition)
and isinstance(kwarg_definition, KwargDefinition)
and (schema_key := kwarg_definition.schema_component_key)
):
return schema_key
return None


def _get_normalized_schema_key(field_definition: FieldDefinition) -> tuple[str, ...]:
"""Create a key for a type annotation.

The key should be a tuple such as ``("path", "to", "type", "TypeName")``.

Args:
field_definition: Field definition

Returns:
A tuple of strings.
"""
if override := _get_component_key_override(field_definition):
return (override,)

annotation = field_definition.annotation
module = getattr(annotation, "__module__", "")
name = str(annotation)[len(module) + 1 :] if isinstance(annotation, _GenericAlias) else annotation.__qualname__
name = name.replace(".<locals>.", ".")
return *module.split("."), name


class RegisteredSchema:
Expand Down Expand Up @@ -43,32 +96,63 @@ def __init__(self) -> None:
self._schema_key_map: dict[tuple[str, ...], RegisteredSchema] = {}
self._schema_reference_map: dict[int, RegisteredSchema] = {}
self._model_name_groups: defaultdict[str, list[RegisteredSchema]] = defaultdict(list)
self._component_type_map: dict[tuple[str, ...], FieldDefinition] = {}

def get_schema_for_key(self, key: tuple[str, ...]) -> Schema:
def get_schema_for_field_definition(self, field: FieldDefinition) -> Schema:
"""Get a registered schema by its key.

Args:
key: The key to the schema to get.
field: The field definition to get the schema for

Returns:
A RegisteredSchema object.
"""
key = _get_normalized_schema_key(field)
if key not in self._schema_key_map:
self._schema_key_map[key] = registered_schema = RegisteredSchema(key, Schema(), [])
self._model_name_groups[key[-1]].append(registered_schema)
self._component_type_map[key] = field
else:
if (existing_type := self._component_type_map[key]) != field:
raise ImproperlyConfiguredException(
f"Schema component keys must be unique. Cannot override existing key {'_'.join(key)!r} for type "
f"{existing_type.raw!r} with new type {field.raw!r}"
)
return self._schema_key_map[key].schema

def get_reference_for_key(self, key: tuple[str, ...]) -> Reference | None:
def get_reference_for_field_definition(self, field: FieldDefinition) -> Reference | None:
"""Get a reference to a registered schema by its key.

Args:
key: The key to the schema to get.
field: The field definition to get the reference for

Returns:
A Reference object.
"""
key = _get_normalized_schema_key(field)
if key not in self._schema_key_map:
return None

if (existing_type := self._component_type_map[key]) != field:
# TODO: This should check for strict equality, e.g. changes in type metadata
# However, this is currently not possible to do without breaking things, as
# we allow to define metadata on a type annotation in one place to be used
# for the same type in a different place, where that same type is *not*
# annotated with this metadata. The proper fix for this would be to e.g.
# inline DTO definitions when they are created at the handler level, as
# they won't be reused (they already generate a unique key), and create a
# more strict lookup policy for component schemas
msg = (
f"Schema component keys must be unique. While obtaining a reference for the type '{field.raw!r}', the "
f"generated key {'_'.join(key)!r} was already associated with a different type '{existing_type.raw!r}'. "
)
if key_override := _get_component_key_override(field): # pragma: no cover
# Currently, this can never not be true, however, in the future we might
# decide to do a stricter equality check as lined out above, in which
# case there can be other cases than overrides that cause this error
msg += f"Hint: Both types are defining a 'schema_component_key' with the value of {key_override!r}"
raise ImproperlyConfiguredException(msg)

registered_schema = self._schema_key_map[key]
reference = Reference(f"#/components/schemas/{'_'.join(key)}")
registered_schema.references.append(reference)
Expand Down Expand Up @@ -107,26 +191,7 @@ def remove_common_prefix(tuples: list[tuple[str, ...]]) -> list[tuple[str, ...]]
A list of tuples with the common prefix removed.
"""

def longest_common_prefix(tuples_: list[tuple[str, ...]]) -> tuple[str, ...]:
"""Find the longest common prefix of a list of tuples.

Args:
tuples_: A list of tuples to find the longest common prefix of.

Returns:
The longest common prefix of the tuples.
"""
prefix_ = tuples_[0]
for t in tuples_:
# Compare the current prefix with each tuple and shorten it
prefix_ = prefix_[: min(len(prefix_), len(t))]
for i in range(len(prefix_)):
if prefix_[i] != t[i]:
prefix_ = prefix_[:i]
break
return prefix_

prefix = longest_common_prefix(tuples)
prefix = _longest_common_prefix(tuples)
prefix_length = len(prefix)
return [t[prefix_length:] for t in tuples]

Expand Down
9 changes: 3 additions & 6 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
create_string_constrained_field_schema,
)
from litestar._openapi.schema_generation.utils import (
_get_normalized_schema_key,
_should_create_enum_schema,
_should_create_literal_schema,
_type_or_first_not_none_inner_type,
Expand Down Expand Up @@ -508,8 +507,7 @@ def for_plugin(self, field_definition: FieldDefinition, plugin: OpenAPISchemaPlu
Returns:
A schema instance.
"""
key = _get_normalized_schema_key(field_definition.annotation)
if (ref := self.schema_registry.get_reference_for_key(key)) is not None:
if (ref := self.schema_registry.get_reference_for_field_definition(field_definition)) is not None:
return ref

schema = plugin.to_openapi_schema(field_definition=field_definition, schema_creator=self)
Expand Down Expand Up @@ -612,8 +610,7 @@ def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schem
schema.examples = get_json_schema_formatted_examples(create_examples_for_field(field))

if schema.title and schema.type == OpenAPIType.OBJECT:
key = _get_normalized_schema_key(field.annotation)
return self.schema_registry.get_reference_for_key(key) or schema
return self.schema_registry.get_reference_for_field_definition(field) or schema
return schema

def create_component_schema(
Expand Down Expand Up @@ -644,7 +641,7 @@ def create_component_schema(
Returns:
A schema instance.
"""
schema = self.schema_registry.get_schema_for_key(_get_normalized_schema_key(type_.annotation))
schema = self.schema_registry.get_schema_for_field_definition(type_)
schema.title = title or _get_type_schema_name(type_)
schema.required = required
schema.type = openapi_type
Expand Down
20 changes: 1 addition & 19 deletions litestar/_openapi/schema_generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping, _GenericAlias # type: ignore[attr-defined]
from typing import TYPE_CHECKING, Any, Mapping

from litestar.utils.helpers import get_name

Expand All @@ -15,7 +15,6 @@
"_type_or_first_not_none_inner_type",
"_should_create_enum_schema",
"_should_create_literal_schema",
"_get_normalized_schema_key",
)


Expand Down Expand Up @@ -83,23 +82,6 @@ def _should_create_literal_schema(field_definition: FieldDefinition) -> bool:
)


def _get_normalized_schema_key(annotation: Any) -> tuple[str, ...]:
"""Create a key for a type annotation.

The key should be a tuple such as ``("path", "to", "type", "TypeName")``.

Args:
annotation: a type annotation

Returns:
A tuple of strings.
"""
module = getattr(annotation, "__module__", "")
name = str(annotation)[len(module) + 1 :] if isinstance(annotation, _GenericAlias) else annotation.__qualname__
name = name.replace(".<locals>.", ".")
return *module.split("."), name


def get_formatted_examples(field_definition: FieldDefinition, examples: Sequence[Example]) -> Mapping[str, Example]:
"""Format the examples into the OpenAPI schema format."""

Expand Down
9 changes: 9 additions & 0 deletions litestar/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class KwargDefinition:

.. versionadded:: 2.8.0
"""
schema_component_key: str | None = None
"""
Use as the key for the reference when creating a component for this type
.. versionadded:: 2.12.0
"""

@property
def is_constrained(self) -> bool:
Expand Down Expand Up @@ -195,6 +200,7 @@ def Parameter(
required: bool | None = None,
title: str | None = None,
schema_extra: dict[str, Any] | None = None,
schema_component_key: str | None = None,
) -> Any:
"""Create an extended parameter kwarg definition.

Expand Down Expand Up @@ -239,6 +245,8 @@ def Parameter(
schema.

.. versionadded:: 2.8.0
schema_component_key: Use this as the key for the reference when creating a component for this type
.. versionadded:: 2.12.0
"""
return ParameterKwarg(
annotation=annotation,
Expand All @@ -264,6 +272,7 @@ def Parameter(
max_length=max_length,
pattern=pattern,
schema_extra=schema_extra,
schema_component_key=schema_component_key,
)


Expand Down
Loading
Loading