Skip to content

Commit

Permalink
KeyBuilder: support np.bool{,_} for numpy<2
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored and inducer committed Nov 13, 2024
1 parent e49e8fb commit 2e9cf0d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pytools/persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def rec(self, key_hash: Hash, key: Any) -> Hash:
method = self.update_for_specific_dtype

# Hashing numpy scalars
elif isinstance(key, np.number):
elif isinstance(key, np.number | np.bool_):
# Non-numpy scalars are handled above in the try block.
method = self.update_for_numpy_scalar

Expand Down
25 changes: 25 additions & 0 deletions pytools/test/test_persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,31 @@ def test_dtype_hashing() -> None:
assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32))


def test_bool_hashing() -> None:
keyb = KeyBuilder()

assert keyb(True) == keyb(True)
assert keyb(False) == keyb(False)
assert keyb(True) != keyb(False)

np = pytest.importorskip("numpy")

bool_types = [np.bool_]
if hasattr(np, "bool"):
bool_types.append(np.bool)

for bool_type in bool_types:
assert keyb(bool_type) != keyb(bool)

assert keyb(bool_type(True)) == keyb(bool_type(True))
assert keyb(bool_type(False)) == keyb(bool_type(False))
assert keyb(bool_type(True)) != keyb(bool_type(False))

assert keyb(bool_type) != keyb(np.dtype(bool_type))
assert keyb(bool_type(True)) != keyb(np.dtype(bool_type(True)))
assert keyb(bool_type(False)) != keyb(np.dtype(bool_type(False)))


def test_scalar_hashing() -> None:
keyb = KeyBuilder()

Expand Down

0 comments on commit 2e9cf0d

Please sign in to comment.