Skip to content

Commit

Permalink
Revert sorting of symbolic tags (#476)
Browse files Browse the repository at this point in the history
* Revert sorting of symbolic tags

Sorting fails when symbolic tags contain bare classes (which can not be
compared most of the time).

The attached test case fails without this PR.

* undo doc
  • Loading branch information
matthiasdiener authored Nov 28, 2023
1 parent e75b479 commit 6ccb338
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pytato/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
.. class:: CommTagType
A type representing a communication tag. Communication tags must be
hashable and totally ordered (and hence comparable).
hashable.
.. class:: ShapeType
Expand Down
6 changes: 1 addition & 5 deletions pytato/distributed/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def number_distributed_tags(
This is a potentially heavyweight MPI-collective operation on
*mpi_communicator*.
.. note::
This function requires that symbolic tags are comparable.
"""
tags = frozenset({
recv.comm_tag
Expand Down Expand Up @@ -106,7 +102,7 @@ def set_union(
next_tag = base_tag
assert isinstance(all_tags, frozenset)

for sym_tag in sorted(all_tags):
for sym_tag in all_tags:
sym_tag_to_int_tag[sym_tag] = next_tag
next_tag += 1

Expand Down
42 changes: 42 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,48 @@ def _do_verify_distributed_partition(ctx_factory):
# }}}


# {{{ test symbolic tag numbering with bare classes

class FooTag1:
pass


class FooTag2:
pass


def test_number_symbolic_tags_bare_classes(ctx_factory):
from mpi4py import MPI # pylint: disable=import-error
comm = MPI.COMM_WORLD
from pytato.distributed.nodes import (staple_distributed_send,
make_distributed_recv)

rank = 0
size = 2

x = pt.make_placeholder("x", (4, 4), int)
y = pt.make_placeholder("y", (4, 4), int)

r1 = staple_distributed_send(x, dest_rank=(rank-1) % size,
comm_tag=FooTag1, stapled_to=make_distributed_recv(
src_rank=(rank+1) % size, comm_tag=FooTag1, shape=(4, 4),
dtype=int))

r2 = staple_distributed_send(y, dest_rank=(rank-1) % size,
comm_tag=FooTag2, stapled_to=make_distributed_recv(
src_rank=(rank+1) % size, comm_tag=FooTag2, shape=(4, 4),
dtype=int))

res = r1 + r2

outputs = pt.make_dict_of_named_arrays({"out": res})
partition = pt.find_distributed_partition(comm, outputs)

pt.number_distributed_tags(comm, partition, base_tag=4242)

# }}}


if __name__ == "__main__":
if "RUN_WITHIN_MPI" in os.environ:
run_test_with_mpi_inner()
Expand Down

0 comments on commit 6ccb338

Please sign in to comment.