Skip to content

Commit

Permalink
do not pickle cached hash of arrays (inducer#563)
Browse files Browse the repository at this point in the history
* do not pickle cached hash of arrays

* no clue why this works

* remove unneeded changes

* add an explanation

* add simple test

* add another test

* copy getstate/setstate implementation from dataclasses

* typo?
  • Loading branch information
matthiasdiener authored Dec 4, 2024
1 parent cd4922c commit d57cb0c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 8 deletions.
46 changes: 38 additions & 8 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,35 +344,65 @@ def _augment_array_dataclass(
cls: type,
generate_hash: bool,
) -> None:
from dataclasses import fields
attr_tuple = ", ".join(f"self.{fld.name}"
for fld in fields(cls) if fld.name != "non_equality_tags")
if attr_tuple:
attr_tuple = f"({attr_tuple},)"
else:
attr_tuple = "()"

# {{{ hashing and hash caching

if generate_hash:
from dataclasses import fields

# Non-equality tags are automatically excluded from equality in
# EqualityComparer, and are excluded here from hashing.
attr_tuple_hash = ", ".join(f"self.{fld.name}"
for fld in fields(cls) if fld.name != "non_equality_tags")

if attr_tuple_hash:
attr_tuple_hash = f"({attr_tuple_hash},)"
else:
attr_tuple_hash = "()"

from pytools.codegen import remove_common_indentation
augment_code = remove_common_indentation(
f"""
from dataclasses import fields
def {cls.__name__}_hash(self):
try:
return self._hash_value
except AttributeError:
pass
h = hash(frozenset({attr_tuple}))
h = hash(frozenset({attr_tuple_hash}))
object.__setattr__(self, "_hash_value", h)
return h
cls.__hash__ = {cls.__name__}_hash
# By default (when slots=False), dataclasses do not have special
# handling for pickling, thus using pickle's default behavior that
# looks at obj.__dict__. This would also pickle the cached hash,
# which may change across invocations. Here, we override the
# pickling methods such that only fields are pickled.
# See also https://github.com/python/cpython/blob/5468d219df65d4fe3335e2bcc09d2f6032a32c70/Lib/dataclasses.py#L1267-L1272
def _dataclass_getstate(self):
return [getattr(self, f.name) for f in fields(self)]
def _dataclass_setstate(self, state):
for field, value in zip(fields(self), state, strict=True):
# use setattr because dataclass may be frozen
object.__setattr__(self, field.name, value)
cls.__getstate__ = _dataclass_getstate
cls.__setstate__ = _dataclass_setstate
""")
exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": augment_code}
exec(compile(augment_code,
f"<dataclass augmentation code for {cls}>", "exec"),
exec_dict)

# }}}

# {{{ assign mapper_method

mm_cls = cast(type[_HasMapperMethod], cls)
Expand Down
47 changes: 47 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,53 @@ def dtype(self):
assert _np_result_dtype(42.0, NotReallyAnArray()) == np.float64


def test_pickling_hash():
# See https://github.com/inducer/pytato/pull/563 for context

# {{{ Placeholder

p = pt.make_placeholder("p", (4, 4), int)

assert not hasattr(p, "_hash_value")

# Force hash creation:
hash(p)

assert hasattr(p, "_hash_value")

from pickle import dumps, loads

p_new = loads(dumps(p))

assert not hasattr(p_new, "_hash_value")

assert p == p_new

# }}}

# {{{ DataWrapper

dw = pt.make_data_wrapper(np.zeros((4, 4), int))

assert not hasattr(dw, "_hash_value")

hash(dw)

# DataWrappers have no hash caching
assert not hasattr(dw, "_hash_value")

dw_new = loads(dumps(dw))

assert dw_new.shape == dw.shape
assert dw_new.dtype == dw.dtype
assert np.all(dw_new.data == dw.data)

# DataWrappers that are not the same object compare unequal
assert dw != dw_new

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit d57cb0c

Please sign in to comment.