diff --git a/pyproject.toml b/pyproject.toml index d2dc8cd..0b20c22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Intended Audience :: Developers", ] name = "arcticfreeze" -version = "0.1.0" +version = "0.1.1" description = "Enjoy Python on the rocks with deeply (recursively) frozen data structures." dependencies = [ "immutabledict >=4, <5" @@ -30,7 +30,6 @@ dependencies = [ [project.optional-dependencies] pydantic = [ "pydantic >=2, <3", - "pydantic_core >=2, <3", ] [project.license] diff --git a/src/arcticfreeze/_internal/frozendict.py b/src/arcticfreeze/_internal/frozendict.py index e1a82ad..fa1dadb 100644 --- a/src/arcticfreeze/_internal/frozendict.py +++ b/src/arcticfreeze/_internal/frozendict.py @@ -64,48 +64,83 @@ def __new__(cls, *args: Any, **kwargs: Any) -> FrozenDict: if PYDANTIC_V2_INSTALLED: from pydantic import GetCoreSchemaHandler - from pydantic_core import core_schema + from pydantic_core import SchemaSerializer, core_schema def get_pydantic_core_schema( cls, source: Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """Get the pydantic core schema for this type.""" - # Validate the type against a Mapping: args = typing.get_args(source) if not args: - validation_schema = handler.generate_schema(Mapping) + key_type = Any + value_type = Any elif len(args) == 2: - validation_schema = handler.generate_schema(Mapping[args[0], args[1]]) # type: ignore + key_type, value_type = args else: raise TypeError( "Expected exactly two (or no) type arguments for FrozenDict, got" + f" {len(args)}" ) + validation_schema = handler.generate_schema( + Mapping[key_type, value_type] # type: ignore + ) + python_serialization_schema = core_schema.plain_serializer_function_ser_schema( lambda x: x, return_schema=core_schema.any_schema() ) - python_schema = core_schema.no_info_after_validator_function( - # callable to use after validation against the schema (convert to - # FrozenDict): - cls, - # the validation schema to use before executing the callable: - validation_schema, + function=cls, + schema=validation_schema, serialization=python_serialization_schema, ) json_serialization_schema = core_schema.plain_serializer_function_ser_schema( dict, return_schema=validation_schema, when_used="json" ) + json_schema = core_schema.no_info_after_validator_function( + function=cls, + schema=validation_schema, + serialization=json_serialization_schema, + ) - # Uses cls as validator function to convert the dict to a FrozenDict: - return core_schema.json_or_python_schema( - json_schema=validation_schema, + schema = core_schema.json_or_python_schema( + json_schema=json_schema, python_schema=python_schema, + ) + + return schema + + def pydantic_serializer(self) -> SchemaSerializer: + """This is needed due to issue: + https://github.com/pydantic/pydantic/issues/7779 + """ + validation_schema = core_schema.any_schema() + + python_serialization_schema = core_schema.plain_serializer_function_ser_schema( + lambda x: x, return_schema=validation_schema + ) + python_schema = core_schema.any_schema( + serialization=python_serialization_schema, + ) + + json_serialization_schema = core_schema.plain_serializer_function_ser_schema( + dict, return_schema=validation_schema, when_used="json" + ) + json_schema = core_schema.any_schema( serialization=json_serialization_schema, ) + schema = core_schema.json_or_python_schema( + json_schema=json_schema, + python_schema=python_schema, + ) + + return SchemaSerializer(schema) + FrozenDict.__get_pydantic_core_schema__ = classmethod( # type: ignore get_pydantic_core_schema ) + FrozenDict.__pydantic_serializer__ = property( # type: ignore + pydantic_serializer + ) diff --git a/src/arcticfreeze/_internal/utils.py b/src/arcticfreeze/_internal/utils.py index 944de82..c24adcb 100644 --- a/src/arcticfreeze/_internal/utils.py +++ b/src/arcticfreeze/_internal/utils.py @@ -16,14 +16,10 @@ # Check if Pydantic v2 is installed and store the result in a constant: PYDANTIC_V2_INSTALLED = False -PYDANTIC_VERSION_PREFIX = "2." try: from pydantic import __version__ as pydantic_version - from pydantic_core import __version__ as pydantic_core_version except ImportError: pass else: - if pydantic_version.startswith( - PYDANTIC_VERSION_PREFIX - ) and pydantic_core_version.startswith(PYDANTIC_VERSION_PREFIX): + if pydantic_version.startswith("2."): PYDANTIC_V2_INSTALLED = True diff --git a/tests/test_frozendict.py b/tests/test_frozendict_pydantic.py similarity index 56% rename from tests/test_frozendict.py rename to tests/test_frozendict_pydantic.py index bec7b05..d37c73a 100644 --- a/tests/test_frozendict.py +++ b/tests/test_frozendict_pydantic.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test the FrozenDict class.""" +"""Test the Pydantic integration of the FrozenDict class.""" import json +from typing import Any +import pydantic +import pytest from arcticfreeze import FrozenDict from pydantic import BaseModel, ConfigDict @@ -39,54 +42,50 @@ class TestModel(BaseModel): ) -def test_frozen_dict_serialization(): - """Test serialization of FrozenDict used in pydantic models when using model_dump.""" +def test_frozen_dict_validation_invalid(): + """Test validation using FrozenDict in pydantic models with invalid input.""" class TestModel(BaseModel): frozen_dict: FrozenDict - input_dict = {"a": 1, "b": 2} - - model = TestModel.model_validate({"frozen_dict": input_dict}) - - dumped_data = model.model_dump() - observed_dict = dumped_data["frozen_dict"] - assert isinstance(observed_dict, FrozenDict) - assert dict(observed_dict) == input_dict + with pytest.raises(pydantic.ValidationError): + TestModel.model_validate({"frozen_dict": ["invalid_input"]}) -def test_frozen_dict_serialization_json(): - """Test serialization of FrozenDict used in pydantic models when using - model_dump_json. +def test_frozen_dict_validation_with_args(): + """Test validation and type conversion using FrozenDict with type arguments in + pydantic models. """ class TestModel(BaseModel): - frozen_dict: FrozenDict + frozen_dict: FrozenDict[str, int] input_dict = {"a": 1, "b": 2} + frozen_input_dict = FrozenDict(input_dict) - model = TestModel.model_validate({"frozen_dict": input_dict}) + model_from_dict = TestModel.model_validate({"frozen_dict": input_dict}) + model_from_frozendict = TestModel(frozen_dict=frozen_input_dict) - dumped_json = model.model_dump_json() - observed_dict = json.loads(dumped_json)["frozen_dict"] - assert observed_dict == input_dict + assert ( + model_from_dict.frozen_dict + == model_from_frozendict.frozen_dict + == frozen_input_dict + ) -def test_frozen_dict_json_schema(): - """Test JSON schema generation of FrozenDicts-containing pydantic models.""" +def test_frozen_dict_validation_with_args_invalid(): + """Test validation and type conversion using FrozenDict with type arguments in + pydantic models with invalid input. + """ class TestModel(BaseModel): - frozen_dict: FrozenDict + frozen_dict: FrozenDict[str, int] - schema_from_frozen_dict = TestModel.model_json_schema() + # values are strings but expected integers: + input_dict = {"a": "invalid", "b": "invalid"} - # redefine TestModel to not using standard dict: - class TestModel(BaseModel): # type: ignore - frozen_dict: dict - - schema_from_dict = TestModel.model_json_schema() - - assert schema_from_frozen_dict == schema_from_dict + with pytest.raises(pydantic.ValidationError): + TestModel.model_validate({"frozen_dict": input_dict}) def test_frozen_dict_hashing(): @@ -119,3 +118,68 @@ class TestModel(BaseModel): model = TestModel.model_validate({"inner": {"test": {"frozen_dict": input_dict}}}) assert isinstance(model.inner["test"], Inner) + + +def test_frozen_dict_json_schema(): + """Test JSON schema generation of FrozenDicts-containing pydantic models.""" + + class TestModel(BaseModel): + frozen_dict: FrozenDict + + schema_from_frozen_dict = TestModel.model_json_schema() + + # redefine TestModel to not using standard dict: + class TestModel(BaseModel): # type: ignore + frozen_dict: dict + + schema_from_dict = TestModel.model_json_schema() + + assert schema_from_frozen_dict == schema_from_dict + + +def test_frozen_dict_serialization(): + """Test serialization of FrozenDict used in pydantic models when using model_dump.""" + + class TestModel(BaseModel): + frozen_dict: FrozenDict + + input_dict = {"a": 1, "b": 2} + + model = TestModel.model_validate({"frozen_dict": input_dict}) + + dumped_data = model.model_dump() + observed_dict = dumped_data["frozen_dict"] + assert isinstance(observed_dict, FrozenDict) + assert dict(observed_dict) == input_dict + + dumped_json = model.model_dump_json() + observed_json_dict = json.loads(dumped_json)["frozen_dict"] + assert observed_json_dict == input_dict + + +def test_frozen_dict_serialization_when_not_announced(): + """Test serialization of FrozenDict when not announced in the model. Moreover, + check that also the serialization of children of the FrozenDict is working. + This test was added as response to a bug. + """ + + class TestModel(BaseModel): + """A model defining a tuple with arbitrary content (hence also FrozenDict) + should be valid. + """ + + frozen_tuple: tuple[Any] + + input_tuple = (FrozenDict({"a": (1, 2)}),) + model = TestModel.model_validate({"frozen_tuple": input_tuple}) + + # serialization to dict: + dumped_dict = model.model_dump() + observed_output = dumped_dict["frozen_tuple"] + assert observed_output == input_tuple + + # serialization to json: + dumped_json = model.model_dump_json() + expected_json_output = [{"a": [1, 2]}] + observed_json_output = json.loads(dumped_json)["frozen_tuple"] + assert observed_json_output == expected_json_output