Skip to content

Commit

Permalink
tag communication by destination volume
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Feb 15, 2023
1 parent d5576fb commit 1a38e2d
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions grudge/trace_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,17 @@ def __init__(self,
self.local_bdry_data = local_bdry_data
self.remote_bdry_data_template = remote_bdry_data_template

self.comm_tag = self.base_comm_tag
comm_tag = _sym_tag_to_num_tag(comm_tag)
if comm_tag is not None:
self.comm_tag += comm_tag
def _generate_num_comm_tag(sym_comm_tag):
result = self.base_comm_tag
num_comm_tag = _sym_tag_to_num_tag(sym_comm_tag)
if num_comm_tag is not None:
result += num_comm_tag
return result

send_sym_comm_tag = (remote_part_id.volume_tag, comm_tag)
recv_sym_comm_tag = (local_part_id.volume_tag, comm_tag)
self.send_comm_tag = _generate_num_comm_tag(send_sym_comm_tag)
self.recv_comm_tag = _generate_num_comm_tag(recv_sym_comm_tag)
del comm_tag

# NOTE: mpi4py currently (2021-11-03) holds a reference to the send
Expand All @@ -588,7 +595,7 @@ def send_single_array(key, local_subary):
if not isinstance(local_subary, Number):
local_subary_np = to_numpy(local_subary, actx)
self.send_reqs.append(
comm.Isend(local_subary_np, remote_rank, tag=self.comm_tag))
comm.Isend(local_subary_np, remote_rank, tag=self.send_comm_tag))
self.send_data.append(local_subary_np)
return local_subary

Expand All @@ -601,7 +608,8 @@ def recv_single_array(key, remote_subary_template):
remote_subary_template.shape,
remote_subary_template.dtype)
self.recv_reqs.append(
comm.Irecv(remote_subary_np, remote_rank, tag=self.comm_tag))
comm.Irecv(remote_subary_np, remote_rank,
tag=self.recv_comm_tag))
self.recv_data[key] = remote_subary_np
return remote_subary_template

Expand Down Expand Up @@ -702,7 +710,7 @@ def send_single_array(key, local_subary):
if isinstance(local_subary, Number):
return
else:
ary_tag = (comm_tag, key)
ary_tag = (remote_part_id.volume_tag, comm_tag, key)
sends[key] = make_distributed_send(
local_subary, dest_rank=remote_rank, comm_tag=ary_tag)

Expand All @@ -711,7 +719,7 @@ def recv_single_array(key, remote_subary_template):
# NOTE: Assumes that the same number is passed on every rank
return remote_subary_template
else:
ary_tag = (comm_tag, key)
ary_tag = (local_part_id.volume_tag, comm_tag, key)
return DistributedSendRefHolder(
sends[key],
make_distributed_recv(
Expand Down

0 comments on commit 1a38e2d

Please sign in to comment.