From d23d36243366557fb4479012686a8a8b451a7111 Mon Sep 17 00:00:00 2001 From: Emanuele Giaquinta Date: Sun, 10 Nov 2024 18:20:30 +0200 Subject: [PATCH] Fix serialization of dataclasses without __slots__ The optimization of serializing the __dict__ attribute of dataclass instances without __slots__ assumes that __dict__ containing K is a necessary and sufficient condition of the dataclass having a non-pseudo field named K. However, this does not hold when the dataclass contains - a field defined with init=False and a default value - a field with a descriptor object as default value - a cached property This commit fixes the serialization in the above cases by changing the code so that the __dict__ attribute is only used to get a field value, falling back to checking the field type and getting the value with PyObject_GetAttr if the value is not in __dict__. The price for correctness is a ~20% slowdown. Signed-off-by: Emanuele Giaquinta --- src/serialize/dataclass.rs | 114 +++++++++++++------------------------ tests/test_dataclass.py | 67 ++++++++++++++++++++++ 2 files changed, 105 insertions(+), 76 deletions(-) diff --git a/src/serialize/dataclass.rs b/src/serialize/dataclass.rs index 0ce38759..9bc2ad38 100644 --- a/src/serialize/dataclass.rs +++ b/src/serialize/dataclass.rs @@ -12,70 +12,6 @@ use serde::ser::{Serialize, SerializeMap, Serializer}; use smallvec::SmallVec; use std::ptr::NonNull; -pub struct Dataclass { - ptr: *mut pyo3::ffi::PyObject, - opts: Opt, - default_calls: u8, - recursion: u8, - default: Option>, -} - -impl Dataclass { - pub fn new( - ptr: *mut pyo3::ffi::PyObject, - opts: Opt, - default_calls: u8, - recursion: u8, - default: Option>, - ) -> Self { - Dataclass { - ptr: ptr, - opts: opts, - default_calls: default_calls, - recursion: recursion, - default: default, - } - } -} - -impl Serialize for Dataclass { - #[inline(never)] - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let ob_type = ob_type!(self.ptr); - if pydict_contains!(ob_type, SLOTS_STR) { - DataclassFields::new( - self.ptr, - self.opts, - self.default_calls, - self.recursion, - self.default, - ) - .serialize(serializer) - } else { - match AttributeDict::new( - self.ptr, - self.opts, - self.default_calls, - self.recursion, - self.default, - ) { - Ok(val) => val.serialize(serializer), - Err(AttributeDictError::DictMissing) => DataclassFields::new( - self.ptr, - self.opts, - self.default_calls, - self.recursion, - self.default, - ) - .serialize(serializer), - } - } - } -} - pub enum AttributeDictError { DictMissing, } @@ -159,7 +95,7 @@ impl Serialize for AttributeDict { } } -pub struct DataclassFields { +pub struct Dataclass { ptr: *mut pyo3::ffi::PyObject, opts: Opt, default_calls: u8, @@ -167,7 +103,7 @@ pub struct DataclassFields { default: Option>, } -impl DataclassFields { +impl Dataclass { pub fn new( ptr: *mut pyo3::ffi::PyObject, opts: Opt, @@ -175,7 +111,7 @@ impl DataclassFields { recursion: u8, default: Option>, ) -> Self { - DataclassFields { + Dataclass { ptr: ptr, opts: opts, default_calls: default_calls, @@ -185,7 +121,13 @@ impl DataclassFields { } } -impl Serialize for DataclassFields { +fn is_pseudo_field(field: *mut pyo3::ffi::PyObject) -> bool { + let field_type = ffi!(PyObject_GetAttr(field, FIELD_TYPE_STR)); + ffi!(Py_DECREF(field_type)); + !is_type!(field_type as *mut pyo3::ffi::PyTypeObject, FIELD_TYPE) +} + +impl Serialize for Dataclass { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -196,14 +138,21 @@ impl Serialize for DataclassFields { if unlikely!(len == 0) { return serializer.serialize_map(Some(0)).unwrap().end(); } + + let dict = { + let ob_type = ob_type!(self.ptr); + if pydict_contains!(ob_type, SLOTS_STR) { + std::ptr::null_mut() + } else { + let dict = ffi!(PyObject_GetAttr(self.ptr, DICT_STR)); + ffi!(Py_DECREF(dict)); + dict + } + }; + let mut items: SmallVec<[(&str, *mut pyo3::ffi::PyObject); 8]> = SmallVec::with_capacity(len); for (attr, field) in PyDictIter::from_pyobject(fields) { - let field_type = ffi!(PyObject_GetAttr(field.as_ptr(), FIELD_TYPE_STR)); - ffi!(Py_DECREF(field_type)); - if !is_type!(field_type as *mut pyo3::ffi::PyTypeObject, FIELD_TYPE) { - continue; - } let data = unicode_to_str(attr.as_ptr()); if unlikely!(data.is_none()) { err!(INVALID_STR); @@ -213,9 +162,22 @@ impl Serialize for DataclassFields { continue; } - let value = ffi!(PyObject_GetAttr(self.ptr, attr.as_ptr())); - ffi!(Py_DECREF(value)); - items.push((key_as_str, value)); + if unlikely!(dict.is_null()) { + if !is_pseudo_field(field.as_ptr()) { + let value = ffi!(PyObject_GetAttr(self.ptr, attr.as_ptr())); + ffi!(Py_DECREF(value)); + items.push((key_as_str, value)); + } + } else { + let value = ffi!(PyDict_GetItem(dict, attr.as_ptr())); + if !value.is_null() { + items.push((key_as_str, value)); + } else if !is_pseudo_field(field.as_ptr()) { + let value = ffi!(PyObject_GetAttr(self.ptr, attr.as_ptr())); + ffi!(Py_DECREF(value)); + items.push((key_as_str, value)); + } + } } let mut map = serializer.serialize_map(Some(items.len())).unwrap(); diff --git a/tests/test_dataclass.py b/tests/test_dataclass.py index f7198693..aa336849 100644 --- a/tests/test_dataclass.py +++ b/tests/test_dataclass.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) from dataclasses import InitVar, asdict, dataclass, field +from functools import cached_property from typing import ClassVar, Optional import msgpack @@ -107,6 +108,72 @@ class Dataclass: assert ormsgpack.packb(Dataclass()) == msgpack.packb({}) +def test_dataclass_with_non_init_field() -> None: + @dataclass + class Dataclass: + a: str + b: int = field(default=1, init=False) + + obj = Dataclass("a") + assert ormsgpack.packb(obj) == msgpack.packb( + { + "a": "a", + "b": 1, + } + ) + + +def test_dataclass_with_descriptor_field() -> None: + class Descriptor: + def __init__(self, *, default: int) -> None: + self._default = default + + def __set_name__(self, owner: object, name: str) -> None: + self._name = "_" + name + + def __get__(self, instance: object, owner: object) -> int: + if instance is None: + return self._default + + return getattr(instance, self._name, self._default) + + def __set__(self, instance: object, value: int) -> None: + setattr(instance, self._name, value) + + @dataclass + class Dataclass: + a: str + b: Descriptor = Descriptor(default=0) + + obj = Dataclass("a", 1) + assert ormsgpack.packb(obj) == msgpack.packb( + { + "a": "a", + "b": 1, + } + ) + + +def test_dataclass_with_cached_property() -> None: + @dataclass + class Dataclass: + a: str + b: int + + @cached_property + def name(self) -> str: + return "dataclass" + + obj = Dataclass("a", 1) + obj.name + assert ormsgpack.packb(obj) == msgpack.packb( + { + "a": "a", + "b": 1, + } + ) + + def test_dataclass_with_private_field() -> None: @dataclass class Dataclass: