Skip to content

Commit

Permalink
Fix serialization of Link/BackLink and OpenAPI schema generation (#1080)
Browse files Browse the repository at this point in the history
* fix: serialization and JSON schema generation in Python and FastAPI contexts

* tests: write tests for serialization and JSON schema generation

* Fix typo

Co-authored-by: Adeel Ahmed <[email protected]>

---------

Co-authored-by: Adeel Ahmed <[email protected]>
  • Loading branch information
staticxterm and adeelsohailahmed authored Dec 23, 2024
1 parent a2e6a5a commit ff5ca5b
Show file tree
Hide file tree
Showing 12 changed files with 511 additions and 176 deletions.
285 changes: 179 additions & 106 deletions beanie/odm/fields.py

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions tests/fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@

from beanie import init_beanie
from tests.conftest import Settings
from tests.fastapi.models import DoorAPI, HouseAPI, RoofAPI, WindowAPI
from tests.fastapi.models import (
DoorAPI,
House,
HouseAPI,
Person,
RoofAPI,
WindowAPI,
)
from tests.fastapi.routes import house_router


Expand All @@ -17,7 +24,7 @@ async def live_span(_: FastAPI):
# INIT BEANIE
await init_beanie(
client.beanie_db,
document_models=[HouseAPI, WindowAPI, DoorAPI, RoofAPI],
document_models=[House, Person, HouseAPI, WindowAPI, DoorAPI, RoofAPI],
)
yield

Expand Down
11 changes: 9 additions & 2 deletions tests/fastapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from httpx import ASGITransport, AsyncClient

from tests.fastapi.app import app
from tests.fastapi.models import DoorAPI, HouseAPI, RoofAPI, WindowAPI
from tests.fastapi.models import (
DoorAPI,
House,
HouseAPI,
Person,
RoofAPI,
WindowAPI,
)


@pytest.fixture(autouse=True)
Expand All @@ -19,7 +26,7 @@ async def api_client(clean_db):

@pytest.fixture(autouse=True)
async def clean_db(db):
models = [HouseAPI, WindowAPI, DoorAPI, RoofAPI]
models = [House, Person, HouseAPI, WindowAPI, DoorAPI, RoofAPI]
yield None

for model in models:
Expand Down
18 changes: 18 additions & 0 deletions tests/fastapi/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List

from pydantic import Field

from beanie import Document, Indexed, Link
from beanie.odm.fields import BackLink
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2


class WindowAPI(Document):
Expand All @@ -20,3 +24,17 @@ class HouseAPI(Document):
windows: List[Link[WindowAPI]]
name: Indexed(str)
height: Indexed(int) = 2


class House(Document):
name: str
owner: Link["Person"]


class Person(Document):
name: str
house: BackLink[House] = (
Field(json_schema_extra={"original_field": "owner"})
if IS_PYDANTIC_V2
else Field(original_field="owner")
)
17 changes: 15 additions & 2 deletions tests/fastapi/routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastapi import APIRouter
from fastapi import APIRouter, Body, status
from pydantic import BaseModel

from beanie import PydanticObjectId, WriteRules
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
from tests.fastapi.models import HouseAPI, WindowAPI
from tests.fastapi.models import House, HouseAPI, Person, WindowAPI

house_router = APIRouter()
if not IS_PYDANTIC_V2:
Expand Down Expand Up @@ -51,3 +51,16 @@ async def create_houses_with_window_link(window: WindowInput):
async def create_houses_2(house: HouseAPI):
await house.insert(link_rule=WriteRules.WRITE)
return house


@house_router.post(
"/house",
response_model=House,
status_code=status.HTTP_201_CREATED,
)
async def create_house_new(house: House = Body(...)):
person = Person(name="Bob")
house.owner = person
await house.save(link_rule=WriteRules.WRITE)
await house.sync()
return house
13 changes: 13 additions & 0 deletions tests/fastapi/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,16 @@ async def test_revision_id(api_client):
resp_json = resp.json()
assert "revision_id" not in resp_json
assert resp_json == {"x": 10, "y": 20, "_id": resp_json["_id"]}


async def test_create_house_new(api_client):
payload = {
"name": "FreshHouse",
"owner": {"name": "will_be_overridden_to_Bob"},
}
resp = await api_client.post("/v1/house", json=payload)
resp_json = resp.json()

assert resp_json["name"] == payload["name"]
assert resp_json["owner"]["name"] == payload["owner"]["name"][-3:]
assert resp_json["owner"]["house"]["collection"] == "House"
19 changes: 19 additions & 0 deletions tests/fastapi/test_openapi_schema_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from json import dumps

from fastapi.openapi.utils import get_openapi

from tests.fastapi.app import app


def test_openapi_schema_generation():
openapi_schema_json_str = dumps(
get_openapi(
title=app.title,
version=app.version,
openapi_version=app.openapi_version,
description=app.description,
routes=app.routes,
),
)

assert openapi_schema_json_str is not None
55 changes: 26 additions & 29 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Optional,
Set,
Tuple,
Type,
Union,
)
from uuid import UUID, uuid4
Expand All @@ -36,7 +37,6 @@
SecretBytes,
SecretStr,
)
from pydantic.fields import FieldInfo
from pydantic_core import core_schema
from pymongo import IndexModel
from typing_extensions import Annotated
Expand Down Expand Up @@ -85,41 +85,33 @@ def as_hex(self):
return self.value

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value):
def _validate(cls, value: Any) -> "Color":
if isinstance(value, Color):
return value
if isinstance(value, dict):
return Color(value["value"])
return Color(value)

@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Callable[[Any], core_schema.CoreSchema], # type: ignore
) -> core_schema.CoreSchema: # type: ignore
def validate(value, _: FieldInfo) -> Color:
if isinstance(value, Color):
return value
if isinstance(value, dict):
return Color(value["value"])
return Color(value)

vf = (
core_schema.with_info_plain_validator_function
if hasattr(core_schema, "with_info_plain_validator_function")
else core_schema.general_plain_validator_function
)
python_schema = vf(validate)
if IS_PYDANTIC_V2:

return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=python_schema,
)
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Type[Any],
_handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.no_info_plain_validator_function(
cls._validate
),
)

else:

@classmethod
def __get_validators__(cls):
yield cls._validate


class Extra(str, Enum):
Expand Down Expand Up @@ -917,6 +909,11 @@ class DocumentWithLink(Document):
s: str = "TEST"


class DocumentWithOptionalLink(Document):
link: Optional[Link["DocumentWithBackLink"]]
s: str = "TEST"


class DocumentWithBackLink(Document):
if IS_PYDANTIC_V2:
back_link: BackLink[DocumentWithLink] = Field(
Expand Down
23 changes: 15 additions & 8 deletions tests/odm/test_beanie_object_dumping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from pydantic import BaseModel, Field

from beanie import Link, PydanticObjectId
Expand All @@ -19,15 +20,21 @@ def data_maker():
)


@pytest.mark.skipif(
not IS_PYDANTIC_V2,
reason="model dumping support is more complete with pydantic v2",
)
def test_id_types_preserved_when_dumping_to_python():
if IS_PYDANTIC_V2:
dumped = data_maker().model_dump(mode="python")
assert isinstance(dumped["my_id"], PydanticObjectId)
assert isinstance(dumped["fake_doc"]["id"], PydanticObjectId)
dumped = data_maker().model_dump(mode="python")
assert isinstance(dumped["my_id"], PydanticObjectId)
assert isinstance(dumped["fake_doc"]["id"], PydanticObjectId)


@pytest.mark.skipif(
not IS_PYDANTIC_V2,
reason="model dumping support is more complete with pydantic v2",
)
def test_id_types_serialized_when_dumping_to_json():
if IS_PYDANTIC_V2:
dumped = data_maker().model_dump(mode="json")
assert isinstance(dumped["my_id"], str)
assert isinstance(dumped["fake_doc"]["id"], str)
dumped = data_maker().model_dump(mode="json")
assert isinstance(dumped["my_id"], str)
assert isinstance(dumped["fake_doc"]["id"], str)
56 changes: 31 additions & 25 deletions tests/odm/test_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from tests.odm.models import DocumentWithCustomIdInt, DocumentWithCustomIdUUID


class A(BaseModel):
id: PydanticObjectId


async def test_uuid_id():
doc = DocumentWithCustomIdUUID(name="TEST")
await doc.insert()
Expand All @@ -22,28 +26,30 @@ async def test_integer_id():
assert isinstance(new_doc.id, int)


if IS_PYDANTIC_V2:

class A(BaseModel):
id: PydanticObjectId

async def test_pydantic_object_id_validation_json():
deserialized = A.model_validate_json(
'{"id": "5eb7cf5a86d9755df3a6c593"}'
)
assert isinstance(deserialized.id, PydanticObjectId)
assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593"
assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593")

@pytest.mark.parametrize(
"data",
[
"5eb7cf5a86d9755df3a6c593",
PydanticObjectId("5eb7cf5a86d9755df3a6c593"),
],
)
async def test_pydantic_object_id_serialization(data):
deserialized = A(**{"id": data})
assert isinstance(deserialized.id, PydanticObjectId)
assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593"
assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593")
@pytest.mark.skipif(
not IS_PYDANTIC_V2,
reason="supports only pydantic v2",
)
async def test_pydantic_object_id_validation_json():
deserialized = A.model_validate_json('{"id": "5eb7cf5a86d9755df3a6c593"}')
assert isinstance(deserialized.id, PydanticObjectId)
assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593"
assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593")


@pytest.mark.skipif(
not IS_PYDANTIC_V2,
reason="supports only pydantic v2",
)
@pytest.mark.parametrize(
"data",
[
"5eb7cf5a86d9755df3a6c593",
PydanticObjectId("5eb7cf5a86d9755df3a6c593"),
],
)
async def test_pydantic_object_id_serialization(data):
deserialized = A(**{"id": data})
assert isinstance(deserialized.id, PydanticObjectId)
assert str(deserialized.id) == "5eb7cf5a86d9755df3a6c593"
assert deserialized.id == PydanticObjectId("5eb7cf5a86d9755df3a6c593")
Loading

0 comments on commit ff5ca5b

Please sign in to comment.