From 8366cd2f557aa9c3ad30d18ef6090f9e67556d99 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 12 Aug 2024 23:36:34 -0500 Subject: [PATCH] coll: avoid extra intra bcast in bcast_intra_smp_new When root is not local rank 0, instead of adding a extra intra-node send/recv or bcast, construct an inter group that includes the root process. --- src/mpi/coll/bcast/bcast_intra_smp_new.c | 138 +++++++++++------------ 1 file changed, 63 insertions(+), 75 deletions(-) diff --git a/src/mpi/coll/bcast/bcast_intra_smp_new.c b/src/mpi/coll/bcast/bcast_intra_smp_new.c index 0bb7434dbcc..94f1daf9ec6 100644 --- a/src/mpi/coll/bcast/bcast_intra_smp_new.c +++ b/src/mpi/coll/bcast/bcast_intra_smp_new.c @@ -5,6 +5,35 @@ #include "mpiimpl.h" +/* TODO: move this to commutil.c */ +static void MPIR_Comm_construct_internode_roots_group(MPIR_Comm * comm, int root, + int *group_p, int *root_rank_p) +{ + int inter_size = comm->num_external; + int inter_rank = comm->internode_table[comm->rank]; + + MPIR_COMM_NEW_SUBGROUP(comm, MPIR_SUBGROUP_TEMP, inter_size, inter_rank); + int inter_group = MPIR_COMM_LAST_SUBGROUP(comm); + + int *proc_table = MPL_malloc(inter_size * sizeof(int), MPL_MEM_OTHER); + for (int i = 0; i < inter_size; i++) { + proc_table[i] = -1; + } + for (int i = 0; i < comm->remote_size; i++) { + int r = comm->internode_table[i]; + if (proc_table[r] == -1) { + proc_table[r] = i; + } + } + int inter_root_rank = comm->internode_table[root]; + proc_table[inter_root_rank] = root; + + comm->subgroups[inter_group].proc_table = proc_table; + + *group_p = inter_group; + *root_rank_p = inter_root_rank; +} + /* The sticky point in the old smp bcast is when root is not a local root, resulting * an extra send/recv. With the new MPIR_Subgroup, we don't have to, in principle. */ @@ -22,20 +51,19 @@ int MPIR_Bcast_intra_smp_new(void *buffer, MPI_Aint count, MPI_Datatype datatype #else status_p = MPI_STATUS_IGNORE; #endif + int node_group = -1, inter_group = -1; - MPIR_Assert(comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT); + MPIR_Assert(comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT || + comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__FLAT); MPIR_Assert(MPIR_COLL_ATTR_GET_SUBGROUP(coll_attr) == 0); - int node_group = -1, node_cross_group = -1; for (int i = 0; i < comm_ptr->num_subgroups; i++) { if (comm_ptr->subgroups[i].kind == MPIR_SUBGROUP_NODE) { node_group = i; - } - if (comm_ptr->subgroups[i].kind == MPIR_SUBGROUP_NODE_CROSS) { - node_cross_group = i; + break; } } - MPIR_Assert(node_group > 0 && node_cross_group > 0); + MPIR_Assert(node_group > 0); #define NODEGROUP(field) comm_ptr->subgroups[node_group].field int local_rank, local_size; @@ -43,9 +71,22 @@ int MPIR_Bcast_intra_smp_new(void *buffer, MPI_Aint count, MPI_Datatype datatype local_size = NODEGROUP(size); int local_root_rank = MPIR_Get_intranode_rank(comm_ptr, root); - int inter_root_rank = MPIR_Get_internode_rank(comm_ptr, root); + + int local_root = 0; + if (local_root_rank > 0) { + local_root = local_root_rank; + } + + /* Construct an internode group */ + int inter_root_rank; +#define INTERGROUP(field) comm_ptr->subgroups[inter_group].field + if (local_rank == local_root) { + MPIR_Comm_construct_internode_roots_group(comm_ptr, root, &inter_group, &inter_root_rank); + MPIR_Assert(inter_group > 0); + } + int node_attr = coll_attr | MPIR_COLL_ATTR_SUBGROUP(node_group); - int node_cross_attr = coll_attr | MPIR_COLL_ATTR_SUBGROUP(node_cross_group); + int inter_attr = coll_attr | MPIR_COLL_ATTR_SUBGROUP(inter_group); MPIR_Datatype_get_size_macro(datatype, type_size); @@ -53,84 +94,31 @@ int MPIR_Bcast_intra_smp_new(void *buffer, MPI_Aint count, MPI_Datatype datatype if (nbytes == 0) goto fn_exit; /* nothing to do */ - if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (local_size < MPIR_CVAR_BCAST_MIN_PROCS)) { - /* send to intranode-rank 0 on the root's node */ - if (local_size > 1 && local_root_rank > 0) { - if (root == comm_ptr->rank) { - int node_root = NODEGROUP(proc_table)[0]; - mpi_errno = MPIC_Send(buffer, count, datatype, node_root, MPIR_BCAST_TAG, comm_ptr, - coll_attr); - MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); - } else if (local_rank == 0) { - mpi_errno = MPIC_Recv(buffer, count, datatype, root, MPIR_BCAST_TAG, comm_ptr, - coll_attr, status_p); - MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); -#ifdef HAVE_ERROR_CHECKING - /* check that we received as much as we expected */ - MPIR_Get_count_impl(status_p, MPI_BYTE, &recvd_size); - MPIR_ERR_COLL_CHECK_SIZE(recvd_size, nbytes, coll_attr, mpi_errno_ret); -#endif - } - - } - + if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (local_size < MPIR_CVAR_BCAST_MIN_PROCS) || + (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(local_size))) { /* perform the internode broadcast */ - if (local_rank == 0) { + if (local_rank == local_root) { mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, inter_root_rank, - comm_ptr, node_cross_attr); + comm_ptr, inter_attr); MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); } /* perform the intranode broadcast */ if (local_size > 1) { - mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, 0, comm_ptr, node_attr); - MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); - } - } else { /* (nbytes > MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_ptr->size >= MPIR_CVAR_BCAST_MIN_PROCS) */ - - /* supposedly... - * smp+doubling good for pof2 - * reg+ring better for non-pof2 */ - if (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(local_size)) { - /* medium-sized msg and pof2 np */ - - /* perform the intranode broadcast on the root's node */ - if (local_size > 1 && local_root_rank > 0) { /* is not the node root (0) and is on our node (!-1) */ - /* FIXME binomial may not be the best algorithm for on-node - * bcast. We need a more comprehensive system for selecting the - * right algorithms here. */ - mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, local_root_rank, - comm_ptr, node_attr); - MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); - } - - /* perform the internode broadcast */ - if (local_rank == 0) { - mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, inter_root_rank, - comm_ptr, node_cross_attr); - MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); - } - - /* perform the intranode broadcast on all except for the root's node */ - if (local_size > 1 && local_root_rank <= 0) { /* 0 if root was local root too, -1 if different node than root */ - /* FIXME binomial may not be the best algorithm for on-node - * bcast. We need a more comprehensive system for selecting the - * right algorithms here. */ - mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, 0, - comm_ptr, node_attr); - MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); - } - } else { /* large msg or non-pof2 */ - - /* FIXME It would be good to have an SMP-aware version of this - * algorithm that (at least approximately) minimized internode - * communication. */ - mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, root, - comm_ptr, coll_attr); + mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, local_root, + comm_ptr, node_attr); MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); } + } else { + /* large msg or non-pof2 */ + /* FIXME: better algorithm selection */ + mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm_ptr, coll_attr); + MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret); } fn_exit: + if (inter_group > 0) { + MPIR_COMM_POP_SUBGROUP(comm_ptr); + } return mpi_errno_ret; }