Skip to content

Commit

Permalink
coll: avoid extra intra bcast in bcast_intra_smp_new
Browse files Browse the repository at this point in the history
When root is not local rank 0, instead of adding a extra intra-node
send/recv or bcast, construct an inter group that includes the root
process.
  • Loading branch information
hzhou committed Aug 13, 2024
1 parent 805f51f commit 8366cd2
Showing 1 changed file with 63 additions and 75 deletions.
138 changes: 63 additions & 75 deletions src/mpi/coll/bcast/bcast_intra_smp_new.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,35 @@

#include "mpiimpl.h"

/* TODO: move this to commutil.c */
static void MPIR_Comm_construct_internode_roots_group(MPIR_Comm * comm, int root,
int *group_p, int *root_rank_p)
{
int inter_size = comm->num_external;
int inter_rank = comm->internode_table[comm->rank];

MPIR_COMM_NEW_SUBGROUP(comm, MPIR_SUBGROUP_TEMP, inter_size, inter_rank);
int inter_group = MPIR_COMM_LAST_SUBGROUP(comm);

int *proc_table = MPL_malloc(inter_size * sizeof(int), MPL_MEM_OTHER);
for (int i = 0; i < inter_size; i++) {
proc_table[i] = -1;
}
for (int i = 0; i < comm->remote_size; i++) {
int r = comm->internode_table[i];
if (proc_table[r] == -1) {
proc_table[r] = i;
}
}
int inter_root_rank = comm->internode_table[root];
proc_table[inter_root_rank] = root;

comm->subgroups[inter_group].proc_table = proc_table;

*group_p = inter_group;
*root_rank_p = inter_root_rank;
}

/* The sticky point in the old smp bcast is when root is not a local root, resulting
* an extra send/recv. With the new MPIR_Subgroup, we don't have to, in principle.
*/
Expand All @@ -22,115 +51,74 @@ int MPIR_Bcast_intra_smp_new(void *buffer, MPI_Aint count, MPI_Datatype datatype
#else
status_p = MPI_STATUS_IGNORE;
#endif
int node_group = -1, inter_group = -1;

MPIR_Assert(comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT);
MPIR_Assert(comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT ||
comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__FLAT);
MPIR_Assert(MPIR_COLL_ATTR_GET_SUBGROUP(coll_attr) == 0);

int node_group = -1, node_cross_group = -1;
for (int i = 0; i < comm_ptr->num_subgroups; i++) {
if (comm_ptr->subgroups[i].kind == MPIR_SUBGROUP_NODE) {
node_group = i;
}
if (comm_ptr->subgroups[i].kind == MPIR_SUBGROUP_NODE_CROSS) {
node_cross_group = i;
break;
}
}
MPIR_Assert(node_group > 0 && node_cross_group > 0);
MPIR_Assert(node_group > 0);

#define NODEGROUP(field) comm_ptr->subgroups[node_group].field
int local_rank, local_size;
local_rank = NODEGROUP(rank);
local_size = NODEGROUP(size);

int local_root_rank = MPIR_Get_intranode_rank(comm_ptr, root);
int inter_root_rank = MPIR_Get_internode_rank(comm_ptr, root);

int local_root = 0;
if (local_root_rank > 0) {
local_root = local_root_rank;
}

/* Construct an internode group */
int inter_root_rank;
#define INTERGROUP(field) comm_ptr->subgroups[inter_group].field
if (local_rank == local_root) {
MPIR_Comm_construct_internode_roots_group(comm_ptr, root, &inter_group, &inter_root_rank);
MPIR_Assert(inter_group > 0);
}

int node_attr = coll_attr | MPIR_COLL_ATTR_SUBGROUP(node_group);
int node_cross_attr = coll_attr | MPIR_COLL_ATTR_SUBGROUP(node_cross_group);
int inter_attr = coll_attr | MPIR_COLL_ATTR_SUBGROUP(inter_group);

MPIR_Datatype_get_size_macro(datatype, type_size);

nbytes = type_size * count;
if (nbytes == 0)
goto fn_exit; /* nothing to do */

if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (local_size < MPIR_CVAR_BCAST_MIN_PROCS)) {
/* send to intranode-rank 0 on the root's node */
if (local_size > 1 && local_root_rank > 0) {
if (root == comm_ptr->rank) {
int node_root = NODEGROUP(proc_table)[0];
mpi_errno = MPIC_Send(buffer, count, datatype, node_root, MPIR_BCAST_TAG, comm_ptr,
coll_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
} else if (local_rank == 0) {
mpi_errno = MPIC_Recv(buffer, count, datatype, root, MPIR_BCAST_TAG, comm_ptr,
coll_attr, status_p);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
#ifdef HAVE_ERROR_CHECKING
/* check that we received as much as we expected */
MPIR_Get_count_impl(status_p, MPI_BYTE, &recvd_size);
MPIR_ERR_COLL_CHECK_SIZE(recvd_size, nbytes, coll_attr, mpi_errno_ret);
#endif
}

}

if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (local_size < MPIR_CVAR_BCAST_MIN_PROCS) ||
(nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(local_size))) {
/* perform the internode broadcast */
if (local_rank == 0) {
if (local_rank == local_root) {
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, inter_root_rank,
comm_ptr, node_cross_attr);
comm_ptr, inter_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}

/* perform the intranode broadcast */
if (local_size > 1) {
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, 0, comm_ptr, node_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}
} else { /* (nbytes > MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_ptr->size >= MPIR_CVAR_BCAST_MIN_PROCS) */

/* supposedly...
* smp+doubling good for pof2
* reg+ring better for non-pof2 */
if (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(local_size)) {
/* medium-sized msg and pof2 np */

/* perform the intranode broadcast on the root's node */
if (local_size > 1 && local_root_rank > 0) { /* is not the node root (0) and is on our node (!-1) */
/* FIXME binomial may not be the best algorithm for on-node
* bcast. We need a more comprehensive system for selecting the
* right algorithms here. */
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, local_root_rank,
comm_ptr, node_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}

/* perform the internode broadcast */
if (local_rank == 0) {
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, inter_root_rank,
comm_ptr, node_cross_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}

/* perform the intranode broadcast on all except for the root's node */
if (local_size > 1 && local_root_rank <= 0) { /* 0 if root was local root too, -1 if different node than root */
/* FIXME binomial may not be the best algorithm for on-node
* bcast. We need a more comprehensive system for selecting the
* right algorithms here. */
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, 0,
comm_ptr, node_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}
} else { /* large msg or non-pof2 */

/* FIXME It would be good to have an SMP-aware version of this
* algorithm that (at least approximately) minimized internode
* communication. */
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, root,
comm_ptr, coll_attr);
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, local_root,
comm_ptr, node_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}
} else {
/* large msg or non-pof2 */
/* FIXME: better algorithm selection */
mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm_ptr, coll_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}

fn_exit:
if (inter_group > 0) {
MPIR_COMM_POP_SUBGROUP(comm_ptr);
}
return mpi_errno_ret;
}

0 comments on commit 8366cd2

Please sign in to comment.