From 914755c3c06bd49c2f5d4b85ba964a51ac2740e3 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 20 Aug 2024 10:59:07 -0500 Subject: [PATCH] mpir: replace subcomm usage with subgroups Directly use information from MPIR_Process rather than from nodecomm in MPIR_Process. One step toward removing subcomms. --- src/mpi/comm/comm_split_type_nbhd.c | 7 +------ src/mpi/init/init_async.c | 9 +++------ src/util/mpir_nodemap.c | 21 ++++++++++----------- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/mpi/comm/comm_split_type_nbhd.c b/src/mpi/comm/comm_split_type_nbhd.c index 07c0e6d288c..9fbaa730ca0 100644 --- a/src/mpi/comm/comm_split_type_nbhd.c +++ b/src/mpi/comm/comm_split_type_nbhd.c @@ -474,12 +474,7 @@ static int network_split_by_min_memsize(MPIR_Comm * comm_ptr, int key, long min_ if (min_mem_size == 0 || topo_type == MPIR_NETTOPO_TYPE__INVALID) { *newcomm_ptr = NULL; } else { - int num_ranks_node; - if (MPIR_Process.comm_world->node_comm != NULL) { - num_ranks_node = MPIR_Comm_size(MPIR_Process.comm_world->node_comm); - } else { - num_ranks_node = 1; - } + int num_ranks_node = MPIR_Process.local_size; memory_per_process = total_memory_size / num_ranks_node; mpi_errno = network_split_by_minsize(comm_ptr, key, min_mem_size / memory_per_process, newcomm_ptr); diff --git a/src/mpi/init/init_async.c b/src/mpi/init/init_async.c index a7f4e1f4807..2e8c82de94a 100644 --- a/src/mpi/init/init_async.c +++ b/src/mpi/init/init_async.c @@ -179,17 +179,14 @@ static int get_thread_affinity(bool * apply_affinity, int **p_thread_affinity, i } global_rank = MPIR_Process.rank; - local_rank = - (MPIR_Process.comm_world->node_comm) ? MPIR_Process.comm_world->node_comm->rank : 0; + local_rank = MPIR_Process.local_rank; if (have_cliques) { - /* If local cliques > 1, using local_size from node_comm will have conflict on thread idx. + /* If local cliques > 1, using local_size will have conflict on thread idx. * In multiple nodes case, this would cost extra memory for allocating thread affinity on every * node, but it is okay to solve progress thread oversubscription. */ local_size = MPIR_Process.comm_world->local_size; } else { - local_size = - (MPIR_Process.comm_world->node_comm) ? MPIR_Process.comm_world-> - node_comm->local_size : 1; + local_size = MPIR_Process.local_size; } async_threads_per_node = local_size; diff --git a/src/util/mpir_nodemap.c b/src/util/mpir_nodemap.c index cf8eaca7efa..e8f6d1ee6ba 100644 --- a/src/util/mpir_nodemap.c +++ b/src/util/mpir_nodemap.c @@ -436,14 +436,14 @@ int MPIR_nodeid_init(void) utarray_resize(MPIR_Process.node_hostnames, MPIR_Process.num_nodes, MPL_MEM_OTHER); char *allhostnames = (char *) utarray_eltptr(MPIR_Process.node_hostnames, 0); - if (MPIR_Process.local_rank == 0) { - MPIR_Comm *node_roots_comm = MPIR_Process.comm_world->node_roots_comm; - if (node_roots_comm == NULL) { - /* num_external == comm->remote_size */ - node_roots_comm = MPIR_Process.comm_world; - } + MPIR_Comm *world_comm = MPIR_Process.comm_world; + int local_rank = world_comm->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = world_comm->subgroups[MPIR_SUBGROUP_NODE].size; + + if (local_rank == 0) { + int inter_rank = world_comm->subgroups[MPIR_SUBGROUP_NODE_CROSS].rank; - char *my_hostname = allhostnames + MAX_HOSTNAME_LEN * node_roots_comm->rank; + char *my_hostname = allhostnames + MAX_HOSTNAME_LEN * inter_rank; int ret = gethostname(my_hostname, MAX_HOSTNAME_LEN); char strerrbuf[MPIR_STRERROR_BUF_SIZE] ATTRIBUTE((unused)); MPIR_ERR_CHKANDJUMP2(ret == -1, mpi_errno, MPI_ERR_OTHER, @@ -453,14 +453,13 @@ int MPIR_nodeid_init(void) mpi_errno = MPIR_Allgather_impl(MPI_IN_PLACE, MAX_HOSTNAME_LEN, MPI_CHAR, allhostnames, MAX_HOSTNAME_LEN, MPI_CHAR, - node_roots_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + world_comm, MPIR_SUBGROUP_NODE_CROSS, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } - MPIR_Comm *node_comm = MPIR_Process.comm_world->node_comm; - if (node_comm) { + if (local_size > 1) { mpi_errno = MPIR_Bcast_impl(allhostnames, MAX_HOSTNAME_LEN * MPIR_Process.num_nodes, - MPI_CHAR, 0, node_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + MPI_CHAR, 0, world_comm, MPIR_SUBGROUP_NODE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); }