diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index a9803bf43..d2f8aeb0b 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -327,9 +327,6 @@ class CommTag: def __hash__(self): return hash(tuple(str(type(self)).encode("ascii"))) - def __eq__(self, other): - return isinstance(other, type(self)) - def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *, comm_tag: Optional[Hashable] = None, tag: Hashable = None, diff --git a/test/test_trace_pair.py b/test/test_trace_pair.py index ce84eba48..9e26032c4 100644 --- a/test/test_trace_pair.py +++ b/test/test_trace_pair.py @@ -25,6 +25,7 @@ from grudge.trace_pair import TracePair, CommTag import meshmode.mesh.generation as mgen from meshmode.dof_array import DOFArray +from dataclasses import dataclass from grudge import DiscretizationCollection @@ -74,15 +75,62 @@ def test_commtag(actx_factory): class DerivedCommTag(CommTag): pass - x = CommTag() - x2 = CommTag() - y = DerivedCommTag() + class DerivedDerivedCommTag(DerivedCommTag): + pass + + # {{{ test equality and hash consistency + + ct = CommTag() + ct2 = CommTag() + dct = DerivedCommTag() + dct2 = DerivedCommTag() + ddct = DerivedDerivedCommTag() + + assert ct == ct2 + assert ct != dct + assert dct == dct2 + assert dct != ddct + assert ddct != dct + assert (ct, dct) != (dct, ct) + + assert hash(ct) == hash(ct2) + assert hash(ct) != hash(dct) + assert hash(dct) != hash(ddct) + + # }}} + + # {{{ test hash stability + + assert hash(ct) == 4644528671524962420 + assert hash(dct) == -1013583671995716582 + assert hash(ddct) == 626392264874077479 + + assert hash((ct, 123)) == -578844573019921397 + assert hash((dct, 123)) == -8009406276367324841 + assert hash((dct, ct)) == 6599529611285265043 + + # }}} + + # {{{ test using derived dataclasses + + @dataclass(frozen=True) + class DataCommTag(CommTag): + data: int + + @dataclass(frozen=True) + class DataCommTag2(CommTag): + data: int + + d1 = DataCommTag(1) + d2 = DataCommTag(2) + d3 = DataCommTag(1) + + assert d1 != d2 + assert hash(d1) != hash(d2) + assert d1 == d3 + assert hash(d1) == hash(d3) - assert hash(x) == hash(x2) - assert hash(x) != hash(y) - assert hash(x) == 4644528671524962420 - assert hash(y) == -1013583671995716582 + d4 = DataCommTag2(1) + assert d1 != d4 - assert hash((x, 123)) == -578844573019921397 - assert hash((y, 123)) == -8009406276367324841 - assert hash((y, x)) == 6599529611285265043 + # }}}