Skip to content

Commit

Permalink
Add tests which shows that code doesnt work correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
kedod committed Mar 9, 2024
1 parent baeeabc commit 71749c1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 26 deletions.
65 changes: 42 additions & 23 deletions litestar/contrib/pydantic/pydantic_schema_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,16 @@ def for_pydantic_model(
Args:
field_definition: FieldDefinition instance.
schema_creator: An instance of the schema creator class
exclude: Whether to exclude specified fields
exclude_defaults: Whether to exclude default fields
exclude_none: Whether to exclude None value fields
exclude_unset: Whether to exclude not set fields
generate_examples: Whether to generate examples if none are given
include: Whether to include only specified fields
Returns:
A schema instance.
"""

annotation = field_definition.annotation
if is_generic(annotation):
is_generic_model = True
Expand Down Expand Up @@ -346,11 +351,40 @@ def for_pydantic_model(
)
property_fields.update(computed_field_definitions)

# TODO: refactor
required = set()
return schema_creator.create_component_schema(
field_definition,
required=cls.get_required_fields(
property_fields,
exclude=exclude,
include=include,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
exclude_unset=exclude_unset,
),
property_fields=property_fields,
title=title,
examples=(
None
if example is None
else get_formatted_examples(
field_definition, [Example(description=f"Example {field_definition.name} value", value=example)]
)
),
)

@classmethod
def get_required_fields(
cls,
property_fields: dict[str, FieldDefinition],
exclude: PydanticFieldsList = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticFieldsList = None,
) -> list[str]:
required = []
exclude = exclude or set()
include = include or []

for prop in property_fields.values():
name = prop.name

Expand All @@ -361,28 +395,13 @@ def for_pydantic_model(
if exclude_none and (prop.is_optional or prop.is_none_type):
continue

if exclude_defaults and prop.has_default:
continue

if exclude_unset and prop.has_default:
if any({exclude_defaults, exclude_unset}) and prop.has_default:
continue

Check warning on line 399 in litestar/contrib/pydantic/pydantic_schema_plugin.py

View check run for this annotation

Codecov / codecov/patch

litestar/contrib/pydantic/pydantic_schema_plugin.py#L399

Added line #L399 was not covered by tests

if include:
if name in include:
required.add(name)
required.append(name)
elif prop.is_required:
required.add(name)
required.append(name)

return schema_creator.create_component_schema(
field_definition,
required=sorted(required),
property_fields=property_fields,
title=title,
examples=(
None
if example is None
else get_formatted_examples(
field_definition, [Example(description=f"Example {field_definition.name} value", value=example)]
)
),
)
return sorted(required)
6 changes: 3 additions & 3 deletions litestar/types/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
except ImportError:
BaseModel = Any # type: ignore[assignment, misc]
IncEx = Any # type: ignore[assignment, misc]
AbstractSetIntStr = Any # type: ignore[assignment, misc]
MappingIntStrAny = Any # type: ignore[assignment, misc]
IncEx = Any # type: ignore[misc]
AbstractSetIntStr = Any
MappingIntStrAny = Any

try:
from attrs import AttrsInstance
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/test_contrib/test_pydantic/test_schema_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,61 @@ def test_exclude_private_fields(model_class: Type[Union[pydantic_v1.BaseModel, p
FieldDefinition.from_annotation(model_class), schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()])
)
assert not schema.properties


@pytest.mark.parametrize(
"plugin_params, required_fields",
(
(
{"exclude": {"default_field"}},
["none_field"],
),
(
{"exclude_defaults": True},
["none_field"],
),
(
{"exclude_none": True},
[],
),
(
{"exclude_unset": True},
["none_field"],
),
(
{"include": {"default_field"}},
["default_field"],
),
),
ids=(
"Exclude specific field",
"Exclude defaults field",
"Exclude None fields",
"Exclude unset fields",
"Include excluded field",
),
)
def test_required_schema_fields(plugin_params: dict, required_fields: dict) -> None:
class ModelV1(pydantic_v1.BaseModel): # pyright: ignore
default_field: str = "default"
default_none_field: None = None
none_field: None

class ModelV2(pydantic_v2.BaseModel):
default_field: str = "default"
default_none_field: None = None
none_field: None

schema_v1 = PydanticSchemaPlugin.for_pydantic_model(
FieldDefinition.from_annotation(ModelV1),
schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()]),
**plugin_params,
)

schema_v2 = PydanticSchemaPlugin.for_pydantic_model(
FieldDefinition.from_annotation(ModelV2),
schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()]),
**plugin_params,
)

assert schema_v1.required == schema_v2.required == required_fields # type:ignore[comparison-overlap]

0 comments on commit 71749c1

Please sign in to comment.