From 1386433dc302be9a5c972cd27b17ba4f3364ca74 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 26 Feb 2024 19:30:43 -0600 Subject: [PATCH 01/10] Record: make field ordering deterministic --- pytools/__init__.py | 19 +++++++++------ pytools/test/test_pytools.py | 47 ++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 0f452a80..9fcc0d1b 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -411,7 +411,9 @@ class RecordWithoutPickling: """ __slots__: ClassVar[List[str]] = [] - fields: ClassVar[Set[str]] + + # A dict, not a set, to maintain a deterministic iteration order + fields: ClassVar[Dict[str, None]] def __init__(self, valuedict=None, exclude=None, **kwargs): assert self.__class__ is not Record @@ -422,14 +424,17 @@ def __init__(self, valuedict=None, exclude=None, **kwargs): try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} + + if isinstance(fields, set): + self.__class__.fields = fields = dict.fromkeys(sorted(fields)) if valuedict is not None: kwargs.update(valuedict) for key, value in kwargs.items(): if key not in exclude: - fields.add(key) + fields[key] = None setattr(self, key, value) def get_copy_kwargs(self, **kwargs): @@ -455,9 +460,9 @@ def register_fields(self, new_fields): try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} - fields.update(new_fields) + fields.update(dict.fromkeys(new_fields)) def __getattr__(self, name): # This method is implemented to avoid pylint 'no-member' errors for @@ -480,10 +485,10 @@ def __setstate__(self, valuedict): try: fields = self.__class__.fields except AttributeError: - self.__class__.fields = fields = set() + self.__class__.fields = fields = {} for key, value in valuedict.items(): - fields.add(key) + fields[key] = None setattr(self, key, value) def __eq__(self, other): diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index 7e1fe3ac..11e736b1 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -784,6 +784,53 @@ def test_unique(): assert next(unique([]), None) is None +def test_record(): + from pytools import Record + + class SimpleRecord(Record): + pass + + r = SimpleRecord(a=1, b=2) + assert r.a == 1 + assert r.b == 2 + + r2 = r.copy() + assert r2.a == 1 + assert r == r2 + + r3 = r.copy(b=3) + assert r3.b == 3 + assert r != r3 + + assert str(r) == "SimpleRecord(a=1, b=2)" + + class SimpleRecord2(Record): + pass + + r = SimpleRecord2(b=2, a=1) + assert r.a == 1 + assert r.b == 2 + + assert str(r) == "SimpleRecord2(b=2, a=1)" + + class SetBasedRecord(Record): + fields = {"c", "b", "a"} + + def __init__(self, c, b, a): + super().__init__(c=c, b=b, a=a) + + r = SetBasedRecord(3, 2, 1) + + # Fields are converted to a dict during __init__ + assert isinstance(r.fields, dict) + assert r.a == 1 + assert r.b == 2 + assert r.c == 3 + + # Fields are sorted alphabetically + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 2e75f46f78400fbdfa0488e1fe3ae6688d0ec657 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 Feb 2024 09:48:38 -0600 Subject: [PATCH 02/10] add set conversion to __setstate__ --- pytools/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytools/__init__.py b/pytools/__init__.py index 9fcc0d1b..963054bd 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -487,6 +487,9 @@ def __setstate__(self, valuedict): except AttributeError: self.__class__.fields = fields = {} + if isinstance(fields, set): + self.__class__.fields = fields = dict.fromkeys(sorted(fields)) + for key, value in valuedict.items(): fields[key] = None setattr(self, key, value) From 35c3685c1c458f7aacc0f56d818ecad9506e1ab7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 Feb 2024 09:50:17 -0600 Subject: [PATCH 03/10] remove duplicate/unnecessary definitions --- pytools/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 963054bd..4bdcdef5 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -473,8 +473,6 @@ def __getattr__(self, name): class Record(RecordWithoutPickling): - __slots__: ClassVar[List[str]] = [] - def __getstate__(self): return { key: getattr(self, key) @@ -500,9 +498,6 @@ def __eq__(self, other): return (self.__class__ == other.__class__ and self.__getstate__() == other.__getstate__()) - def __ne__(self, other): - return not self.__eq__(other) - class ImmutableRecordWithoutPickling(RecordWithoutPickling): """Hashable record. Does not explicitly enforce immutability.""" From f6c762a3483e802900a0eaf21a9e503e86e0307e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 Feb 2024 10:02:10 -0600 Subject: [PATCH 04/10] add warning --- pytools/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytools/__init__.py b/pytools/__init__.py index 4bdcdef5..3bb14485 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -416,6 +416,9 @@ class RecordWithoutPickling: fields: ClassVar[Dict[str, None]] def __init__(self, valuedict=None, exclude=None, **kwargs): + from warnings import warn + warn("pytools.Record is deprecated. Use dataclasses instead.") + assert self.__class__ is not Record if exclude is None: From 0b43c94939aae15f57a2a6cd63c497a999fd126c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 Feb 2024 10:10:52 -0600 Subject: [PATCH 05/10] fix register_fields --- pytools/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 3bb14485..d83a74d3 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -465,7 +465,7 @@ def register_fields(self, new_fields): except AttributeError: self.__class__.fields = fields = {} - fields.update(dict.fromkeys(new_fields)) + fields.update(dict.fromkeys(sorted(new_fields))) def __getattr__(self, name): # This method is implemented to avoid pylint 'no-member' errors for From 6e3bb4f3c14702975d2257c8d2bb225a0d8fc153 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 Feb 2024 10:24:18 -0600 Subject: [PATCH 06/10] improve tests --- pytools/test/test_pytools.py | 61 ++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index 11e736b1..28517bd7 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -790,28 +790,47 @@ def test_record(): class SimpleRecord(Record): pass - r = SimpleRecord(a=1, b=2) - assert r.a == 1 - assert r.b == 2 + r1 = SimpleRecord(a=1, b=2) + assert r1.a == 1 + assert r1.b == 2 - r2 = r.copy() + r2 = r1.copy() assert r2.a == 1 - assert r == r2 + assert r1 == r2 - r3 = r.copy(b=3) + r3 = r1.copy(b=3) assert r3.b == 3 - assert r != r3 + assert r1 != r3 + + assert str(r1) == str(r2) == "SimpleRecord(a=1, b=2)" + assert str(r3) == "SimpleRecord(a=1, b=3)" - assert str(r) == "SimpleRecord(a=1, b=2)" + # Unregistered fields are (silently) ignored for printing + r1.f = 6 + assert str(r1) == "SimpleRecord(a=1, b=2)" + + with pytest.raises(AttributeError): + r1.ff class SimpleRecord2(Record): pass - r = SimpleRecord2(b=2, a=1) - assert r.a == 1 - assert r.b == 2 + r_new = SimpleRecord2(b=2, a=1) + assert r_new.a == 1 + assert r_new.b == 2 + + assert str(r_new) == "SimpleRecord2(b=2, a=1)" + + assert r_new != r1 + + r1.register_fields({"d", "e"}) + assert str(r1) == "SimpleRecord(a=1, b=2)" - assert str(r) == "SimpleRecord2(b=2, a=1)" + r1.d = 4 + r1.e = 5 + assert str(r1) == "SimpleRecord(a=1, b=2, d=4, e=5)" + + # {{{ Legacy set-based record (used in Loopy) class SetBasedRecord(Record): fields = {"c", "b", "a"} @@ -827,9 +846,25 @@ def __init__(self, c, b, a): assert r.b == 2 assert r.c == 3 - # Fields are sorted alphabetically + # Fields are sorted alphabetically in set-based records + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + r.register_fields({"d", "e"}) assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + r.d = 4 + r.e = 5 + assert str(r) == "SetBasedRecord(a=1, b=2, c=3, d=4, e=5)" + + # Unregistered fields are (silently) ignored for printing + r.f = 6 + assert str(r) == "SetBasedRecord(a=1, b=2, c=3, d=4, e=5)" + + with pytest.raises(AttributeError): + r.ff + + # }}} + if __name__ == "__main__": if len(sys.argv) > 1: From 0608e1e994ac651fd905033408c6584b431e453c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 27 Feb 2024 13:25:43 -0600 Subject: [PATCH 07/10] add pickle test --- pytools/test/test_pytools.py | 59 ++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index 28517bd7..f4262a47 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -26,6 +26,8 @@ import pytest +from pytools import Record + logger = logging.getLogger(__name__) from typing import FrozenSet @@ -784,11 +786,20 @@ def test_unique(): assert next(unique([]), None) is None -def test_record(): - from pytools import Record +# These classes must be defined global to be picklable +class SimpleRecord(Record): + pass - class SimpleRecord(Record): - pass + +class SetBasedRecord(Record): + fields = {"c", "b", "a"} # type: ignore[assignment] + + def __init__(self, c, b, a): + super().__init__(c=c, b=b, a=a) + + +def test_record(): + # {{{ New, dict-based Record r1 = SimpleRecord(a=1, b=2) assert r1.a == 1 @@ -809,9 +820,23 @@ class SimpleRecord(Record): r1.f = 6 assert str(r1) == "SimpleRecord(a=1, b=2)" + # Registered fields are printed + r1.register_fields({"d", "e"}) + assert str(r1) == "SimpleRecord(a=1, b=2)" + + r1.d = 4 + r1.e = 5 + assert str(r1) == "SimpleRecord(a=1, b=2, d=4, e=5)" + with pytest.raises(AttributeError): r1.ff + # Test pickling + + import pickle + r1_pickled = pickle.loads(pickle.dumps(r1)) + assert r1 == r1_pickled + class SimpleRecord2(Record): pass @@ -823,21 +848,10 @@ class SimpleRecord2(Record): assert r_new != r1 - r1.register_fields({"d", "e"}) - assert str(r1) == "SimpleRecord(a=1, b=2)" - - r1.d = 4 - r1.e = 5 - assert str(r1) == "SimpleRecord(a=1, b=2, d=4, e=5)" + # }}} # {{{ Legacy set-based record (used in Loopy) - class SetBasedRecord(Record): - fields = {"c", "b", "a"} - - def __init__(self, c, b, a): - super().__init__(c=c, b=b, a=a) - r = SetBasedRecord(3, 2, 1) # Fields are converted to a dict during __init__ @@ -849,6 +863,11 @@ def __init__(self, c, b, a): # Fields are sorted alphabetically in set-based records assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + # Unregistered fields are (silently) ignored for printing + r.f = 6 + assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" + + # Registered fields are printed r.register_fields({"d", "e"}) assert str(r) == "SetBasedRecord(a=1, b=2, c=3)" @@ -856,13 +875,13 @@ def __init__(self, c, b, a): r.e = 5 assert str(r) == "SetBasedRecord(a=1, b=2, c=3, d=4, e=5)" - # Unregistered fields are (silently) ignored for printing - r.f = 6 - assert str(r) == "SetBasedRecord(a=1, b=2, c=3, d=4, e=5)" - with pytest.raises(AttributeError): r.ff + # Test pickling + r_pickled = pickle.loads(pickle.dumps(r)) + assert r == r_pickled + # }}} From 73327b42fb0639983db0cf9f9d38ab5f80e4b42d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 15 Mar 2024 15:37:23 -0500 Subject: [PATCH 08/10] restore __slots__ in Record, add to test --- pytools/__init__.py | 2 ++ pytools/test/test_pytools.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/pytools/__init__.py b/pytools/__init__.py index d83a74d3..1130cb36 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -476,6 +476,8 @@ def __getattr__(self, name): class Record(RecordWithoutPickling): + __slots__: ClassVar[List[str]] = [] + def __getstate__(self): return { key: getattr(self, key) diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index f4262a47..de6917cf 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -884,6 +884,40 @@ class SimpleRecord2(Record): # }}} + # {{{ __slots__, __dict__, __weakref__ handling + + class RecordWithEmptySlots(Record): + __slots__ = [] + + assert hasattr(RecordWithEmptySlots(), "__slots__") + assert not hasattr(RecordWithEmptySlots(), "__dict__") + assert not hasattr(RecordWithEmptySlots(), "__weakref__") + + class RecordWithUnsetSlots(Record): + pass + + assert hasattr(RecordWithUnsetSlots(), "__slots__") + assert hasattr(RecordWithUnsetSlots(), "__dict__") + assert hasattr(RecordWithUnsetSlots(), "__weakref__") + + from pytools import ImmutableRecord + + class ImmutableRecordWithEmptySlots(ImmutableRecord): + __slots__ = [] + + assert hasattr(ImmutableRecordWithEmptySlots(), "__slots__") + assert hasattr(ImmutableRecordWithEmptySlots(), "__dict__") + assert hasattr(ImmutableRecordWithEmptySlots(), "__weakref__") + + class ImmutableRecordWithUnsetSlots(ImmutableRecord): + pass + + assert hasattr(ImmutableRecordWithUnsetSlots(), "__slots__") + assert hasattr(ImmutableRecordWithUnsetSlots(), "__dict__") + assert hasattr(ImmutableRecordWithUnsetSlots(), "__weakref__") + + # }}} + if __name__ == "__main__": if len(sys.argv) > 1: From d091bac40013206702f7a63df01ff137787a9e13 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 15 Mar 2024 15:50:30 -0500 Subject: [PATCH 09/10] better warning --- pytools/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 1130cb36..47d282f2 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -417,7 +417,8 @@ class RecordWithoutPickling: def __init__(self, valuedict=None, exclude=None, **kwargs): from warnings import warn - warn("pytools.Record is deprecated. Use dataclasses instead.") + warn(f"{self.__class__.__bases__[0]} is deprecated and will be " + "removed in 2025. Use dataclasses instead.") assert self.__class__ is not Record From 0a508c40b4601258c10b5600ede4de2ad9b2cb1c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 15 Mar 2024 15:55:01 -0500 Subject: [PATCH 10/10] add type annotations --- pytools/__init__.py | 29 +++++++++++++++-------------- pytools/test/test_pytools.py | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 47d282f2..33d5b206 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -37,7 +37,7 @@ from sys import intern from typing import ( Any, Callable, ClassVar, Dict, Generic, Hashable, Iterable, Iterator, List, - Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) + Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, cast) try: @@ -415,7 +415,8 @@ class RecordWithoutPickling: # A dict, not a set, to maintain a deterministic iteration order fields: ClassVar[Dict[str, None]] - def __init__(self, valuedict=None, exclude=None, **kwargs): + def __init__(self, valuedict: Optional[Mapping[str, Any]] = None, + exclude: Optional[Iterable[str]] = None, **kwargs: Any) -> None: from warnings import warn warn(f"{self.__class__.__bases__[0]} is deprecated and will be " "removed in 2025. Use dataclasses instead.") @@ -441,7 +442,7 @@ def __init__(self, valuedict=None, exclude=None, **kwargs): fields[key] = None setattr(self, key, value) - def get_copy_kwargs(self, **kwargs): + def get_copy_kwargs(self, **kwargs: Any) -> Dict[str, Any]: for f in self.__class__.fields: if f not in kwargs: try: @@ -450,17 +451,17 @@ def get_copy_kwargs(self, **kwargs): pass return kwargs - def copy(self, **kwargs): + def copy(self, **kwargs: Any) -> "RecordWithoutPickling": return self.__class__(**self.get_copy_kwargs(**kwargs)) - def __repr__(self): + def __repr__(self) -> str: return "{}({})".format( self.__class__.__name__, ", ".join(f"{fld}={getattr(self, fld)!r}" for fld in self.__class__.fields if hasattr(self, fld))) - def register_fields(self, new_fields): + def register_fields(self, new_fields: Iterable[str]) -> None: try: fields = self.__class__.fields except AttributeError: @@ -468,7 +469,7 @@ def register_fields(self, new_fields): fields.update(dict.fromkeys(sorted(new_fields))) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # This method is implemented to avoid pylint 'no-member' errors for # attribute access. raise AttributeError( @@ -479,13 +480,13 @@ def __getattr__(self, name): class Record(RecordWithoutPickling): __slots__: ClassVar[List[str]] = [] - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { key: getattr(self, key) for key in self.__class__.fields if hasattr(self, key)} - def __setstate__(self, valuedict): + def __setstate__(self, valuedict: Mapping[str, Any]) -> None: try: fields = self.__class__.fields except AttributeError: @@ -498,7 +499,7 @@ def __setstate__(self, valuedict): fields[key] = None setattr(self, key, value) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self is other: return True return (self.__class__ == other.__class__ @@ -507,18 +508,18 @@ def __eq__(self, other): class ImmutableRecordWithoutPickling(RecordWithoutPickling): """Hashable record. Does not explicitly enforce immutability.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: RecordWithoutPickling.__init__(self, *args, **kwargs) - self._cached_hash = None + self._cached_hash: Optional[int] = None - def __hash__(self): + def __hash__(self) -> int: # This attribute may vanish during pickling. if getattr(self, "_cached_hash", None) is None: self._cached_hash = hash( (type(self),) + tuple(getattr(self, field) for field in self.__class__.fields)) - return self._cached_hash + return cast(int, self._cached_hash) class ImmutableRecord(ImmutableRecordWithoutPickling, Record): diff --git a/pytools/test/test_pytools.py b/pytools/test/test_pytools.py index de6917cf..294c3eb0 100644 --- a/pytools/test/test_pytools.py +++ b/pytools/test/test_pytools.py @@ -786,7 +786,7 @@ def test_unique(): assert next(unique([]), None) is None -# These classes must be defined global to be picklable +# These classes must be defined globally to be picklable class SimpleRecord(Record): pass