diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index 32c16b18a..fa0b37580 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -75,9 +75,7 @@ impl SerField { fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &CombinedSerializer) -> PyResult { if extra.exclude_defaults { if let Some(default) = serializer.get_default(value.py())? { - if value.eq(default)? { - return Ok(true); - } + return Ok(value.eq(default).unwrap_or(false)); } } Ok(false) diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index df507a248..608f84a55 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -170,6 +170,49 @@ def test_exclude_default(): assert v.to_json({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == b'{"foo":1}' +def test_exclude_incomparable_default(): + """Values that can't be compared with eq are treated as not equal to the default""" + + def ser_x(*args): + return [1, 2, 3] + + cls_schema = core_schema.any_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_x)) + + class Incomparable: + __pydantic_serializer__ = SchemaSerializer(cls_schema) + + def __get_pydantic_core_schema__(*args): + return cls_schema + + def __eq__(self, other): + raise NotImplementedError("Can't be compared!") + + class NeqComparable(Incomparable): + def __eq__(self, other): + return False + + class EqComparable(Incomparable): + def __eq__(self, other): + return True + + v = SchemaSerializer( + core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field( + core_schema.with_default_schema(core_schema.any_schema(), default=None) + ), + } + ) + ) + + assert v.to_python({'foo': Incomparable()}, exclude_defaults=True)['foo'] == [1, 2, 3] + assert v.to_json({'foo': Incomparable()}, exclude_defaults=True) == b'{"foo":[1,2,3]}' + assert v.to_python({'foo': NeqComparable()}, exclude_defaults=True)['foo'] == [1, 2, 3] + assert v.to_json({'foo': NeqComparable()}, exclude_defaults=True) == b'{"foo":[1,2,3]}' + assert v.to_python({'foo': EqComparable()}, exclude_defaults=True) == {} + assert v.to_json({'foo': EqComparable()}, exclude_defaults=True) == b'{}' + + def test_function_plain_field_serializer_to_python(): class Model(TypedDict): x: int