Skip to content

Commit

Permalink
Add tests and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
kedod committed Mar 9, 2024
1 parent a135229 commit 6203f9b
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 32 deletions.
74 changes: 45 additions & 29 deletions litestar/contrib/pydantic/pydantic_schema_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,16 @@ def for_pydantic_model(
Args:
field_definition: FieldDefinition instance.
schema_creator: An instance of the schema creator class
exclude: Whether to exclude specified fields
exclude_defaults: Whether to exclude default fields
exclude_none: Whether to exclude None value fields
exclude_unset: Whether to exclude not set fields
generate_examples: Whether to generate examples if none are given
include: Whether to include only specified fields
Returns:
A schema instance.
"""

annotation = field_definition.annotation
if is_generic(annotation):
is_generic_model = True
Expand Down Expand Up @@ -346,36 +351,16 @@ def for_pydantic_model(
)
property_fields.update(computed_field_definitions)

# TODO: refactor
required = set()
exclude = exclude or set()
include = include or []

for prop in property_fields.values():
name = prop.name

if prop.is_required:
if name in exclude:
continue

if exclude_none and (prop.is_optional or prop.is_none_type):
continue

if exclude_defaults and prop.has_default:
continue

if exclude_unset and prop.has_default:
continue

if include:
if name in include:
required.add(name)
elif prop.is_required:
required.add(name)

return schema_creator.create_component_schema(
field_definition,
required=sorted(required),
required=cls.get_required_fields(
property_fields,
exclude=exclude,
include=include,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
exclude_unset=exclude_unset,
),
property_fields=property_fields,
title=title,
examples=(
Expand All @@ -386,3 +371,34 @@ def for_pydantic_model(
)
),
)

@classmethod
def get_required_fields(
cls,
property_fields: dict[str, FieldDefinition],
exclude: PydanticFieldsList = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticFieldsList = None,
) -> list[str]:
required = []
exclude = exclude or set()
include = include or []

for prop in property_fields.values():
name = prop.name

if any(
[
prop.name in exclude,
exclude_none and (prop.is_optional or prop.is_none_type),
(exclude_defaults or exclude_unset) and prop.has_default,
]
):
continue

if name in include or (not include and prop.is_required):
required.append(name)

return sorted(required)
6 changes: 3 additions & 3 deletions litestar/types/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
except ImportError:
BaseModel = Any # type: ignore[assignment, misc]
IncEx = Any # type: ignore[assignment, misc]
AbstractSetIntStr = Any # type: ignore[assignment, misc]
MappingIntStrAny = Any # type: ignore[assignment, misc]
IncEx = Any # type: ignore[misc]
AbstractSetIntStr = Any
MappingIntStrAny = Any

try:
from attrs import AttrsInstance
Expand Down
64 changes: 64 additions & 0 deletions tests/unit/test_contrib/test_pydantic/test_schema_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,67 @@ def test_exclude_private_fields(model_class: Type[Union[pydantic_v1.BaseModel, p
FieldDefinition.from_annotation(model_class), schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()])
)
assert not schema.properties


@pytest.mark.parametrize(
"plugin_params, required_fields",
(
(
{"exclude": {"default_field"}},
["none_field"],
),
(
{"exclude_defaults": True},
["none_field"],
),
(
{"exclude_none": True},
[],
),
(
{"exclude_unset": True},
["none_field"],
),
(
{"include": {"default_field"}},
["default_field"],
),
(
{"include": {"default_field"}, "exclude": {"default_field"}},
[],
),
),
ids=(
"Exclude specific field",
"Exclude defaults field",
"Exclude None fields",
"Exclude unset fields",
"Include excluded field",
"Exclude over include",
),
)
def test_required_schema_fields(plugin_params: dict, required_fields: dict) -> None:
class ModelV1(pydantic_v1.BaseModel): # pyright: ignore
default_field: str = "default"
default_none_field: None = None
none_field: None

class ModelV2(pydantic_v2.BaseModel):
default_field: str = "default"
default_none_field: None = None
none_field: None

schema_v1 = PydanticSchemaPlugin.for_pydantic_model(
FieldDefinition.from_annotation(ModelV1),
schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()]),
**plugin_params,
)

schema_v2 = PydanticSchemaPlugin.for_pydantic_model(
FieldDefinition.from_annotation(ModelV2),
schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()]),
**plugin_params,
)

assert schema_v1.required == required_fields # type:ignore[comparison-overlap]
assert schema_v2.required == required_fields # type:ignore[comparison-overlap]

0 comments on commit 6203f9b

Please sign in to comment.