Skip to content

Commit

Permalink
Fix serialization of dataclasses without __slots__
Browse files Browse the repository at this point in the history
The optimization of serializing the __dict__ attribute of dataclass
instances without __slots__ assumes that __dict__ containing K is a
necessary and sufficient condition of the dataclass having a
non-pseudo field named K.

However, this does not hold when the dataclass contains

- a field defined with init=False and a default value
- a field with a descriptor object as default value
- a cached property

This commit fixes the serialization in the above cases by changing the
code so that the __dict__ attribute is only used to get a field value,
falling back to checking the field type and getting the value with
PyObject_GetAttr if the value is not in __dict__. The price for
correctness is a ~20% slowdown.

Signed-off-by: Emanuele Giaquinta <[email protected]>
  • Loading branch information
exg committed Nov 22, 2024
1 parent 8f8bcde commit d23d362
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 76 deletions.
114 changes: 38 additions & 76 deletions src/serialize/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,6 @@ use serde::ser::{Serialize, SerializeMap, Serializer};
use smallvec::SmallVec;
use std::ptr::NonNull;

pub struct Dataclass {
ptr: *mut pyo3::ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
default: Option<NonNull<pyo3::ffi::PyObject>>,
}

impl Dataclass {
pub fn new(
ptr: *mut pyo3::ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
default: Option<NonNull<pyo3::ffi::PyObject>>,
) -> Self {
Dataclass {
ptr: ptr,
opts: opts,
default_calls: default_calls,
recursion: recursion,
default: default,
}
}
}

impl Serialize for Dataclass {
#[inline(never)]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let ob_type = ob_type!(self.ptr);
if pydict_contains!(ob_type, SLOTS_STR) {
DataclassFields::new(
self.ptr,
self.opts,
self.default_calls,
self.recursion,
self.default,
)
.serialize(serializer)
} else {
match AttributeDict::new(
self.ptr,
self.opts,
self.default_calls,
self.recursion,
self.default,
) {
Ok(val) => val.serialize(serializer),
Err(AttributeDictError::DictMissing) => DataclassFields::new(
self.ptr,
self.opts,
self.default_calls,
self.recursion,
self.default,
)
.serialize(serializer),
}
}
}
}

pub enum AttributeDictError {
DictMissing,
}
Expand Down Expand Up @@ -159,23 +95,23 @@ impl Serialize for AttributeDict {
}
}

pub struct DataclassFields {
pub struct Dataclass {
ptr: *mut pyo3::ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
default: Option<NonNull<pyo3::ffi::PyObject>>,
}

impl DataclassFields {
impl Dataclass {
pub fn new(
ptr: *mut pyo3::ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
default: Option<NonNull<pyo3::ffi::PyObject>>,
) -> Self {
DataclassFields {
Dataclass {
ptr: ptr,
opts: opts,
default_calls: default_calls,
Expand All @@ -185,7 +121,13 @@ impl DataclassFields {
}
}

impl Serialize for DataclassFields {
fn is_pseudo_field(field: *mut pyo3::ffi::PyObject) -> bool {
let field_type = ffi!(PyObject_GetAttr(field, FIELD_TYPE_STR));
ffi!(Py_DECREF(field_type));
!is_type!(field_type as *mut pyo3::ffi::PyTypeObject, FIELD_TYPE)
}

impl Serialize for Dataclass {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
Expand All @@ -196,14 +138,21 @@ impl Serialize for DataclassFields {
if unlikely!(len == 0) {
return serializer.serialize_map(Some(0)).unwrap().end();
}

let dict = {
let ob_type = ob_type!(self.ptr);
if pydict_contains!(ob_type, SLOTS_STR) {
std::ptr::null_mut()
} else {
let dict = ffi!(PyObject_GetAttr(self.ptr, DICT_STR));
ffi!(Py_DECREF(dict));
dict
}
};

let mut items: SmallVec<[(&str, *mut pyo3::ffi::PyObject); 8]> =
SmallVec::with_capacity(len);
for (attr, field) in PyDictIter::from_pyobject(fields) {
let field_type = ffi!(PyObject_GetAttr(field.as_ptr(), FIELD_TYPE_STR));
ffi!(Py_DECREF(field_type));
if !is_type!(field_type as *mut pyo3::ffi::PyTypeObject, FIELD_TYPE) {
continue;
}
let data = unicode_to_str(attr.as_ptr());
if unlikely!(data.is_none()) {
err!(INVALID_STR);
Expand All @@ -213,9 +162,22 @@ impl Serialize for DataclassFields {
continue;
}

let value = ffi!(PyObject_GetAttr(self.ptr, attr.as_ptr()));
ffi!(Py_DECREF(value));
items.push((key_as_str, value));
if unlikely!(dict.is_null()) {
if !is_pseudo_field(field.as_ptr()) {
let value = ffi!(PyObject_GetAttr(self.ptr, attr.as_ptr()));
ffi!(Py_DECREF(value));
items.push((key_as_str, value));
}
} else {
let value = ffi!(PyDict_GetItem(dict, attr.as_ptr()));
if !value.is_null() {
items.push((key_as_str, value));
} else if !is_pseudo_field(field.as_ptr()) {
let value = ffi!(PyObject_GetAttr(self.ptr, attr.as_ptr()));
ffi!(Py_DECREF(value));
items.push((key_as_str, value));
}
}
}

let mut map = serializer.serialize_map(Some(items.len())).unwrap();
Expand Down
67 changes: 67 additions & 0 deletions tests/test_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from dataclasses import InitVar, asdict, dataclass, field
from functools import cached_property
from typing import ClassVar, Optional

import msgpack
Expand Down Expand Up @@ -107,6 +108,72 @@ class Dataclass:
assert ormsgpack.packb(Dataclass()) == msgpack.packb({})


def test_dataclass_with_non_init_field() -> None:
@dataclass
class Dataclass:
a: str
b: int = field(default=1, init=False)

obj = Dataclass("a")
assert ormsgpack.packb(obj) == msgpack.packb(
{
"a": "a",
"b": 1,
}
)


def test_dataclass_with_descriptor_field() -> None:
class Descriptor:
def __init__(self, *, default: int) -> None:
self._default = default

def __set_name__(self, owner: object, name: str) -> None:
self._name = "_" + name

def __get__(self, instance: object, owner: object) -> int:
if instance is None:
return self._default

return getattr(instance, self._name, self._default)

def __set__(self, instance: object, value: int) -> None:
setattr(instance, self._name, value)

@dataclass
class Dataclass:
a: str
b: Descriptor = Descriptor(default=0)

obj = Dataclass("a", 1)
assert ormsgpack.packb(obj) == msgpack.packb(
{
"a": "a",
"b": 1,
}
)


def test_dataclass_with_cached_property() -> None:
@dataclass
class Dataclass:
a: str
b: int

@cached_property
def name(self) -> str:
return "dataclass"

obj = Dataclass("a", 1)
obj.name
assert ormsgpack.packb(obj) == msgpack.packb(
{
"a": "a",
"b": 1,
}
)


def test_dataclass_with_private_field() -> None:
@dataclass
class Dataclass:
Expand Down

0 comments on commit d23d362

Please sign in to comment.