Skip to content

Commit

Permalink
number_distributed_tags: non-set, non-sorted numbering (#469)
Browse files Browse the repository at this point in the history
* number_distributed_tags: non-set, non-sorted numbering

* make the test a bit more difficult

* improve test

* add comments based on review
  • Loading branch information
matthiasdiener committed Feb 6, 2024
1 parent dc1421b commit c3fe047
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
49 changes: 20 additions & 29 deletions pytato/distributed/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar
from typing import TYPE_CHECKING, Tuple, TypeVar

from pytato.distributed.partition import DistributedGraphPartition

Expand Down Expand Up @@ -63,48 +63,39 @@ def number_distributed_tags(
This is a potentially heavyweight MPI-collective operation on
*mpi_communicator*.
"""
tags = frozenset({
from pytools import flatten

# A potential optimization here could be to use a 'set' to collect the tags,
# but this would introduce non-determinism in the tag numbering. Another
# option would be use something like pytools.unique() to reduce the amount
# of data communicated, but since all sends and receives should each
# have unique tags, this would at most buy us a factor of 2.
tags = tuple([
recv.comm_tag
for part in partition.parts.values()
for recv in part.name_to_recv_node.values()
} | {
] + [
send.comm_tag
for part in partition.parts.values()
for sends in part.name_to_send_nodes.values()
for send in sends})

from mpi4py import MPI

def set_union(
set_a: FrozenSet[T], set_b: FrozenSet[T],
mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]:
assert mpi_data_type is None
assert isinstance(set_a, frozenset)
assert isinstance(set_b, frozenset)

return set_a | set_b
for send in sends])

root_rank = 0

set_union_mpi_op = MPI.Op.Create(
# type ignore reason: mpi4py misdeclares op functions as returning
# None.
set_union, # type: ignore[arg-type]
commute=True)
try:
all_tags = mpi_communicator.reduce(
tags, set_union_mpi_op, root=root_rank)
finally:
set_union_mpi_op.Free()
# We can't let MPI do a set union here, since the result would be
# non-deterministic.
all_tags = mpi_communicator.gather(tags, root=root_rank)

if mpi_communicator.rank == root_rank:
sym_tag_to_int_tag = {}
next_tag = base_tag
assert isinstance(all_tags, frozenset)
assert isinstance(all_tags, list)
assert len(all_tags) == mpi_communicator.size

for sym_tag in all_tags:
sym_tag_to_int_tag[sym_tag] = next_tag
next_tag += 1
for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call]
if sym_tag not in sym_tag_to_int_tag:
sym_tag_to_int_tag[sym_tag] = next_tag
next_tag += 1

mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank)
else:
Expand Down
22 changes: 18 additions & 4 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def _do_test_distributed_execution_random_dag(ctx_factory):

gen_comm_called = False

ntests = 10
ntests = 20
for i in range(ntests):
seed = 120 + i
print(f"Step {i} {seed}")
print(f"Step {i} {seed=}")

# {{{ compute value with communication

Expand All @@ -278,7 +278,15 @@ def gen_comm(rdagc):

nonlocal comm_tag
comm_tag += 1
tag = (comm_tag, _RandomDAGTag) # noqa: B023

if comm_tag % 5 == 1:
tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag]))
elif comm_tag % 5 == 2:
tag = (comm_tag, (_RandomDAGTag, "b"))
elif comm_tag % 5 == 3:
tag = (_RandomDAGTag, comm_tag)
else:
tag = (comm_tag, _RandomDAGTag) # noqa: B023

inner = make_random_dag(rdagc)
return pt.staple_distributed_send(
Expand Down Expand Up @@ -881,7 +889,13 @@ def test_number_symbolic_tags_bare_classes(ctx_factory):
outputs = pt.make_dict_of_named_arrays({"out": res})
partition = pt.find_distributed_partition(comm, outputs)

pt.number_distributed_tags(comm, partition, base_tag=4242)
(distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242)

assert next_tag == 4244

# FIXME: For the next assertion, find_distributed_partition needs to be
# deterministic too (https://github.com/inducer/pytato/pull/465).
# assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501

# }}}

Expand Down

0 comments on commit c3fe047

Please sign in to comment.