Skip to content

Commit

Permalink
fix serialization of unannounced FrozenDict (#8)
Browse files Browse the repository at this point in the history
Pydantic did not recognize FrozenDicts when they are not explicitly
announced as part of the type in a model. This is due to issue:
pydantic/pydantic#7779

Bumps version 0.1.1.
  • Loading branch information
KerstenBreuer authored Apr 3, 2024
1 parent 865948a commit 6dd967a
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 50 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Intended Audience :: Developers",
]
name = "arcticfreeze"
version = "0.1.0"
version = "0.1.1"
description = "Enjoy Python on the rocks with deeply (recursively) frozen data structures."
dependencies = [
"immutabledict >=4, <5"
Expand All @@ -30,7 +30,6 @@ dependencies = [
[project.optional-dependencies]
pydantic = [
"pydantic >=2, <3",
"pydantic_core >=2, <3",
]

[project.license]
Expand Down
61 changes: 48 additions & 13 deletions src/arcticfreeze/_internal/frozendict.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,48 +64,83 @@ def __new__(cls, *args: Any, **kwargs: Any) -> FrozenDict:

if PYDANTIC_V2_INSTALLED:
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from pydantic_core import SchemaSerializer, core_schema

def get_pydantic_core_schema(
cls, source: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""Get the pydantic core schema for this type."""
# Validate the type against a Mapping:
args = typing.get_args(source)
if not args:
validation_schema = handler.generate_schema(Mapping)
key_type = Any
value_type = Any
elif len(args) == 2:
validation_schema = handler.generate_schema(Mapping[args[0], args[1]]) # type: ignore
key_type, value_type = args
else:
raise TypeError(
"Expected exactly two (or no) type arguments for FrozenDict, got"
+ f" {len(args)}"
)

validation_schema = handler.generate_schema(
Mapping[key_type, value_type] # type: ignore
)

python_serialization_schema = core_schema.plain_serializer_function_ser_schema(
lambda x: x, return_schema=core_schema.any_schema()
)

python_schema = core_schema.no_info_after_validator_function(
# callable to use after validation against the schema (convert to
# FrozenDict):
cls,
# the validation schema to use before executing the callable:
validation_schema,
function=cls,
schema=validation_schema,
serialization=python_serialization_schema,
)

json_serialization_schema = core_schema.plain_serializer_function_ser_schema(
dict, return_schema=validation_schema, when_used="json"
)
json_schema = core_schema.no_info_after_validator_function(
function=cls,
schema=validation_schema,
serialization=json_serialization_schema,
)

# Uses cls as validator function to convert the dict to a FrozenDict:
return core_schema.json_or_python_schema(
json_schema=validation_schema,
schema = core_schema.json_or_python_schema(
json_schema=json_schema,
python_schema=python_schema,
)

return schema

def pydantic_serializer(self) -> SchemaSerializer:
"""This is needed due to issue:
https://github.com/pydantic/pydantic/issues/7779
"""
validation_schema = core_schema.any_schema()

python_serialization_schema = core_schema.plain_serializer_function_ser_schema(
lambda x: x, return_schema=validation_schema
)
python_schema = core_schema.any_schema(
serialization=python_serialization_schema,
)

json_serialization_schema = core_schema.plain_serializer_function_ser_schema(
dict, return_schema=validation_schema, when_used="json"
)
json_schema = core_schema.any_schema(
serialization=json_serialization_schema,
)

schema = core_schema.json_or_python_schema(
json_schema=json_schema,
python_schema=python_schema,
)

return SchemaSerializer(schema)

FrozenDict.__get_pydantic_core_schema__ = classmethod( # type: ignore
get_pydantic_core_schema
)
FrozenDict.__pydantic_serializer__ = property( # type: ignore
pydantic_serializer
)
6 changes: 1 addition & 5 deletions src/arcticfreeze/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@

# Check if Pydantic v2 is installed and store the result in a constant:
PYDANTIC_V2_INSTALLED = False
PYDANTIC_VERSION_PREFIX = "2."
try:
from pydantic import __version__ as pydantic_version
from pydantic_core import __version__ as pydantic_core_version
except ImportError:
pass
else:
if pydantic_version.startswith(
PYDANTIC_VERSION_PREFIX
) and pydantic_core_version.startswith(PYDANTIC_VERSION_PREFIX):
if pydantic_version.startswith("2."):
PYDANTIC_V2_INSTALLED = True
124 changes: 94 additions & 30 deletions tests/test_frozendict.py → tests/test_frozendict_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the FrozenDict class."""
"""Test the Pydantic integration of the FrozenDict class."""

import json
from typing import Any

import pydantic
import pytest
from arcticfreeze import FrozenDict
from pydantic import BaseModel, ConfigDict

Expand All @@ -39,54 +42,50 @@ class TestModel(BaseModel):
)


def test_frozen_dict_serialization():
"""Test serialization of FrozenDict used in pydantic models when using model_dump."""
def test_frozen_dict_validation_invalid():
"""Test validation using FrozenDict in pydantic models with invalid input."""

class TestModel(BaseModel):
frozen_dict: FrozenDict

input_dict = {"a": 1, "b": 2}

model = TestModel.model_validate({"frozen_dict": input_dict})

dumped_data = model.model_dump()
observed_dict = dumped_data["frozen_dict"]
assert isinstance(observed_dict, FrozenDict)
assert dict(observed_dict) == input_dict
with pytest.raises(pydantic.ValidationError):
TestModel.model_validate({"frozen_dict": ["invalid_input"]})


def test_frozen_dict_serialization_json():
"""Test serialization of FrozenDict used in pydantic models when using
model_dump_json.
def test_frozen_dict_validation_with_args():
"""Test validation and type conversion using FrozenDict with type arguments in
pydantic models.
"""

class TestModel(BaseModel):
frozen_dict: FrozenDict
frozen_dict: FrozenDict[str, int]

input_dict = {"a": 1, "b": 2}
frozen_input_dict = FrozenDict(input_dict)

model = TestModel.model_validate({"frozen_dict": input_dict})
model_from_dict = TestModel.model_validate({"frozen_dict": input_dict})
model_from_frozendict = TestModel(frozen_dict=frozen_input_dict)

dumped_json = model.model_dump_json()
observed_dict = json.loads(dumped_json)["frozen_dict"]
assert observed_dict == input_dict
assert (
model_from_dict.frozen_dict
== model_from_frozendict.frozen_dict
== frozen_input_dict
)


def test_frozen_dict_json_schema():
"""Test JSON schema generation of FrozenDicts-containing pydantic models."""
def test_frozen_dict_validation_with_args_invalid():
"""Test validation and type conversion using FrozenDict with type arguments in
pydantic models with invalid input.
"""

class TestModel(BaseModel):
frozen_dict: FrozenDict
frozen_dict: FrozenDict[str, int]

schema_from_frozen_dict = TestModel.model_json_schema()
# values are strings but expected integers:
input_dict = {"a": "invalid", "b": "invalid"}

# redefine TestModel to not using standard dict:
class TestModel(BaseModel): # type: ignore
frozen_dict: dict

schema_from_dict = TestModel.model_json_schema()

assert schema_from_frozen_dict == schema_from_dict
with pytest.raises(pydantic.ValidationError):
TestModel.model_validate({"frozen_dict": input_dict})


def test_frozen_dict_hashing():
Expand Down Expand Up @@ -119,3 +118,68 @@ class TestModel(BaseModel):

model = TestModel.model_validate({"inner": {"test": {"frozen_dict": input_dict}}})
assert isinstance(model.inner["test"], Inner)


def test_frozen_dict_json_schema():
"""Test JSON schema generation of FrozenDicts-containing pydantic models."""

class TestModel(BaseModel):
frozen_dict: FrozenDict

schema_from_frozen_dict = TestModel.model_json_schema()

# redefine TestModel to not using standard dict:
class TestModel(BaseModel): # type: ignore
frozen_dict: dict

schema_from_dict = TestModel.model_json_schema()

assert schema_from_frozen_dict == schema_from_dict


def test_frozen_dict_serialization():
"""Test serialization of FrozenDict used in pydantic models when using model_dump."""

class TestModel(BaseModel):
frozen_dict: FrozenDict

input_dict = {"a": 1, "b": 2}

model = TestModel.model_validate({"frozen_dict": input_dict})

dumped_data = model.model_dump()
observed_dict = dumped_data["frozen_dict"]
assert isinstance(observed_dict, FrozenDict)
assert dict(observed_dict) == input_dict

dumped_json = model.model_dump_json()
observed_json_dict = json.loads(dumped_json)["frozen_dict"]
assert observed_json_dict == input_dict


def test_frozen_dict_serialization_when_not_announced():
"""Test serialization of FrozenDict when not announced in the model. Moreover,
check that also the serialization of children of the FrozenDict is working.
This test was added as response to a bug.
"""

class TestModel(BaseModel):
"""A model defining a tuple with arbitrary content (hence also FrozenDict)
should be valid.
"""

frozen_tuple: tuple[Any]

input_tuple = (FrozenDict({"a": (1, 2)}),)
model = TestModel.model_validate({"frozen_tuple": input_tuple})

# serialization to dict:
dumped_dict = model.model_dump()
observed_output = dumped_dict["frozen_tuple"]
assert observed_output == input_tuple

# serialization to json:
dumped_json = model.model_dump_json()
expected_json_output = [{"a": [1, 2]}]
observed_json_output = json.loads(dumped_json)["frozen_tuple"]
assert observed_json_output == expected_json_output

0 comments on commit 6dd967a

Please sign in to comment.