Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(NumpyModel): Equality operator when the fields are heterogeneous. #46

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
446 changes: 253 additions & 193 deletions poetry.lock

Large diffs are not rendered by default.

42 changes: 22 additions & 20 deletions pydantic_numpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,29 @@ class NumpyModel(BaseModel):
_directory_suffix: ClassVar[str] = ".pdnp"

def __eq__(self, other: Any) -> bool:
if isinstance(other, NumpyModel):
self_type = self.__pydantic_generic_metadata__["origin"] or self.__class__
other_type = other.__pydantic_generic_metadata__["origin"] or other.__class__
if not isinstance(other, BaseModel):
return NotImplemented # delegate to the other item in the comparison

self_type = self.__pydantic_generic_metadata__["origin"] or self.__class__
other_type = other.__pydantic_generic_metadata__["origin"] or other.__class__

if not (
self_type == other_type
and getattr(self, "__pydantic_private__", None) == getattr(other, "__pydantic_private__", None)
and self.__pydantic_extra__ == other.__pydantic_extra__
):
return False

if isinstance(other, NumpyModel):
self_ndarray_field_to_array, self_other_field_to_value = self._dump_numpy_split_dict()
other_ndarray_field_to_array, other_other_field_to_value = other._dump_numpy_split_dict()

return (
self_type == other_type
and self_other_field_to_value == other_other_field_to_value
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
and _compare_np_array_dicts(self_ndarray_field_to_array, other_ndarray_field_to_array)
return self_other_field_to_value == other_other_field_to_value and _compare_np_array_dicts(
self_ndarray_field_to_array, other_ndarray_field_to_array
)
elif isinstance(other, BaseModel):
return super().__eq__(other)
else:
return NotImplemented # delegate to the other item in the comparison

# Self is NumpyModel, other is not; likely unequal; checking anyway.
return super().__eq__(other)

@classmethod
@validate_call
Expand Down Expand Up @@ -156,10 +161,10 @@ def _dump_numpy_split_dict(self) -> tuple[dict, dict]:
ndarray_field_to_array = {}
other_field_to_value = {}

for k, v in self.model_dump(exclude_unset=True).items():
for k, v in self.model_dump().items():
if isinstance(v, np.ndarray):
ndarray_field_to_array[k] = v
else:
elif v:
other_field_to_value[k] = v

return ndarray_field_to_array, other_field_to_value
Expand Down Expand Up @@ -259,16 +264,13 @@ def _compare_np_array_dicts(
keys2 = frozenset(dict_b.keys())

if keys1 != keys2:
raise ValueError("Dictionaries have different keys")
return False

for key in keys1:
arr_a = dict_a[key]
arr_b = dict_b[key]

if arr_a.shape != arr_b.shape:
raise ValueError(f"Arrays for key '{key}' have different shapes")

if not np_general_all_close(arr_a, arr_b, rtol, atol):
if arr_a.shape != arr_b.shape or not np_general_all_close(arr_a, arr_b, rtol, atol):
return False

return True
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydantic_numpy"
version = "5.0.1"
version = "5.0.2"
description = "Pydantic Model integration of the NumPy array"
authors = ["Can H. Tartanoglu", "Christoph Heindl"]
maintainers = ["Can H. Tartanoglu <[email protected]>"]
Expand Down Expand Up @@ -30,6 +30,7 @@ semver = "^3.0.1"
pytest = "^7.4.0"
parameterized = "^0.9.0"
orjson = "*"
coverage = "^7.5.1"

[tool.poetry.group.format.dependencies]
black = "^23.7.0"
Expand Down
4 changes: 2 additions & 2 deletions tests/helper/testing_groups.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import platform

import numpy as np

Expand Down Expand Up @@ -185,7 +185,7 @@
(np.array([[[0]]]), np.int64, Np3DArrayInt64, 3),
]

if os.name != "nt":
if platform.system() != "Windows":

def get_strict_data_type_nd_array_typing_dimensions_128_bit():
return [
Expand Down
54 changes: 36 additions & 18 deletions tests/test_np_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import platform
import tempfile
from pathlib import Path

import numpy as np
import pytest

from pydantic_numpy.model import model_agnostic_load
from pydantic_numpy.model import NumpyModel, model_agnostic_load
from pydantic_numpy.typing import NpNDArray
from tests.model import (
NpNDArrayModelWithNonArray,
Expand All @@ -17,28 +17,19 @@
NON_ARRAY_VALUE = 5


def _numpy_model():
@pytest.fixture
def numpy_model() -> NpNDArrayModelWithNonArray:
return NpNDArrayModelWithNonArray(array=np.array([0.0]), non_array=NON_ARRAY_VALUE)


@pytest.fixture
def numpy_model():
return _numpy_model()


@pytest.fixture(
params=[
_numpy_model(),
NpNDArrayModelWithNonArrayWithArbitrary(
array=np.array([0.0]), non_array=NON_ARRAY_VALUE, my_arbitrary_slice=slice(0, 10)
),
]
)
def numpy_model_with_arbitrary(request):
return request.param
def numpy_model_with_arbitrary() -> NpNDArrayModelWithNonArrayWithArbitrary:
return NpNDArrayModelWithNonArrayWithArbitrary(
array=np.array([0.0]), non_array=NON_ARRAY_VALUE, my_arbitrary_slice=slice(0, 1)
)


if os.name != "nt":
if platform.system() != "Windows":

def test_io_yaml(numpy_model: NpNDArrayModelWithNonArray) -> None:
with tempfile.TemporaryDirectory() as tmp_dirname:
Expand Down Expand Up @@ -80,3 +71,30 @@ class NumpyModelBForTest(NpNDArrayModelWithNonArray):
models = [NumpyModelAForTest, NumpyModelBForTest]
assert model_a == model_agnostic_load(tmp_dir_path, TEST_MODEL_OBJECT_ID, models=models)
assert model_b == model_agnostic_load(tmp_dir_path, OTHER_TEST_MODEL_OBJECT_ID, models=models)

def test_simple_eq(numpy_model: NpNDArrayModelWithNonArray) -> None:
assert numpy_model == numpy_model

def test_not_eq_different_fields(numpy_model, numpy_model_with_arbitrary) -> None:
assert numpy_model != numpy_model_with_arbitrary

class AnotherModel(NumpyModel):
yarra: NpNDArray

assert numpy_model != AnotherModel(yarra=np.array([0.0]))

def test_not_eq_different_inner(numpy_model: NpNDArrayModelWithNonArray) -> None:
assert numpy_model != NpNDArrayModelWithNonArray(array=np.array([1.0]), non_array=NON_ARRAY_VALUE)

def test_not_eq_different_shape(numpy_model: NpNDArrayModelWithNonArray) -> None:
assert numpy_model != NpNDArrayModelWithNonArray(array=np.array([0.0, 1.0]), non_array=NON_ARRAY_VALUE)

def test_random_not_eq(numpy_model: NpNDArrayModelWithNonArray) -> None:
for r in (0, 5, 1.0, "1"):
assert numpy_model != r

def test_serde_eq(numpy_model: NpNDArrayModelWithNonArray) -> None:
ser = numpy_model.model_dump_json()
reread_data = numpy_model.model_validate_json(ser)

assert numpy_model == reread_data
4 changes: 2 additions & 2 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import platform
import tempfile
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_wrong_dimension():
get_numpy_type_model(Np1DArrayInt64)(array_field=np.array([[0]]))


if os.name != "nt":
if platform.system() == "Linux":
from tests.helper.testing_groups import (
get_strict_data_type_nd_array_typing_dimensions_128_bit,
)
Expand Down
Loading