From d0adcdae1250e8a219b1c84ed7e2045be80c2ee7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Nov 2023 16:53:19 -0600 Subject: [PATCH 1/4] number_distributed_tags: non-set, non-sorted numbering --- pytato/distributed/__init__.py | 2 +- pytato/distributed/tags.py | 46 ++++++++++------------------------ test/test_distributed.py | 10 ++++++-- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/pytato/distributed/__init__.py b/pytato/distributed/__init__.py index 4354b2f0f..ee0ff39de 100644 --- a/pytato/distributed/__init__.py +++ b/pytato/distributed/__init__.py @@ -23,7 +23,7 @@ .. class:: CommTagType A type representing a communication tag. Communication tags must be - hashable and totally ordered (and hence comparable). + hashable. .. class:: ShapeType diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 41ae3273c..4b97b2d50 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -31,7 +31,7 @@ """ -from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar +from typing import TYPE_CHECKING, Tuple, TypeVar from pytato.distributed.partition import DistributedGraphPartition @@ -62,53 +62,33 @@ def number_distributed_tags( This is a potentially heavyweight MPI-collective operation on *mpi_communicator*. - - .. note:: - - This function requires that symbolic tags are comparable. """ - tags = frozenset({ + from pytools import flatten + + tags = tuple([ recv.comm_tag for part in partition.parts.values() for recv in part.name_to_recv_node.values() - } | { + ] + [ send.comm_tag for part in partition.parts.values() for sends in part.name_to_send_nodes.values() - for send in sends}) - - from mpi4py import MPI - - def set_union( - set_a: FrozenSet[T], set_b: FrozenSet[T], - mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]: - assert mpi_data_type is None - assert isinstance(set_a, frozenset) - assert isinstance(set_b, frozenset) - - return set_a | set_b + for send in sends]) root_rank = 0 - set_union_mpi_op = MPI.Op.Create( - # type ignore reason: mpi4py misdeclares op functions as returning - # None. - set_union, # type: ignore[arg-type] - commute=True) - try: - all_tags = mpi_communicator.reduce( - tags, set_union_mpi_op, root=root_rank) - finally: - set_union_mpi_op.Free() + all_tags = mpi_communicator.gather(tags, root=root_rank) if mpi_communicator.rank == root_rank: sym_tag_to_int_tag = {} next_tag = base_tag - assert isinstance(all_tags, frozenset) + assert isinstance(all_tags, list) + assert len(all_tags) == mpi_communicator.size - for sym_tag in sorted(all_tags): - sym_tag_to_int_tag[sym_tag] = next_tag - next_tag += 1 + for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call] + if sym_tag not in sym_tag_to_int_tag: + sym_tag_to_int_tag[sym_tag] = next_tag + next_tag += 1 mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank) else: diff --git a/test/test_distributed.py b/test/test_distributed.py index f7a8e5b4c..c36f4caae 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -266,7 +266,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory): ntests = 10 for i in range(ntests): seed = 120 + i - print(f"Step {i} {seed}") + print(f"Step {i} {seed=}") # {{{ compute value with communication @@ -278,7 +278,13 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - tag = (comm_tag, _RandomDAGTag) # noqa: B023 + + if comm_tag % 5 == 1: + tag = (comm_tag, frozenset([_RandomDAGTag, _RandomDAGTag])) + elif comm_tag % 5 == 2: + tag = (comm_tag, (_RandomDAGTag,)) + else: + tag = (comm_tag, _RandomDAGTag) # noqa: B023 inner = make_random_dag(rdagc) return pt.staple_distributed_send( From 9527630722a5882617972c8b1a8586bcd5416b4f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Nov 2023 17:33:10 -0600 Subject: [PATCH 2/4] make the test a bit more difficult --- test/test_distributed.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index c36f4caae..6778a213a 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -263,7 +263,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory): gen_comm_called = False - ntests = 10 + ntests = 20 for i in range(ntests): seed = 120 + i print(f"Step {i} {seed=}") @@ -280,9 +280,11 @@ def gen_comm(rdagc): comm_tag += 1 if comm_tag % 5 == 1: - tag = (comm_tag, frozenset([_RandomDAGTag, _RandomDAGTag])) + tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag])) elif comm_tag % 5 == 2: - tag = (comm_tag, (_RandomDAGTag,)) + tag = (comm_tag, (_RandomDAGTag, "b")) + elif comm_tag % 5 == 3: + tag = (_RandomDAGTag, comm_tag) else: tag = (comm_tag, _RandomDAGTag) # noqa: B023 From 998edbdde67da15802e178d492555be0c405e455 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 14:31:56 -0600 Subject: [PATCH 3/4] improve test --- test/test_distributed.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index 33ef5a225..63e65eccd 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -889,7 +889,13 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): outputs = pt.make_dict_of_named_arrays({"out": res}) partition = pt.find_distributed_partition(comm, outputs) - pt.number_distributed_tags(comm, partition, base_tag=4242) + (distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) + + assert next_tag == 4244 + + # For the next assertion, find_distributed_partition needs to be + # deterministic too (https://github.com/inducer/pytato/pull/465). + # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501 # }}} From fab72fb5282a53e300e2ae375b22857bc7cc8e1b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 6 Feb 2024 10:56:28 -0600 Subject: [PATCH 4/4] add comments based on review --- pytato/distributed/tags.py | 7 +++++++ test/test_distributed.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 4b97b2d50..4e33055a9 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -65,6 +65,11 @@ def number_distributed_tags( """ from pytools import flatten + # A potential optimization here could be to use a 'set' to collect the tags, + # but this would introduce non-determinism in the tag numbering. Another + # option would be use something like pytools.unique() to reduce the amount + # of data communicated, but since all sends and receives should each + # have unique tags, this would at most buy us a factor of 2. tags = tuple([ recv.comm_tag for part in partition.parts.values() @@ -77,6 +82,8 @@ def number_distributed_tags( root_rank = 0 + # We can't let MPI do a set union here, since the result would be + # non-deterministic. all_tags = mpi_communicator.gather(tags, root=root_rank) if mpi_communicator.rank == root_rank: diff --git a/test/test_distributed.py b/test/test_distributed.py index 63e65eccd..3da30de97 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -893,7 +893,7 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): assert next_tag == 4244 - # For the next assertion, find_distributed_partition needs to be + # FIXME: For the next assertion, find_distributed_partition needs to be # deterministic too (https://github.com/inducer/pytato/pull/465). # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501