Skip to content

Commit

Permalink
add (strict) typing to Record
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Mar 14, 2023
1 parent 56efa1b commit 8a14dfb
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,8 @@ class RecordWithoutPickling:
__slots__: ClassVar[List[str]] = []
fields: ClassVar[Set[str]]

def __init__(self, valuedict=None, exclude=None, **kwargs):
def __init__(self, valuedict: Optional[Dict[str, Any]] = None,
exclude: Optional[List[str]] = None, **kwargs: Any) -> None:
assert self.__class__ is not Record

if exclude is None:
Expand All @@ -427,7 +428,7 @@ def __init__(self, valuedict=None, exclude=None, **kwargs):
fields.add(key)
setattr(self, key, value)

def get_copy_kwargs(self, **kwargs):
def get_copy_kwargs(self, **kwargs: Any) -> Any:
for f in self.__class__.fields:
if f not in kwargs:
try:
Expand All @@ -436,25 +437,25 @@ def get_copy_kwargs(self, **kwargs):
pass
return kwargs

def copy(self, **kwargs):
def copy(self, **kwargs: Any) -> "RecordWithoutPickling":
return self.__class__(**self.get_copy_kwargs(**kwargs))

def __repr__(self):
def __repr__(self) -> str:
return "{}({})".format(
self.__class__.__name__,
", ".join(f"{fld}={getattr(self, fld)!r}"
for fld in self.__class__.fields
if hasattr(self, fld)))

def register_fields(self, new_fields):
def register_fields(self, new_fields: Set[str]) -> None:
try:
fields = self.__class__.fields
except AttributeError:
self.__class__.fields = fields = set()

fields.update(new_fields)

def __getattr__(self, name):
def __getattr__(self, name: str) -> None:
# This method is implemented to avoid pylint 'no-member' errors for
# attribute access.
raise AttributeError(
Expand All @@ -465,13 +466,13 @@ def __getattr__(self, name):
class Record(RecordWithoutPickling):
__slots__: ClassVar[List[str]] = []

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
return {
key: getattr(self, key)
for key in self.__class__.fields
if hasattr(self, key)}

def __setstate__(self, valuedict):
def __setstate__(self, valuedict: Dict[str, Any]) -> None:
try:
fields = self.__class__.fields
except AttributeError:
Expand All @@ -481,30 +482,33 @@ def __setstate__(self, valuedict):
fields.add(key)
setattr(self, key, value)

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if self is other:
return True
if not isinstance(other, Record):
return False
return (self.__class__ == other.__class__
and self.__getstate__() == other.__getstate__())

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)


class ImmutableRecordWithoutPickling(RecordWithoutPickling):
"""Hashable record. Does not explicitly enforce immutability."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
RecordWithoutPickling.__init__(self, *args, **kwargs)
self._cached_hash = None
self._cached_hash: Optional[int] = None

def __hash__(self):
def __hash__(self) -> int:
# This attribute may vanish during pickling.
if getattr(self, "_cached_hash", None) is None:
self._cached_hash = hash(
(type(self),) + tuple(getattr(self, field)
for field in self.__class__.fields))

return self._cached_hash
from typing import cast
return cast(int, self._cached_hash)


class ImmutableRecord(ImmutableRecordWithoutPickling, Record):
Expand Down

0 comments on commit 8a14dfb

Please sign in to comment.