Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

number_distributed_tags: non-set, non-sorted numbering #469

Merged
merged 8 commits into from
Feb 6, 2024
49 changes: 20 additions & 29 deletions pytato/distributed/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar
from typing import TYPE_CHECKING, Tuple, TypeVar

from pytato.distributed.partition import DistributedGraphPartition

Expand Down Expand Up @@ -63,48 +63,39 @@ def number_distributed_tags(
This is a potentially heavyweight MPI-collective operation on
*mpi_communicator*.
"""
tags = frozenset({
from pytools import flatten

# A potential optimization here could be to use a 'set' to collect the tags,
# but this would introduce non-determinism in the tag numbering. Another
# option would be use something like pytools.unique() to reduce the amount
# of data communicated, but since all sends and receives should each
# have unique tags, this would at most buy us a factor of 2.
tags = tuple([
inducer marked this conversation as resolved.
Show resolved Hide resolved
recv.comm_tag
for part in partition.parts.values()
for recv in part.name_to_recv_node.values()
} | {
] + [
send.comm_tag
for part in partition.parts.values()
for sends in part.name_to_send_nodes.values()
for send in sends})

from mpi4py import MPI

def set_union(
set_a: FrozenSet[T], set_b: FrozenSet[T],
mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]:
assert mpi_data_type is None
assert isinstance(set_a, frozenset)
assert isinstance(set_b, frozenset)

return set_a | set_b
for send in sends])

root_rank = 0

set_union_mpi_op = MPI.Op.Create(
# type ignore reason: mpi4py misdeclares op functions as returning
# None.
set_union, # type: ignore[arg-type]
commute=True)
try:
all_tags = mpi_communicator.reduce(
tags, set_union_mpi_op, root=root_rank)
finally:
set_union_mpi_op.Free()
# We can't let MPI do a set union here, since the result would be
# non-deterministic.
all_tags = mpi_communicator.gather(tags, root=root_rank)
inducer marked this conversation as resolved.
Show resolved Hide resolved

if mpi_communicator.rank == root_rank:
sym_tag_to_int_tag = {}
next_tag = base_tag
assert isinstance(all_tags, frozenset)
assert isinstance(all_tags, list)
assert len(all_tags) == mpi_communicator.size

for sym_tag in all_tags:
sym_tag_to_int_tag[sym_tag] = next_tag
next_tag += 1
for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call]
if sym_tag not in sym_tag_to_int_tag:
sym_tag_to_int_tag[sym_tag] = next_tag
next_tag += 1

mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank)
else:
Expand Down
22 changes: 18 additions & 4 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def _do_test_distributed_execution_random_dag(ctx_factory):

gen_comm_called = False

ntests = 10
ntests = 20
for i in range(ntests):
seed = 120 + i
print(f"Step {i} {seed}")
print(f"Step {i} {seed=}")

# {{{ compute value with communication

Expand All @@ -278,7 +278,15 @@ def gen_comm(rdagc):

nonlocal comm_tag
comm_tag += 1
tag = (comm_tag, _RandomDAGTag) # noqa: B023

if comm_tag % 5 == 1:
tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag]))
elif comm_tag % 5 == 2:
tag = (comm_tag, (_RandomDAGTag, "b"))
elif comm_tag % 5 == 3:
tag = (_RandomDAGTag, comm_tag)
else:
tag = (comm_tag, _RandomDAGTag) # noqa: B023

inner = make_random_dag(rdagc)
return pt.staple_distributed_send(
Expand Down Expand Up @@ -881,7 +889,13 @@ def test_number_symbolic_tags_bare_classes(ctx_factory):
outputs = pt.make_dict_of_named_arrays({"out": res})
partition = pt.find_distributed_partition(comm, outputs)

pt.number_distributed_tags(comm, partition, base_tag=4242)
(distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242)

assert next_tag == 4244

# FIXME: For the next assertion, find_distributed_partition needs to be
# deterministic too (https://github.com/inducer/pytato/pull/465).
# assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When combining this PR with #465, this assertion passes.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File an issue to remind us to reenable it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #482


# }}}

Expand Down
Loading