From 1a38e2df87a657ad964de674aa1b6acf874005b0 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 3 Nov 2022 10:07:26 -0700 Subject: [PATCH] tag communication by destination volume --- grudge/trace_pair.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 7358e5af..acc08650 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -570,10 +570,17 @@ def __init__(self, self.local_bdry_data = local_bdry_data self.remote_bdry_data_template = remote_bdry_data_template - self.comm_tag = self.base_comm_tag - comm_tag = _sym_tag_to_num_tag(comm_tag) - if comm_tag is not None: - self.comm_tag += comm_tag + def _generate_num_comm_tag(sym_comm_tag): + result = self.base_comm_tag + num_comm_tag = _sym_tag_to_num_tag(sym_comm_tag) + if num_comm_tag is not None: + result += num_comm_tag + return result + + send_sym_comm_tag = (remote_part_id.volume_tag, comm_tag) + recv_sym_comm_tag = (local_part_id.volume_tag, comm_tag) + self.send_comm_tag = _generate_num_comm_tag(send_sym_comm_tag) + self.recv_comm_tag = _generate_num_comm_tag(recv_sym_comm_tag) del comm_tag # NOTE: mpi4py currently (2021-11-03) holds a reference to the send @@ -588,7 +595,7 @@ def send_single_array(key, local_subary): if not isinstance(local_subary, Number): local_subary_np = to_numpy(local_subary, actx) self.send_reqs.append( - comm.Isend(local_subary_np, remote_rank, tag=self.comm_tag)) + comm.Isend(local_subary_np, remote_rank, tag=self.send_comm_tag)) self.send_data.append(local_subary_np) return local_subary @@ -601,7 +608,8 @@ def recv_single_array(key, remote_subary_template): remote_subary_template.shape, remote_subary_template.dtype) self.recv_reqs.append( - comm.Irecv(remote_subary_np, remote_rank, tag=self.comm_tag)) + comm.Irecv(remote_subary_np, remote_rank, + tag=self.recv_comm_tag)) self.recv_data[key] = remote_subary_np return remote_subary_template @@ -702,7 +710,7 @@ def send_single_array(key, local_subary): if isinstance(local_subary, Number): return else: - ary_tag = (comm_tag, key) + ary_tag = (remote_part_id.volume_tag, comm_tag, key) sends[key] = make_distributed_send( local_subary, dest_rank=remote_rank, comm_tag=ary_tag) @@ -711,7 +719,7 @@ def recv_single_array(key, remote_subary_template): # NOTE: Assumes that the same number is passed on every rank return remote_subary_template else: - ary_tag = (comm_tag, key) + ary_tag = (local_part_id.volume_tag, comm_tag, key) return DistributedSendRefHolder( sends[key], make_distributed_recv(