From c3fe047fcc5557484bb6c014a1e3d4e28dd93741 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 6 Feb 2024 15:52:15 -0600 Subject: [PATCH] number_distributed_tags: non-set, non-sorted numbering (#469) * number_distributed_tags: non-set, non-sorted numbering * make the test a bit more difficult * improve test * add comments based on review --- pytato/distributed/tags.py | 49 ++++++++++++++++---------------------- test/test_distributed.py | 22 +++++++++++++---- 2 files changed, 38 insertions(+), 33 deletions(-) diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index f9eb3b7ec..4e33055a9 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -31,7 +31,7 @@ """ -from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar +from typing import TYPE_CHECKING, Tuple, TypeVar from pytato.distributed.partition import DistributedGraphPartition @@ -63,48 +63,39 @@ def number_distributed_tags( This is a potentially heavyweight MPI-collective operation on *mpi_communicator*. """ - tags = frozenset({ + from pytools import flatten + + # A potential optimization here could be to use a 'set' to collect the tags, + # but this would introduce non-determinism in the tag numbering. Another + # option would be use something like pytools.unique() to reduce the amount + # of data communicated, but since all sends and receives should each + # have unique tags, this would at most buy us a factor of 2. + tags = tuple([ recv.comm_tag for part in partition.parts.values() for recv in part.name_to_recv_node.values() - } | { + ] + [ send.comm_tag for part in partition.parts.values() for sends in part.name_to_send_nodes.values() - for send in sends}) - - from mpi4py import MPI - - def set_union( - set_a: FrozenSet[T], set_b: FrozenSet[T], - mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]: - assert mpi_data_type is None - assert isinstance(set_a, frozenset) - assert isinstance(set_b, frozenset) - - return set_a | set_b + for send in sends]) root_rank = 0 - set_union_mpi_op = MPI.Op.Create( - # type ignore reason: mpi4py misdeclares op functions as returning - # None. - set_union, # type: ignore[arg-type] - commute=True) - try: - all_tags = mpi_communicator.reduce( - tags, set_union_mpi_op, root=root_rank) - finally: - set_union_mpi_op.Free() + # We can't let MPI do a set union here, since the result would be + # non-deterministic. + all_tags = mpi_communicator.gather(tags, root=root_rank) if mpi_communicator.rank == root_rank: sym_tag_to_int_tag = {} next_tag = base_tag - assert isinstance(all_tags, frozenset) + assert isinstance(all_tags, list) + assert len(all_tags) == mpi_communicator.size - for sym_tag in all_tags: - sym_tag_to_int_tag[sym_tag] = next_tag - next_tag += 1 + for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call] + if sym_tag not in sym_tag_to_int_tag: + sym_tag_to_int_tag[sym_tag] = next_tag + next_tag += 1 mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank) else: diff --git a/test/test_distributed.py b/test/test_distributed.py index 925d2e070..3da30de97 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -263,10 +263,10 @@ def _do_test_distributed_execution_random_dag(ctx_factory): gen_comm_called = False - ntests = 10 + ntests = 20 for i in range(ntests): seed = 120 + i - print(f"Step {i} {seed}") + print(f"Step {i} {seed=}") # {{{ compute value with communication @@ -278,7 +278,15 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - tag = (comm_tag, _RandomDAGTag) # noqa: B023 + + if comm_tag % 5 == 1: + tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag])) + elif comm_tag % 5 == 2: + tag = (comm_tag, (_RandomDAGTag, "b")) + elif comm_tag % 5 == 3: + tag = (_RandomDAGTag, comm_tag) + else: + tag = (comm_tag, _RandomDAGTag) # noqa: B023 inner = make_random_dag(rdagc) return pt.staple_distributed_send( @@ -881,7 +889,13 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): outputs = pt.make_dict_of_named_arrays({"out": res}) partition = pt.find_distributed_partition(comm, outputs) - pt.number_distributed_tags(comm, partition, base_tag=4242) + (distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) + + assert next_tag == 4244 + + # FIXME: For the next assertion, find_distributed_partition needs to be + # deterministic too (https://github.com/inducer/pytato/pull/465). + # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501 # }}}