From 9340ed5f37b22a5cb171a3daf612a75200c742e4 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 12 Aug 2024 23:43:15 -0500 Subject: [PATCH 01/27] comm: store num_local and num_external in MPIR_Comm Store num_local and num_external in MPIR_Comm. Along with internode_table, they help construct internode subgroups. --- src/include/mpir_comm.h | 2 ++ src/mpi/comm/commutil.c | 3 +++ 2 files changed, 5 insertions(+) diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index a72daa01722..210b76f2571 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -188,6 +188,8 @@ struct MPIR_Comm { * node_roots_comm of rank i in this comm. * It is of size 'local_size'. */ int node_count; /* number of nodes this comm is spread over */ + int num_local; /* number of procs in this comm on local node */ + int num_external; /* number of nodes this comm is spread over */ int is_low_group; /* For intercomms only, this boolean is * set for all members of one of the diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index 29d7d473ea7..7ed4a5cb2f5 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -659,6 +659,9 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm) MPIR_Assert(num_local > 1 || external_rank >= 0); MPIR_Assert(external_rank < 0 || external_procs != NULL); + comm->num_local = num_local; + comm->num_external = num_external; + /* if the node_roots_comm and comm would be the same size, then creating * the second communicator is useless and wasteful. */ if (num_external == comm->remote_size) { From 417c2c55741beb87ef24f1724e6560325b39a70e Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 13 Aug 2024 08:46:09 -0500 Subject: [PATCH 02/27] comm: remove node_count This is the same as num_external. --- src/include/mpir_comm.h | 1 - src/mpi/coll/src/csel.c | 8 +++--- src/mpi/comm/commutil.c | 55 ----------------------------------------- 3 files changed, 4 insertions(+), 60 deletions(-) diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 210b76f2571..4b8e8a74cc6 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -187,7 +187,6 @@ struct MPIR_Comm { int *internode_table; /* internode_table[i] gives the rank in * node_roots_comm of rank i in this comm. * It is of size 'local_size'. */ - int node_count; /* number of nodes this comm is spread over */ int num_local; /* number of procs in this comm on local node */ int num_external; /* number of nodes this comm is spread over */ diff --git a/src/mpi/coll/src/csel.c b/src/mpi/coll/src/csel.c index 1253f847f66..47bd4c9e1be 100644 --- a/src/mpi/coll/src/csel.c +++ b/src/mpi/coll/src/csel.c @@ -659,14 +659,14 @@ static csel_node_s *prune_tree(csel_node_s * root, MPIR_Comm * comm_ptr) break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LE: - if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LT: - if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; @@ -1329,14 +1329,14 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LE: - if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; break; case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LT: - if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->node_count) + if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->num_external) node = node->success; else node = node->failure; diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index 7ed4a5cb2f5..9cd70fb58a7 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -517,58 +517,6 @@ int MPIR_Comm_map_free(MPIR_Comm * comm) return mpi_errno; } -static int get_node_count(MPIR_Comm * comm, int *node_count) -{ - int mpi_errno = MPI_SUCCESS; - struct uniq_nodes { - int id; - UT_hash_handle hh; - } *node_list = NULL; - struct uniq_nodes *s, *tmp; - - if (comm->comm_kind != MPIR_COMM_KIND__INTRACOMM) { - *node_count = comm->local_size; - goto fn_exit; - } else if (comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE) { - *node_count = 1; - goto fn_exit; - } else if (comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE_ROOTS) { - *node_count = comm->local_size; - goto fn_exit; - } - - /* go through the list of ranks and add the unique ones to the - * node_list array */ - for (int i = 0; i < comm->local_size; i++) { - int node; - - mpi_errno = MPID_Get_node_id(comm, i, &node); - MPIR_ERR_CHECK(mpi_errno); - - HASH_FIND_INT(node_list, &node, s); - if (s == NULL) { - s = (struct uniq_nodes *) MPL_malloc(sizeof(struct uniq_nodes), MPL_MEM_COLL); - MPIR_Assert(s); - s->id = node; - HASH_ADD_INT(node_list, id, s, MPL_MEM_COLL); - } - } - - /* the final size of our hash table is our node count */ - *node_count = HASH_COUNT(node_list); - - /* free up everything */ - HASH_ITER(hh, node_list, s, tmp) { - HASH_DEL(node_list, s); - MPL_free(s); - } - - fn_exit: - return mpi_errno; - fn_fail: - goto fn_exit; -} - static int MPIR_Comm_commit_internal(MPIR_Comm * comm) { int mpi_errno = MPI_SUCCESS; @@ -578,9 +526,6 @@ static int MPIR_Comm_commit_internal(MPIR_Comm * comm) mpi_errno = MPID_Comm_commit_pre_hook(comm); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = get_node_count(comm, &comm->node_count); - MPIR_ERR_CHECK(mpi_errno); - MPIR_Comm_map_free(comm); fn_exit: From a24834b53aea4e6bb3a34e4bd74850131ae6bd19 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 19 Aug 2024 17:19:12 -0500 Subject: [PATCH 03/27] comm/csel: remove reference to subcomms in csel prune_tree As the title. --- src/mpi/coll/src/csel.c | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/mpi/coll/src/csel.c b/src/mpi/coll/src/csel.c index 47bd4c9e1be..dfcd2b75746 100644 --- a/src/mpi/coll/src/csel.c +++ b/src/mpi/coll/src/csel.c @@ -630,8 +630,8 @@ static csel_node_s *prune_tree(csel_node_s * root, MPIR_Comm * comm_ptr) break; case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_NODE_COMM_SIZE: - if (comm_ptr->node_comm != NULL && - MPIR_Comm_size(comm_ptr) == MPIR_Comm_size(comm_ptr->node_comm)) + /* comm_size equal to node_comm_size just mean the size inter-node is 1 */ + if (comm_ptr->num_external == 1) node = node->success; else node = node->failure; @@ -1229,8 +1229,7 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_NODE_COMM_SIZE: - if (comm_ptr->node_comm != NULL && - MPIR_Comm_size(comm_ptr) == MPIR_Comm_size(comm_ptr->node_comm)) + if (comm_ptr->num_external == 1) node = node->success; else node = node->failure; From 132c188233aa95499bdd853756cf3e82bd8c0936 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Wed, 21 Aug 2024 09:50:39 -0500 Subject: [PATCH 04/27] coll: remove coll.pof2 field It does not take many instructions to calculate pof2 on the fly. Use of hard coded pof2 prevents collective algorithms to be used for non-trivial coll_group. --- maint/gen_coll.py | 4 ++-- src/include/mpir_comm.h | 3 --- .../allreduce/allreduce_intra_reduce_scatter_allgather.c | 2 +- .../iallreduce/iallreduce_intra_sched_recursive_doubling.c | 2 +- .../iallreduce_intra_sched_reduce_scatter_allgather.c | 2 +- .../coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c | 2 +- .../ireduce_scatter_intra_sched_recursive_halving.c | 2 +- src/mpi/coll/mpir_coll_sched_auto.c | 4 ++-- src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c | 2 +- src/mpi/coll/src/coll_impl.c | 2 -- src/mpi/coll/src/csel.c | 2 +- src/mpid/ch4/src/ch4_comm.c | 2 -- src/mpid/ch4/src/init_comm.c | 5 ++--- 13 files changed, 13 insertions(+), 21 deletions(-) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 10fd49d9086..4e3a858caff 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -552,9 +552,9 @@ def dump_fallback(algo): elif a == "noinplace": cond_list.append("sendbuf != MPI_IN_PLACE") elif a == "power-of-two": - cond_list.append("comm_ptr->local_size == comm_ptr->coll.pof2") + cond_list.append("MPL_is_pof2(comm_ptr->local_size)") elif a == "size-ge-pof2": - cond_list.append("count >= comm_ptr->coll.pof2") + cond_list.append("count >= MPL_pof2(comm_ptr->local_size)") elif a == "commutative": cond_list.append("MPIR_Op_is_commutative(op)") elif a== "builtin-op": diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 4b8e8a74cc6..97f5102e06b 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -224,9 +224,6 @@ struct MPIR_Comm { * use int array for fast access */ struct { - int pof2; /* Nearest (smaller than or equal to) power of 2 - * to the number of ranks in the communicator. - * To be used during collective communication */ int pofk[MAX_RADIX - 1]; int k[MAX_RADIX - 1]; int step1_sendto[MAX_RADIX - 1]; diff --git a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c index 327148196fd..1371b31de34 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c @@ -72,7 +72,7 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c index b0a08613efd..8f409293afe 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c @@ -38,7 +38,7 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c index 75bab1d9a84..1fc14c6a367 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c @@ -45,7 +45,7 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c index 2f35e5d6f93..d6bc57637fd 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c @@ -64,7 +64,7 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re tmp_buf = (void *) ((char *) tmp_buf - true_lb); /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(HANDLE_IS_BUILTIN(op)); diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c index 69814169823..2043b96686c 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c @@ -95,7 +95,7 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/mpir_coll_sched_auto.c b/src/mpi/coll/mpir_coll_sched_auto.c index 8b118bda349..2b2e4e17daa 100644 --- a/src/mpi/coll/mpir_coll_sched_auto.c +++ b/src/mpi/coll/mpir_coll_sched_auto.c @@ -538,7 +538,7 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c MPIR_Datatype_get_size_macro(datatype, type_size); /* get nearest power-of-two less than or equal to number of ranks in the communicator */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); if ((count * type_size > MPIR_CVAR_REDUCE_SHORT_MSG_SIZE) && (HANDLE_IS_BUILTIN(op)) && (count >= pof2)) { @@ -595,7 +595,7 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain MPIR_Datatype_get_size_macro(datatype, type_size); /* get nearest power-of-two less than or equal to number of ranks in the communicator */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); /* If op is user-defined or count is less than pof2, use * recursive doubling algorithm. Otherwise do a reduce-scatter diff --git a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c index a8113ce6658..8c1cd0fe4f0 100644 --- a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c +++ b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c @@ -77,7 +77,7 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = comm_ptr->coll.pof2; + pof2 = MPL_pof2(comm_ptr->local_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(HANDLE_IS_BUILTIN(op)); diff --git a/src/mpi/coll/src/coll_impl.c b/src/mpi/coll/src/coll_impl.c index b4b70a10b7a..46cb38c4e5a 100644 --- a/src/mpi/coll/src/coll_impl.c +++ b/src/mpi/coll/src/coll_impl.c @@ -224,8 +224,6 @@ int MPIR_Coll_comm_init(MPIR_Comm * comm) { int mpi_errno = MPI_SUCCESS; - comm->coll.pof2 = MPL_pof2(comm->local_size); - /* initialize any stub algo related data structures */ mpi_errno = MPII_Stubalgo_comm_init(comm); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/src/csel.c b/src/mpi/coll/src/csel.c index dfcd2b75746..e0bb9dbf4ce 100644 --- a/src/mpi/coll/src/csel.c +++ b/src/mpi/coll/src/csel.c @@ -1285,7 +1285,7 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COUNT_LT_POW2: - if (get_count(coll_info) < coll_info.comm_ptr->coll.pof2) + if (get_count(coll_info) < MPL_pof2(coll_info.comm_ptr->local_size)) node = node->success; else node = node->failure; diff --git a/src/mpid/ch4/src/ch4_comm.c b/src/mpid/ch4/src/ch4_comm.c index 808d6f6e21b..1af06962f61 100644 --- a/src/mpid/ch4/src/ch4_comm.c +++ b/src/mpid/ch4/src/ch4_comm.c @@ -780,8 +780,6 @@ int MPIDI_Comm_create_multi_leaders(MPIR_Comm * comm) MPIDI_COMM(comm, multi_leads_comm)); MPIDI_COMM(comm, multi_leads_comm)->local_size = num_external; - MPIDI_COMM(comm, multi_leads_comm)->coll.pof2 = - MPL_pof2(MPIDI_COMM(comm, multi_leads_comm)->local_size); MPIDI_COMM(comm, multi_leads_comm)->remote_size = num_external; MPIR_Comm_map_irregular(MPIDI_COMM(comm, multi_leads_comm), comm, diff --git a/src/mpid/ch4/src/init_comm.c b/src/mpid/ch4/src/init_comm.c index e546337bd6f..1ff8135e2c8 100644 --- a/src/mpid/ch4/src/init_comm.c +++ b/src/mpid/ch4/src/init_comm.c @@ -32,7 +32,6 @@ int MPIDI_create_init_comm(MPIR_Comm ** comm) init_comm->rank = node_roots_comm_rank; init_comm->remote_size = node_roots_comm_size; init_comm->local_size = node_roots_comm_size; - init_comm->coll.pof2 = MPL_pof2(node_roots_comm_size); MPIDI_COMM(init_comm, map).mode = MPIDI_RANK_MAP_LUT_INTRA; mpi_errno = MPIDIU_alloc_lut(&lut, node_roots_comm_size); MPIR_ERR_CHECK(mpi_errno); @@ -47,8 +46,8 @@ int MPIDI_create_init_comm(MPIR_Comm ** comm) mpi_errno = MPIDIG_init_comm(init_comm); MPIR_ERR_CHECK(mpi_errno); /* hacky, consider a separate MPIDI_{NM,SHM}_init_comm_hook - * to initialize the init_comm, e.g. to eliminate potential - * runtime features for stability during init */ + * to initialize the init_comm, e.g. to eliminate potential + * runtime features for stability during init */ mpi_errno = MPIDI_NM_mpi_comm_commit_pre_hook(init_comm); MPIR_ERR_CHECK(mpi_errno); From c941fbe67d7ea27714fcd8811baa5c341bca996e Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 11 Aug 2024 15:10:52 -0500 Subject: [PATCH 05/27] comm: add MPIR_Subgroup Lightweight struct to describe sub-groups of a communicator. They intend to replace the subcomms. Preset a set of reserved subgroups to simplify common usages such as intranode group and crossnode group. Since we only expect limited number of dynamic subgroups and they should always be push/pop'ed within the scope, we don't need many dynamic slots. --- src/include/mpir_comm.h | 47 +++++++++++++++++++++++++++++++++++++++++ src/mpi/comm/commutil.c | 31 +++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 97f5102e06b..945f887eb39 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -101,6 +101,51 @@ enum MPIR_COMM_HINT_PREDEFINED_t { MPIR_COMM_HINT_PREDEFINED_COUNT }; +/* MPIR_Subgroup is similar to MPIR_Group, but only used to describe subgroups within + * an intra communicator. The proc_table refers to ranks within the communicator. + * It is only used internally for group collectives. + */ +typedef struct MPIR_Subgroup { + int size; + int rank; + int *proc_table; /* can be NULL if the group is trivial */ +} MPIR_Subgroup; + +#define MPIR_MAX_SUBGROUPS 16 + +/* reserved subgroup indexes */ +enum { + MPIR_SUBGROUP_THREADCOMM = -1, + MPIR_SUBGROUP_NONE = 0, + MPIR_SUBGROUP_NODE, /* i.e. nodecomm */ + MPIR_SUBGROUP_NODE_CROSS, /* node_roots_comm, node_rank_1_comm, ... */ + MPIR_SUBGROUP_NUMA1, /* 1-level below node in topology */ + MPIR_SUBGROUP_NUMA1_CROSS, /* cross-link group at NUMA1 within NODE */ + MPIR_SUBGROUP_NUMA2, /* and so on */ + MPIR_SUBGROUP_NUMA2_CROSS, + MPIR_SUBGROUP_NUM_RESERVED, +}; + +/* macros to create dynamic subgroups. + * It is expected to fillout the proc_table after MPIR_COMM_PUSH_SUBGROUP. + */ +#define MPIR_COMM_PUSH_SUBGROUP(comm, _size, _rank, newgrp, proc_table_out) \ + do { \ + (newgrp) = (comm)->num_subgroups++; \ + MPIR_Assert((comm)->num_subgroups < MPIR_MAX_SUBGROUPS); \ + (comm)->subgroups[newgrp].size = _size; \ + (comm)->subgroups[newgrp].rank = _rank; \ + (proc_table_out) = MPL_malloc((_size) * sizeof(int), MPL_MEM_OTHER); \ + (comm)->subgroups[newgrp].proc_table = (proc_table_out); \ + } while (0) + +#define MPIR_COMM_POP_SUBGROUP(comm) \ + do { \ + int i = --(comm)->num_subgroups; \ + MPIR_Assert(i > 0); \ + MPL_free((comm)->subgroups[i].proc_table); \ + } while (0) + /*S MPIR_Comm - Description of the Communicator data structure @@ -197,6 +242,8 @@ struct MPIR_Comm { * intercommunicator collective operations * that wish to use half-duplex operations * to implement a full-duplex operation */ + MPIR_Subgroup subgroups[MPIR_MAX_SUBGROUPS]; + int num_subgroups; struct MPIR_Comm *comm_next; /* Provides a chain through all active * communicators */ diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index 9cd70fb58a7..6e3ba2fcdeb 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -302,6 +302,7 @@ int MPII_Comm_init(MPIR_Comm * comm_p) comm_p->node_roots_comm = NULL; comm_p->intranode_table = NULL; comm_p->internode_table = NULL; + comm_p->num_subgroups = 0; /* abstractions bleed a bit here... :(*/ comm_p->next_sched_tag = MPIR_FIRST_NBC_TAG; @@ -607,6 +608,22 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm) comm->num_local = num_local; comm->num_external = num_external; + /* node */ +#define NODE_GROUP(field) comm->subgroups[MPIR_SUBGROUP_NODE].field + NODE_GROUP(rank) = local_rank; + NODE_GROUP(size) = num_local; + NODE_GROUP(proc_table) = MPL_malloc(num_local * sizeof(int), MPL_MEM_OTHER); + for (int i = 0; i < num_local; i++) { + NODE_GROUP(proc_table)[i] = local_procs[i]; + } +#define NODE_CROSS_GROUP(field) comm->subgroups[MPIR_SUBGROUP_NODE_CROSS].field + NODE_CROSS_GROUP(rank) = external_rank; + NODE_CROSS_GROUP(size) = num_external; + NODE_CROSS_GROUP(proc_table) = MPL_malloc(num_external * sizeof(int), MPL_MEM_OTHER); + for (int i = 0; i < num_external; i++) { + NODE_CROSS_GROUP(proc_table)[i] = external_procs[i]; + } + /* if the node_roots_comm and comm would be the same size, then creating * the second communicator is useless and wasteful. */ if (num_external == comm->remote_size) { @@ -731,6 +748,14 @@ int MPIR_Comm_commit(MPIR_Comm * comm) MPIR_FUNC_ENTER; + /* preset reserved subgroups */ + comm->num_subgroups = MPIR_SUBGROUP_NUM_RESERVED; + for (int i = 0; i < comm->num_subgroups; i++) { + comm->subgroups[i].rank = -1; + comm->subgroups[i].size = 0; + comm->subgroups[i].proc_table = NULL; + } + /* It's OK to relax these assertions, but we should do so very * intentionally. For now this function is the only place that we create * our hierarchy of communicators */ @@ -1162,6 +1187,12 @@ int MPIR_Comm_delete_internal(MPIR_Comm * comm_ptr) MPL_free(comm_ptr->intranode_table); MPL_free(comm_ptr->internode_table); + /* free subgroups */ + for (int i = 0; i < comm_ptr->num_subgroups; i++) { + MPL_free(comm_ptr->subgroups[i].proc_table); + } + comm_ptr->num_subgroups = 0; + MPIR_stream_comm_free(comm_ptr); /* Free the context value. This should come after freeing the From d9bf21b384a4112a85fc93091a02337bbe899df0 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 22 Aug 2024 10:14:26 -0500 Subject: [PATCH 06/27] coll: add macros to get rank/size with coll_group Group collectives will have non-trivial coll_group that alter the rank and size of the communicator. Thease macros and functions will facilitate it. --- src/include/mpir_coll.h | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/include/mpir_coll.h b/src/include/mpir_coll.h index 4038a272017..41f223c5741 100644 --- a/src/include/mpir_coll.h +++ b/src/include/mpir_coll.h @@ -8,6 +8,52 @@ #include "coll_impl.h" #include "coll_algos.h" +#include "mpir_threadcomm.h" + +#ifdef ENABLE_THREADCOMM +#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ + MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \ + MPIR_Assert(threadcomm); \ + int intracomm_size = (comm)->local_size; \ + size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \ + rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \ + } while (0) +#else +#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ + MPIR_Assert(0); \ + size_ = 0; \ + rank_ = -1; \ + } while (0) +#endif + +#define MPIR_COLL_RANK_SIZE(comm, coll_group, rank_, size_) do { \ + if (coll_group == MPIR_SUBGROUP_NONE) { \ + rank_ = (comm)->rank; \ + size_ = (comm)->local_size; \ + } else if (coll_group == MPIR_SUBGROUP_THREADCOMM) { \ + MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_); \ + } else { \ + rank_ = (comm)->subgroups[coll_group].rank; \ + size_ = (comm)->subgroups[coll_group].size; \ + } \ + } while (0) + +/* sometime it is convenient to just get the rank or size */ +static inline int MPIR_Coll_size(MPIR_Comm * comm, int coll_group) +{ + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + (void) rank; + return size; +} + +static inline int MPIR_Coll_rank(MPIR_Comm * comm, int coll_group) +{ + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + (void) size; + return rank; +} /* During init, not all algorithms are safe to use. For example, the csel * may not have been initialized. We define a set of fallback routines that From 105c7e0c46a0e8439c55b97df566ad28fe9afe04 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 16 Aug 2024 08:48:01 -0500 Subject: [PATCH 07/27] coll: add coll_group argument to coll interfaces Add coll_group, index to comm->subgroups[], to all collectives except neighborhood collectives. --- maint/gen_coll.py | 10 +++ maint/local_python/binding_c.py | 5 ++ src/binding/c/comm_api.txt | 2 +- src/include/mpir_coll.h | 14 ++-- src/include/mpir_op.h | 4 +- src/mpid/ch4/ch4_api.txt | 137 ++++++++++++++++---------------- 6 files changed, 96 insertions(+), 76 deletions(-) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 4e3a858caff..6327480f41f 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -644,6 +644,9 @@ def get_algo_extra_params(algo): # additional wrappers def get_algo_args(args, algo, kind): algo_args = args + if not re.match(r'i?neighbor_', algo['func-commkind']): + algo_args += ", coll_group" + if 'extra_params' in algo: algo_args += ", " + get_algo_extra_args(algo, kind) @@ -658,6 +661,9 @@ def get_algo_args(args, algo, kind): def get_algo_params(params, algo): algo_params = params + if not re.match(r'i?neighbor_', algo['func-commkind']): + algo_params += ", int coll_group" + if 'extra_params' in algo: algo_params += ", " + get_algo_extra_params(algo) @@ -681,6 +687,8 @@ def get_algo_name(algo): def get_func_params(params, name, kind): func_params = params + if not name.startswith('neighbor_'): + func_params += ", int coll_group" if kind == "blocking": if not name.startswith('neighbor_'): func_params += ", MPIR_Errflag_t errflag" @@ -701,6 +709,8 @@ def get_func_params(params, name, kind): def get_func_args(args, name, kind): func_args = args + if not name.startswith('neighbor_'): + func_args += ", coll_group" if kind == "blocking": if not name.startswith('neighbor_'): func_args += ", errflag" diff --git a/maint/local_python/binding_c.py b/maint/local_python/binding_c.py index f219ee4e194..f712618232a 100644 --- a/maint/local_python/binding_c.py +++ b/maint/local_python/binding_c.py @@ -1686,6 +1686,8 @@ def push_impl_decl(func, impl_name=None): if func['_impl_param_list']: params = ', '.join(func['_impl_param_list']) if func['dir'] == 'coll': + if not RE.match(r'MPI_(Ineighbor|Neighbor)', func['name']): + params = params.replace('comm_ptr', 'comm_ptr, int coll_group') # block collective use an extra errflag if not RE.match(r'MPI_(I.*|Neighbor.*|.*_init)$', func['name']): params = params + ", MPIR_Errflag_t errflag" @@ -1726,6 +1728,8 @@ def dump_body_coll(func): mpir_name = re.sub(r'^MPIX?_', 'MPIR_', func['name']) args = ", ".join(func['_impl_arg_list']) + if not RE.match(r'MPI_(Ineighbor|Neighbor)', func['name']): + args = args.replace('comm_ptr', 'comm_ptr, MPIR_SUBGROUP_NONE') if RE.match(r'MPI_(I.*|.*_init)$', func['name'], re.IGNORECASE): # non-blocking collectives @@ -1956,6 +1960,7 @@ def dump_body_reduce_equal(func): args = ", ".join(func['_impl_arg_list']) args = re.sub(r'recvbuf, ', '', args) args = re.sub(r'op, ', 'recvbuf, ', args) + args += ", MPIR_SUBGROUP_NONE" dump_line_with_break("mpi_errno = %s(%s);" % (impl, args)) dump_error_check("") diff --git a/src/binding/c/comm_api.txt b/src/binding/c/comm_api.txt index 58dbabf4169..a827c66e7ff 100644 --- a/src/binding/c/comm_api.txt +++ b/src/binding/c/comm_api.txt @@ -301,7 +301,7 @@ MPI_Intercomm_merge: * error to make */ acthigh = high ? 1 : 0; /* Clamp high into 1 or 0 */ mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &acthigh, 1, MPI_INT, - MPI_SUM, intercomm_ptr->local_comm, MPIR_ERR_NONE); + MPI_SUM, intercomm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* acthigh must either == 0 or the size of the local comm */ if (acthigh != 0 && acthigh != intercomm_ptr->local_size) { diff --git a/src/include/mpir_coll.h b/src/include/mpir_coll.h index 41f223c5741..05eb02c213e 100644 --- a/src/include/mpir_coll.h +++ b/src/include/mpir_coll.h @@ -94,16 +94,20 @@ int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status * statuses); int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op); -int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag); +int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag); /* TSP auto */ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched); + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched); int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched); -int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sched); + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched); +int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched); int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched); + MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched); #endif /* MPIR_COLL_H_INCLUDED */ diff --git a/src/include/mpir_op.h b/src/include/mpir_op.h index ec04103b3e1..c940aacc06d 100644 --- a/src/include/mpir_op.h +++ b/src/include/mpir_op.h @@ -235,8 +235,8 @@ int MPIR_Op_is_commutative(MPI_Op); MPI_Datatype MPIR_Op_get_alt_datatype(MPI_Op op, MPI_Datatype datatype); int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, int root, MPIR_Comm * comm_ptr); + int *is_equal, int root, MPIR_Comm * comm_ptr, int coll_group); int MPIR_Allreduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, MPIR_Comm * comm_ptr); + int *is_equal, MPIR_Comm * comm_ptr, int coll_group); #endif /* MPIR_OP_H_INCLUDED */ diff --git a/src/mpid/ch4/ch4_api.txt b/src/mpid/ch4/ch4_api.txt index 6b1ae91a4d7..a9b329d9c89 100644 --- a/src/mpid/ch4/ch4_api.txt +++ b/src/mpid/ch4/ch4_api.txt @@ -284,56 +284,56 @@ Native API: rank_is_local : int NM*: target, comm mpi_barrier : int - NM*: comm, errflag - SHM*: comm, errflag + NM*: comm, coll_group, errflag + SHM*: comm, coll_group, errflag mpi_bcast : int - NM*: buffer, count, datatype, root, comm, errflag - SHM*: buffer, count, datatype, root, comm, errflag + NM*: buffer, count, datatype, root, comm, coll_group, errflag + SHM*: buffer, count, datatype, root, comm, coll_group, errflag mpi_allreduce : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, comm, errflag + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag mpi_allgather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag mpi_allgatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, errflag mpi_scatter : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag mpi_scatterv : int - NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, errflag - SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, errflag + NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag mpi_gather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, errflag mpi_gatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, errflag mpi_alltoall : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, errflag + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, errflag mpi_alltoallv : int - NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, errflag - SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, errflag + NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, errflag + SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, errflag mpi_alltoallw : int - NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, errflag - SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, errflag + NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, errflag + SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, errflag mpi_reduce : int - NM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag + NM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag mpi_reduce_scatter : int - NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag - SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag + NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, errflag mpi_reduce_scatter_block : int - NM*: sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, errflag - SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, errflag + NM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, errflag mpi_scan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, comm, errflag + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag mpi_exscan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, errflag - SHM*: sendbuf, recvbuf, count, datatype, op, comm, errflag + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag mpi_neighbor_allgather : int NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm @@ -365,56 +365,56 @@ Native API: NM*: sendbuf, sendcounts, sdispls-2, sendtypes, recvbuf, recvcounts, rdispls-2, recvtypes, comm, req_p SHM*: sendbuf, sendcounts, sdispls-2, sendtypes, recvbuf, recvcounts, rdispls-2, recvtypes, comm, req_p mpi_ibarrier : int - NM*: comm, req_p - SHM*: comm, req_p + NM*: comm, coll_group, req_p + SHM*: comm, coll_group, req_p mpi_ibcast : int - NM*: buffer, count, datatype, root, comm, req_p - SHM*: buffer, count, datatype, root, comm, req_p + NM*: buffer, count, datatype, root, comm, coll_group, req_p + SHM*: buffer, count, datatype, root, comm, coll_group, req_p mpi_iallgather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p mpi_iallgatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, req_p mpi_iallreduce : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p mpi_ialltoall : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req_p mpi_ialltoallv : int - NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req_p - SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req_p + NM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req_p + SHM*: sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req_p mpi_ialltoallw : int - NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req_p - SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req_p + NM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req_p + SHM*: sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req_p mpi_iexscan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p mpi_igather : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p mpi_igatherv : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, req_p mpi_ireduce_scatter_block : int - NM*: sendbuf, recvbuf, recvcount, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, req_p mpi_ireduce_scatter : int - NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, req_p mpi_ireduce : int - NM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req_p + NM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req_p mpi_iscan : int - NM*: sendbuf, recvbuf, count, datatype, op, comm, req_p - SHM*: sendbuf, recvbuf, count, datatype, op, comm, req_p + NM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p + SHM*: sendbuf, recvbuf, count, datatype, op, comm, coll_group, req_p mpi_iscatter : int - NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p - SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, req_p + NM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p mpi_iscatterv : int - NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, req_p - SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, req_p + NM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p + SHM*: sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req_p mpi_type_commit_hook : int NM : datatype_p SHM : type @@ -449,6 +449,7 @@ PARAM: buf: const void * buf-2: void * buffer: void * + coll_group: int comm: MPIR_Comm * comm_ptr: MPIR_Comm * compare_addr: const void * From 3501b573996d03124dca2fbcf41dd75e72048f77 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 16 Aug 2024 15:40:04 -0500 Subject: [PATCH 08/27] continue: add coll_group to collective interfaces --- .../coll/algorithms/recexchalgo/recexchalgo.h | 12 +- src/mpi/coll/algorithms/treealgo/treeutil.c | 2 +- src/mpi/coll/allgather/allgather_allcomm_nb.c | 4 +- ...llgather_inter_local_gather_remote_bcast.c | 13 +- .../coll/allgather/allgather_intra_brucks.c | 3 +- .../coll/allgather/allgather_intra_k_brucks.c | 4 +- .../coll/allgather/allgather_intra_recexch.c | 2 +- .../allgather_intra_recursive_doubling.c | 3 +- src/mpi/coll/allgather/allgather_intra_ring.c | 3 +- .../coll/allgatherv/allgatherv_allcomm_nb.c | 5 +- ...lgatherv_inter_remote_gather_local_bcast.c | 13 +- .../coll/allgatherv/allgatherv_intra_brucks.c | 2 +- .../allgatherv_intra_recursive_doubling.c | 3 +- .../coll/allgatherv/allgatherv_intra_ring.c | 3 +- src/mpi/coll/allreduce/allreduce_allcomm_nb.c | 5 +- .../allreduce_inter_reduce_exchange_bcast.c | 8 +- ...lreduce_intra_k_reduce_scatter_allgather.c | 5 +- .../coll/allreduce/allreduce_intra_recexch.c | 4 +- .../allreduce_intra_recursive_doubling.c | 3 +- ...allreduce_intra_reduce_scatter_allgather.c | 3 +- src/mpi/coll/allreduce/allreduce_intra_ring.c | 4 +- src/mpi/coll/allreduce/allreduce_intra_smp.c | 13 +- src/mpi/coll/allreduce/allreduce_intra_tree.c | 2 +- src/mpi/coll/alltoall/alltoall_allcomm_nb.c | 4 +- .../alltoall_inter_pairwise_exchange.c | 2 +- src/mpi/coll/alltoall/alltoall_intra_brucks.c | 3 +- .../coll/alltoall/alltoall_intra_k_brucks.c | 2 +- .../coll/alltoall/alltoall_intra_pairwise.c | 2 +- ...alltoall_intra_pairwise_sendrecv_replace.c | 3 +- .../coll/alltoall/alltoall_intra_scattered.c | 2 +- src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c | 5 +- .../alltoallv_inter_pairwise_exchange.c | 3 +- ...lltoallv_intra_pairwise_sendrecv_replace.c | 3 +- .../alltoallv/alltoallv_intra_scattered.c | 2 +- src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c | 4 +- .../alltoallw_inter_pairwise_exchange.c | 3 +- ...lltoallw_intra_pairwise_sendrecv_replace.c | 3 +- .../alltoallw/alltoallw_intra_scattered.c | 2 +- src/mpi/coll/barrier/barrier_allcomm_nb.c | 4 +- src/mpi/coll/barrier/barrier_inter_bcast.c | 12 +- .../barrier/barrier_intra_k_dissemination.c | 7 +- src/mpi/coll/barrier/barrier_intra_recexch.c | 4 +- src/mpi/coll/barrier/barrier_intra_smp.c | 8 +- src/mpi/coll/bcast/bcast_allcomm_nb.c | 4 +- .../bcast_inter_remote_send_local_bcast.c | 6 +- src/mpi/coll/bcast/bcast_intra_binomial.c | 3 +- .../coll/bcast/bcast_intra_pipelined_tree.c | 2 +- ...tra_scatter_recursive_doubling_allgather.c | 2 +- .../bcast_intra_scatter_ring_allgather.c | 3 +- src/mpi/coll/bcast/bcast_intra_smp.c | 19 +- src/mpi/coll/bcast/bcast_intra_tree.c | 2 +- src/mpi/coll/exscan/exscan_allcomm_nb.c | 4 +- .../exscan/exscan_intra_recursive_doubling.c | 3 +- src/mpi/coll/gather/gather_allcomm_nb.c | 4 +- src/mpi/coll/gather/gather_inter_linear.c | 2 +- .../gather_inter_local_gather_remote_send.c | 5 +- src/mpi/coll/gather/gather_intra_binomial.c | 2 +- src/mpi/coll/gatherv/gatherv_allcomm_linear.c | 3 +- src/mpi/coll/gatherv/gatherv_allcomm_nb.c | 4 +- ...er_inter_sched_local_gather_remote_bcast.c | 13 +- .../iallgather_intra_sched_brucks.c | 3 +- ...allgather_intra_sched_recursive_doubling.c | 3 +- .../iallgather/iallgather_intra_sched_ring.c | 3 +- .../coll/iallgather/iallgather_tsp_brucks.c | 3 +- .../coll/iallgather/iallgather_tsp_recexch.c | 22 +- src/mpi/coll/iallgather/iallgather_tsp_ring.c | 2 +- ...rv_inter_sched_remote_gather_local_bcast.c | 17 +- .../iallgatherv_intra_sched_brucks.c | 3 +- ...llgatherv_intra_sched_recursive_doubling.c | 3 +- .../iallgatherv_intra_sched_ring.c | 3 +- .../coll/iallgatherv/iallgatherv_tsp_brucks.c | 2 +- .../iallgatherv/iallgatherv_tsp_recexch.c | 21 +- .../coll/iallgatherv/iallgatherv_tsp_ring.c | 2 +- ...ce_inter_sched_remote_reduce_local_bcast.c | 16 +- .../iallreduce/iallreduce_intra_sched_naive.c | 10 +- ...allreduce_intra_sched_recursive_doubling.c | 3 +- ...uce_intra_sched_reduce_scatter_allgather.c | 2 +- .../iallreduce/iallreduce_intra_sched_smp.c | 18 +- src/mpi/coll/iallreduce/iallreduce_tsp_auto.c | 24 +- .../coll/iallreduce/iallreduce_tsp_recexch.c | 6 +- ...ecexch_reduce_scatter_recexch_allgatherv.c | 9 +- ...iallreduce_tsp_recursive_exchange_common.c | 2 +- ...iallreduce_tsp_recursive_exchange_common.h | 2 +- src/mpi/coll/iallreduce/iallreduce_tsp_ring.c | 4 +- src/mpi/coll/iallreduce/iallreduce_tsp_tree.c | 5 +- .../ialltoall_inter_sched_pairwise_exchange.c | 3 +- .../ialltoall/ialltoall_intra_sched_brucks.c | 3 +- .../ialltoall/ialltoall_intra_sched_inplace.c | 3 +- .../ialltoall_intra_sched_pairwise.c | 3 +- .../ialltoall_intra_sched_permuted_sendrecv.c | 3 +- src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c | 2 +- src/mpi/coll/ialltoall/ialltoall_tsp_ring.c | 2 +- .../coll/ialltoall/ialltoall_tsp_scattered.c | 4 +- ...ialltoallv_inter_sched_pairwise_exchange.c | 3 +- .../ialltoallv_intra_sched_blocked.c | 2 +- .../ialltoallv_intra_sched_inplace.c | 2 +- .../coll/ialltoallv/ialltoallv_tsp_blocked.c | 3 +- .../coll/ialltoallv/ialltoallv_tsp_inplace.c | 3 +- .../ialltoallv/ialltoallv_tsp_scattered.c | 4 +- ...ialltoallw_inter_sched_pairwise_exchange.c | 3 +- .../ialltoallw_intra_sched_blocked.c | 2 +- .../ialltoallw_intra_sched_inplace.c | 2 +- .../coll/ialltoallw/ialltoallw_tsp_blocked.c | 2 +- .../coll/ialltoallw/ialltoallw_tsp_inplace.c | 2 +- .../ibarrier/ibarrier_inter_sched_bcast.c | 12 +- .../ibarrier_intra_sched_recursive_doubling.c | 3 +- .../coll/ibarrier/ibarrier_intra_tsp_dissem.c | 3 +- .../ibarrier/ibarrier_intra_tsp_recexch.c | 5 +- src/mpi/coll/ibarrier/ibarrier_tsp_auto.c | 14 +- src/mpi/coll/ibcast/ibcast_inter_sched_flat.c | 5 +- .../coll/ibcast/ibcast_intra_sched_binomial.c | 2 +- ...hed_scatter_recursive_doubling_allgather.c | 2 +- ...bcast_intra_sched_scatter_ring_allgather.c | 3 +- src/mpi/coll/ibcast/ibcast_intra_sched_smp.c | 7 +- src/mpi/coll/ibcast/ibcast_tsp_auto.c | 32 +- .../ibcast/ibcast_tsp_scatterv_allgatherv.c | 12 +- .../ibcast_tsp_scatterv_ring_allgatherv.c | 6 +- src/mpi/coll/ibcast/ibcast_tsp_tree.c | 4 +- .../iexscan_intra_sched_recursive_doubling.c | 3 +- .../coll/igather/igather_inter_sched_long.c | 2 +- .../coll/igather/igather_inter_sched_short.c | 4 +- .../igather/igather_intra_sched_binomial.c | 7 +- src/mpi/coll/igather/igather_tsp_tree.c | 2 +- .../igatherv/igatherv_allcomm_sched_linear.c | 2 +- src/mpi/coll/igatherv/igatherv_tsp_linear.c | 2 +- ...uce_inter_sched_local_reduce_remote_send.c | 4 +- .../ireduce/ireduce_intra_sched_binomial.c | 2 +- ...reduce_intra_sched_reduce_scatter_gather.c | 3 +- .../coll/ireduce/ireduce_intra_sched_smp.c | 17 +- src/mpi/coll/ireduce/ireduce_tsp_auto.c | 21 +- src/mpi/coll/ireduce/ireduce_tsp_tree.c | 4 +- ...inter_sched_remote_reduce_local_scatterv.c | 12 +- ...educe_scatter_intra_sched_noncommutative.c | 3 +- .../ireduce_scatter_intra_sched_pairwise.c | 3 +- ...e_scatter_intra_sched_recursive_doubling.c | 3 +- ...ce_scatter_intra_sched_recursive_halving.c | 3 +- .../ireduce_scatter_tsp_recexch.c | 21 +- ...inter_sched_remote_reduce_local_scatterv.c | 12 +- ...scatter_block_intra_sched_noncommutative.c | 2 +- ...educe_scatter_block_intra_sched_pairwise.c | 3 +- ...ter_block_intra_sched_recursive_doubling.c | 3 +- ...tter_block_intra_sched_recursive_halving.c | 3 +- .../ireduce_scatter_block_tsp_recexch.c | 4 +- .../iscan_intra_sched_recursive_doubling.c | 2 +- src/mpi/coll/iscan/iscan_intra_sched_smp.c | 12 +- .../coll/iscan/iscan_tsp_recursive_doubling.c | 3 +- .../iscatter/iscatter_inter_sched_linear.c | 2 +- ...er_inter_sched_remote_send_local_scatter.c | 5 +- .../iscatter/iscatter_intra_sched_binomial.c | 2 +- src/mpi/coll/iscatter/iscatter_tsp_tree.c | 2 +- .../iscatterv_allcomm_sched_linear.c | 3 +- src/mpi/coll/iscatterv/iscatterv_tsp_linear.c | 2 +- src/mpi/coll/mpir_coll_sched_auto.c | 212 ++++++----- src/mpi/coll/op/opequal.c | 13 +- src/mpi/coll/reduce/reduce_allcomm_nb.c | 5 +- .../reduce_inter_local_reduce_remote_send.c | 6 +- src/mpi/coll/reduce/reduce_intra_binomial.c | 3 +- .../reduce_intra_reduce_scatter_gather.c | 3 +- src/mpi/coll/reduce/reduce_intra_smp.c | 12 +- .../reduce_scatter_allcomm_nb.c | 5 +- ...catter_inter_remote_reduce_local_scatter.c | 12 +- .../reduce_scatter_intra_noncommutative.c | 2 +- .../reduce_scatter_intra_pairwise.c | 3 +- .../reduce_scatter_intra_recursive_doubling.c | 2 +- .../reduce_scatter_intra_recursive_halving.c | 2 +- .../reduce_scatter_block_allcomm_nb.c | 5 +- ..._block_inter_remote_reduce_local_scatter.c | 11 +- ...educe_scatter_block_intra_noncommutative.c | 3 +- .../reduce_scatter_block_intra_pairwise.c | 3 +- ...e_scatter_block_intra_recursive_doubling.c | 3 +- ...ce_scatter_block_intra_recursive_halving.c | 3 +- src/mpi/coll/scan/scan_allcomm_nb.c | 4 +- .../coll/scan/scan_intra_recursive_doubling.c | 3 +- src/mpi/coll/scan/scan_intra_smp.c | 12 +- src/mpi/coll/scatter/scatter_allcomm_nb.c | 4 +- src/mpi/coll/scatter/scatter_inter_linear.c | 2 +- .../scatter_inter_remote_send_local_scatter.c | 4 +- src/mpi/coll/scatter/scatter_intra_binomial.c | 2 +- .../coll/scatterv/scatterv_allcomm_linear.c | 2 +- src/mpi/coll/scatterv/scatterv_allcomm_nb.c | 4 +- src/mpi/comm/comm_impl.c | 24 +- src/mpi/comm/comm_split.c | 9 +- src/mpi/comm/comm_split_type_nbhd.c | 6 +- src/mpi/comm/commutil.c | 10 +- src/mpi/comm/contextid.c | 14 +- src/mpi/stream/stream_enqueue.c | 5 +- src/mpi/stream/stream_impl.c | 8 +- src/mpi/threadcomm/threadcomm_coll_impl.c | 41 ++- src/mpi/threadcomm/threadcomm_impl.c | 5 +- src/mpi/topo/dist_graph_create.c | 3 +- .../ch3/channels/nemesis/src/ch3_win_fns.c | 32 +- src/mpid/ch3/include/mpid_coll.h | 136 +++---- src/mpid/ch3/include/mpidpre.h | 34 +- src/mpid/ch3/src/ch3u_comm_spawn_multiple.c | 6 +- src/mpid/ch3/src/ch3u_port.c | 26 +- src/mpid/ch3/src/ch3u_rma_sync.c | 12 +- src/mpid/ch3/src/ch3u_win_fns.c | 2 +- src/mpid/ch3/src/mpid_startall.c | 68 ++-- src/mpid/ch3/src/mpid_vc.c | 10 +- src/mpid/ch3/src/mpidi_rma.c | 2 +- src/mpid/ch4/include/mpidch4.h | 195 +++++----- .../netmod/include/netmod_am_fallback_coll.h | 162 ++++---- .../ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h | 10 +- .../netmod/ofi/coll/ofi_bcast_tree_tagged.h | 14 +- src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h | 5 +- src/mpid/ch4/netmod/ofi/init_addrxchg.c | 8 +- src/mpid/ch4/netmod/ofi/ofi_coll.h | 182 +++++---- src/mpid/ch4/netmod/ofi/ofi_comm.c | 3 +- src/mpid/ch4/netmod/ofi/ofi_init.c | 3 +- src/mpid/ch4/netmod/ofi/ofi_win.c | 14 +- src/mpid/ch4/netmod/ucx/ucx_coll.h | 170 +++++---- src/mpid/ch4/netmod/ucx/ucx_init.c | 3 +- src/mpid/ch4/netmod/ucx/ucx_win.c | 8 +- src/mpid/ch4/shm/ipc/src/ipc_win.c | 3 +- src/mpid/ch4/shm/posix/posix_coll.h | 192 ++++++---- src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h | 37 +- .../shm/posix/posix_coll_nb_release_gather.h | 3 +- .../ch4/shm/posix/posix_coll_release_gather.h | 14 +- src/mpid/ch4/shm/posix/posix_init.c | 3 +- .../posix/release_gather/nb_release_gather.c | 12 +- .../shm/posix/release_gather/release_gather.c | 12 +- src/mpid/ch4/shm/src/shm_am_fallback_coll.h | 157 ++++---- src/mpid/ch4/shm/src/shm_coll.h | 160 ++++---- src/mpid/ch4/shm/src/topotree.c | 10 +- src/mpid/ch4/src/ch4_coll.h | 346 ++++++++++-------- src/mpid/ch4/src/ch4_coll_impl.h | 240 +++++++----- src/mpid/ch4/src/ch4_comm.c | 24 +- src/mpid/ch4/src/ch4_init.c | 2 +- src/mpid/ch4/src/ch4_persist.c | 100 +++-- src/mpid/ch4/src/ch4_spawn.c | 14 +- src/mpid/ch4/src/mpidig_win.c | 18 +- src/mpid/ch4/src/mpidig_win.h | 2 +- src/mpid/common/bc/mpidu_bc.c | 2 +- src/mpid/common/shm/mpidu_shm_alloc.c | 34 +- src/util/mpir_nodemap.c | 4 +- 235 files changed, 2161 insertions(+), 1608 deletions(-) diff --git a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h index 8b8fb8fc40d..eab92d1decb 100644 --- a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h +++ b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.h @@ -27,15 +27,15 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n size_t recv_extent, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int is_dist_halving, MPIR_Comm * comm, - MPIR_TSP_sched_t sched); + int coll_group, MPIR_TSP_sched_t sched); int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void *tmp_recvbuf, const MPI_Aint * recvcounts, MPI_Aint * displs, MPI_Datatype datatype, MPI_Op op, size_t extent, int tag, - MPIR_Comm * comm, int k, int is_dist_halving, - int step2_nphases, int **step2_nbrs, - int rank, int nranks, int sink_id, - int is_out_vtcs, int *reduce_id_, - MPIR_TSP_sched_t sched); + MPIR_Comm * comm, int coll_group, int k, + int is_dist_halving, int step2_nphases, + int **step2_nbrs, int rank, int nranks, + int sink_id, int is_out_vtcs, + int *reduce_id_, MPIR_TSP_sched_t sched); #endif /* RECEXCHALGO_H_INCLUDED */ diff --git a/src/mpi/coll/algorithms/treealgo/treeutil.c b/src/mpi/coll/algorithms/treealgo/treeutil.c index ad9b003c369..7b522332b49 100644 --- a/src/mpi/coll/algorithms/treealgo/treeutil.c +++ b/src/mpi/coll/algorithms/treealgo/treeutil.c @@ -758,7 +758,7 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, /* Do an allgather to know the current num_children on each rank */ MPIR_Errflag_t errflag = MPIR_ERR_NONE; MPIR_Allgather_impl(&(ct->num_children), 1, MPI_INT, num_childrens, 1, MPI_INT, - comm, errflag); + comm, MPIR_SUBGROUP_NONE, errflag); if (mpi_errno) { goto fn_fail; } diff --git a/src/mpi/coll/allgather/allgather_allcomm_nb.c b/src/mpi/coll/allgather/allgather_allcomm_nb.c index 37800564381..8a818710365 100644 --- a/src/mpi/coll/allgather/allgather_allcomm_nb.c +++ b/src/mpi/coll/allgather/allgather_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datat /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm_ptr, - &req_ptr); + coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c b/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c index fedb0ff358e..4b0932aa5db 100644 --- a/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c +++ b/src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c @@ -15,7 +15,8 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root; MPI_Aint sendtype_sz; @@ -47,7 +48,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (sendcount != 0) { mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, - MPI_BYTE, 0, newcomm_ptr, errflag); + MPI_BYTE, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -58,7 +59,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size, - MPI_BYTE, root, comm_ptr, errflag); + MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -66,7 +67,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (recvcount != 0) { root = 0; mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, errflag); + recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { @@ -74,7 +75,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (recvcount != 0) { root = 0; mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, errflag); + recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -82,7 +83,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size, - MPI_BYTE, root, comm_ptr, errflag); + MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/allgather/allgather_intra_brucks.c b/src/mpi/coll/allgather/allgather_intra_brucks.c index 4e22866ebeb..c3bd63bd0aa 100644 --- a/src/mpi/coll/allgather/allgather_intra_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_brucks.c @@ -19,7 +19,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/allgather/allgather_intra_k_brucks.c b/src/mpi/coll/allgather/allgather_intra_k_brucks.c index 010a5a8567d..6e86b7a23cd 100644 --- a/src/mpi/coll/allgather/allgather_intra_k_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_k_brucks.c @@ -22,8 +22,8 @@ int MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, int k, - MPIR_Errflag_t errflag) + MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, + int coll_group, int k, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int i, j; diff --git a/src/mpi/coll/allgather/allgather_intra_recexch.c b/src/mpi/coll/allgather/allgather_intra_recexch.c index 20a2f0501b7..2e4cba4ef66 100644 --- a/src/mpi/coll/allgather/allgather_intra_recexch.c +++ b/src/mpi/coll/allgather/allgather_intra_recexch.c @@ -18,7 +18,7 @@ * */ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int recexch_type, int k, int single_phase_recv, MPIR_Errflag_t errflag) { diff --git a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c index 3dd37b22ea5..6a4316fbf30 100644 --- a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c +++ b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c @@ -23,7 +23,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/allgather/allgather_intra_ring.c b/src/mpi/coll/allgather/allgather_intra_ring.c index 12f19b0b427..88df120075d 100644 --- a/src/mpi/coll/allgather/allgather_intra_ring.c +++ b/src/mpi/coll/allgather/allgather_intra_ring.c @@ -25,7 +25,8 @@ int MPIR_Allgather_intra_ring(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c b/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c index 1a7fc1430b3..c2ecc46466b 100644 --- a/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c +++ b/src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c @@ -7,7 +7,8 @@ int MPIR_Allgatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +16,7 @@ int MPIR_Allgatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Data /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm_ptr, &req_ptr); + comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c b/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c index a5c05fae9a2..4f4f1b3a15f 100644 --- a/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c +++ b/src/mpi/coll/allgatherv/allgatherv_inter_remote_gather_local_bcast.c @@ -19,7 +19,8 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int remote_size, mpi_errno, root, rank; MPIR_Comm *newcomm_ptr = NULL; @@ -34,23 +35,23 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain /* gatherv from right group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* gatherv to right group */ root = 0; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* gatherv to left group */ root = 0; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* gatherv from left group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -71,7 +72,7 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain mpi_errno = MPIR_Type_commit_impl(&newtype); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Bcast_allcomm_auto(recvbuf, 1, newtype, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Bcast_allcomm_auto(recvbuf, 1, newtype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&newtype); diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c index 99f867f732d..1769221ac5d 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c @@ -23,7 +23,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, rank, j, i; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c index d083b43e411..390b81679e1 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c @@ -25,7 +25,8 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank, j, i; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c index 016c35b77c7..10fee8a9cb6 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c @@ -28,7 +28,8 @@ int MPIR_Allgatherv_intra_ring(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank, i, left, right; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/allreduce/allreduce_allcomm_nb.c b/src/mpi/coll/allreduce/allreduce_allcomm_nb.c index c076b2bcd8e..bc7decc7290 100644 --- a/src/mpi/coll/allreduce/allreduce_allcomm_nb.c +++ b/src/mpi/coll/allreduce/allreduce_allcomm_nb.c @@ -7,13 +7,14 @@ int MPIR_Allreduce_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Iallreduce(sendbuf, recvbuf, count, datatype, op, comm_ptr, &req_ptr); + mpi_errno = + MPIR_Iallreduce(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c b/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c index ab199653c20..ae4839a3470 100644 --- a/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c +++ b/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c @@ -15,7 +15,8 @@ int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPI_Aint true_extent, true_lb, extent; @@ -39,7 +40,8 @@ int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbu newcomm_ptr = comm_ptr->local_comm; /* Do a local reduce on this intracommunicator */ - mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* Do a exchange between local and remote rank 0 on this intercommunicator */ @@ -51,7 +53,7 @@ int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbu } /* Do a local broadcast on this intracommunicator */ - mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c index cb46472865c..df552c6d85e 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c @@ -14,8 +14,9 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int k, - int single_phase_recv, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm, int coll_group, + int k, int single_phase_recv, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int rank, nranks, nbr; diff --git a/src/mpi/coll/allreduce/allreduce_intra_recexch.c b/src/mpi/coll/allreduce/allreduce_intra_recexch.c index 5503a54ec91..adf6988a40f 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recexch.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recexch.c @@ -17,8 +17,8 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int k, int single_phase_recv, - MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm, int coll_group, int k, + int single_phase_recv, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int is_commutative, rank, nranks, nbr, myidx; diff --git a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c index 896a8d5359a..a87d75cd553 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c @@ -22,7 +22,8 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPIR_CHKLMEM_DECL(1); int comm_size, rank; diff --git a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c index 1371b31de34..5d7c3320bd4 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c @@ -43,7 +43,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPIR_CHKLMEM_DECL(3); int comm_size, rank; diff --git a/src/mpi/coll/allreduce/allreduce_intra_ring.c b/src/mpi/coll/allreduce/allreduce_intra_ring.c index ca87f50b9ce..f05b393060c 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_ring.c +++ b/src/mpi/coll/allreduce/allreduce_intra_ring.c @@ -11,7 +11,7 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int i, src, dst; @@ -96,7 +96,7 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count /* Phase 3: Allgatherv ring, so everyone has the reduced data */ mpi_errno = MPIR_Allgatherv_intra_ring(MPI_IN_PLACE, -1, MPI_DATATYPE_NULL, recvbuf, cnts, - displs, datatype, comm, errflag); + displs, datatype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPL_free(cnts); diff --git a/src/mpi/coll/allreduce/allreduce_intra_smp.c b/src/mpi/coll/allreduce/allreduce_intra_smp.c index 24d1ef57a47..4ad6920dcbf 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_smp.c +++ b/src/mpi/coll/allreduce/allreduce_intra_smp.c @@ -6,7 +6,7 @@ #include "mpiimpl.h" int MPIR_Allreduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -22,11 +22,13 @@ int MPIR_Allreduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * allreduce is in recvbuf. Pass that as the sendbuf to reduce. */ mpi_errno = - MPIR_Reduce(recvbuf, NULL, count, datatype, op, 0, comm_ptr->node_comm, errflag); + MPIR_Reduce(recvbuf, NULL, count, datatype, op, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = - MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr->node_comm, errflag); + MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { @@ -41,13 +43,14 @@ int MPIR_Allreduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, if (comm_ptr->node_roots_comm != NULL) { mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, recvbuf, count, datatype, op, comm_ptr->node_roots_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* now broadcast the result among local processes */ if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } goto fn_exit; diff --git a/src/mpi/coll/allreduce/allreduce_intra_tree.c b/src/mpi/coll/allreduce/allreduce_intra_tree.c index 7a6bb4f9709..bb09fcf09eb 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_tree.c +++ b/src/mpi/coll/allreduce/allreduce_intra_tree.c @@ -14,7 +14,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, int tree_type, int k, int chunk_size, int buffer_per_child, MPIR_Errflag_t errflag) { diff --git a/src/mpi/coll/alltoall/alltoall_allcomm_nb.c b/src/mpi/coll/alltoall/alltoall_allcomm_nb.c index ecb74cd135f..ec3a545e412 100644 --- a/src/mpi/coll/alltoall/alltoall_allcomm_nb.c +++ b/src/mpi/coll/alltoall/alltoall_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Alltoall_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Alltoall_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Dataty /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Ialltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm_ptr, - &req_ptr); + coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c b/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c index 932d7965c50..1f6c59b0eed 100644 --- a/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c @@ -19,7 +19,7 @@ int MPIR_Alltoall_inter_pairwise_exchange(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int local_size, remote_size, max_size, i; MPI_Aint sendtype_extent, recvtype_extent; diff --git a/src/mpi/coll/alltoall/alltoall_intra_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_brucks.c index 2aea8b1860d..e26df86f4fa 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_brucks.c @@ -23,7 +23,8 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, pof2; MPI_Aint sendtype_extent, recvtype_extent; diff --git a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c index 9286b2a5ee4..33ae4224dc4 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c @@ -108,7 +108,7 @@ int MPIR_Alltoall_intra_k_brucks(const void *sendbuf, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcnt, - MPI_Datatype recvtype, MPIR_Comm * comm, int k, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int k, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c index 28dcd7ed7d6..829b3a9f1f3 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c @@ -28,7 +28,7 @@ int MPIR_Alltoall_intra_pairwise(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; MPI_Aint sendtype_extent, recvtype_extent; diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c index 22604189e30..2746fe79771 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c @@ -25,7 +25,8 @@ int MPIR_Alltoall_intra_pairwise_sendrecv_replace(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, j; MPI_Aint recvtype_extent; diff --git a/src/mpi/coll/alltoall/alltoall_intra_scattered.c b/src/mpi/coll/alltoall/alltoall_intra_scattered.c index f1986f99ce6..f4437bbe9ca 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_scattered.c +++ b/src/mpi/coll/alltoall/alltoall_intra_scattered.c @@ -33,7 +33,7 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; MPI_Aint sendtype_extent, recvtype_extent; diff --git a/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c b/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c index 40854e91c20..288427a910a 100644 --- a/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c +++ b/src/mpi/coll/alltoallv/alltoallv_allcomm_nb.c @@ -8,7 +8,8 @@ int MPIR_Alltoallv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -16,7 +17,7 @@ int MPIR_Alltoallv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Ialltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, - recvtype, comm_ptr, &req_ptr); + recvtype, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c b/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c index ea4cb5d1962..abc3b6eb39b 100644 --- a/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c @@ -23,7 +23,8 @@ int MPIR_Alltoallv_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint * const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int local_size, remote_size, max_size, i; MPI_Aint send_extent, recv_extent; diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c index 7f6cc8d4814..6835bddc3cb 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c @@ -22,7 +22,8 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, j; MPI_Aint recv_extent; diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c index 4adaeb83681..6fba549f28a 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c @@ -24,7 +24,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; diff --git a/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c b/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c index e3e55da8a89..ca12d33031a 100644 --- a/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c +++ b/src/mpi/coll/alltoallw/alltoallw_allcomm_nb.c @@ -8,7 +8,7 @@ int MPIR_Alltoallw_allcomm_nb(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, + const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -17,7 +17,7 @@ int MPIR_Alltoallw_allcomm_nb(const void *sendbuf, const MPI_Aint sendcounts[], /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, - recvtypes, comm_ptr, &req_ptr); + recvtypes, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c b/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c index c1918c7be0a..fafa9502e00 100644 --- a/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c @@ -23,7 +23,8 @@ int MPIR_Alltoallw_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint s const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int local_size, remote_size, max_size, i; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c index d0ce2ccc10c..bd67b157b71 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c @@ -23,7 +23,8 @@ int MPIR_Alltoallw_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, i, j; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c index f0063d4ad91..348c5528af0 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c @@ -23,7 +23,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, i; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/barrier/barrier_allcomm_nb.c b/src/mpi/coll/barrier/barrier_allcomm_nb.c index 72a579949fd..7c1f26956a6 100644 --- a/src/mpi/coll/barrier/barrier_allcomm_nb.c +++ b/src/mpi/coll/barrier/barrier_allcomm_nb.c @@ -5,13 +5,13 @@ #include "mpiimpl.h" -int MPIR_Barrier_allcomm_nb(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_allcomm_nb(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Ibarrier(comm_ptr, &req_ptr); + mpi_errno = MPIR_Ibarrier(comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/barrier/barrier_inter_bcast.c b/src/mpi/coll/barrier/barrier_inter_bcast.c index e1d81c23443..3d775eb1272 100644 --- a/src/mpi/coll/barrier/barrier_inter_bcast.c +++ b/src/mpi/coll/barrier/barrier_inter_bcast.c @@ -17,7 +17,7 @@ * group. */ -int MPIR_Barrier_inter_bcast(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_inter_bcast(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, mpi_errno = MPI_SUCCESS, root; int i = 0; @@ -34,28 +34,28 @@ int MPIR_Barrier_inter_bcast(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) newcomm_ptr = comm_ptr->local_comm; /* do a barrier on the local intracommunicator */ - mpi_errno = MPIR_Barrier(newcomm_ptr, errflag); + mpi_errno = MPIR_Barrier(newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (comm_ptr->is_low_group) { /* bcast to right */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* receive bcast from right */ root = 0; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* receive bcast from left */ root = 0; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* bcast to left */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c index 927c62843c4..849806e5658 100644 --- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c +++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c @@ -16,7 +16,7 @@ * process i sends to process (i + 2^k) % p and receives from process * (i - 2^k + p) % p. */ -int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int size, rank, src, dst, mask, mpi_errno = MPI_SUCCESS; @@ -42,7 +42,8 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errfla /* Algorithm: high radix dissemination * Similar to dissemination algorithm, but generalized with high radix k */ -int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_Errflag_t errflag) +int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int i, j, nranks, rank; @@ -62,7 +63,7 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_Errflag_t e k = nranks; if (k == 2) { - return MPIR_Barrier_intra_dissemination(comm, errflag); + return MPIR_Barrier_intra_dissemination(comm, coll_group, errflag); } /* If k value is greater than the maximum radix defined by MAX_RADIX macro, diff --git a/src/mpi/coll/barrier/barrier_intra_recexch.c b/src/mpi/coll/barrier/barrier_intra_recexch.c index a46a6e25d8e..72e0ef918e3 100644 --- a/src/mpi/coll/barrier/barrier_intra_recexch.c +++ b/src/mpi/coll/barrier/barrier_intra_recexch.c @@ -8,13 +8,13 @@ /* Algorithm: call Allreduce's recursive exchange algorithm */ -int MPIR_Barrier_intra_recexch(MPIR_Comm * comm, int k, int single_phase_recv, +int MPIR_Barrier_intra_recexch(MPIR_Comm * comm, int coll_group, int k, int single_phase_recv, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allreduce_intra_recexch(MPI_IN_PLACE, NULL, 0, - MPI_BYTE, MPI_SUM, comm, + MPI_BYTE, MPI_SUM, comm, coll_group, k, single_phase_recv, errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/barrier/barrier_intra_smp.c b/src/mpi/coll/barrier/barrier_intra_smp.c index f723be96165..36711f8560e 100644 --- a/src/mpi/coll/barrier/barrier_intra_smp.c +++ b/src/mpi/coll/barrier/barrier_intra_smp.c @@ -5,7 +5,7 @@ #include "mpiimpl.h" -int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -13,13 +13,13 @@ int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) /* do the intranode barrier on all nodes */ if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Barrier(comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* do the barrier across roots of all nodes */ if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Barrier(comm_ptr->node_roots_comm, errflag); + mpi_errno = MPIR_Barrier(comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -28,7 +28,7 @@ int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) * anything) */ if (comm_ptr->node_comm != NULL) { int i = 0; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, 0, comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_allcomm_nb.c b/src/mpi/coll/bcast/bcast_allcomm_nb.c index 99c615ca658..5f2f19ff1f4 100644 --- a/src/mpi/coll/bcast/bcast_allcomm_nb.c +++ b/src/mpi/coll/bcast/bcast_allcomm_nb.c @@ -6,13 +6,13 @@ #include "mpiimpl.h" int MPIR_Bcast_allcomm_nb(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Ibcast(buffer, count, datatype, root, comm_ptr, &req_ptr); + mpi_errno = MPIR_Ibcast(buffer, count, datatype, root, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c index b22a916ff32..4bace6c2a72 100644 --- a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c +++ b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c @@ -14,7 +14,8 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, mpi_errno; MPI_Status status; @@ -50,7 +51,8 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, /* now do the usual broadcast on this intracommunicator * with rank 0 as root. */ - mpi_errno = MPIR_Bcast_allcomm_auto(buffer, count, datatype, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Bcast_allcomm_auto(buffer, count, datatype, 0, newcomm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_intra_binomial.c b/src/mpi/coll/bcast/bcast_intra_binomial.c index 77d3bd73d24..0040b6a98ce 100644 --- a/src/mpi/coll/bcast/bcast_intra_binomial.c +++ b/src/mpi/coll/bcast/bcast_intra_binomial.c @@ -13,7 +13,8 @@ int MPIR_Bcast_intra_binomial(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, src, dst; int relative_rank, mask; diff --git a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c index f42d0515053..3ef926f3630 100644 --- a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c @@ -14,7 +14,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, int tree_type, + int root, MPIR_Comm * comm_ptr, int coll_group, int tree_type, int branching_factor, int is_nb, int chunk_size, int recv_pre_posted, MPIR_Errflag_t errflag) { diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c index 95f23dd51e5..43abef6d1be 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c @@ -29,7 +29,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { MPI_Status status; diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c index 41e17b99e3c..ad44f2f03dd 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c @@ -24,7 +24,8 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/bcast/bcast_intra_smp.c b/src/mpi/coll/bcast/bcast_intra_smp.c index 9ff9e684e54..402e2dbc860 100644 --- a/src/mpi/coll/bcast/bcast_intra_smp.c +++ b/src/mpi/coll/bcast/bcast_intra_smp.c @@ -11,7 +11,7 @@ * be able to make changes along these lines almost exclusively in this function * and some new functions. [goodell@ 2008/01/07] */ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Aint type_size, nbytes = 0; @@ -63,13 +63,14 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in if (comm_ptr->node_roots_comm != NULL) { mpi_errno = MPIR_Bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* perform the intranode broadcast on all except for the root's node */ if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { /* (nbytes > MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_ptr->size >= MPIR_CVAR_BCAST_MIN_PROCS) */ @@ -87,7 +88,7 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in * right algorithms here. */ mpi_errno = MPIR_Bcast(buffer, count, datatype, MPIR_Get_intranode_rank(comm_ptr, root), - comm_ptr->node_comm, errflag); + comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -95,7 +96,7 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in if (comm_ptr->node_roots_comm != NULL) { mpi_errno = MPIR_Bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -104,7 +105,8 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in /* FIXME binomial may not be the best algorithm for on-node * bcast. We need a more comprehensive system for selecting the * right algorithms here. */ - mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { /* large msg or non-pof2 */ @@ -112,9 +114,8 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in /* FIXME It would be good to have an SMP-aware version of this * algorithm that (at least approximately) minimized internode * communication. */ - mpi_errno = - MPIR_Bcast_intra_scatter_ring_allgather(buffer, count, datatype, root, comm_ptr, - errflag); + mpi_errno = MPIR_Bcast_intra_scatter_ring_allgather(buffer, count, datatype, root, + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/bcast/bcast_intra_tree.c b/src/mpi/coll/bcast/bcast_intra_tree.c index 6cd2f02ba7e..560232ea825 100644 --- a/src/mpi/coll/bcast/bcast_intra_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_tree.c @@ -12,7 +12,7 @@ int MPIR_Bcast_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, int tree_type, + int root, MPIR_Comm * comm_ptr, int coll_group, int tree_type, int branching_factor, int is_nb, MPIR_Errflag_t errflag) { int rank, comm_size, src, dst, *p, j, k, lrank = -1, is_contig; diff --git a/src/mpi/coll/exscan/exscan_allcomm_nb.c b/src/mpi/coll/exscan/exscan_allcomm_nb.c index a1050eb428f..317745d2ad9 100644 --- a/src/mpi/coll/exscan/exscan_allcomm_nb.c +++ b/src/mpi/coll/exscan/exscan_allcomm_nb.c @@ -6,14 +6,14 @@ #include "mpiimpl.h" int MPIR_Exscan_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Iexscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, &req_ptr); + mpi_errno = MPIR_Iexscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c index b21c91636cd..3d10c8d08cf 100644 --- a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c +++ b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c @@ -48,7 +48,8 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPI_Status status; int rank, comm_size; diff --git a/src/mpi/coll/gather/gather_allcomm_nb.c b/src/mpi/coll/gather/gather_allcomm_nb.c index 91f4f71b68a..6a81234d6cb 100644 --- a/src/mpi/coll/gather/gather_allcomm_nb.c +++ b/src/mpi/coll/gather/gather_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Gather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Gather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Igather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, - &req_ptr); + coll_group, &req_ptr); mpi_errno = MPIC_Wait(req_ptr); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/gather/gather_inter_linear.c b/src/mpi/coll/gather/gather_inter_linear.c index fbf29f904e5..a1f7dbdec37 100644 --- a/src/mpi/coll/gather/gather_inter_linear.c +++ b/src/mpi/coll/gather/gather_inter_linear.c @@ -15,7 +15,7 @@ int MPIR_Gather_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int remote_size, mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c b/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c index 934aca48fa3..22a3e71aa9d 100644 --- a/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c +++ b/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c @@ -16,7 +16,8 @@ int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS; MPI_Status status; @@ -67,7 +68,7 @@ int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, MPI_Aint sen /* now do the a local gather on this intracommunicator */ mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, MPI_BYTE, 0, newcomm_ptr, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { diff --git a/src/mpi/coll/gather/gather_intra_binomial.c b/src/mpi/coll/gather/gather_intra_binomial.c index d8915452cda..97b1dfd1d88 100644 --- a/src/mpi/coll/gather/gather_intra_binomial.c +++ b/src/mpi/coll/gather/gather_intra_binomial.c @@ -39,7 +39,7 @@ */ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c index cabf1ef9bb8..bf571243595 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c @@ -22,7 +22,8 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int comm_size, rank; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_nb.c b/src/mpi/coll/gatherv/gatherv_allcomm_nb.c index 3a49ce11a4e..433035a0850 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_nb.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Gatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, + MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -16,7 +16,7 @@ int MPIR_Gatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatyp /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, - comm_ptr, &req_ptr); + comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c b/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c index 62e702e4679..2920955fd69 100644 --- a/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c +++ b/src/mpi/coll/iallgather/iallgather_inter_sched_local_gather_remote_bcast.c @@ -14,7 +14,8 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, local_size, remote_size, root; @@ -46,7 +47,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (sendcount != 0) { mpi_errno = MPIR_Igather_intra_sched_auto(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, MPI_BYTE, 0, - newcomm_ptr, s); + newcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -58,7 +59,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ibcast_inter_sched_auto(tmp_buf, sendcount * local_size * sendtype_sz, - MPI_BYTE, root, comm_ptr, s); + MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -68,7 +69,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (recvcount != 0) { root = 0; mpi_errno = MPIR_Ibcast_inter_sched_auto(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -77,7 +78,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (recvcount != 0) { root = 0; mpi_errno = MPIR_Ibcast_inter_sched_auto(recvbuf, recvcount * remote_size, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -87,7 +88,7 @@ int MPIR_Iallgather_inter_sched_local_gather_remote_bcast(const void *sendbuf, M if (sendcount != 0) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ibcast_inter_sched_auto(tmp_buf, sendcount * local_size * sendtype_sz, - MPI_BYTE, root, comm_ptr, s); + MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c index 955a1447ce5..65fb949e5b3 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c @@ -16,7 +16,8 @@ */ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2, rem, src, dst; diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c index dd0ba4c217a..c035d04939b 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c @@ -46,7 +46,8 @@ static int reset_shared_state(MPIR_Comm * comm, int tag, void *state) int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct shared_state *ss = NULL; diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c index 78bbbacce62..16f148f3187 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c @@ -22,7 +22,8 @@ */ int MPIR_Iallgather_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; diff --git a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c index 8e526ac45b3..7743e2bb7ca 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c @@ -10,7 +10,8 @@ int MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, int k, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, j; diff --git a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c index 6f554bacecc..9b6d213fb73 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c @@ -12,7 +12,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(int rank, int n MPI_Datatype recvtype, size_t recv_extent, MPI_Aint recvcount, int tag, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -65,7 +65,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step1(int step1_sendto, int * void *recvbuf, size_t recv_extent, MPI_Aint recvcount, MPI_Datatype recvtype, int n_invtcs, int *invtx, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i; @@ -111,7 +111,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step2(int step1_sendto, int s void *recvbuf, size_t recv_extent, MPI_Aint recvcount, MPI_Datatype recvtype, int is_dist_halving, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int phase, i, j, count, nbr, offset, rank_for_offset; @@ -191,7 +191,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step3(int step1_sendto, int * int nranks, int k, int nrecvs, int *recv_id, int tag, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, vtx_id; @@ -232,8 +232,8 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step3(int step1_sendto, int * int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, int is_dist_halving, int k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int is_dist_halving, + int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace, i; @@ -292,7 +292,7 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco MPIR_TSP_Iallgather_sched_intra_recexch_step1(step1_sendto, step1_recvfrom, step1_nrecvs, is_inplace, rank, tag, sendbuf, recvbuf, recv_extent, recvcount, recvtype, n_invtcs, - &invtx, comm, sched); + &invtx, comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); @@ -302,7 +302,8 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco if (step1_sendto == -1) { MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(rank, nranks, k, p_of_k, log_pofk, T, recvbuf, recvtype, recv_extent, - recvcount, tag, comm, sched); + recvcount, tag, comm, coll_group, + sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); } @@ -312,13 +313,14 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco MPIR_TSP_Iallgather_sched_intra_recexch_step2(step1_sendto, step2_nphases, step2_nbrs, rank, nranks, k, p_of_k, log_pofk, T, &nrecvs, &recv_id, tag, recvbuf, recv_extent, recvcount, recvtype, - is_dist_halving, comm, sched); + is_dist_halving, comm, coll_group, sched); /* Step 3: This is reverse of Step 1. Ranks that participated in Step 2 * send the data to non-partcipating ranks */ MPIR_TSP_Iallgather_sched_intra_recexch_step3(step1_sendto, step1_recvfrom, step1_nrecvs, step2_nphases, recvbuf, recvcount, nranks, k, - nrecvs, recv_id, tag, recvtype, comm, sched); + nrecvs, recv_id, tag, recvtype, comm, coll_group, + sched); /* free the memory */ for (i = 0; i < step2_nphases; i++) diff --git a/src/mpi/coll/iallgather/iallgather_tsp_ring.c b/src/mpi/coll/iallgather/iallgather_tsp_ring.c index f9b725ba9fc..9798e63aae7 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_ring.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_ring.c @@ -9,7 +9,7 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, src, dst, copy_dst; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c b/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c index 575dbf7bedb..70d273e1074 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_inter_sched_remote_gather_local_bcast.c @@ -21,7 +21,8 @@ int MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int remote_size, root, rank; @@ -37,23 +38,27 @@ int MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(const void *sendbuf, /* gatherv from right group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* gatherv to right group */ root = 0; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* gatherv to left group */ root = 0; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* gatherv from left group */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Igatherv_inter_sched_auto(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, s); + recvcounts, displs, recvtype, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -76,7 +81,7 @@ int MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(const void *sendbuf, mpi_errno = MPIR_Type_commit_impl(&newtype); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, 1, newtype, 0, newcomm_ptr, s); + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, 1, newtype, 0, newcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&newtype); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c index a849dd3522e..dd764574004 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c @@ -8,7 +8,8 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, j, i; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c index a178c94799d..6fc7a618799 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c @@ -9,7 +9,8 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, i, j, k; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c index 76b390f94b3..74e366d9d60 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c @@ -8,7 +8,8 @@ int MPIR_Iallgatherv_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c index 896cc02847e..3cbc1785426 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c @@ -29,7 +29,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int k, MPIR_TSP_sched_t sched) { int i, j, l; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c index 48505646110..12c7e0e0130 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c @@ -13,7 +13,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(int rank, int size_t recv_extent, const MPI_Aint * recvcounts, const MPI_Aint * displs, int tag, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -74,7 +74,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(int step1_sendto, int const MPI_Aint * displs, MPI_Datatype recvtype, int n_invtcs, int *invtx, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, vtx_id; @@ -120,7 +120,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n size_t recv_extent, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int is_dist_halving, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int phase, i, j, count, nbr, offset, rank_for_offset; @@ -206,7 +206,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step3(int step1_sendto, int const MPI_Aint * recvcounts, int nranks, int k, int nrecvs, int *recv_id, int tag, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPI_Aint total_count = 0; @@ -244,7 +244,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - int is_dist_halving, int k, MPIR_TSP_sched_t sched) + int coll_group, int is_dist_halving, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace, i; @@ -300,7 +301,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(step1_sendto, step1_recvfrom, step1_nrecvs, is_inplace, rank, tag, sendbuf, recvbuf, recv_extent, recvcounts, displs, recvtype, - n_invtcs, &invtx, comm, sched); + n_invtcs, &invtx, comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); @@ -311,7 +312,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(rank, nranks, k, p_of_k, log_pofk, T, recvbuf, recvtype, recv_extent, recvcounts, displs, - tag, comm, sched); + tag, comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_fence(sched); MPIR_ERR_CHECK(mpi_errno); } @@ -321,13 +322,15 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(step1_sendto, step2_nphases, step2_nbrs, rank, nranks, k, p_of_k, log_pofk, T, &nrecvs, &recv_id, tag, recvbuf, recv_extent, recvcounts, - displs, recvtype, is_dist_halving, comm, sched); + displs, recvtype, is_dist_halving, comm, + coll_group, sched); /* Step 3: This is reverse of Step 1. Ranks that participated in Step 2 * send the data to non-partcipating ranks */ MPIR_TSP_Iallgatherv_sched_intra_recexch_step3(step1_sendto, step1_recvfrom, step1_nrecvs, step2_nphases, recvbuf, recvcounts, nranks, k, - nrecvs, recv_id, tag, recvtype, comm, sched); + nrecvs, recv_id, tag, recvtype, comm, coll_group, + sched); fn_exit: /* free the memory */ diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c index dfcdf0e81c1..129c1f67d94 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c @@ -10,7 +10,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { size_t extent; diff --git a/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c b/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c index 1e4e312c83f..6ce03a2255c 100644 --- a/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c +++ b/src/mpi/coll/iallreduce/iallreduce_inter_sched_remote_reduce_local_bcast.c @@ -19,7 +19,7 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, root; @@ -35,7 +35,8 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* no barrier, these reductions can be concurrent */ @@ -43,13 +44,15 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v /* reduce to rank 0 of right group */ root = 0; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* no barrier, these reductions can be concurrent */ @@ -57,7 +60,8 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = - MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_inter_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -72,7 +76,7 @@ int MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(const void *sendbuf, v } lcomm_ptr = comm_ptr->local_comm; - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, lcomm_ptr, s); + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, lcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c index 36a26b6e42c..da7d0186c76 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c @@ -8,7 +8,7 @@ /* implements the naive intracomm allreduce, that is, reduce followed by bcast */ int MPIR_Iallreduce_intra_sched_naive(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank; @@ -17,17 +17,19 @@ int MPIR_Iallreduce_intra_sched_naive(const void *sendbuf, void *recvbuf, MPI_Ai if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) { mpi_errno = - MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, comm_ptr, s); + MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr, s); + MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c index 8f409293afe..36278689541 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c @@ -7,7 +7,8 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2, rem, comm_size, is_commutative, rank; diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c index 1fc14c6a367..44cc9175f1a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c @@ -9,7 +9,7 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, newrank, pof2, rem; diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c index 20fb4fa88c9..e0acb2954ff 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c @@ -8,7 +8,7 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int is_commutative; @@ -26,7 +26,8 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint if (!is_commutative) { /* use flat fallback */ mpi_errno = - MPIR_Iallreduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, comm_ptr, s); + MPIR_Iallreduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -39,11 +40,14 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint if ((sendbuf == MPI_IN_PLACE) && (comm_ptr->node_comm->rank != 0)) { /* IN_PLACE and not root of reduce. Data supplied to this * allreduce is in recvbuf. Pass that as the sendbuf to reduce. */ - mpi_errno = MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, nc, s); + mpi_errno = + MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, nc, + MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, nc, s); + MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, nc, + MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -59,14 +63,16 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint /* now do an IN_PLACE allreduce among the local roots of all nodes */ if (nrc != NULL) { mpi_errno = - MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, recvbuf, count, datatype, op, nrc, s); + MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, recvbuf, count, datatype, op, nrc, + MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } /* now broadcast the result among local processes */ if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, nc, s); + mpi_errno = + MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, nc, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c b/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c index be266211063..54129fdeaca 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c @@ -10,7 +10,8 @@ /* Routine to schedule a pipelined tree based allreduce */ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_commutative = MPIR_Op_is_commutative(op); @@ -35,7 +36,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPIR_CVAR_IALLREDUCE_INTRA_ALGORITHM_tsp_recexch_single_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); break; @@ -43,7 +44,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPIR_CVAR_IALLREDUCE_INTRA_ALGORITHM_tsp_recexch_multiple_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); break; @@ -56,7 +57,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, "Iallreduce gentran_tree cannot be applied.\n"); mpi_errno = MPIR_TSP_Iallreduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_Iallreduce_tree_type, + comm, coll_group, MPIR_Iallreduce_tree_type, MPIR_CVAR_IALLREDUCE_TREE_KVAL, MPIR_CVAR_IALLREDUCE_TREE_PIPELINE_CHUNK_SIZE, MPIR_CVAR_IALLREDUCE_TREE_BUFFER_PER_CHILD, @@ -68,7 +69,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, "Iallreduce gentran_ring cannot be applied.\n"); mpi_errno = MPIR_TSP_Iallreduce_sched_intra_ring(sendbuf, recvbuf, count, datatype, - op, comm, sched); + op, comm, coll_group, sched); break; case MPIR_CVAR_IALLREDUCE_INTRA_ALGORITHM_tsp_recexch_reduce_scatter_recexch_allgatherv: /* This algorithm will work for commutative @@ -86,6 +87,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, datatype, op, comm, + coll_group, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); break; @@ -97,7 +99,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_recexch_single_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER, cnt->u. iallreduce.intra_tsp_recexch_single_buffer. @@ -107,7 +109,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_recexch_multiple_buffer: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, cnt->u. iallreduce.intra_tsp_recexch_single_buffer. @@ -117,7 +119,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_tree: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - comm, + comm, coll_group, cnt->u.iallreduce. intra_tsp_tree.tree_type, cnt->u.iallreduce.intra_tsp_tree.k, @@ -131,13 +133,13 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_ring: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_ring(sendbuf, recvbuf, count, datatype, op, - comm, sched); + comm, coll_group, sched); break; case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Iallreduce_intra_tsp_recexch_reduce_scatter_recexch_allgatherv: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv - (sendbuf, recvbuf, count, datatype, op, comm, + (sendbuf, recvbuf, count, datatype, op, comm, coll_group, cnt->u.iallreduce.intra_tsp_recexch_reduce_scatter_recexch_allgatherv.k, sched); break; @@ -155,7 +157,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, fallback: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(sendbuf, recvbuf, count, - datatype, op, comm, + datatype, op, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, MPIR_CVAR_IALLREDUCE_RECEXCH_KVAL, sched); fn_exit: diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c index 91aa69af236..c498685986a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c @@ -11,8 +11,8 @@ /* Routine to schedule a recursive exchange based allreduce */ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, int per_nbr_buffer, int k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int per_nbr_buffer, + int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace, i, j; @@ -76,7 +76,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, tag, extent, dtcopy_id, recv_id, reduce_id, vtcs, is_inplace, step1_sendto, in_step2, step1_nrecvs, step1_recvfrom, per_nbr_buffer, &step1_recvbuf, - comm, sched); + comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_sink(sched, &step1_id); /* sink for all the tasks up to end of Step 1 */ if (mpi_errno) diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c index b741e42df12..5053de12eed 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c @@ -15,7 +15,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - int k, + int coll_group, int k, MPIR_TSP_sched_t sched) { @@ -86,7 +86,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co tag, extent, dtcopy_id, recv_id, reduce_id, vtcs, is_inplace, step1_sendto, in_step2, step1_nrecvs, step1_recvfrom, per_nbr_buffer, &step1_recvbuf, - comm, sched); + comm, coll_group, sched); mpi_errno = MPIR_TSP_sched_sink(sched, &sink_id); /* sink for all the tasks up to end of Step 1 */ if (mpi_errno) @@ -119,7 +119,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(recvbuf, tmp_recvbuf, cnts, displs, datatype, op, extent, tag, - comm, k, redscat_algo_type, + comm, coll_group, k, redscat_algo_type, step2_nphases, step2_nbrs, rank, nranks, sink_id, 0, NULL, sched); @@ -128,7 +128,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(step1_sendto, step2_nphases, step2_nbrs, rank, nranks, k, p_of_k, log_pofk, T, &nvtcs, &recv_id, tag, recvbuf, extent, cnts, displs, - datatype, allgather_algo_type, comm, sched); + datatype, allgather_algo_type, comm, + coll_group, sched); } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c index cb6a631cb6d..26902fd80b1 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c @@ -46,7 +46,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, int step1_sendto, bool in_step2, int step1_nrecvs, int *step1_recvfrom, int per_nbr_buffer, void ***step1_recvbuf_, MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, nvtcs, vtx_id; diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h index 47964cb2729..1e362a4c8c6 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.h @@ -16,5 +16,5 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, int step1_sendto, bool in_step2, int step1_nrecvs, int *step1_recvfrom, int per_nbr_buffer, void ***step1_recvbuf_, MPIR_Comm * comm, - MPIR_TSP_sched_t sched); + int coll_group, MPIR_TSP_sched_t sched); #endif /* IALLREDUCE_TSP_RECURSIVE_EXCHANGE_COMMON_H_INCLUDED */ diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c index fb136764b87..f8fec2f41f8 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c @@ -12,7 +12,7 @@ * explained here: http://andrew.gibiansky.com/ */ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, src, dst; @@ -111,7 +111,7 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI /* Phase 3: Allgatherv ring, so everyone has the reduced data */ MPIR_TSP_Iallgatherv_sched_intra_ring(MPI_IN_PLACE, -1, MPI_DATATYPE_NULL, recvbuf, cnts, - displs, datatype, comm, sched); + displs, datatype, comm, coll_group, sched); MPIR_CHKLMEM_FREEALL(); diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c index 64c84cc6af1..0a363224777 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c @@ -10,8 +10,9 @@ /* Routine to schedule a pipelined tree based allreduce */ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, int tree_type, int k, int chunk_size, - int buffer_per_child, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int tree_type, int k, + int chunk_size, int buffer_per_child, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, j, t; diff --git a/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c index b6c46c5b54a..27ed068ce8a 100644 --- a/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c @@ -19,7 +19,8 @@ int MPIR_Ialltoall_inter_sched_pairwise_exchange(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int local_size, remote_size, max_size, i; diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c index 154176f55ff..3338aa5cd75 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c @@ -20,7 +20,8 @@ */ int MPIR_Ialltoall_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c index a634d337005..abccd14055d 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c @@ -19,7 +19,8 @@ * scenario. */ int MPIR_Ialltoall_intra_sched_inplace(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; void *tmp_buf = NULL; diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c index 1f953411065..c6380474e80 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c @@ -24,7 +24,8 @@ */ int MPIR_Ialltoall_intra_sched_pairwise(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c index f8e03a093bd..a2aa1153099 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c @@ -16,7 +16,8 @@ int MPIR_Ialltoall_intra_sched_permuted_sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c index 4d19c6fdf08..c86987531f2 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c @@ -117,7 +117,7 @@ brucks_sched_pup(int pack, void *rbuf, void *pupbuf, MPI_Datatype rtype, MPI_Ain int MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, int k, int buffer_per_phase, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c index 7d79dab6394..7a1b8515c45 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c @@ -35,7 +35,7 @@ copy (buf1)<--recv (buf1) send (buf2) / /* Routine to schedule a ring based allgather */ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c index bf6f6406eb4..86371daebed 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c @@ -37,8 +37,8 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, int batch_size, int bblock, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int batch_size, + int bblock, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int src, dst; diff --git a/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c index b3436310334..48f41bbb411 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c @@ -9,7 +9,8 @@ int MPIR_Ialltoallv_inter_sched_pairwise_exchange(const void *sendbuf, const MPI const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { /* Intercommunicator alltoallv. We use a pairwise exchange algorithm * similar to the one used in intracommunicator alltoallv. Since the diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c index 0c5b926111c..ce1b6706bf6 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c @@ -9,7 +9,7 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size; diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c index 8f039334481..4c51b670d06 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c @@ -9,7 +9,7 @@ int MPIR_Ialltoallv_intra_sched_inplace(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { void *tmp_buf = NULL; int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c index 67ab8288b01..159fc019602 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c @@ -11,7 +11,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, int bblock, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int bblock, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; size_t recv_extent, send_extent, sendtype_size, recvtype_size; diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c index 25566a0041b..7fc0e5ccdfa 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c @@ -11,7 +11,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_inplace(const void *sendbuf, const MPI_Aint const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; size_t recv_extent; diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c index 4266aec8f32..d5339566593 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c @@ -11,8 +11,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, int batch_size, int bblock, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int batch_size, + int bblock, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int src, dst; diff --git a/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c index 163aaf79f34..1b2b35cd414 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c @@ -11,7 +11,8 @@ int MPIR_Ialltoallw_inter_sched_pairwise_exchange(const void *sendbuf, const MPI const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { /* Intercommunicator alltoallw. We use a pairwise exchange algorithm similar to the one used in intracommunicator alltoallw. Since the local and diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c index 1523dd78f0d..13745bf24f9 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c @@ -23,7 +23,7 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, i; diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c index 23bdca07055..eab0e1cb524 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c @@ -21,7 +21,7 @@ int MPIR_Ialltoallw_intra_sched_inplace(const void *sendbuf, const MPI_Aint send const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, i, j; diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c index 21c1667b1d1..04b46f2ff57 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c @@ -12,7 +12,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], MPIR_Comm * comm, - int bblock, MPIR_TSP_sched_t sched) + int coll_group, int bblock, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int tag, vtx_id; diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c index fdac5344e56..22b1dde0e22 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c @@ -12,7 +12,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_inplace(const void *sendbuf, const MPI_Aint const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], MPIR_Comm * comm, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int tag; diff --git a/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c b/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c index e98bff426f0..c74d09e7b0c 100644 --- a/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c +++ b/src/mpi/coll/ibarrier/ibarrier_inter_sched_bcast.c @@ -5,7 +5,7 @@ #include "mpiimpl.h" -int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, root; @@ -23,7 +23,7 @@ int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, MPIR_Sched_t s) /* do a barrier on the local intracommunicator */ if (comm_ptr->local_size != 1) { - mpi_errno = MPIR_Ibarrier_intra_sched_auto(comm_ptr->local_comm, s); + mpi_errno = MPIR_Ibarrier_intra_sched_auto(comm_ptr->local_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -40,26 +40,26 @@ int MPIR_Ibarrier_inter_sched_bcast(MPIR_Comm * comm_ptr, MPIR_Sched_t s) * left group */ if (comm_ptr->is_low_group) { root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); /* receive bcast from right */ root = 0; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* receive bcast from left */ root = 0; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); /* bcast to left */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; - mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, s); + mpi_errno = MPIR_Ibcast_inter_sched_auto(buf, 1, MPI_BYTE, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c index 243b54fc9f1..8ed47527124 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c @@ -18,7 +18,8 @@ * process i sends to process (i + 2^k) % p and receives from process * (i - 2^k + p) % p. */ -int MPIR_Ibarrier_intra_sched_recursive_doubling(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_intra_sched_recursive_doubling(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int size, rank, src, dst, mask; diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c index bb59afc82be..33145105d84 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c @@ -6,7 +6,8 @@ #include "mpiimpl.h" /* Routine to schedule a disdem based barrier with radix k */ -int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int k, MPIR_TSP_sched_t sched) +int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c index 31e2f1de567..29b814e9e56 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_recexch.c @@ -6,7 +6,8 @@ #include "mpiimpl.h" /* Routine to schedule a disdem based barrier with radix k */ -int MPIR_TSP_Ibarrier_sched_intra_recexch(MPIR_Comm * comm, int k, MPIR_TSP_sched_t sched) +int MPIR_TSP_Ibarrier_sched_intra_recexch(MPIR_Comm * comm, int coll_group, int k, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; void *recvbuf = NULL; @@ -14,7 +15,7 @@ int MPIR_TSP_Ibarrier_sched_intra_recexch(MPIR_Comm * comm, int k, MPIR_TSP_sche mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, recvbuf, 0, MPI_BYTE, MPI_SUM, - comm, + comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, k, sched); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c b/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c index d98497b9ff0..bcb3a295d3a 100644 --- a/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c +++ b/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c @@ -6,7 +6,7 @@ #include "mpiimpl.h" /* sched version of CVAR and json based collective selection. Meant only for gentran scheduler */ -int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sched) +int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -23,14 +23,15 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc case MPIR_CVAR_IBARRIER_INTRA_ALGORITHM_tsp_recexch: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, recvbuf, 0, MPI_BYTE, MPI_SUM, - comm, + comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, MPIR_CVAR_IBARRIER_RECEXCH_KVAL, sched); break; case MPIR_CVAR_IBARRIER_INTRA_ALGORITHM_tsp_k_dissemination: mpi_errno = - MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, MPIR_CVAR_IBARRIER_DISSEM_KVAL, + MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, coll_group, + MPIR_CVAR_IBARRIER_DISSEM_KVAL, sched); break; @@ -42,7 +43,7 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibarrier_intra_tsp_recexch: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, recvbuf, 0, MPI_BYTE, - MPI_SUM, comm, + MPI_SUM, comm, coll_group, MPIR_IALLREDUCE_RECEXCH_TYPE_MULTIPLE_BUFFER, cnt->u.ibarrier.intra_tsp_recexch.k, sched); @@ -50,7 +51,7 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibarrier_intra_tsp_k_dissemination: mpi_errno = - MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, + MPIR_TSP_Ibarrier_sched_intra_k_dissemination(comm, coll_group, cnt->u. ibarrier.intra_tsp_k_dissemination. k, sched); @@ -68,7 +69,8 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sc fallback: mpi_errno = MPIR_TSP_Iallreduce_sched_intra_recexch(MPI_IN_PLACE, NULL, 0, - MPI_BYTE, MPI_SUM, comm, 0, 2, sched); + MPI_BYTE, MPI_SUM, comm, coll_group, 0, 2, + sched); fn_exit: return mpi_errno; diff --git a/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c b/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c index d7f7108c970..dc79d1f5282 100644 --- a/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c +++ b/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c @@ -7,7 +7,7 @@ #include "ibcast.h" int MPIR_Ibcast_inter_sched_flat(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; @@ -39,7 +39,8 @@ int MPIR_Ibcast_inter_sched_flat(void *buffer, MPI_Aint count, MPI_Datatype data /* now do the usual broadcast on this intracommunicator * with rank 0 as root. */ mpi_errno = - MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, root, comm_ptr->local_comm, s); + MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, root, comm_ptr->local_comm, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c index 0a6a6d3d038..00943d9d824 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c @@ -13,7 +13,7 @@ * to build up a larger hierarchical broadcast from multiple invocations of this * function. */ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int mask; diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c index 0307947da03..9f26d944e2d 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c @@ -49,7 +49,7 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, dst; diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c index 0418f6e28b8..b3f5103ce0c 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c @@ -26,7 +26,8 @@ */ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank; diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c index d734e7d8c2d..810e12a1dbf 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c @@ -28,7 +28,7 @@ static int sched_test_length(MPIR_Comm * comm, int tag, void *state) * currently make any decision about which particular algorithm to use for any * subcommunicator. */ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint type_size; @@ -69,7 +69,7 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat if (comm_ptr->node_roots_comm != NULL) { mpi_errno = MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, s); + comm_ptr->node_roots_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* don't allow the local ops for the intranode phase to start until this has completed */ @@ -78,7 +78,8 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat /* perform the intranode broadcast on all except for the root's node */ if (comm_ptr->node_comm != NULL) { mpi_errno = - MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, 0, comm_ptr->node_comm, s); + MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, 0, comm_ptr->node_comm, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ibcast/ibcast_tsp_auto.c b/src/mpi/coll/ibcast/ibcast_tsp_auto.c index 52c9016c083..4128a82372f 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_auto.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_auto.c @@ -13,7 +13,8 @@ /* Remove this function when gentran algos are in json file */ static int MPIR_Ibcast_sched_intra_tsp_flat_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int comm_size; @@ -30,13 +31,14 @@ static int MPIR_Ibcast_sched_intra_tsp_flat_auto(void *buffer, MPI_Aint count, /* simplistic implementation for now */ if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS)) { /* gentran tree with knomial tree type, radix 2 and pipeline block size 0 */ - mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - tree_type, radix, block_size, sched); + mpi_errno = + MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, coll_group, + tree_type, radix, block_size, sched); } else { /* gentran scatterv recexch allgather with radix 2 */ mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, root, - comm_ptr, + comm_ptr, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_recexch_doubling, scatterv_k, allgatherv_k, sched); } @@ -51,7 +53,8 @@ static int MPIR_Ibcast_sched_intra_tsp_flat_auto(void *buffer, MPI_Aint count, /* sched version of CVAR and json based collective selection. Meant only for gentran scheduler */ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -72,7 +75,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_tree: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - MPIR_Ibcast_tree_type, + coll_group, MPIR_Ibcast_tree_type, MPIR_CVAR_IBCAST_TREE_KVAL, MPIR_CVAR_IBCAST_TREE_PIPELINE_CHUNK_SIZE, sched); break; @@ -80,7 +83,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_scatterv_recexch_allgatherv: mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, - root, comm_ptr, + root, comm_ptr, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_recexch_doubling, MPIR_CVAR_IBCAST_SCATTERV_KVAL, MPIR_CVAR_IBCAST_ALLGATHERV_RECEXCH_KVAL, @@ -90,13 +93,14 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_scatterv_ring_allgatherv: mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_ring_allgatherv(buffer, count, datatype, - root, comm_ptr, 1, sched); + root, comm_ptr, coll_group, 1, + sched); break; case MPIR_CVAR_IBCAST_INTRA_ALGORITHM_tsp_ring: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - MPIR_TREE_TYPE_KARY, 1, + coll_group, MPIR_TREE_TYPE_KARY, 1, MPIR_CVAR_IBCAST_RING_CHUNK_SIZE, sched); break; @@ -108,6 +112,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibcast_intra_tsp_tree: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, + coll_group, cnt->u.ibcast.intra_tsp_tree.tree_type, cnt->u.ibcast.intra_tsp_tree.k, cnt->u.ibcast.intra_tsp_tree.chunk_size, @@ -116,7 +121,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibcast_intra_tsp_scatterv_recexch_allgatherv: mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, - root, comm_ptr, + root, comm_ptr, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_recexch_doubling, cnt->u. ibcast.intra_tsp_scatterv_recexch_allgatherv.scatterv_k, @@ -129,13 +134,14 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_ring_allgatherv(buffer, count, datatype, root, - comm_ptr, 1, sched); + comm_ptr, coll_group, + 1, sched); break; case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ibcast_intra_tsp_ring: mpi_errno = MPIR_TSP_Ibcast_sched_intra_tree(buffer, count, datatype, root, comm_ptr, - MPIR_TREE_TYPE_KARY, 1, + coll_group, MPIR_TREE_TYPE_KARY, 1, cnt->u.ibcast.intra_tsp_tree.chunk_size, sched); break; @@ -150,7 +156,7 @@ int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datat fallback: mpi_errno = MPIR_Ibcast_sched_intra_tsp_flat_auto(buffer, count, datatype, root, - comm_ptr, sched); + comm_ptr, coll_group, sched); fn_exit: return mpi_errno; diff --git a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c index f7601720547..58c9eba80d5 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c @@ -10,9 +10,9 @@ /* Routine to schedule a scatter followed by recursive exchange based broadcast */ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, int allgatherv_algo, - int scatterv_k, int allgatherv_k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + int allgatherv_algo, int scatterv_k, + int allgatherv_k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; size_t extent, type_size; @@ -188,13 +188,13 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count /* Schedule Allgatherv ring */ mpi_errno = MPIR_TSP_Iallgatherv_sched_intra_ring(MPI_IN_PLACE, cnts[rank], MPI_BYTE, tmp_buf, - cnts, displs, MPI_BYTE, comm, sched); + cnts, displs, MPI_BYTE, comm, coll_group, sched); else /* Schedule Allgatherv recexch */ mpi_errno = MPIR_TSP_Iallgatherv_sched_intra_recexch(MPI_IN_PLACE, cnts[rank], MPI_BYTE, tmp_buf, - cnts, displs, MPI_BYTE, comm, 0, allgatherv_k, - sched); + cnts, displs, MPI_BYTE, comm, coll_group, 0, + allgatherv_k, sched); MPIR_ERR_CHECK(mpi_errno); if (!is_contig) { diff --git a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c index 25016705afa..1ea3e9f0f50 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_ring_allgatherv.c @@ -9,15 +9,15 @@ /* Routine to schedule a scatter followed by ring based broadcast */ int MPIR_TSP_Ibcast_sched_intra_scatterv_ring_allgatherv(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, int scatterv_k, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + int scatterv_k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(buffer, count, datatype, root, - comm, + comm, coll_group, MPIR_CVAR_IALLGATHERV_INTRA_ALGORITHM_tsp_ring, scatterv_k, 0, sched); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ibcast/ibcast_tsp_tree.c b/src/mpi/coll/ibcast/ibcast_tsp_tree.c index f308714e294..b56b3b76cdf 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_tree.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_tree.c @@ -10,8 +10,8 @@ /* Routine to schedule a pipelined tree based broadcast */ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, int tree_type, int k, int chunk_size, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int tree_type, int k, + int chunk_size, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c index 62148d159e4..10ff8466e68 100644 --- a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c @@ -50,7 +50,8 @@ */ int MPIR_Iexscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; diff --git a/src/mpi/coll/igather/igather_inter_sched_long.c b/src/mpi/coll/igather/igather_inter_sched_long.c index fade46c3594..fe38bc5694e 100644 --- a/src/mpi/coll/igather/igather_inter_sched_long.c +++ b/src/mpi/coll/igather/igather_inter_sched_long.c @@ -14,7 +14,7 @@ */ int MPIR_Igather_inter_sched_long(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint remote_size; diff --git a/src/mpi/coll/igather/igather_inter_sched_short.c b/src/mpi/coll/igather/igather_inter_sched_short.c index 81dc2bc2ddf..e2d684b827e 100644 --- a/src/mpi/coll/igather/igather_inter_sched_short.c +++ b/src/mpi/coll/igather/igather_inter_sched_short.c @@ -15,7 +15,7 @@ */ int MPIR_Igather_inter_sched_short(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank; @@ -61,7 +61,7 @@ int MPIR_Igather_inter_sched_short(const void *sendbuf, MPI_Aint sendcount, MPI_ /* now do the a local gather on this intracommunicator */ mpi_errno = MPIR_Igather_intra_sched_auto(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz, MPI_BYTE, 0, - newcomm_ptr, s); + newcomm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { diff --git a/src/mpi/coll/igather/igather_intra_sched_binomial.c b/src/mpi/coll/igather/igather_intra_sched_binomial.c index 0d15129bba7..d0a77292f85 100644 --- a/src/mpi/coll/igather/igather_intra_sched_binomial.c +++ b/src/mpi/coll/igather/igather_intra_sched_binomial.c @@ -28,7 +28,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank; @@ -93,15 +93,14 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, if (rank == root) { if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype, - ((char *) recvbuf + extent * recvcount * rank), + ((char *) recvbuf + extent * recvcount * rank), recvcount, recvtype, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } } else if (tmp_buf_size && (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE)) { /* copy from sendbuf into tmp_buf */ - mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype, - tmp_buf, nbytes, MPI_BYTE, s); + mpi_errno = MPIR_Sched_copy(sendbuf, sendcount, sendtype, tmp_buf, nbytes, MPI_BYTE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/igather/igather_tsp_tree.c b/src/mpi/coll/igather/igather_tsp_tree.c index 45b4e6517ef..da395125e68 100644 --- a/src/mpi/coll/igather/igather_tsp_tree.c +++ b/src/mpi/coll/igather/igather_tsp_tree.c @@ -11,7 +11,7 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - int k, MPIR_TSP_sched_t sched) + int coll_group, int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int size, rank, lrank; diff --git a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c index e2f6cf21dc5..b0154a27b28 100644 --- a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c +++ b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c @@ -15,7 +15,7 @@ int MPIR_Igatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/igatherv/igatherv_tsp_linear.c b/src/mpi/coll/igatherv/igatherv_tsp_linear.c index 932efd3414d..cd4e6e6b6ac 100644 --- a/src/mpi/coll/igatherv/igatherv_tsp_linear.c +++ b/src/mpi/coll/igatherv/igatherv_tsp_linear.c @@ -21,7 +21,7 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, vtx_id; diff --git a/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c b/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c index 683e18f3e6d..c0a6711a555 100644 --- a/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c +++ b/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c @@ -14,7 +14,7 @@ int MPIR_Ireduce_inter_sched_local_reduce_remote_send(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank; @@ -57,7 +57,7 @@ int MPIR_Ireduce_inter_sched_local_reduce_remote_send(const void *sendbuf, void mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, - comm_ptr->local_comm, s); + comm_ptr->local_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c index 5ed414d2c09..31193f2d00b 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c @@ -7,7 +7,7 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, is_commutative; diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c index d6bc57637fd..9e997ae2cb5 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c @@ -34,7 +34,8 @@ */ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i, j, comm_size, rank, pof2, is_commutative ATTRIBUTE((unused)); diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c index a1a8a63c72f..adb3e10868d 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c @@ -7,7 +7,7 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int is_commutative; @@ -26,7 +26,8 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co is_commutative = MPIR_Op_is_commutative(op); if (!is_commutative) { mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, s); + MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -44,7 +45,9 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co /* do the intranode reduce on all nodes other than the root's node */ if (nc != NULL && MPIR_Get_intranode_rank(comm_ptr, root) == -1) { - mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, nc, s); + mpi_errno = + MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, nc, + MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -57,7 +60,7 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co const void *buf = (nc == NULL ? sendbuf : tmp_buf); mpi_errno = MPIR_Ireduce_intra_sched_auto(buf, NULL, count, datatype, op, MPIR_Get_internode_rank(comm_ptr, root), - nrc, s); + nrc, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { /* I am on root's node. I have not participated in the earlier reduce. */ @@ -68,7 +71,7 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, MPIR_Get_internode_rank(comm_ptr, root), nrc, - s); + MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -80,7 +83,7 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, MPIR_Get_internode_rank(comm_ptr, root), nrc, - s); + MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -94,7 +97,7 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co if (nc != NULL && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, MPIR_Get_intranode_rank(comm_ptr, root), nc, - s); + MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce/ireduce_tsp_auto.c b/src/mpi/coll/ireduce/ireduce_tsp_auto.c index 1666e2dd029..3d558be3dab 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_auto.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_auto.c @@ -11,7 +11,7 @@ /* Remove this function when gentran algos are in json file */ static int MPIR_Ireduce_sched_intra_tsp_flat_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -25,7 +25,7 @@ static int MPIR_Ireduce_sched_intra_tsp_flat_auto(const void *sendbuf, void *rec * gentran_tree algo */ /* gentran tree with knomial tree type, radix 2 and pipeline block size 0 */ mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, + datatype, op, root, comm_ptr, coll_group, tree_type, radix, block_size, buffer_per_child, sched); if (mpi_errno) @@ -40,7 +40,8 @@ static int MPIR_Ireduce_sched_intra_tsp_flat_auto(const void *sendbuf, void *rec /* sched version of CVAR and json based collective selection. Meant only for gentran scheduler */ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; @@ -67,7 +68,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP mpi_errno, "Ireduce gentran_tree cannot be applied.\n"); mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, root, - comm_ptr, MPIR_Ireduce_tree_type, + comm_ptr, coll_group, MPIR_Ireduce_tree_type, MPIR_CVAR_IREDUCE_TREE_KVAL, MPIR_CVAR_IREDUCE_TREE_PIPELINE_CHUNK_SIZE, MPIR_CVAR_IREDUCE_TREE_BUFFER_PER_CHILD, sched); @@ -76,7 +77,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP case MPIR_CVAR_IREDUCE_INTRA_ALGORITHM_tsp_ring: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, root, - comm_ptr, MPIR_TREE_TYPE_KARY, 1, + comm_ptr, coll_group, MPIR_TREE_TYPE_KARY, 1, MPIR_CVAR_IREDUCE_RING_CHUNK_SIZE, MPIR_CVAR_IREDUCE_TREE_BUFFER_PER_CHILD, sched); break; @@ -89,7 +90,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ireduce_intra_tsp_tree: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - root, comm_ptr, + root, comm_ptr, coll_group, cnt->u.ireduce.intra_tsp_tree.tree_type, cnt->u.ireduce.intra_tsp_tree.k, cnt->u.ireduce.intra_tsp_tree.chunk_size, @@ -100,7 +101,8 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Ireduce_intra_tsp_ring: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, - root, comm_ptr, MPIR_TREE_TYPE_KARY, 1, + root, comm_ptr, coll_group, + MPIR_TREE_TYPE_KARY, 1, cnt->u.ireduce.intra_tsp_ring.chunk_size, cnt->u.ireduce. intra_tsp_ring.buffer_per_child, sched); @@ -110,7 +112,8 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP /* Replace this call with MPIR_Assert(0) when json files have gentran algos */ mpi_errno = MPIR_Ireduce_sched_intra_tsp_flat_auto(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, sched); + datatype, op, root, comm_ptr, + coll_group, sched); break; } } @@ -120,7 +123,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP fallback: mpi_errno = MPIR_TSP_Ireduce_sched_intra_tree(sendbuf, recvbuf, count, datatype, op, root, - comm_ptr, MPIR_TREE_TYPE_KARY, 1, + comm_ptr, coll_group, MPIR_TREE_TYPE_KARY, 1, MPIR_CVAR_IREDUCE_RING_CHUNK_SIZE, MPIR_CVAR_IREDUCE_TREE_BUFFER_PER_CHILD, sched); diff --git a/src/mpi/coll/ireduce/ireduce_tsp_tree.c b/src/mpi/coll/ireduce/ireduce_tsp_tree.c index 0ce1f25852a..a2bb817bdc0 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_tree.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_tree.c @@ -10,8 +10,8 @@ /* Routine to schedule a pipelined tree based reduce */ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, int tree_type, int k, int chunk_size, - int buffer_per_child, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int tree_type, int k, + int chunk_size, int buffer_per_child, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int i, j, t; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c index f015ae31d5c..41c4d159d67 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_inter_sched_remote_reduce_local_scatterv.c @@ -17,7 +17,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, root, local_size, total_count, i; @@ -62,7 +62,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -71,13 +71,13 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -86,7 +86,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -101,7 +101,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se mpi_errno = MPIR_Iscatterv_intra_sched_auto(tmp_buf, recvcounts, disps, datatype, recvbuf, recvcounts[rank], datatype, 0, newcomm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c index 863467231c9..2ecf91f004e 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c @@ -23,7 +23,8 @@ int MPIR_Ireduce_scatter_intra_sched_noncommutative(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size = comm_ptr->local_size; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c index 7b720f4951f..b9686961f5f 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c @@ -14,7 +14,8 @@ */ int MPIR_Ireduce_scatter_intra_sched_pairwise(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c index 6b05fd0195a..aed3af7f144 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c @@ -19,7 +19,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c index 2043b96686c..93da7604868 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c @@ -36,7 +36,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c index c1e3bbc582b..304462eede5 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c @@ -41,11 +41,11 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * const MPI_Aint * recvcounts, MPI_Aint * displs, MPI_Datatype datatype, MPI_Op op, size_t extent, int tag, - MPIR_Comm * comm, int k, int is_dist_halving, - int step2_nphases, int **step2_nbrs, - int rank, int nranks, int sink_id, - int is_out_vtcs, int *reduce_id_, - MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, int k, + int is_dist_halving, int step2_nphases, + int **step2_nbrs, int rank, int nranks, + int sink_id, int is_out_vtcs, + int *reduce_id_, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int x, i, j, phase; @@ -132,8 +132,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * /* Routine to schedule a recursive exchange based reduce_scatter with distance halving in each phase */ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int is_dist_halving, - int k, MPIR_TSP_sched_t sched) + MPI_Op op, MPIR_Comm * comm, int coll_group, + int is_dist_halving, int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace; @@ -246,9 +246,10 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv if (in_step2) { MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(tmp_results, tmp_recvbuf, recvcounts, displs, datatype, op, extent, - tag, comm, k, is_dist_halving, - step2_nphases, step2_nbrs, rank, nranks, - sink_id, 1, &reduce_id, sched); + tag, comm, coll_group, k, + is_dist_halving, step2_nphases, + step2_nbrs, rank, nranks, sink_id, 1, + &reduce_id, sched); /* copy data from tmp_results buffer correct position into recvbuf for all participating ranks */ nvtcs = 1; vtcs[0] = reduce_id; /* This assignment will also be used in step3 sends */ diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c index 793db9a8bd8..8c9629fe642 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv.c @@ -18,6 +18,7 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; @@ -51,7 +52,7 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -60,13 +61,13 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sched barrier intentionally omitted here to allow both reductions to @@ -75,7 +76,7 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Ireduce_inter_sched_auto(sendbuf, tmp_buf, total_count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -89,7 +90,8 @@ int MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(const vo newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Iscatter_intra_sched_auto(tmp_buf, recvcount, datatype, - recvbuf, recvcount, datatype, 0, newcomm_ptr, s); + recvbuf, recvcount, datatype, 0, newcomm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c index 910c04314fe..afd6c96691e 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c @@ -12,7 +12,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_noncommutative(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size = comm_ptr->local_size; diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c index b5091c80817..7c24f889773 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c @@ -9,7 +9,8 @@ * commutative op and is intended for use with large messages. */ int MPIR_Ireduce_scatter_block_intra_sched_pairwise(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c index 071682730f4..0786592616b 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c @@ -10,7 +10,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c index 2a092a14721..2ce35875de5 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c @@ -10,7 +10,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, i; diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c index d0c2839ec81..c5cbf8c2a83 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c @@ -10,8 +10,8 @@ /* Routine to schedule a recursive exchange based reduce_scatter with distance halving in each phase */ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, int k, - MPIR_TSP_sched_t sched) + MPI_Op op, MPIR_Comm * comm, int coll_group, + int k, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int is_inplace; diff --git a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c index 4a914802380..885f1b88a5a 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c @@ -7,7 +7,7 @@ int MPIR_Iscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint true_extent, true_lb, extent; diff --git a/src/mpi/coll/iscan/iscan_intra_sched_smp.c b/src/mpi/coll/iscan/iscan_intra_sched_smp.c index 3cd74b72ed6..79b71954349 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_smp.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_smp.c @@ -8,7 +8,7 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank = comm_ptr->rank; @@ -28,7 +28,7 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun /* We can't use the SMP-aware algorithm, use the non-SMP-aware * one */ return MPIR_Iscan_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); } node_comm = comm_ptr->node_comm; @@ -58,7 +58,8 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * one process, just copy the raw data. */ if (node_comm != NULL) { mpi_errno = - MPIR_Iscan_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, node_comm, s); + MPIR_Iscan_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, node_comm, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else if (sendbuf != MPI_IN_PLACE) { @@ -96,7 +97,7 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun mpi_errno = MPIR_Iscan_intra_sched_auto(localfulldata, prefulldata, count, datatype, op, roots_comm, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -124,7 +125,8 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * "prefulldata" from another leader into "tempbuf" */ if (node_comm != NULL) { - mpi_errno = MPIR_Ibcast_intra_sched_auto(tempbuf, count, datatype, 0, node_comm, s); + mpi_errno = + MPIR_Ibcast_intra_sched_auto(tempbuf, count, datatype, 0, node_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c index b6010400be0..eae2bed1fcf 100644 --- a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c @@ -9,7 +9,8 @@ /* Routine to schedule a recursive exchange based scan */ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_TSP_sched_t sched) + MPIR_Comm * comm, int coll_group, + MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; MPI_Aint extent, true_extent; diff --git a/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c b/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c index 93a564d543d..32f6c4b920b 100644 --- a/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c +++ b/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c @@ -15,7 +15,7 @@ int MPIR_Iscatter_inter_sched_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int remote_size; diff --git a/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c b/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c index 1d0f7a3439d..b12e2db5367 100644 --- a/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c +++ b/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c @@ -18,7 +18,7 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, local_size, remote_size; @@ -70,7 +70,8 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI /* now do the usual scatter on this intracommunicator */ mpi_errno = MPIR_Iscatter_intra_sched_auto(tmp_buf, recvcount * recvtype_sz, MPI_BYTE, - recvbuf, recvcount, recvtype, 0, newcomm_ptr, s); + recvbuf, recvcount, recvtype, 0, newcomm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c index 49cc7fc0159..0bd9bb1820d 100644 --- a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c +++ b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c @@ -66,7 +66,7 @@ static int calc_curr_count(MPIR_Comm * comm, int tag, void *state) int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint extent = 0; diff --git a/src/mpi/coll/iscatter/iscatter_tsp_tree.c b/src/mpi/coll/iscatter/iscatter_tsp_tree.c index b92ed46017d..a3955317444 100644 --- a/src/mpi/coll/iscatter/iscatter_tsp_tree.c +++ b/src/mpi/coll/iscatter/iscatter_tsp_tree.c @@ -10,7 +10,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - int k, MPIR_TSP_sched_t sched) + int coll_group, int k, MPIR_TSP_sched_t sched) { MPIR_FUNC_ENTER; diff --git a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c index 0257a6df1ec..3ed60e0f3b1 100644 --- a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c @@ -18,7 +18,8 @@ int MPIR_Iscatterv_allcomm_sched_linear(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; diff --git a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c index dbc0669299a..8313d41eb78 100644 --- a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c @@ -11,7 +11,7 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm_ptr, - MPIR_TSP_sched_t sched) + int coll_group, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; int rank, comm_size; diff --git a/src/mpi/coll/mpir_coll_sched_auto.c b/src/mpi/coll/mpir_coll_sched_auto.c index 2b2e4e17daa..0060a10d38e 100644 --- a/src/mpi/coll/mpir_coll_sched_auto.c +++ b/src/mpi/coll/mpir_coll_sched_auto.c @@ -12,28 +12,28 @@ * defining them here. */ -int MPIR_Ibarrier_intra_sched_auto(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_intra_sched_auto(MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibarrier_intra_sched_recursive_doubling(comm_ptr, s); + mpi_errno = MPIR_Ibarrier_intra_sched_recursive_doubling(comm_ptr, coll_group, s); return mpi_errno; } /* It will choose between several different algorithms based on the given * parameters. */ -int MPIR_Ibarrier_inter_sched_auto(MPIR_Comm * comm_ptr, MPIR_Sched_t s) +int MPIR_Ibarrier_inter_sched_auto(MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno; - mpi_errno = MPIR_Ibarrier_inter_sched_bcast(comm_ptr, s); + mpi_errno = MPIR_Ibarrier_inter_sched_bcast(comm_ptr, coll_group, s); return mpi_errno; } int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size; @@ -42,7 +42,8 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT) { - mpi_errno = MPIR_Ibcast_intra_sched_smp(buffer, count, datatype, root, comm_ptr, s); + mpi_errno = + MPIR_Ibcast_intra_sched_smp(buffer, count, datatype, root, comm_ptr, coll_group, s); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -55,7 +56,9 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data /* simplistic implementation for now */ if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS)) { - mpi_errno = MPIR_Ibcast_intra_sched_binomial(buffer, count, datatype, root, comm_ptr, s); + mpi_errno = + MPIR_Ibcast_intra_sched_binomial(buffer, count, datatype, root, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else { /* (nbytes >= MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_size >= MPIR_CVAR_BCAST_MIN_PROCS) */ @@ -63,12 +66,13 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data mpi_errno = MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(buffer, count, datatype, root, - comm_ptr, s); + comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Ibcast_intra_sched_scatter_ring_allgather(buffer, count, datatype, root, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -83,24 +87,25 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data * know anything about hierarchy. It will choose between several * different algorithms based on the given parameters. */ int MPIR_Ibcast_inter_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibcast_inter_sched_flat(buffer, count, datatype, root, comm_ptr, s); + mpi_errno = + MPIR_Ibcast_inter_sched_flat(buffer, count, datatype, root, comm_ptr, coll_group, s); return mpi_errno; } int MPIR_Igather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igather_intra_sched_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -111,7 +116,7 @@ int MPIR_Igather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_D int MPIR_Igather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; MPI_Aint local_size, remote_size; @@ -137,11 +142,11 @@ int MPIR_Igather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_D if (nbytes < MPIR_CVAR_GATHER_INTER_SHORT_MSG_SIZE) { mpi_errno = MPIR_Igather_inter_sched_short(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Igather_inter_sched_long(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); } fn_exit: @@ -151,13 +156,13 @@ int MPIR_Igather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_D int MPIR_Igatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igatherv_allcomm_sched_linear(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm_ptr, s); + displs, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -170,13 +175,13 @@ int MPIR_Igatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Igatherv_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igatherv_allcomm_sched_linear(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm_ptr, s); + displs, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -188,13 +193,13 @@ int MPIR_Igatherv_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Iscatter_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatter_intra_sched_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm_ptr, s); + recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); @@ -206,7 +211,7 @@ int MPIR_Iscatter_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Iscatter_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int local_size, remote_size; MPI_Aint sendtype_size, recvtype_size, nbytes; @@ -227,11 +232,11 @@ int MPIR_Iscatter_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ mpi_errno = MPIR_Iscatter_inter_sched_remote_send_local_scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, - comm_ptr, s); + comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Iscatter_inter_sched_linear(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, - s); + coll_group, s); } MPIR_ERR_CHECK(mpi_errno); @@ -245,13 +250,13 @@ int MPIR_Iscatter_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_ int MPIR_Iscatterv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatterv_allcomm_sched_linear(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, s); + recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -264,13 +269,13 @@ int MPIR_Iscatterv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcoun int MPIR_Iscatterv_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatterv_allcomm_sched_linear(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, s); + recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -282,7 +287,7 @@ int MPIR_Iscatterv_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcoun int MPIR_Iallgather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size; @@ -296,15 +301,16 @@ int MPIR_Iallgather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MP if ((tot_bytes < MPIR_CVAR_ALLGATHER_LONG_MSG_SIZE) && !(comm_size & (comm_size - 1))) { mpi_errno = MPIR_Iallgather_intra_sched_recursive_doubling(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, s); + recvcount, recvtype, comm_ptr, + coll_group, s); } else if (tot_bytes < MPIR_CVAR_ALLGATHER_SHORT_MSG_SIZE) { mpi_errno = MPIR_Iallgather_intra_sched_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Iallgather_intra_sched_ring(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } MPIR_ERR_CHECK(mpi_errno); @@ -316,13 +322,15 @@ int MPIR_Iallgather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MP int MPIR_Iallgather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgather_inter_sched_local_gather_remote_bcast(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, + coll_group, s); return mpi_errno; } @@ -330,7 +338,8 @@ int MPIR_Iallgather_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, int MPIR_Iallgatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i, comm_size; @@ -353,21 +362,21 @@ int MPIR_Iallgatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_Iallgatherv_intra_sched_recursive_doubling(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else if (total_count * recvtype_size < MPIR_CVAR_ALLGATHER_SHORT_MSG_SIZE) { /* Short message and non-power-of-two no. of processes. Use * Bruck algorithm (see description above). */ mpi_errno = MPIR_Iallgatherv_intra_sched_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, s); + displs, recvtype, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* long message or medium-size message and non-power-of-two * no. of processes. Use ring algorithm. */ mpi_errno = MPIR_Iallgatherv_intra_sched_ring(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, s); + displs, recvtype, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -380,21 +389,22 @@ int MPIR_Iallgatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, int MPIR_Iallgatherv_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm_ptr, s); + comm_ptr, coll_group, s); return mpi_errno; } int MPIR_Ialltoall_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int comm_size; @@ -408,19 +418,20 @@ int MPIR_Ialltoall_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI if (sendbuf == MPI_IN_PLACE) { mpi_errno = MPIR_Ialltoall_intra_sched_inplace(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } else if ((nbytes <= MPIR_CVAR_ALLTOALL_SHORT_MSG_SIZE) && (comm_size >= 8)) { mpi_errno = MPIR_Ialltoall_intra_sched_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } else if (nbytes <= MPIR_CVAR_ALLTOALL_MEDIUM_MSG_SIZE) { mpi_errno = MPIR_Ialltoall_intra_sched_permuted_sendrecv(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, s); + recvcount, recvtype, comm_ptr, coll_group, + s); } else { mpi_errno = MPIR_Ialltoall_intra_sched_pairwise(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm_ptr, s); + recvtype, comm_ptr, coll_group, s); } MPIR_ERR_CHECK(mpi_errno); @@ -433,13 +444,14 @@ int MPIR_Ialltoall_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI int MPIR_Ialltoall_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoall_inter_sched_pairwise_exchange(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, s); + comm_ptr, coll_group, s); return mpi_errno; } @@ -447,7 +459,8 @@ int MPIR_Ialltoall_inter_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI int MPIR_Ialltoallv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; @@ -456,11 +469,11 @@ int MPIR_Ialltoallv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcou if (sendbuf == MPI_IN_PLACE) { mpi_errno = MPIR_Ialltoallv_intra_sched_inplace(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm_ptr, s); + rdispls, recvtype, comm_ptr, coll_group, s); } else { mpi_errno = MPIR_Ialltoallv_intra_sched_blocked(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm_ptr, s); + rdispls, recvtype, comm_ptr, coll_group, s); } return mpi_errno; @@ -469,13 +482,15 @@ int MPIR_Ialltoallv_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcou int MPIR_Ialltoallv_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallv_inter_sched_pairwise_exchange(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm_ptr, s); + rdispls, recvtype, comm_ptr, + coll_group, s); return mpi_errno; } @@ -484,18 +499,20 @@ int MPIR_Ialltoallw_intra_sched_auto(const void *sendbuf, const MPI_Aint sendcou const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; if (sendbuf == MPI_IN_PLACE) { mpi_errno = MPIR_Ialltoallw_intra_sched_inplace(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, s); + rdispls, recvtypes, comm_ptr, coll_group, + s); } else { mpi_errno = MPIR_Ialltoallw_intra_sched_blocked(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, s); + rdispls, recvtypes, comm_ptr, coll_group, + s); } return mpi_errno; @@ -505,20 +522,21 @@ int MPIR_Ialltoallw_inter_sched_auto(const void *sendbuf, const MPI_Aint sendcou const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallw_inter_sched_pairwise_exchange(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, s); + rdispls, recvtypes, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2; @@ -528,7 +546,7 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT && MPIR_Op_is_commutative(op)) { mpi_errno = MPIR_Ireduce_intra_sched_smp(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, coll_group, s); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -545,13 +563,13 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c /* do a reduce-scatter followed by gather to root. */ mpi_errno = MPIR_Ireduce_intra_sched_reduce_scatter_gather(sendbuf, recvbuf, count, datatype, op, - root, comm_ptr, s); + root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* use a binomial tree algorithm */ mpi_errno = MPIR_Ireduce_intra_sched_binomial(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -563,19 +581,20 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c int MPIR_Ireduce_inter_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_inter_sched_local_reduce_remote_send(sendbuf, recvbuf, count, - datatype, op, root, comm_ptr, s); + datatype, op, root, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int pof2; @@ -584,7 +603,8 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT && MPIR_Op_is_commutative(op)) { mpi_errno = - MPIR_Iallreduce_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, s); + MPIR_Iallreduce_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, + coll_group, s); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -611,13 +631,13 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain /* use recursive doubling */ mpi_errno = MPIR_Iallreduce_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* do a reduce-scatter followed by allgather */ mpi_errno = MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(sendbuf, recvbuf, count, datatype, - op, comm_ptr, s); + op, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } @@ -629,19 +649,21 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain int MPIR_Iallreduce_inter_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallreduce_inter_sched_remote_reduce_local_bcast(sendbuf, recvbuf, count, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int i; @@ -666,12 +688,13 @@ int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, if (is_commutative && (nbytes < MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_intra_sched_recursive_halving(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else if (is_commutative && (nbytes >= MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_intra_sched_pairwise(sendbuf, recvbuf, recvcounts, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* (!is_commutative) */ @@ -687,13 +710,15 @@ int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, /* noncommutative, pof2 size, and block regular */ mpi_errno = MPIR_Ireduce_scatter_intra_sched_noncommutative(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } else { /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */ mpi_errno = MPIR_Ireduce_scatter_intra_sched_recursive_doubling(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -706,20 +731,23 @@ int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, int MPIR_Ireduce_scatter_inter_sched_auto(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); return mpi_errno; } int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int is_commutative; @@ -740,12 +768,13 @@ int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvb if (is_commutative && (nbytes < MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else if (is_commutative && (nbytes >= MPIR_CVAR_REDUCE_SCATTER_COMMUTATIVE_LONG_MSG_SIZE)) { mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_pairwise(sendbuf, recvbuf, recvcount, datatype, - op, comm_ptr, s); + op, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* (!is_commutative) */ @@ -753,14 +782,15 @@ int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvb /* noncommutative, pof2 size */ mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_noncommutative(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, s); + datatype, op, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* noncommutative and non-pof2, use recursive doubling. */ mpi_errno = MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(sendbuf, recvbuf, recvcount, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -774,30 +804,34 @@ int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvb int MPIR_Ireduce_scatter_block_inter_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Sched_t s) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_block_inter_sched_remote_reduce_local_scatterv(sendbuf, recvbuf, recvcount, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, + s); return mpi_errno; } int MPIR_Iscan_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT) { - mpi_errno = MPIR_Iscan_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, s); + mpi_errno = + MPIR_Iscan_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + s); } else { mpi_errno = MPIR_Iscan_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm_ptr, s); + comm_ptr, coll_group, s); } return mpi_errno; @@ -805,13 +839,13 @@ int MPIR_Iscan_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint cou int MPIR_Iexscan_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iexscan_intra_sched_recursive_doubling(sendbuf, recvbuf, count, datatype, op, comm_ptr, - s); + coll_group, s); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/op/opequal.c b/src/mpi/coll/op/opequal.c index cf6c7cb846f..fd9d57da163 100644 --- a/src/mpi/coll/op/opequal.c +++ b/src/mpi/coll/op/opequal.c @@ -55,7 +55,7 @@ int MPIR_EQUAL_check_dtype(MPI_Datatype type) MPIR_Assert(actual_pack_bytes == count * type_sz) int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, int root, MPIR_Comm * comm_ptr) + int *is_equal, int root, MPIR_Comm * comm_ptr, int coll_group) { int mpi_errno = MPI_SUCCESS; @@ -64,10 +64,12 @@ int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype /* Not all algorithm will work. In particular, we can't split the message */ if (comm_ptr->rank == root) { mpi_errno = MPIR_Reduce_intra_binomial(MPI_IN_PLACE, local_buf, byte_count, MPI_BYTE, - MPIX_EQUAL, root, comm_ptr, MPIR_ERR_NONE); + MPIX_EQUAL, root, comm_ptr, coll_group, + MPIR_ERR_NONE); } else { mpi_errno = MPIR_Reduce_intra_binomial(local_buf, NULL, byte_count, MPI_BYTE, - MPIX_EQUAL, root, comm_ptr, MPIR_ERR_NONE); + MPIX_EQUAL, root, comm_ptr, coll_group, + MPIR_ERR_NONE); } MPIR_ERR_CHECK(mpi_errno); @@ -84,7 +86,7 @@ int MPIR_Reduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype int MPIR_Allreduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, - int *is_equal, MPIR_Comm * comm_ptr) + int *is_equal, MPIR_Comm * comm_ptr, int coll_group) { int mpi_errno = MPI_SUCCESS; @@ -93,7 +95,8 @@ int MPIR_Allreduce_equal(const void *sendbuf, MPI_Aint count, MPI_Datatype datat /* Not all algorithm will work. In particular, we can't split the message */ mpi_errno = MPIR_Allreduce_intra_recursive_doubling(MPI_IN_PLACE, local_buf, byte_count, MPI_BYTE, - MPIX_EQUAL, comm_ptr, MPIR_ERR_NONE); + MPIX_EQUAL, comm_ptr, coll_group, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); *is_equal = local_buf->is_equal; diff --git a/src/mpi/coll/reduce/reduce_allcomm_nb.c b/src/mpi/coll/reduce/reduce_allcomm_nb.c index 2ab95bc494e..9c31c443cfe 100644 --- a/src/mpi/coll/reduce/reduce_allcomm_nb.c +++ b/src/mpi/coll/reduce/reduce_allcomm_nb.c @@ -7,13 +7,14 @@ int MPIR_Reduce_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, &req_ptr); + mpi_errno = + MPIR_Ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c b/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c index ad54cce357b..8b90a7ebd05 100644 --- a/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c +++ b/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c @@ -19,7 +19,8 @@ int MPIR_Reduce_inter_local_reduce_remote_send(const void *sendbuf, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, mpi_errno; MPI_Status status; @@ -63,7 +64,8 @@ int MPIR_Reduce_inter_local_reduce_remote_send(const void *sendbuf, newcomm_ptr = comm_ptr->local_comm; /* now do a local reduce on this intracommunicator */ - mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, errflag); + mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { diff --git a/src/mpi/coll/reduce/reduce_intra_binomial.c b/src/mpi/coll/reduce/reduce_intra_binomial.c index 4363b5a7031..ff6cf3182fb 100644 --- a/src/mpi/coll/reduce/reduce_intra_binomial.c +++ b/src/mpi/coll/reduce/reduce_intra_binomial.c @@ -13,7 +13,8 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Status status; diff --git a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c index 8c1cd0fe4f0..ea6813457ff 100644 --- a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c +++ b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c @@ -37,7 +37,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int comm_size, rank, pof2, rem, newrank; diff --git a/src/mpi/coll/reduce/reduce_intra_smp.c b/src/mpi/coll/reduce/reduce_intra_smp.c index 6ff46543185..e932611afc1 100644 --- a/src/mpi/coll/reduce/reduce_intra_smp.c +++ b/src/mpi/coll/reduce/reduce_intra_smp.c @@ -7,7 +7,7 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; void *tmp_buf = NULL; @@ -37,7 +37,7 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, /* do the intranode reduce on all nodes other than the root's node */ if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) == -1) { mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, - op, 0, comm_ptr->node_comm, errflag); + op, 0, comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -49,7 +49,7 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, const void *buf = (comm_ptr->node_comm == NULL ? sendbuf : tmp_buf); mpi_errno = MPIR_Reduce(buf, NULL, count, datatype, op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* I am on root's node. I have not participated in the earlier reduce. */ if (comm_ptr->rank != root) { @@ -58,7 +58,7 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* point sendbuf at tmp_buf to make final intranode reduce easy */ @@ -68,7 +68,7 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, errflag); + comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* set sendbuf to MPI_IN_PLACE to make final intranode reduce easy. */ @@ -82,7 +82,7 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, MPIR_Get_intranode_rank(comm_ptr, root), - comm_ptr->node_comm, errflag); + comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c b/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c index a29d82db025..51ae8ca3120 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_allcomm_nb.c @@ -7,14 +7,15 @@ int MPIR_Reduce_scatter_allcomm_nb(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ mpi_errno = - MPIR_Ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, &req_ptr); + MPIR_Ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, coll_group, + &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c b/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c index eba1c515df0..9b7770fccfd 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_inter_remote_reduce_local_scatter.c @@ -15,7 +15,7 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, mpi_errno, root, local_size, total_count, i; @@ -61,25 +61,25 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, v /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -92,7 +92,7 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, v newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Scatterv(tmp_buf, recvcounts, disps, datatype, recvbuf, - recvcounts[rank], datatype, 0, newcomm_ptr, errflag); + recvcounts[rank], datatype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c index d3085b27355..042e7dfc9f0 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c @@ -22,7 +22,7 @@ */ int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c index 0ae8b879377..1ad6e16422f 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c @@ -15,7 +15,8 @@ */ int MPIR_Reduce_scatter_intra_pairwise(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c index 6ab1d986302..e72c6384cc1 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c @@ -19,7 +19,7 @@ */ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, comm_size, i; diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c index 03f58c3d7c2..0009e54b5f0 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c @@ -36,7 +36,7 @@ */ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, comm_size, i; diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c index 42a37790218..adf50cf765d 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_allcomm_nb.c @@ -7,14 +7,15 @@ int MPIR_Reduce_scatter_block_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ mpi_errno = - MPIR_Ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, &req_ptr); + MPIR_Ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm_ptr, coll_group, + &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c index ed15eafb02f..84d1b45758e 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_inter_remote_reduce_local_scatter.c @@ -18,6 +18,7 @@ int MPIR_Reduce_scatter_block_inter_remote_reduce_local_scatter(const void *send MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int rank, mpi_errno, root, local_size; @@ -51,25 +52,25 @@ int MPIR_Reduce_scatter_block_inter_remote_reduce_local_scatter(const void *send /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce to rank 0 of right group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* reduce to rank 0 of left group */ root = 0; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* reduce from right group to rank 0 */ root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL; mpi_errno = MPIR_Reduce_allcomm_auto(sendbuf, tmp_buf, total_count, datatype, op, - root, comm_ptr, errflag); + root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -80,7 +81,7 @@ int MPIR_Reduce_scatter_block_inter_remote_reduce_local_scatter(const void *send newcomm_ptr = comm_ptr->local_comm; mpi_errno = MPIR_Scatter(tmp_buf, recvcount, datatype, recvbuf, - recvcount, datatype, 0, newcomm_ptr, errflag); + recvcount, datatype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c index b4edac03945..08319ceb12e 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c @@ -25,7 +25,8 @@ int MPIR_Reduce_scatter_block_intra_noncommutative(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int comm_size = comm_ptr->local_size; diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c index f7e5e636906..81d2ef3467b 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c @@ -25,7 +25,8 @@ int MPIR_Reduce_scatter_block_intra_pairwise(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c index d666ee884a1..0f237c26b40 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c @@ -25,7 +25,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c index 05fede37670..29572e9be2f 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c @@ -40,7 +40,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int rank, comm_size, i; MPI_Aint extent, true_extent, true_lb; diff --git a/src/mpi/coll/scan/scan_allcomm_nb.c b/src/mpi/coll/scan/scan_allcomm_nb.c index 0528d599fa7..e957aa40664 100644 --- a/src/mpi/coll/scan/scan_allcomm_nb.c +++ b/src/mpi/coll/scan/scan_allcomm_nb.c @@ -6,13 +6,13 @@ #include "mpiimpl.h" int MPIR_Scan_allcomm_nb(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; /* just call the nonblocking version and wait on it */ - mpi_errno = MPIR_Iscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, &req_ptr); + mpi_errno = MPIR_Iscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/scan/scan_intra_recursive_doubling.c b/src/mpi/coll/scan/scan_intra_recursive_doubling.c index 6afdb2dfd34..6d9a2d40658 100644 --- a/src/mpi/coll/scan/scan_intra_recursive_doubling.c +++ b/src/mpi/coll/scan/scan_intra_recursive_doubling.c @@ -44,7 +44,8 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { MPI_Status status; int rank, comm_size; diff --git a/src/mpi/coll/scan/scan_intra_smp.c b/src/mpi/coll/scan/scan_intra_smp.c index 9ea89a81786..0414f3dfc67 100644 --- a/src/mpi/coll/scan/scan_intra_smp.c +++ b/src/mpi/coll/scan/scan_intra_smp.c @@ -7,7 +7,7 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -43,7 +43,8 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, /* perform intranode scan to get temporary result in recvbuf. if there is only * one process, just copy the raw data. */ if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf, count, datatype); @@ -75,7 +76,7 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * main process of node 3. */ if (comm_ptr->node_roots_comm != NULL) { mpi_errno = MPIR_Scan(localfulldata, prefulldata, count, datatype, - op, comm_ptr->node_roots_comm, errflag); + op, comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (MPIR_Get_internode_rank(comm_ptr, rank) != comm_ptr->node_roots_comm->local_size - 1) { @@ -100,13 +101,14 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * reduce it with recvbuf to get final result if necessary. */ if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (noneed == 0) { if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIR_Bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scatter/scatter_allcomm_nb.c b/src/mpi/coll/scatter/scatter_allcomm_nb.c index e344a0ed25c..16a53460c1a 100644 --- a/src/mpi/coll/scatter/scatter_allcomm_nb.c +++ b/src/mpi/coll/scatter/scatter_allcomm_nb.c @@ -7,7 +7,7 @@ int MPIR_Scatter_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -15,7 +15,7 @@ int MPIR_Scatter_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatyp /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm_ptr, - &req_ptr); + coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/coll/scatter/scatter_inter_linear.c b/src/mpi/coll/scatter/scatter_inter_linear.c index 96e616b104e..9b5f5ec5800 100644 --- a/src/mpi/coll/scatter/scatter_inter_linear.c +++ b/src/mpi/coll/scatter/scatter_inter_linear.c @@ -14,7 +14,7 @@ int MPIR_Scatter_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int remote_size, mpi_errno = MPI_SUCCESS; int i; diff --git a/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c b/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c index 81f5ce30245..d02865b1117 100644 --- a/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c +++ b/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c @@ -16,7 +16,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS; @@ -69,7 +69,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint s /* now do the usual scatter on this intracommunicator */ mpi_errno = MPIR_Scatter(tmp_buf, recvcount * recvtype_sz, MPI_BYTE, - recvbuf, recvcount, recvtype, 0, newcomm_ptr, errflag); + recvbuf, recvcount, recvtype, 0, newcomm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scatter/scatter_intra_binomial.c b/src/mpi/coll/scatter/scatter_intra_binomial.c index dd211126582..cd6a5fe13d6 100644 --- a/src/mpi/coll/scatter/scatter_intra_binomial.c +++ b/src/mpi/coll/scatter/scatter_intra_binomial.c @@ -28,7 +28,7 @@ /* not declared static because a machine-specific function may call this one in some cases */ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { MPI_Status status; MPI_Aint extent = 0; diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c index 8612a34b09f..c34aa40b4a5 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c @@ -20,7 +20,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int rank, comm_size, mpi_errno = MPI_SUCCESS; MPI_Aint extent; diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_nb.c b/src/mpi/coll/scatterv/scatterv_allcomm_nb.c index 953bb60b819..53bc8904361 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_nb.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_nb.c @@ -8,7 +8,7 @@ int MPIR_Scatterv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Request *req_ptr = NULL; @@ -16,7 +16,7 @@ int MPIR_Scatterv_allcomm_nb(const void *sendbuf, const MPI_Aint * sendcounts, /* just call the nonblocking version and wait on it */ mpi_errno = MPIR_Iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, - comm_ptr, &req_ptr); + comm_ptr, coll_group, &req_ptr); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Wait(req_ptr); diff --git a/src/mpi/comm/comm_impl.c b/src/mpi/comm/comm_impl.c index 7fcb9828f74..0e76fc19109 100644 --- a/src/mpi/comm/comm_impl.c +++ b/src/mpi/comm/comm_impl.c @@ -490,15 +490,19 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co MPIR_ERR_CHECK(mpi_errno); /* Broadcast to the other members of the local group */ - mpi_errno = MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast(remote_mapping, remote_size, MPI_INT, 0, - comm_ptr->local_comm, MPIR_ERR_NONE); + comm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else { /* The other processes */ /* Broadcast to the other members of the local group */ - mpi_errno = MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(rinfo, 2, MPI_INT, 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*newcomm_ptr != NULL) { (*newcomm_ptr)->context_id = rinfo[0]; @@ -508,7 +512,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co remote_size * sizeof(int), mpi_errno, "remote_mapping", MPL_MEM_ADDRESS); mpi_errno = MPIR_Bcast(remote_mapping, remote_size, MPI_INT, 0, - comm_ptr->local_comm, MPIR_ERR_NONE); + comm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -1038,14 +1042,18 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader, * along with the final context id */ comm_info[0] = final_context_id; MPL_DBG_MSG(MPIR_DBG_COMM, VERBOSE, "About to bcast on local_comm"); - mpi_errno = MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_D(MPIR_DBG_COMM, VERBOSE, "end of bcast on local_comm of size %d", local_comm_ptr->local_size); } else { /* we're the other processes */ MPL_DBG_MSG(MPIR_DBG_COMM, VERBOSE, "About to receive bcast on local_comm"); - mpi_errno = MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(comm_info, 1, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Extract the context and group sign information */ @@ -1217,7 +1225,9 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i * value of local_high, which may have changed if both groups * of processes had the same value for high */ - mpi_errno = MPIR_Bcast(&local_high, 1, MPI_INT, 0, comm_ptr->local_comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(&local_high, 1, MPI_INT, 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* diff --git a/src/mpi/comm/comm_split.c b/src/mpi/comm/comm_split.c index 94f722f4ef9..c929598767b 100644 --- a/src/mpi/comm/comm_split.c +++ b/src/mpi/comm/comm_split.c @@ -114,7 +114,8 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** } /* Gather information on the local group of processes */ mpi_errno = - MPIR_Allgather(MPI_IN_PLACE, 2, MPI_INT, table, 2, MPI_INT, local_comm_ptr, MPIR_ERR_NONE); + MPIR_Allgather(MPI_IN_PLACE, 2, MPI_INT, table, 2, MPI_INT, local_comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Step 2: How many processes have our same color? */ @@ -161,7 +162,7 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** mypair.color = color; mypair.key = key; mpi_errno = MPIR_Allgather(&mypair, 2, MPI_INT, remotetable, 2, MPI_INT, - comm_ptr, MPIR_ERR_NONE); + comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Each process can now match its color with the entries in the table */ @@ -220,7 +221,7 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast(&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, local_comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (!in_newcomm) { @@ -230,7 +231,7 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** /* Broadcast to the other members of the local group */ mpi_errno = MPIR_Bcast(&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, local_comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/comm/comm_split_type_nbhd.c b/src/mpi/comm/comm_split_type_nbhd.c index 9e76911aafb..07c0e6d288c 100644 --- a/src/mpi/comm/comm_split_type_nbhd.c +++ b/src/mpi/comm/comm_split_type_nbhd.c @@ -277,7 +277,7 @@ static int network_split_by_minsize(MPIR_Comm * comm_ptr, int key, int subcomm_m /* Send the count to processes */ mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, num_processes_at_node, num_nodes, MPI_INT, - MPI_SUM, comm_ptr, MPIR_ERR_NONE); + MPI_SUM, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (topo_type == MPIR_NETTOPO_TYPE__FAT_TREE || topo_type == MPIR_NETTOPO_TYPE__CLOS_NETWORK) { @@ -377,7 +377,7 @@ static int network_split_by_minsize(MPIR_Comm * comm_ptr, int key, int subcomm_m /* get min tree depth to all processes */ MPIR_Allreduce(&tree_depth, &min_tree_depth, 1, MPI_INT, MPI_MIN, node_comm, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (min_tree_depth) { int num_hwloc_objs_at_depth; @@ -391,7 +391,7 @@ static int network_split_by_minsize(MPIR_Comm * comm_ptr, int key, int subcomm_m /* get parent_idx to all processes */ MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, parent_idx, 1, MPI_INT, - node_comm, MPIR_ERR_NONE); + node_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); /* reorder parent indices */ for (i = 0; i < num_procs - 1; i++) { diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index 6e3ba2fcdeb..ac6d22b4164 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -717,7 +717,8 @@ static int init_comm_seq(MPIR_Comm * comm) /* Every rank need share the same seq from root. NOTE: it is possible for * different communicators to have the same seq. It is only used as an * opportunistic optimization */ - mpi_errno = MPIR_Bcast_allcomm_auto(&tmp, 1, MPI_INT, 0, comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_allcomm_auto(&tmp, 1, MPI_INT, 0, comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); comm->seq = tmp; @@ -1276,11 +1277,14 @@ int MPII_collect_info_key(MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, const char } int is_equal; - mpi_errno = MPIR_Allreduce_equal(&hint_str_size, 1, MPI_INT, &is_equal, comm_ptr); + mpi_errno = + MPIR_Allreduce_equal(&hint_str_size, 1, MPI_INT, &is_equal, comm_ptr, MPIR_SUBGROUP_NONE); MPIR_ERR_CHECK(mpi_errno); if (is_equal && hint_str_size > 0) { - mpi_errno = MPIR_Allreduce_equal(hint_str, hint_str_size, MPI_CHAR, &is_equal, comm_ptr); + mpi_errno = + MPIR_Allreduce_equal(hint_str, hint_str_size, MPI_CHAR, &is_equal, comm_ptr, + MPIR_SUBGROUP_NONE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/comm/contextid.c b/src/mpi/comm/contextid.c index 8bee3890caf..25faa6f4fba 100644 --- a/src/mpi/comm/contextid.c +++ b/src/mpi/comm/contextid.c @@ -462,7 +462,8 @@ int MPIR_Get_contextid_sparse_group(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr MPIR_ERR_NONE); } else { mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, st.local_mask, MPIR_MAX_CONTEXT_MASK + 1, - MPI_INT, MPI_BAND, comm_ptr, MPIR_ERR_NONE); + MPI_INT, MPI_BAND, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } MPIR_ERR_CHECK(mpi_errno); @@ -562,7 +563,8 @@ int MPIR_Get_contextid_sparse_group(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr comm_ptr, group_ptr, coll_tag, MPIR_ERR_NONE); } else { mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, &minfree, 1, MPI_INT, - MPI_MIN, comm_ptr, MPIR_ERR_NONE); + MPI_MIN, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } if (minfree > 0) { @@ -655,7 +657,7 @@ static int sched_cb_gcn_bcast(MPIR_Comm * comm, int tag, void *state) mpi_errno = MPIR_Ibcast_intra_sched_auto(st->ctx1, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, st->comm_ptr, - st->s); + MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(st->s); } @@ -733,7 +735,7 @@ static int sched_cb_gcn_allocate_cid(MPIR_Comm * comm, int tag, void *state) */ /* FIXME: study and resolve */ /* - * mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &minfree, 1, MPI_INT, MPI_MIN, st->comm_ptr, MPIR_ERR_NONE); + * mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &minfree, 1, MPI_INT, MPI_MIN, st->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); * MPIR_ERR_CHECK(mpi_errno); */ if (minfree > 0) { @@ -837,7 +839,7 @@ static int sched_cb_gcn_copy_mask(MPIR_Comm * comm, int tag, void *state) mpi_errno = MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, st->local_mask, MPIR_MAX_CONTEXT_MASK + 1, MPI_UINT32_T, MPI_BAND, - st->comm_ptr, st->s); + st->comm_ptr, MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(st->s); @@ -1063,7 +1065,7 @@ int MPIR_Get_intercomm_contextid(MPIR_Comm * comm_ptr, MPIR_Context_id_t * conte /* Make sure that all of the local processes now have this * id */ mpi_errno = MPIR_Bcast_impl(&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, - 0, comm_ptr->local_comm, MPIR_ERR_NONE); + 0, comm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* The recvcontext_id must be the one that was allocated out of the local * group, not the remote group. Otherwise we could end up posting two diff --git a/src/mpi/stream/stream_enqueue.c b/src/mpi/stream/stream_enqueue.c index fe50b4b693f..b8367ed5599 100644 --- a/src/mpi/stream/stream_enqueue.c +++ b/src/mpi/stream/stream_enqueue.c @@ -611,8 +611,9 @@ static void allreduce_enqueue_cb(void *data) } } - mpi_errno = MPIR_Allreduce(sendbuf, recvbuf, p->count, p->datatype, p->op, p->comm_ptr, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Allreduce(sendbuf, recvbuf, p->count, p->datatype, p->op, p->comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_Assertp(mpi_errno == MPI_SUCCESS); if (p->host_recvbuf) { diff --git a/src/mpi/stream/stream_impl.c b/src/mpi/stream/stream_impl.c index 6c71de24046..e5b441127ee 100644 --- a/src/mpi/stream/stream_impl.c +++ b/src/mpi/stream/stream_impl.c @@ -269,7 +269,8 @@ int MPIR_Stream_comm_create_impl(MPIR_Comm * comm_ptr, MPIR_Stream * stream_ptr, MPIR_ERR_CHKANDJUMP(!vci_table, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = - MPIR_Allgather_impl(&vci, 1, MPI_INT, vci_table, 1, MPI_INT, comm_ptr, MPIR_ERR_NONE); + MPIR_Allgather_impl(&vci, 1, MPI_INT, vci_table, 1, MPI_INT, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); (*newcomm_ptr)->stream_comm_type = MPIR_STREAM_COMM_SINGLE; @@ -313,7 +314,8 @@ int MPIR_Stream_comm_create_multiplex_impl(MPIR_Comm * comm_ptr, MPI_Aint num_tmp = num_streams; mpi_errno = MPIR_Allgather_impl(&num_tmp, 1, MPI_AINT, - num_table, 1, MPI_AINT, comm_ptr, MPIR_ERR_NONE); + num_table, 1, MPI_AINT, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPI_Aint num_total = 0; @@ -347,7 +349,7 @@ int MPIR_Stream_comm_create_multiplex_impl(MPIR_Comm * comm_ptr, mpi_errno = MPIR_Allgatherv_impl(local_vcis, num_streams, MPI_INT, vci_table, num_table, displs, MPI_INT, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); (*newcomm_ptr)->stream_comm_type = MPIR_STREAM_COMM_MULTIPLEX; diff --git a/src/mpi/threadcomm/threadcomm_coll_impl.c b/src/mpi/threadcomm/threadcomm_coll_impl.c index 23951e5ccbf..fde81702361 100644 --- a/src/mpi/threadcomm/threadcomm_coll_impl.c +++ b/src/mpi/threadcomm/threadcomm_coll_impl.c @@ -34,7 +34,7 @@ int MPIR_Threadcomm_barrier_impl(MPIR_Comm * comm) if (comm->local_size == 1) { thread_barrier(comm->threadcomm); } else { - mpi_errno = MPIR_Barrier_intra_dissemination(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_intra_dissemination(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); } return mpi_errno; @@ -45,7 +45,9 @@ int MPIR_Threadcomm_bcast_impl(void *buffer, MPI_Aint count, MPI_Datatype dataty { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); return mpi_errno; } @@ -57,7 +59,8 @@ int MPIR_Threadcomm_gather_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Dat int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Gather_intra_binomial(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, root, comm, MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, root, comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -71,7 +74,7 @@ int MPIR_Threadcomm_gatherv_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Da mpi_errno = MPIR_Gatherv_allcomm_linear(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -84,7 +87,7 @@ int MPIR_Threadcomm_scatter_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Da mpi_errno = MPIR_Scatter_intra_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -98,7 +101,7 @@ int MPIR_Threadcomm_scatterv_impl(const void *sendbuf, const MPI_Aint * sendcoun mpi_errno = MPIR_Scatterv_allcomm_linear(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -110,7 +113,8 @@ int MPIR_Threadcomm_allgather_impl(const void *sendbuf, MPI_Aint sendcount, MPI_ int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allgather_intra_brucks(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); return mpi_errno; } @@ -124,7 +128,7 @@ int MPIR_Threadcomm_allgatherv_impl(const void *sendbuf, MPI_Aint sendcount, MPI mpi_errno = MPIR_Allgatherv_intra_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -137,7 +141,8 @@ int MPIR_Threadcomm_alltoall_impl(const void *sendbuf, MPI_Aint sendcount, MPI_D MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoall_intra_brucks(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); return mpi_errno; } @@ -153,7 +158,7 @@ int MPIR_Threadcomm_alltoallv_impl(const void *sendbuf, const MPI_Aint * sendcou MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoallv_intra_scattered(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -169,7 +174,7 @@ int MPIR_Threadcomm_alltoallw_impl(const void *sendbuf, const MPI_Aint * sendcou MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoallw_intra_scattered(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -181,7 +186,7 @@ int MPIR_Threadcomm_allreduce_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allreduce_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -193,7 +198,7 @@ int MPIR_Threadcomm_reduce_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_intra_binomial(sendbuf, recvbuf, count, datatype, op, root, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -206,7 +211,8 @@ int MPIR_Threadcomm_reduce_scatter_impl(const void *sendbuf, void *recvbuf, MPIR_Assert(MPIR_Op_is_commutative(op)); mpi_errno = MPIR_Reduce_scatter_intra_recursive_halving(sendbuf, recvbuf, recvcounts, - datatype, op, comm, MPIR_ERR_NONE); + datatype, op, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); return mpi_errno; } @@ -220,7 +226,8 @@ int MPIR_Threadcomm_reduce_scatter_block_impl(const void *sendbuf, void *recvbuf MPIR_Assert(MPIR_Op_is_commutative(op)); mpi_errno = MPIR_Reduce_scatter_block_intra_recursive_halving(sendbuf, recvbuf, recvcount, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); return mpi_errno; } @@ -231,7 +238,7 @@ int MPIR_Threadcomm_scan_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Scan_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } @@ -242,7 +249,7 @@ int MPIR_Threadcomm_exscan_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Exscan_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); return mpi_errno; } diff --git a/src/mpi/threadcomm/threadcomm_impl.c b/src/mpi/threadcomm/threadcomm_impl.c index 8161ecad5fa..8b96ddcf95c 100644 --- a/src/mpi/threadcomm/threadcomm_impl.c +++ b/src/mpi/threadcomm/threadcomm_impl.c @@ -34,8 +34,9 @@ int MPIR_Threadcomm_init_impl(MPIR_Comm * comm, int num_threads, MPIR_Comm ** co threads_table = MPL_malloc(comm_size * sizeof(int), MPL_MEM_OTHER); MPIR_ERR_CHKANDJUMP(!threads_table, mpi_errno, MPI_ERR_OTHER, "**nomem"); - mpi_errno = MPIR_Allgather_impl(&num_threads, 1, MPI_INT, threads_table, 1, MPI_INT, comm, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Allgather_impl(&num_threads, 1, MPI_INT, threads_table, 1, MPI_INT, comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); int *rank_offset_table;; diff --git a/src/mpi/topo/dist_graph_create.c b/src/mpi/topo/dist_graph_create.c index ffe9c27c850..e42e61c65c2 100644 --- a/src/mpi/topo/dist_graph_create.c +++ b/src/mpi/topo/dist_graph_create.c @@ -133,7 +133,8 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, /* compute the number of peers I will recv from */ int in_out_peers[2] = { -1, 1 }; mpi_errno = - MPIR_Reduce_scatter_block(rs, in_out_peers, 2, MPI_INT, MPI_SUM, comm_ptr, MPIR_ERR_NONE); + MPIR_Reduce_scatter_block(rs, in_out_peers, 2, MPI_INT, MPI_SUM, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPIR_Assert(in_out_peers[0] <= comm_size && in_out_peers[0] >= 0); diff --git a/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c b/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c index ca4039a9d9d..1760f00fce8 100644 --- a/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c +++ b/src/mpid/ch3/channels/nemesis/src/ch3_win_fns.c @@ -201,7 +201,7 @@ static int MPIDI_CH3I_SHM_Wins_match(MPIR_Win ** win_ptr, MPIR_Win ** matched_wi base_shm_offs[node_rank] = (MPI_Aint) ((*win_ptr)->base) - (MPI_Aint) (shm_win->shm_base_addr); mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - base_shm_offs, 1, MPI_AINT, node_comm_ptr, MPIR_ERR_NONE); + base_shm_offs, 1, MPI_AINT, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); base_diff = 0; @@ -345,12 +345,12 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* wait for other processes to attach to win */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -362,7 +362,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, /* get serialized handle from rank 0 and deserialize it */ mpi_errno = - MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -376,7 +376,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, &(*win_ptr)->info_shm_base_addr, 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -391,7 +391,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, tmp_buf[4 * comm_rank + 3] = (MPI_Aint) (*win_ptr)->handle; mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, tmp_buf, 4, MPI_AINT, - (*win_ptr)->comm_ptr, MPIR_ERR_NONE); + (*win_ptr)->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (node_rank == 0) { @@ -406,7 +406,7 @@ static int MPIDI_CH3I_Win_gather_info(void *base, MPI_Aint size, int disp_unit, } /* Make sure that all local processes see the results written by node_rank == 0 */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -479,7 +479,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, node_sizes, sizeof(MPI_Aint), MPI_BYTE, - node_comm_ptr, MPIR_ERR_NONE); + node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); MPIR_ERR_CHECK(mpi_errno); @@ -518,12 +518,12 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* wait for other processes to attach to win */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -535,7 +535,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * /* get serialized handle from rank 0 and deserialize it */ mpi_errno = - MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -550,7 +550,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * &(*win_ptr)->shm_base_addr, 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -577,12 +577,12 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd_ptr, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* wait for other processes to attach to win */ - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -594,7 +594,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * /* get serialized handle from rank 0 and deserialize it */ mpi_errno = - MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, + MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -609,7 +609,7 @@ static int MPIDI_CH3I_Win_allocate_shm(MPI_Aint size, int disp_unit, MPIR_Info * (void **) &(*win_ptr)->shm_mutex, 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch3/include/mpid_coll.h b/src/mpid/ch3/include/mpid_coll.h index 7de2103a3f9..7a2c47cb23a 100644 --- a/src/mpid/ch3/include/mpid_coll.h +++ b/src/mpid/ch3/include/mpid_coll.h @@ -11,39 +11,39 @@ #include "../../common/hcoll/hcoll.h" #endif -static inline int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +static inline int MPID_Barrier(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Barrier(comm, errflag)) return MPI_SUCCESS; #endif - return MPIR_Barrier_impl(comm, errflag); + return MPIR_Barrier_impl(comm, coll_group, errflag); } static inline int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Bcast(buffer, count, datatype, root, comm, errflag)) return MPI_SUCCESS; #endif - return MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); } static inline int MPID_Allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag)) return MPI_SUCCESS; #endif - return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); } static inline int MPID_Allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { #ifdef HAVE_HCOLL if (MPI_SUCCESS == hcoll_Allgather(sendbuf, sendcount, sendtype, recvbuf, @@ -51,17 +51,17 @@ static inline int MPID_Allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Da return MPI_SUCCESS; #endif return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, errflag); } static inline int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm, + recvcounts, displs, recvtype, comm, coll_group, errflag); return mpi_errno; @@ -69,25 +69,25 @@ static inline int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_D static inline int MPID_Scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Scatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, + recvbuf, recvcount, recvtype, root, comm, coll_group, errflag); return mpi_errno; @@ -95,25 +95,25 @@ static inline int MPID_Scatterv(const void *sendbuf, const MPI_Aint * sendcounts static inline int MPID_Gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, + recvcounts, displs, recvtype, root, comm, coll_group, errflag); return mpi_errno; @@ -121,26 +121,26 @@ static inline int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Data static inline int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Alltoallv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, - const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, + const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, - comm, errflag); + comm, coll_group, errflag); return mpi_errno; } @@ -148,75 +148,75 @@ static inline int MPID_Alltoallv(const void *sendbuf, const MPI_Aint * sendcount static inline int MPID_Alltoallw(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], - const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, + const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); return mpi_errno; } static inline int MPID_Reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); return mpi_errno; } static inline int MPID_Reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, + datatype, op, comm_ptr, coll_group, errflag); return mpi_errno; } static inline int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); return mpi_errno; } static inline int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); return mpi_errno; @@ -366,70 +366,70 @@ static inline int MPID_Ineighbor_alltoallw(const void *sendbuf, const MPI_Aint s return mpi_errno; } -static inline int MPID_Ibarrier(MPIR_Comm * comm, MPIR_Request **request) +static inline int MPID_Ibarrier(MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibarrier_impl(comm, request); + mpi_errno = MPIR_Ibarrier_impl(comm, coll_group, request); return mpi_errno; } static inline int MPID_Ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, request); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, request); return mpi_errno; } static inline int MPID_Iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, request); + recvcount, recvtype, comm, coll_group, request); return mpi_errno; } static inline int MPID_Iallgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm * comm, MPIR_Request **request) + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm, + recvcounts, displs, recvtype, comm, coll_group, request); return mpi_errno; } static inline int MPID_Iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, - comm, request); + comm, coll_group, request); return mpi_errno; } static inline int MPID_Ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, request); + recvcount, recvtype, comm, coll_group, request); return mpi_errno; } @@ -438,13 +438,13 @@ static inline int MPID_Ialltoallv(const void *sendbuf, const MPI_Aint sendcounts const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, - comm, request); + comm, coll_group, request); return mpi_errno; } @@ -453,24 +453,24 @@ static inline int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint sendcounts const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request **request) + MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, - comm, request); + comm, coll_group, request); return mpi_errno; } static inline int MPID_Iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); return mpi_errno; @@ -478,71 +478,71 @@ static inline int MPID_Iexscan(const void *sendbuf, void *recvbuf, MPI_Aint coun static inline int MPID_Igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request **request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); return mpi_errno; } static inline int MPID_Igatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, + recvcounts, displs, recvtype, root, comm, coll_group, request); return mpi_errno; } static inline int MPID_Ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm, request); + datatype, op, comm, coll_group, request); return mpi_errno; } static inline int MPID_Ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm, request); + datatype, op, comm, coll_group, request); return mpi_errno; } static inline int MPID_Ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm * comm, MPIR_Request **request) + MPI_Op op, int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, - comm, request); + comm, coll_group, request); return mpi_errno; } static inline int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Request **request) + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); return mpi_errno; @@ -550,12 +550,12 @@ static inline int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, static inline int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request **request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); return mpi_errno; } @@ -563,12 +563,12 @@ static inline int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Dat static inline int MPID_Iscatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request **request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, + recvbuf, recvcount, recvtype, root, comm, coll_group, request); return mpi_errno; diff --git a/src/mpid/ch3/include/mpidpre.h b/src/mpid/ch3/include/mpidpre.h index 9a98093d328..63835586fdf 100644 --- a/src/mpid/ch3/include/mpidpre.h +++ b/src/mpid/ch3/include/mpidpre.h @@ -614,71 +614,71 @@ int MPID_Recv_init( void *buf, MPI_Aint count, MPI_Datatype datatype, int MPID_Startall(int count, MPIR_Request *requests[]); int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request **request); + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request); int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Op op, int root, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request); int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request); + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request); int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request); + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request); int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Datatype recvtype, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request); int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); + int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Neighbor_allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, diff --git a/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c b/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c index 7c1feb92dae..c1c9d94c940 100644 --- a/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c +++ b/src/mpid/ch3/src/ch3u_comm_spawn_multiple.c @@ -75,13 +75,13 @@ int MPIDI_Comm_spawn_multiple(int count, char **commands, } if (errcodes != MPI_ERRCODES_IGNORE) { - mpi_errno = MPIR_Bcast(&should_accept, 1, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast(&should_accept, 1, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Bcast(&total_num_processes, 1, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast(&total_num_processes, 1, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Bcast(errcodes, total_num_processes, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast(errcodes, total_num_processes, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch3/src/ch3u_port.c b/src/mpid/ch3/src/ch3u_port.c index 3ef43d48aab..9ca1fc5bfde 100644 --- a/src/mpid/ch3/src/ch3u_port.c +++ b/src/mpid/ch3/src/ch3u_port.c @@ -657,7 +657,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, /* broadcast the received info to local processes */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"broadcasting the received 3 ints"); - mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* check if root was unable to connect to the port */ @@ -711,7 +711,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, /* Broadcast out the remote rank translation array */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Broadcasting remote translation"); mpi_errno = MPIR_Bcast_allcomm_auto(remote_translation, remote_comm_size * 2, MPI_INT, - root, comm_ptr, MPIR_ERR_NONE); + root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); #ifdef MPICH_DBG_OUTPUT @@ -749,7 +749,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, } /*printf("connect:barrier\n");fflush(stdout);*/ - mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Free new_vc. It was explicitly allocated in MPIDI_CH3_Connect_to_root.*/ @@ -795,7 +795,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, /* notify other processes to return an error */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"broadcasting 3 ints: error case"); - mpi_errno2 = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno2 = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno2) MPIR_ERR_ADD(mpi_errno, mpi_errno2); goto fn_fail; } @@ -943,7 +943,7 @@ static int ReceivePGAndDistribute( MPIR_Comm *tmp_comm, MPIR_Comm *comm_ptr, /* Broadcast the size and data to the local communicator */ /*printf("accept:broadcasting 1 int\n");fflush(stdout);*/ - mpi_errno = MPIR_Bcast_allcomm_auto(&j, 1, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(&j, 1, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (rank != root) { @@ -954,7 +954,7 @@ static int ReceivePGAndDistribute( MPIR_Comm *tmp_comm, MPIR_Comm *comm_ptr, } } /*printf("accept:broadcasting string of length %d\n", j);fflush(stdout);*/ - mpi_errno = MPIR_Bcast_allcomm_auto(pg_str, j, MPI_CHAR, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(pg_str, j, MPI_CHAR, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Then reconstruct the received process group. This step also initializes the created process group */ @@ -998,7 +998,7 @@ int MPID_PG_BCast( MPIR_Comm *peercomm_p, MPIR_Comm *comm_p, int root ) } /* Now, broadcast the number of local pgs */ - mpi_errno = MPIR_Bcast( &n_local_pgs, 1, MPI_INT, root, comm_p, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast( &n_local_pgs, 1, MPI_INT, root, comm_p, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); pg_list = pg_head; @@ -1018,7 +1018,7 @@ int MPID_PG_BCast( MPIR_Comm *peercomm_p, MPIR_Comm *comm_p, int root ) len = pg_list->lenStr; pg_list = pg_list->next; } - mpi_errno = MPIR_Bcast( &len, 1, MPI_INT, root, comm_p, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast( &len, 1, MPI_INT, root, comm_p, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (rank != root) { pg_str = (char *)MPL_malloc(len, MPL_MEM_DYNAMIC); @@ -1027,7 +1027,7 @@ int MPID_PG_BCast( MPIR_Comm *peercomm_p, MPIR_Comm *comm_p, int root ) goto fn_exit; } } - mpi_errno = MPIR_Bcast( pg_str, len, MPI_CHAR, root, comm_p, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast( pg_str, len, MPI_CHAR, root, comm_p, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno) { if (rank != root) MPL_free( pg_str ); @@ -1189,7 +1189,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, /* broadcast the received info to local processes */ /*printf("accept:broadcasting 2 ints - %d and %d\n", recv_ints[0], recv_ints[1]);fflush(stdout);*/ - mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Bcast_allcomm_auto(recv_ints, 3, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -1244,7 +1244,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, /* Broadcast out the remote rank translation array */ MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Broadcast remote_translation"); mpi_errno = MPIR_Bcast_allcomm_auto(remote_translation, remote_comm_size * 2, MPI_INT, - root, comm_ptr, MPIR_ERR_NONE); + root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); #ifdef MPICH_DBG_OUTPUT MPL_DBG_MSG_D(MPIDI_CH3_DBG_OTHER,TERSE,"[%d]accept:Received remote_translation after broadcast:\n", rank); @@ -1280,7 +1280,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, } MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Barrier"); - mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Free new_vc once the connection is completed. It was explicitly @@ -1360,7 +1360,7 @@ static int SetupNewIntercomm( MPIR_Comm *comm_ptr, int remote_comm_size, MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"Barrier"); - mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_allcomm_auto(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpid/ch3/src/ch3u_rma_sync.c b/src/mpid/ch3/src/ch3u_rma_sync.c index 6463c17ab4a..8510fb6a10d 100644 --- a/src/mpid/ch3/src/ch3u_rma_sync.c +++ b/src/mpid/ch3/src/ch3u_rma_sync.c @@ -489,11 +489,11 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) if (win_ptr->shm_allocated == TRUE) { MPIR_Comm *node_comm_ptr = win_ptr->comm_ptr->node_comm; - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, &fence_sync_req_ptr); + mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, &fence_sync_req_ptr); MPIR_ERR_CHECK(mpi_errno); if (fence_sync_req_ptr == NULL) { @@ -539,7 +539,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) win_ptr->at_completion_counter += comm_size; mpi_errno = MPIR_Reduce_scatter_block(MPI_IN_PLACE, rma_target_marks, 1, - MPI_INT, MPI_SUM, win_ptr->comm_ptr, MPIR_ERR_NONE); + MPI_INT, MPI_SUM, win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); win_ptr->at_completion_counter -= comm_size; @@ -579,7 +579,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) MPIR_ERR_CHECK(mpi_errno); if (scalable_fence_enabled) { - mpi_errno = MPIR_Barrier(win_ptr->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Set window access state properly. */ @@ -604,7 +604,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) MPIR_Request* fence_sync_req_ptr; /* Prepare for the next possible epoch */ - mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, &fence_sync_req_ptr); + mpi_errno = MPIR_Ibarrier(win_ptr->comm_ptr, MPIR_SUBGROUP_NONE, &fence_sync_req_ptr); MPIR_ERR_CHECK(mpi_errno); if (fence_sync_req_ptr == NULL) { @@ -629,7 +629,7 @@ int MPID_Win_fence(int assert, MPIR_Win * win_ptr) if (win_ptr->shm_allocated == TRUE) { MPIR_Comm *node_comm_ptr = win_ptr->comm_ptr->node_comm; - mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(node_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpid/ch3/src/ch3u_win_fns.c b/src/mpid/ch3/src/ch3u_win_fns.c index 0891b21d723..5e1b3083016 100644 --- a/src/mpid/ch3/src/ch3u_win_fns.c +++ b/src/mpid/ch3/src/ch3u_win_fns.c @@ -62,7 +62,7 @@ int MPIDI_CH3U_Win_gather_info(void *base, MPI_Aint size, int disp_unit, tmp_buf[4 * rank + 3] = (MPI_Aint) (*win_ptr)->handle; mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - tmp_buf, 4, MPI_AINT, (*win_ptr)->comm_ptr, MPIR_ERR_NONE); + tmp_buf, 4, MPI_AINT, (*win_ptr)->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch3/src/mpid_startall.c b/src/mpid/ch3/src/mpid_startall.c index cba93847a43..034d0c95bb1 100644 --- a/src/mpid/ch3/src/mpid_startall.c +++ b/src/mpid/ch3/src/mpid_startall.c @@ -317,12 +317,12 @@ int MPID_Recv_init(void * buf, MPI_Aint count, MPI_Datatype datatype, int rank, } int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request **request) + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, info_ptr, request); + mpi_errno = MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -334,13 +334,13 @@ int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int roo } int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, info_ptr, + mpi_errno = MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -353,13 +353,13 @@ int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_ } int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, int root, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Op op, int root, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request **request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + mpi_errno = MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -373,13 +373,13 @@ int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Dat int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request) + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoall_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -392,14 +392,14 @@ int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sen int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallv_init_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm_ptr, info_ptr, + recvcounts, rdispls, recvtype, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -414,13 +414,13 @@ int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallw_init_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm_ptr, info_ptr, + recvcounts, rdispls, recvtypes, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -434,13 +434,13 @@ int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, MPIR_Request** request) + MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -453,14 +453,14 @@ int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype se int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, - MPI_Datatype recvtype, MPIR_Comm *comm_ptr, MPIR_Info* info_ptr, + MPI_Datatype recvtype, MPIR_Comm *comm_ptr, int coll_group, MPIR_Info* info_ptr, MPIR_Request** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, info_ptr, request); + displs, recvtype, comm_ptr, coll_group, info_ptr, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -472,13 +472,13 @@ int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype s } int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, + mpi_errno = MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -491,13 +491,13 @@ int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint } int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, + mpi_errno = MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -510,12 +510,12 @@ int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint } int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -527,14 +527,14 @@ int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datat } int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -547,13 +547,13 @@ int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendt int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -565,14 +565,14 @@ int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype send } int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -585,14 +585,14 @@ int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype send int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, MPIR_Info * info, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_init_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -603,12 +603,12 @@ int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const M goto fn_exit; } -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Barrier_init_impl(comm, info, request); + mpi_errno = MPIR_Barrier_init_impl(comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); @@ -620,12 +620,12 @@ int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** reques } int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, request); MPIR_ERR_CHECK(mpi_errno); MPIDI_Request_set_type(*request, MPIDI_REQUEST_TYPE_PERSISTENT_COLL); diff --git a/src/mpid/ch3/src/mpid_vc.c b/src/mpid/ch3/src/mpid_vc.c index 81cb71c91e6..cf64f4063f3 100644 --- a/src/mpid/ch3/src/mpid_vc.c +++ b/src/mpid/ch3/src/mpid_vc.c @@ -554,10 +554,10 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, comm_info[0] = *remote_size; comm_info[1] = *is_low_group; MPL_DBG_MSG(MPIDI_CH3_DBG_OTHER,VERBOSE,"About to bcast on local_comm"); - mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE ); + mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast( remote_gpids, (*remote_size)*sizeof(MPIDI_Gpid), MPI_BYTE, local_leader, - local_comm_ptr, MPIR_ERR_NONE ); + local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_D(MPIDI_CH3_DBG_OTHER,VERBOSE,"end of bcast on local_comm of size %d", local_comm_ptr->local_size ); @@ -566,13 +566,13 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, { /* we're the other processes */ MPL_DBG_MSG(MPIDI_CH3_DBG_OTHER,VERBOSE,"About to receive bcast on local_comm"); - mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_ERR_NONE ); + mpi_errno = MPIR_Bcast( comm_info, 2, MPI_INT, local_leader, local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); *remote_size = comm_info[0]; MPIR_CHKLMEM_MALLOC(remote_gpids,MPIDI_Gpid*,(*remote_size)*sizeof(MPIDI_Gpid), mpi_errno,"remote_gpids", MPL_MEM_DYNAMIC); *remote_lpids = (uint64_t*) MPL_malloc((*remote_size)*sizeof(uint64_t), MPL_MEM_ADDRESS); mpi_errno = MPIR_Bcast( remote_gpids, (*remote_size)*sizeof(MPIDI_Gpid), MPI_BYTE, local_leader, - local_comm_ptr, MPIR_ERR_NONE ); + local_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); /* Extract the context and group sign information */ @@ -736,7 +736,7 @@ int MPIDI_PG_ForwardPGInfo( MPIR_Comm *peer_ptr, MPIR_Comm *comm_ptr, } /* See if everyone is happy */ - mpi_errno = MPIR_Allreduce( MPI_IN_PLACE, &allfound, 1, MPI_INT, MPI_LAND, comm_ptr, MPIR_ERR_NONE ); + mpi_errno = MPIR_Allreduce( MPI_IN_PLACE, &allfound, 1, MPI_INT, MPI_LAND, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); if (allfound) return MPI_SUCCESS; diff --git a/src/mpid/ch3/src/mpidi_rma.c b/src/mpid/ch3/src/mpidi_rma.c index 8d86d47ff91..2715e87d8b7 100644 --- a/src/mpid/ch3/src/mpidi_rma.c +++ b/src/mpid/ch3/src/mpidi_rma.c @@ -164,7 +164,7 @@ int MPID_Win_free(MPIR_Win ** win_ptr) MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Barrier((*win_ptr)->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier((*win_ptr)->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Free window resources in lower layer. */ diff --git a/src/mpid/ch4/include/mpidch4.h b/src/mpid/ch4/include/mpidch4.h index 3dd3528efbc..4fc7ad8e999 100644 --- a/src/mpid/ch4/include/mpidch4.h +++ b/src/mpid/ch4/include/mpidch4.h @@ -175,52 +175,62 @@ int MPID_Comm_set_hints(MPIR_Comm *, MPIR_Info *); int MPID_Comm_commit_post_hook(MPIR_Comm *); int MPID_Stream_create_hook(MPIR_Stream * stream); int MPID_Stream_free_hook(MPIR_Stream * stream); -MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; +MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *, MPI_Aint, MPI_Datatype, int, MPIR_Comm *, - MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, void *, MPI_Aint, MPI_Datatype, int, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, int, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, MPIR_Comm *, + int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], MPIR_Comm *, + int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, int, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *, void *, const MPI_Aint[], - MPI_Datatype, MPI_Op, MPIR_Comm *, + MPI_Datatype, MPI_Op, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *, void *, MPI_Aint, MPI_Datatype, - MPI_Op, MPIR_Comm *, + MPI_Op, MPIR_Comm *, int coll_group, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Errflag_t) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Neighbor_allgather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, MPI_Datatype, MPIR_Comm *) MPL_STATIC_INLINE_SUFFIX; @@ -261,118 +271,119 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ineighbor_alltoallw(const void *, const MPI_Ai void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; -MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; +MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ibcast(void *, MPI_Aint, MPI_Datatype, int, MPIR_Comm *, - MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iallgather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iallgatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, - MPIR_Comm *, + MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iallreduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ialltoall(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, MPIR_Comm *, + MPI_Datatype, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallv(const void *, const MPI_Aint[], const MPI_Aint[], MPI_Datatype, void *, const MPI_Aint[], const MPI_Aint[], MPI_Datatype, MPIR_Comm *, + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], void *, const MPI_Aint[], const MPI_Aint[], const MPI_Datatype[], MPIR_Comm *, + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iexscan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Igather(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Igatherv(const void *, MPI_Aint, MPI_Datatype, void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, int, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter_block(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, MPIR_Comm *, + int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter(const void *, void *, const MPI_Aint[], - MPI_Datatype, MPI_Op, MPIR_Comm *, + MPI_Datatype, MPI_Op, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Ireduce(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, int, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iscan(const void *, void *, MPI_Aint, MPI_Datatype, MPI_Op, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iscatter(const void *, MPI_Aint, MPI_Datatype, void *, MPI_Aint, - MPI_Datatype, int, MPIR_Comm *, + MPI_Datatype, int, MPIR_Comm *, int coll_group, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; MPL_STATIC_INLINE_PREFIX int MPID_Iscatterv(const void *, const MPI_Aint *, const MPI_Aint *, MPI_Datatype, void *, MPI_Aint, MPI_Datatype, int, - MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; -int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], - const MPI_Aint sdispls[], MPI_Datatype sendtype, - void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], - const MPI_Aint sdispls[], - const MPI_Datatype sendtypes[], - void *recvbuf, const MPI_Aint recvcounts[], - const MPI_Aint rdispls[], - const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - const MPI_Aint * recvcounts, - const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request); -int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, - MPI_Aint recvcount, - MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, - const MPI_Aint recvcounts[], - MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Info * info, MPIR_Request ** request); -int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Info * info, MPIR_Request ** request); -int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, - MPI_Datatype sendtype, void *recvbuf, - MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); -int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], - const MPI_Aint displs[], MPI_Datatype sendtype, - void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPIR_Comm *, int coll_group, + MPIR_Request **) MPL_STATIC_INLINE_SUFFIX; +int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, + MPI_Op op, int root, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Info * info_ptr, MPIR_Request ** request); +int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, + void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], + MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], + const MPI_Aint rdispls[], MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Info * info_ptr, MPIR_Request ** request); +int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], + const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], + const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, + void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request); +int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, + void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, + MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int coll_group, + MPIR_Info * info_ptr, MPIR_Request ** request); +int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Info * info, MPIR_Request ** request); +int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, + MPIR_Info * info, MPIR_Request ** request); +int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request); +int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Info * info, MPIR_Request ** request); +int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, + const MPI_Aint recvcounts[], const MPI_Aint displs[], MPI_Datatype recvtype, + int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request); +int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, + MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Info * info, MPIR_Request ** request); +int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], + MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request); +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request); int MPID_Neighbor_allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, diff --git a/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h b/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h index 9ea0e1f6048..7515d06e583 100644 --- a/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h +++ b/src/mpid/ch4/netmod/include/netmod_am_fallback_coll.h @@ -6,93 +6,99 @@ #ifndef NETMOD_AM_FALLBACK_COLL_H_INCLUDED #define NETMOD_AM_FALLBACK_COLL_H_INCLUDED -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Barrier_impl(comm_ptr, errflag); + return MPIR_Barrier_impl(comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm_ptr, errflag); + recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, @@ -100,10 +106,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, @@ -113,50 +121,57 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag); + return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_neighbor_allgather(const void *sendbuf, @@ -288,25 +303,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ineighbor_alltoallw(const void *sendbu rdispls, recvtypes, comm_ptr, req); } -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibarrier_impl(comm_ptr, req); + return MPIR_Ibarrier_impl(comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, req); + return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_Aint sendcount, @@ -314,27 +332,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, req); + recvcounts, displs, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { - return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, @@ -344,10 +363,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, @@ -357,81 +376,87 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, req); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, req); + return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { - return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { return MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, @@ -439,10 +464,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { return MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); } #endif /* NETMOD_AM_FALLBACK_COLL_H_INCLUDED */ diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h index cd32092ccd7..acf822b005e 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_rma.h @@ -29,7 +29,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_rma(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - int tree_type, + int coll_group, int tree_type, int branching_factor) { int mpi_errno = MPI_SUCCESS; @@ -53,13 +53,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_rma(void *buffer, i /* Invoke the helper function to perform one-sided knomial tree-based Ibcast */ mpi_errno = MPIDI_OFI_Ibcast_knomial_triggered_rma(buffer, count, datatype, root, comm_ptr, - tree_type, branching_factor, &num_children, - &snd_cntr, &rcv_cntr, &r_mr, &works, &my_tree, - myrank, nranks, &num_works); + coll_group, tree_type, branching_factor, + &num_children, &snd_cntr, &rcv_cntr, &r_mr, + &works, &my_tree, myrank, nranks, &num_works); } else { /* Invoke the helper function to perform one-sided kary tree-based Ibcast */ mpi_errno = - MPIDI_OFI_Ibcast_kary_triggered_rma(buffer, count, datatype, root, comm_ptr, + MPIDI_OFI_Ibcast_kary_triggered_rma(buffer, count, datatype, root, comm_ptr, coll_group, branching_factor, &leaf, &num_children, &snd_cntr, &rcv_cntr, &r_mr, &works, myrank, nranks, &num_works); diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h index 46fdc79a163..a1077eb93d7 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_bcast_tree_tagged.h @@ -28,7 +28,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_tagged(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - int tree_type, + int coll_group, int tree_type, int branching_factor) { int mpi_errno = MPI_SUCCESS; @@ -52,16 +52,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Bcast_intra_triggered_tagged(void *buffer if (tree_type == MPIR_TREE_TYPE_KNOMIAL_1 || tree_type == MPIR_TREE_TYPE_KNOMIAL_2) { mpi_errno = MPIDI_OFI_Ibcast_knomial_triggered_tagged(buffer, count, datatype, root, comm_ptr, - tree_type, branching_factor, &num_children, - &snd_cntr, &rcv_cntr, &works, &my_tree, - myrank, nranks, &num_works); + coll_group, tree_type, branching_factor, + &num_children, &snd_cntr, &rcv_cntr, &works, + &my_tree, myrank, nranks, &num_works); } else { /* Invoke the helper function to perform kary tree-based Ibcast */ mpi_errno = MPIDI_OFI_Ibcast_kary_triggered_tagged(buffer, count, datatype, root, comm_ptr, - branching_factor, &leaf, &num_children, - &snd_cntr, &rcv_cntr, &works, myrank, nranks, - &num_works); + coll_group, branching_factor, &leaf, + &num_children, &snd_cntr, &rcv_cntr, &works, + myrank, nranks, &num_works); } /* Wait for the completion counters to reach their desired values */ diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h index a9de6f689d6..957a67d11ef 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h @@ -36,6 +36,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, int tree_type, int branching_factor, int *num_children, @@ -193,6 +194,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, int branching_factor, int *is_leaf, int *children, struct fid_cntr **snd_cntr, @@ -347,7 +349,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, - int tree_type, + int coll_group, int tree_type, int branching_factor, int *num_children, struct fid_cntr **snd_cntr, @@ -502,6 +504,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_rma(void *buffer, int count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, int branching_factor, int *is_leaf, int *children, struct fid_cntr **snd_cntr, diff --git a/src/mpid/ch4/netmod/ofi/init_addrxchg.c b/src/mpid/ch4/netmod/ofi/init_addrxchg.c index f323565b663..5a3fec99785 100644 --- a/src/mpid/ch4/netmod/ofi/init_addrxchg.c +++ b/src/mpid/ch4/netmod/ofi/init_addrxchg.c @@ -222,7 +222,8 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) MPIR_CHKLMEM_MALLOC(all_num_vcis, void *, sizeof(int) * size, mpi_errno, "all_num_vcis", MPL_MEM_ADDRESS); mpi_errno = MPIR_Allgather_fallback(&MPIDI_OFI_global.num_vcis, 1, MPI_INT, - all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); + all_num_vcis, 1, MPI_INT, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); max_vcis = 0; @@ -261,7 +262,8 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) } /* Allgather */ mpi_errno = MPIR_Allgather_fallback(MPI_IN_PLACE, 0, MPI_BYTE, - all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + all_names, my_len, MPI_BYTE, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); /* Step 2: insert and store non-root nic/vci on the root context */ int root_ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); @@ -335,7 +337,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) } } } - mpi_errno = MPIR_Barrier_fallback(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_fallback(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* check */ diff --git a/src/mpid/ch4/netmod/ofi/ofi_coll.h b/src/mpid/ch4/netmod/ofi/ofi_coll.h index 36a680845b2..baa39ddc85c 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_coll.h +++ b/src/mpid/ch4/netmod/ofi/ofi_coll.h @@ -30,13 +30,14 @@ === END_MPI_T_CVAR_INFO_BLOCK === */ -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Barrier_impl(comm, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -49,11 +50,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag } static inline int MPIDI_OFI_bcast_json(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -63,7 +65,8 @@ static inline int MPIDI_OFI_bcast_json(void *buffer, MPI_Aint count, MPI_Datatyp } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; enum fi_datatype fi_dt; @@ -79,7 +82,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP "Bcast triggered_tagged cannot be applied.\n"); mpi_errno = MPIDI_OFI_Bcast_intra_triggered_tagged(buffer, count, datatype, root, comm, - MPIR_Bcast_tree_type, + coll_group, MPIR_Bcast_tree_type, MPIR_CVAR_BCAST_TREE_KVAL); break; case MPIR_CVAR_BCAST_OFI_INTRA_ALGORITHM_trigger_tree_rma: @@ -89,14 +92,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP NULL) != -1, mpi_errno, "Bcast triggered_rma cannot be applied.\n"); mpi_errno = - MPIDI_OFI_Bcast_intra_triggered_rma(buffer, count, datatype, root, comm, + MPIDI_OFI_Bcast_intra_triggered_rma(buffer, count, datatype, root, comm, coll_group, MPIR_Bcast_tree_type, MPIR_CVAR_BCAST_TREE_KVAL); break; case MPIR_CVAR_BCAST_OFI_INTRA_ALGORITHM_mpir: goto fallback; case MPIR_CVAR_BCAST_OFI_INTRA_ALGORITHM_auto: - mpi_errno = MPIDI_OFI_bcast_json(buffer, count, datatype, root, comm, errflag); + mpi_errno = + MPIDI_OFI_bcast_json(buffer, count, datatype, root, comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -105,7 +109,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP goto fn_exit; fallback: - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_FUNC_EXIT; @@ -118,14 +122,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -140,14 +145,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *r MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -163,14 +169,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, errflag); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -185,14 +193,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -208,7 +217,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -216,7 +225,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, errflag); + recvcounts, displs, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -232,7 +241,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -240,7 +249,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -256,14 +265,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -278,14 +288,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -302,7 +313,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -310,7 +322,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); + rdispls, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -329,7 +341,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -337,7 +350,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); + rdispls, recvtypes, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -351,13 +364,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -372,13 +387,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, errflag); + mpi_errno = + MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -393,14 +411,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, vo MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, errflag); + MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -414,13 +433,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -434,13 +453,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbu MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -634,12 +653,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ineighbor_alltoallw(const void *sendbu return mpi_errno; } -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibarrier_impl(comm, req); + mpi_errno = MPIR_Ibarrier_impl(comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -647,12 +667,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Reques MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, req); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -661,13 +681,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -678,13 +699,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, req); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -692,13 +713,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + mpi_errno = + MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -707,13 +729,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -726,13 +749,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req); + sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -745,13 +769,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, - sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req); + sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -759,12 +785,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -773,13 +800,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -789,13 +817,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, root, comm, req); + recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -804,14 +834,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm, req); + datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -820,12 +850,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *send MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + mpi_errno = + MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -833,12 +866,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, v MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, req); + mpi_errno = + MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -846,12 +881,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -860,14 +895,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -878,13 +913,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/netmod/ofi/ofi_comm.c b/src/mpid/ch4/netmod/ofi/ofi_comm.c index 57b9cb131de..984b2412a31 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_comm.c +++ b/src/mpid/ch4/netmod/ofi/ofi_comm.c @@ -106,7 +106,8 @@ static int update_nic_preferences(MPIR_Comm * comm) /* Collect the NIC IDs set for the other ranks. We always expect to receive a single * NIC id from each rank, i.e., one MPI_INT. */ mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_INT, - pref_nic_copy, 1, MPI_INT, comm, MPIR_ERR_NONE); + pref_nic_copy, 1, MPI_INT, comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (MPIDI_OFI_COMM(comm).pref_nic == NULL) { diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 634d3b7facb..a2e27a3db8b 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -867,7 +867,8 @@ static int check_num_nics(void) /* Confirm that all processes have the same number of NICs */ mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, &num_nics, 1, MPI_INT, - MPI_MIN, MPIR_Process.comm_world, MPIR_ERR_NONE); + MPI_MIN, MPIR_Process.comm_world, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIDI_OFI_global.num_vcis = tmp_num_vcis; MPIDI_OFI_global.num_nics = tmp_num_nics; MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/netmod/ofi/ofi_win.c b/src/mpid/ch4/netmod/ofi/ofi_win.c index 997a4dfd111..35501d40dfe 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_win.c +++ b/src/mpid/ch4/netmod/ofi/ofi_win.c @@ -137,7 +137,8 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) * available to the processes involved in the RMA window. Use the current maximum + 1 * to ensure that the key is available for all processes. */ mpi_errno = MPIR_Allreduce(&MPIDI_OFI_global.global_max_optimized_mr_key, &local_key, 1, - MPI_UNSIGNED, MPI_MAX, comm_ptr, MPIR_ERR_NONE); + MPI_UNSIGNED, MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (local_key + 1 < MPIDI_OFI_NUM_OPTIMIZED_MEMORY_REGIONS) { @@ -220,7 +221,7 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) } /* Check if any process fails to register. If so, release local MR and force AM path. */ - MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_ERR_NONE); + MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (allrc < 0) { if (rc >= 0 && MPIDI_OFI_WIN(win).mr) MPIDI_OFI_CALL(fi_close(&MPIDI_OFI_WIN(win).mr->fid), fi_close); @@ -244,7 +245,8 @@ static int win_allgather(MPIR_Win * win, void *base, int disp_unit) mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, - winfo, sizeof(*winfo), MPI_BYTE, comm_ptr, MPIR_ERR_NONE); + winfo, sizeof(*winfo), MPI_BYTE, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (!MPIDI_OFI_ENABLE_MR_PROV_KEY && !MPIDI_OFI_ENABLE_MR_VIRT_ADDRESS) { @@ -969,7 +971,7 @@ int MPIDI_OFI_mpi_win_attach_hook(MPIR_Win * win, void *base, MPI_Aint size) } /* Check if any process fails to register. If so, release local MR and force AM path. */ - MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_ERR_NONE); + MPIR_Allreduce(&rc, &allrc, 1, MPI_INT, MPI_MIN, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (allrc < 0) { if (rc >= 0) MPIDI_OFI_CALL(fi_close(&mr->fid), fi_close); @@ -995,7 +997,7 @@ int MPIDI_OFI_mpi_win_attach_hook(MPIR_Win * win, void *base, MPI_Aint size) mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, target_mrs, sizeof(dwin_target_mr_t), MPI_BYTE, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Insert each remote MR which will be searched when issuing an RMA operation @@ -1053,7 +1055,7 @@ int MPIDI_OFI_mpi_win_detach_hook(MPIR_Win * win, const void *base) target_bases[comm_ptr->rank] = base; mpi_errno = MPIR_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, target_bases, sizeof(const void *), MPI_BYTE, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Search and delete each remote MR */ diff --git a/src/mpid/ch4/netmod/ucx/ucx_coll.h b/src/mpid/ch4/netmod/ucx/ucx_coll.h index 6a1d4759958..dde47058a15 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_coll.h +++ b/src/mpid/ch4/netmod/ucx/ucx_coll.h @@ -11,7 +11,8 @@ #include "../../../common/hcoll/hcoll.h" #endif -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -21,7 +22,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Err if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; @@ -29,7 +30,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Err } MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; @@ -40,7 +41,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; @@ -49,7 +50,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_bcast(void *buffer, MPI_Aint count, MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm_ptr, + MPI_Op op, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; @@ -60,7 +61,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *r if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + errflag); } MPIR_FUNC_EXIT; @@ -70,7 +73,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allreduce(const void *sendbuf, void *r MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -82,7 +86,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgather(const void *sendbuf, MPI_Ain #endif { mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; @@ -93,13 +97,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -108,14 +113,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_allgatherv(const void *sendbuf, MPI_Ai MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -125,14 +130,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -141,14 +147,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_gatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -158,13 +164,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm_ptr, errflag); + recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -173,7 +181,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scatterv(const void *sendbuf, const MP MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -185,7 +194,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoall(const void *sendbuf, MPI_Aint #endif { mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPIR_FUNC_EXIT; return mpi_errno; @@ -196,7 +205,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -208,7 +218,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallv(const void *sendbuf, #endif { mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, + coll_group, errflag); } MPIR_FUNC_EXIT; return mpi_errno; @@ -221,13 +232,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -235,7 +248,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; @@ -245,8 +259,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recv if (mpi_errno != MPI_SUCCESS) #endif { - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, - errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); } MPIR_FUNC_EXIT; return mpi_errno; @@ -255,14 +270,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -271,14 +286,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter(const void *sendbuf, vo MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -286,12 +301,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_reduce_scatter_block(const void *sendb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -299,12 +316,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_scan(const void *sendbuf, void *recvbu MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return mpi_errno; @@ -501,12 +520,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ineighbor_alltoallw(const void *sendbu return mpi_errno; } -MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibarrier_impl(comm_ptr, req); + mpi_errno = MPIR_Ibarrier_impl(comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -514,12 +534,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Re MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, req); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -528,13 +549,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ibcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -545,13 +567,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, req); + recvcounts, displs, recvtype, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -559,13 +581,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallgatherv(const void *sendbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + mpi_errno = + MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -574,13 +597,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iallreduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -593,13 +617,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -612,13 +637,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -626,12 +653,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -640,14 +668,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iexscan(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -657,14 +685,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_igatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, req); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -674,13 +702,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *send MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -689,13 +717,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter_block(const void *send MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -703,13 +732,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce_scatter(const void *sendbuf, v MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + mpi_errno = + MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -717,12 +747,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_ireduce(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -731,14 +762,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscan(const void *sendbuf, void *recvb MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -749,13 +780,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/netmod/ucx/ucx_init.c b/src/mpid/ch4/netmod/ucx/ucx_init.c index 56701043f1d..0a727e44ce7 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_init.c +++ b/src/mpid/ch4/netmod/ucx/ucx_init.c @@ -170,7 +170,8 @@ static int all_vcis_address_exchange(void) /* Allgather */ MPIR_Comm *comm = MPIR_Process.comm_world; mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_BYTE, - all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + all_names, my_len, MPI_BYTE, comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* insert the addresses */ diff --git a/src/mpid/ch4/netmod/ucx/ucx_win.c b/src/mpid/ch4/netmod/ucx/ucx_win.c index 6a4a51fa85c..837a748e219 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_win.c +++ b/src/mpid/ch4/netmod/ucx/ucx_win.c @@ -83,7 +83,8 @@ static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void rkey_sizes = (MPI_Aint *) MPL_malloc(sizeof(MPI_Aint) * comm_ptr->local_size, MPL_MEM_OTHER); rkey_sizes[comm_ptr->rank] = (MPI_Aint) rkey_size; mpi_errno = - MPIR_Allgather(MPI_IN_PLACE, 1, MPI_AINT, rkey_sizes, 1, MPI_AINT, comm_ptr, MPIR_ERR_NONE); + MPIR_Allgather(MPI_IN_PLACE, 1, MPI_AINT, rkey_sizes, 1, MPI_AINT, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -100,7 +101,7 @@ static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void /* allgather */ mpi_errno = MPIR_Allgatherv(rkey_buffer, rkey_size, MPI_BYTE, rkey_recv_buff, rkey_sizes, recv_disps, MPI_BYTE, comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -141,7 +142,8 @@ static int win_allgather(MPIR_Win * win, size_t length, uint32_t disp_unit, void mpi_errno = MPIR_Allgather(MPI_IN_PLACE, sizeof(struct ucx_share), MPI_BYTE, share_data, - sizeof(struct ucx_share), MPI_BYTE, comm_ptr, MPIR_ERR_NONE); + sizeof(struct ucx_share), MPI_BYTE, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < comm_ptr->local_size; i++) { diff --git a/src/mpid/ch4/shm/ipc/src/ipc_win.c b/src/mpid/ch4/shm/ipc/src/ipc_win.c index 03b72f8c682..e921b3493b8 100644 --- a/src/mpid/ch4/shm/ipc/src/ipc_win.c +++ b/src/mpid/ch4/shm/ipc/src/ipc_win.c @@ -154,7 +154,8 @@ int MPIDI_IPC_mpi_win_create_hook(MPIR_Win * win) 0, MPI_DATATYPE_NULL, ipc_shared_table, - sizeof(win_shared_info_t), MPI_BYTE, shm_comm_ptr, MPIR_ERR_NONE); + sizeof(win_shared_info_t), MPI_BYTE, shm_comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/shm/posix/posix_coll.h b/src/mpid/ch4/shm/posix/posix_coll.h index 3a5ca87f0fc..219508f1c08 100644 --- a/src/mpid/ch4/shm/posix/posix_coll.h +++ b/src/mpid/ch4/shm/posix/posix_coll.h @@ -148,7 +148,8 @@ */ -MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { @@ -163,7 +164,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf case MPIR_CVAR_BARRIER_POSIX_INTRA_ALGORITHM_release_gather: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, !MPIR_IS_THREADED, mpi_errno, "Barrier release_gather cannot be applied.\n"); - mpi_errno = MPIDI_POSIX_mpi_barrier_release_gather(comm, errflag); + mpi_errno = MPIDI_POSIX_mpi_barrier_release_gather(comm, coll_group, errflag); break; case MPIR_CVAR_BARRIER_POSIX_INTRA_ALGORITHM_mpir: @@ -177,7 +178,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf switch (cnt->id) { case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_barrier_release_gather: mpi_errno = - MPIDI_POSIX_mpi_barrier_release_gather(comm, errflag); + MPIDI_POSIX_mpi_barrier_release_gather(comm, coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Barrier_impl: goto fallback; @@ -194,7 +195,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf goto fn_exit; fallback: - mpi_errno = MPIR_Barrier_impl(comm, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -206,7 +207,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, MPIR_Errf MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { @@ -226,12 +228,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, !MPIR_IS_THREADED, mpi_errno, "Bcast release_gather cannot be applied.\n"); mpi_errno = - MPIDI_POSIX_mpi_bcast_release_gather(buffer, count, datatype, root, comm, errflag); + MPIDI_POSIX_mpi_bcast_release_gather(buffer, count, datatype, root, comm, + coll_group, errflag); break; case MPIR_CVAR_BCAST_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = - MPIDI_POSIX_mpi_bcast_gpu_ipc_read(buffer, count, datatype, root, comm, errflag); + MPIDI_POSIX_mpi_bcast_gpu_ipc_read(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIR_CVAR_BCAST_POSIX_INTRA_ALGORITHM_mpir: @@ -257,12 +261,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_bcast_release_gather: mpi_errno = MPIDI_POSIX_mpi_bcast_release_gather(buffer, count, datatype, root, comm, - errflag); + coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_bcast_ipc_read: mpi_errno = MPIDI_POSIX_mpi_bcast_gpu_ipc_read(buffer, count, datatype, root, comm, - errflag); + coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Bcast_impl: goto fallback; @@ -279,7 +283,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, goto fn_exit; fallback: - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -291,7 +295,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -315,7 +319,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void "Allreduce release_gather cannot be applied.\n"); mpi_errno = MPIDI_POSIX_mpi_allreduce_release_gather(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_ALLREDUCE_POSIX_INTRA_ALGORITHM_mpir: @@ -330,7 +334,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_allreduce_release_gather: mpi_errno = MPIDI_POSIX_mpi_allreduce_release_gather(sendbuf, recvbuf, count, datatype, - op, comm, errflag); + op, comm, coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Allreduce_impl: @@ -349,7 +353,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void goto fn_exit; fallback: - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -362,7 +367,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce(const void *sendbuf, void MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -372,7 +378,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather(const void *sendbuf, MPI_ case MPIR_CVAR_ALLGATHER_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = MPIDI_POSIX_mpi_allgather_gpu_ipc_read(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_ALLGATHER_POSIX_INTRA_ALGORITHM_mpir: @@ -387,7 +393,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather(const void *sendbuf, MPI_ fallback: mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -402,7 +408,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -412,7 +418,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI case MPIR_CVAR_ALLGATHERV_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm, errflag); + recvtype, comm, coll_group, + errflag); break; case MPIR_CVAR_ALLGATHERV_POSIX_INTRA_ALGORITHM_mpir: @@ -427,7 +434,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI fallback: mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, errflag); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -440,7 +448,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv(const void *sendbuf, MPI MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -448,7 +456,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gather(const void *sendbuf, MPI_Ain MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -464,7 +472,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gatherv(const void *sendbuf, MPI_Ai MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -472,7 +480,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gatherv(const void *sendbuf, MPI_Ai MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); + displs, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -488,7 +496,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_gatherv(const void *sendbuf, MPI_Ai MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -496,7 +504,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatter(const void *sendbuf, MPI_Ai MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -513,7 +521,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -521,7 +529,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatterv(const void *sendbuf, MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, errflag); + recvbuf, recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -536,7 +544,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scatterv(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -546,7 +555,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall(const void *sendbuf, MPI_A case MPIR_CVAR_ALLTOALL_POSIX_INTRA_ALGORITHM_ipc_read: mpi_errno = MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_ALLTOALL_POSIX_INTRA_ALGORITHM_mpir: @@ -561,7 +570,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall(const void *sendbuf, MPI_A fallback: mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -578,7 +587,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno; @@ -586,7 +595,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallv(const void *sendbuf, mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); + rdispls, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -605,7 +614,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallw(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno; @@ -613,7 +623,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallw(const void *sendbuf, mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); + rdispls, recvtypes, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -628,7 +638,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_Csel_coll_sig_s coll_sig = { @@ -652,7 +662,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r "Reduce release_gather cannot be applied.\n"); mpi_errno = MPIDI_POSIX_mpi_reduce_release_gather(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; case MPIR_CVAR_REDUCE_POSIX_INTRA_ALGORITHM_mpir: @@ -667,7 +677,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIDI_POSIX_mpi_reduce_release_gather: mpi_errno = MPIDI_POSIX_mpi_reduce_release_gather(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; case MPIDI_POSIX_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_Reduce_impl: @@ -686,7 +696,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r goto fn_exit; fallback: - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -699,14 +710,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce(const void *sendbuf, void *r MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, errflag); + mpi_errno = + MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -721,14 +734,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, errflag); + MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -742,13 +756,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_scatter_block(const void *se MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -762,14 +777,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_scan(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -973,12 +988,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ineighbor_alltoallw(const void *sen return mpi_errno; } -MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibarrier(MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibarrier_impl(comm, req); + mpi_errno = MPIR_Ibarrier_impl(comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -986,12 +1002,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibarrier(MPIR_Comm * comm, MPIR_Req MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, req); + mpi_errno = MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1000,13 +1017,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ibcast(void *buffer, MPI_Aint count MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1017,13 +1035,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallgatherv(const void *sendbuf, MP const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, comm, req); + recvbuf, recvcounts, displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1032,13 +1050,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallgatherv(const void *sendbuf, MP MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1051,13 +1070,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, req); + sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1070,13 +1090,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoallw(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, - sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, req); + sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, + coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1084,13 +1106,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1099,14 +1121,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iexscan(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1117,13 +1139,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_igatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcounts, displs, recvtype, root, comm, req); + recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1133,14 +1157,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter_block(const void *s void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = - MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, req); + MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1149,12 +1174,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter_block(const void *s MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + mpi_errno = + MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1163,12 +1191,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce_scatter(const void *sendbuf MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, req); + mpi_errno = + MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1176,13 +1205,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_ireduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1190,12 +1219,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iallreduce(const void *sendbuf, voi MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int mpi_errno; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, req); + mpi_errno = MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return mpi_errno; @@ -1204,14 +1234,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscan(const void *sendbuf, void *re MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -1222,14 +1252,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { int mpi_errno; MPIR_FUNC_ENTER; mpi_errno = MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h b/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h index 466d575dff9..a28b0a40920 100644 --- a/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h +++ b/src/mpid/ch4/shm/posix/posix_coll_gpu_ipc.h @@ -64,7 +64,8 @@ #ifdef MPIDI_CH4_SHM_ENABLE_GPU static int allgather_ipc_handles(const void *buf, MPI_Aint count, MPI_Datatype datatype, - MPIR_Comm * comm, int threshold, MPI_Aint * data_sz_out, + MPIR_Comm * comm, int coll_group, + int threshold, MPI_Aint * data_sz_out, void **mem_addr_out, MPIDI_IPCI_ipc_handle_t ** ipc_handles_out) { int mpi_errno = MPI_SUCCESS; @@ -101,7 +102,7 @@ static int allgather_ipc_handles(const void *buf, MPI_Aint count, MPI_Datatype d /* allgather is needed to exchange all the IPC handles */ mpi_errno = MPIR_Allgather_impl(&my_ipc_handle, sizeof(MPIDI_IPCI_ipc_handle_t), MPI_BYTE, ipc_handles, sizeof(MPIDI_IPCI_ipc_handle_t), MPI_BYTE, - comm, MPIR_ERR_NONE); + comm, coll_group, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* check the ipc_handles to make sure all the buffers are on GPU */ @@ -131,6 +132,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_gpu_ipc_read(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -141,7 +143,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_gpu_ipc_read(void *buffer, MPI_Aint data_sz; void *mem_addr; MPIDI_IPCI_ipc_handle_t *ipc_handles = NULL; - mpi_errno = allgather_ipc_handles(buffer, count, datatype, comm_ptr, + mpi_errno = allgather_ipc_handles(buffer, count, datatype, comm_ptr, coll_group, MPIR_CVAR_BCAST_IPC_READ_MSG_SIZE_THRESHOLD, &data_sz, &mem_addr, &ipc_handles); MPIR_ERR_CHECK(mpi_errno); @@ -186,7 +188,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_gpu_ipc_read(void *buffer, goto fn_exit; fallback: /* Fall back to other algorithms as gpu ipc bcast cannot be used */ - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -198,6 +200,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -216,7 +219,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *s MPI_Aint data_sz; void *send_mem_addr; MPIDI_IPCI_ipc_handle_t *ipc_handles = NULL; - mpi_errno = allgather_ipc_handles(sendbuf, sendcount, sendtype, comm_ptr, + mpi_errno = allgather_ipc_handles(sendbuf, sendcount, sendtype, comm_ptr, coll_group, MPIR_CVAR_ALLTOALL_IPC_READ_MSG_SIZE_THRESHOLD, &data_sz, &send_mem_addr, &ipc_handles); MPIR_ERR_CHECK(mpi_errno); @@ -280,7 +283,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *s fallback: /* Fall back to other algorithms as gpu ipc alltoall cannot be used */ mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -292,6 +295,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void * MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -310,7 +314,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void * MPI_Aint data_sz; void *send_mem_addr; MPIDI_IPCI_ipc_handle_t *ipc_handles = NULL; - mpi_errno = allgather_ipc_handles(sendbuf, sendcount, sendtype, comm_ptr, + mpi_errno = allgather_ipc_handles(sendbuf, sendcount, sendtype, comm_ptr, coll_group, MPIR_CVAR_ALLGATHER_IPC_READ_MSG_SIZE_THRESHOLD, &data_sz, &send_mem_addr, &ipc_handles); MPIR_ERR_CHECK(mpi_errno); @@ -374,7 +378,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void * fallback: /* Fall back to other algorithms as gpu ipc allgather cannot be used */ mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -387,6 +391,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -406,7 +411,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void MPI_Aint data_sz; void *send_mem_addr; MPIDI_IPCI_ipc_handle_t *ipc_handles = NULL; - mpi_errno = allgather_ipc_handles(sendbuf, sendcount, sendtype, comm_ptr, + mpi_errno = allgather_ipc_handles(sendbuf, sendcount, sendtype, comm_ptr, coll_group, MPIR_CVAR_ALLGATHERV_IPC_READ_MSG_SIZE_THRESHOLD, &data_sz, &send_mem_addr, &ipc_handles); MPIR_ERR_CHECK(mpi_errno); @@ -473,7 +478,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void fallback: /* Fall back to other algorithms as gpu ipc allgatherv cannot be used */ mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm_ptr, errflag); + recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -483,9 +488,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_gpu_ipc_read(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *sendbuf, @@ -495,10 +501,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_alltoall_gpu_ipc_read(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void *sendbuf, @@ -508,10 +515,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgather_gpu_ipc_read(const void * MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void *sendbuf, @@ -522,10 +530,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allgatherv_gpu_ipc_read(const void const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); } #endif /* !MPIDI_CH4_SHM_ENABLE_GPU */ diff --git a/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h b/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h index b1cd8b83d8b..63e7e2ed786 100644 --- a/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h +++ b/src/mpid/ch4/shm/posix/posix_coll_nb_release_gather.h @@ -19,7 +19,7 @@ */ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_ibcast_release_gather(void *buffer, int count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched) { MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_POSIX_MPI_IBCAST_RELEASE_GATHER); @@ -69,6 +69,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_ireduce_release_gather(const void *send MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_TSP_sched_t sched) { MPIR_FUNC_ENTER; diff --git a/src/mpid/ch4/shm/posix/posix_coll_release_gather.h b/src/mpid/ch4/shm/posix/posix_coll_release_gather.h index 72e53cc89c4..0c8b8f7fd50 100644 --- a/src/mpid/ch4/shm/posix/posix_coll_release_gather.h +++ b/src/mpid/ch4/shm/posix/posix_coll_release_gather.h @@ -39,6 +39,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_release_gather(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { MPIR_FUNC_ENTER; @@ -152,7 +153,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_bcast_release_gather(void *buffer, goto fn_exit; fallback: /* Fall back to other algo as release_gather based bcast cannot be used */ - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -168,6 +169,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_release_gather(const void *s MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int i; @@ -251,7 +253,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_reduce_release_gather(const void *s goto fn_exit; fallback: /* Fall back to other algo as release_gather algo cannot be used */ - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -266,6 +269,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int i; @@ -346,7 +350,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void goto fn_exit; fallback: - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -355,6 +360,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_allreduce_release_gather(const void * framework. */ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier_release_gather(MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -395,7 +401,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier_release_gather(MPIR_Comm * goto fn_exit; fallback: - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } diff --git a/src/mpid/ch4/shm/posix/posix_init.c b/src/mpid/ch4/shm/posix/posix_init.c index ecb64c7d2e6..63e2f982a52 100644 --- a/src/mpid/ch4/shm/posix/posix_init.c +++ b/src/mpid/ch4/shm/posix/posix_init.c @@ -294,7 +294,8 @@ int MPIDI_POSIX_post_init(void) local_rank_topo = MPL_calloc(MPIR_Process.local_size, topo_info_size, MPL_MEM_SHM); mpi_errno = MPIR_Allgather_fallback(&MPIDI_POSIX_global.topo, topo_info_size, MPI_BYTE, local_rank_topo, topo_info_size, MPI_BYTE, - MPIR_Process.comm_world->node_comm, MPIR_ERR_NONE); + MPIR_Process.comm_world->node_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); for (int i = 0; i < MPIR_Process.local_size; i++) { if (local_rank_topo[i].l3_cache_id == -1 || local_rank_topo[i].numa_id == -1) { diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c b/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c index 9458a53d95f..d68f2b0d8bb 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c +++ b/src/mpid/ch4/shm/posix/release_gather/nb_release_gather.c @@ -121,17 +121,19 @@ int MPIDI_POSIX_nb_release_gather_comm_init(MPIR_Comm * comm_ptr, to other algorithms.\n"); } fallback = 1; - MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); } else { /* More shm can be created, update the shared counter */ MPL_atomic_fetch_add_uint64(MPIDI_POSIX_shm_limit_counter, memory_to_be_allocated); fallback = 0; - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (fallback) { MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); @@ -168,7 +170,7 @@ int MPIDI_POSIX_nb_release_gather_comm_init(MPIR_Comm * comm_ptr, topotree_fail[1] = -1; } mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, topotree_fail, 2, MPI_INT, - MPI_MAX, comm_ptr, errflag); + MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, errflag); } else { topotree_fail[0] = -1; topotree_fail[1] = -1; @@ -266,7 +268,7 @@ int MPIDI_POSIX_nb_release_gather_comm_init(MPIR_Comm * comm_ptr, if (initialize_ibcast_buf || initialize_ireduce_buf) { /* Make sure all the flags are set before ranks start reading each other's flags from shm */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch4/shm/posix/release_gather/release_gather.c b/src/mpid/ch4/shm/posix/release_gather/release_gather.c index 8bcc1301332..6bb1507c886 100644 --- a/src/mpid/ch4/shm/posix/release_gather/release_gather.c +++ b/src/mpid/ch4/shm/posix/release_gather/release_gather.c @@ -303,17 +303,19 @@ int MPIDI_POSIX_mpi_release_gather_comm_init(MPIR_Comm * comm_ptr, } fallback = 1; - MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); } else { /* More shm can be created, update the shared counter */ MPL_atomic_fetch_add_uint64(MPIDI_POSIX_shm_limit_counter, memory_to_be_allocated); fallback = 0; - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, errflag); + mpi_errno = MPIR_Bcast_impl(&fallback, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (fallback) { MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_NO_MEM, "**nomem"); @@ -359,7 +361,7 @@ int MPIDI_POSIX_mpi_release_gather_comm_init(MPIR_Comm * comm_ptr, topotree_fail[1] = -1; } mpi_errno = MPIR_Allreduce_impl(MPI_IN_PLACE, topotree_fail, 2, MPI_INT, - MPI_MAX, comm_ptr, errflag); + MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } else { topotree_fail[0] = -1; @@ -423,7 +425,7 @@ int MPIDI_POSIX_mpi_release_gather_comm_init(MPIR_Comm * comm_ptr, release_gather_info_ptr->release_state); /* Make sure all the flags are set before ranks start reading each other's flags from shm */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch4/shm/src/shm_am_fallback_coll.h b/src/mpid/ch4/shm/src/shm_am_fallback_coll.h index b9e4d53ef40..58c9de5524c 100644 --- a/src/mpid/ch4/shm/src/shm_am_fallback_coll.h +++ b/src/mpid/ch4/shm/src/shm_am_fallback_coll.h @@ -6,33 +6,37 @@ #ifndef SHM_AM_FALLBACK_COLL_H_INCLUDED #define SHM_AM_FALLBACK_COLL_H_INCLUDED -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Barrier_impl(comm_ptr, errflag); + return MPIR_Barrier_impl(comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, errflag); + return MPIR_Bcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_Aint sendcount, @@ -40,41 +44,41 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, errflag); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { return MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, @@ -82,19 +86,21 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm_ptr, errflag); + recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallv(const void *sendbuf, @@ -104,10 +110,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, @@ -117,51 +124,58 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { return MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, errflag); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + return MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { - return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag); + return MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { return MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, errflag); + datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { - return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + return MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_neighbor_allgather(const void *sendbuf, @@ -295,25 +309,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ineighbor_alltoallw(const void *sendb rdispls, recvtypes, comm_ptr, req); } -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm_ptr, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibarrier_impl(comm_ptr, req); + return MPIR_Ibarrier_impl(comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, req); + return MPIR_Ibcast_impl(buffer, count, datatype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_Aint sendcount, @@ -321,27 +338,28 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_ const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, req); + recvcounts, displs, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { - return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, request); + return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, req); + recvcount, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallv(const void *sendbuf, @@ -351,10 +369,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { return MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtype, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, @@ -364,82 +382,88 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { return MPIR_Ialltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, req); + recvbuf, recvcounts, rdispls, recvtypes, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igatherv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { return MPIR_Igatherv_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm_ptr, req); + recvcounts, displs, recvtype, root, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Request ** req) { return MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, - datatype, op, comm_ptr, req); + datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { - return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, req); + return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { - return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { - return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, req); + return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Request ** request) { return MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, request); + recvcount, recvtype, root, comm, coll_group, request); } MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatterv(const void *sendbuf, @@ -447,10 +471,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** request) { return MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, request); + recvbuf, recvcount, recvtype, root, comm, coll_group, request); } #endif /* SHM_AM_FALLBACK_COLL_H_INCLUDED */ diff --git a/src/mpid/ch4/shm/src/shm_coll.h b/src/mpid/ch4/shm/src/shm_coll.h index 737cc921de2..dbde9186e24 100644 --- a/src/mpid/ch4/shm/src/shm_coll.h +++ b/src/mpid/ch4/shm/src/shm_coll.h @@ -9,13 +9,14 @@ #include #include "../posix/shm_inline.h" -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_barrier(comm, errflag); + ret = MPIDI_POSIX_mpi_barrier(comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -23,13 +24,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_barrier(MPIR_Comm * comm, MPIR_Errfla MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_bcast(buffer, count, datatype, root, comm, errflag); + ret = MPIDI_POSIX_mpi_bcast(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -37,14 +38,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_bcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + ret = + MPIDI_POSIX_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -53,14 +55,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allreduce(const void *sendbuf, void * MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -71,14 +74,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_A const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm, errflag); + displs, recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -87,7 +90,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_allgatherv(const void *sendbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -95,7 +98,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatter(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -106,14 +109,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, errflag); + recvcount, recvtype, root, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -122,7 +126,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scatterv(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -130,7 +134,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gather(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -140,7 +144,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, + int root, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -148,7 +152,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); + displs, recvtype, root, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -157,14 +161,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_gatherv(const void *sendbuf, MPI_Aint MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -177,14 +182,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm, errflag); + recvcounts, rdispls, recvtype, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -197,14 +202,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_alltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm, errflag); + recvcounts, rdispls, recvtypes, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -212,14 +218,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_alltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, errflag); + ret = + MPIDI_POSIX_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + errflag); MPIR_FUNC_EXIT; return ret; @@ -228,7 +236,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -236,7 +244,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter(const void *sendbuf, v MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_reduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, - comm_ptr, errflag); + comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -246,6 +254,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *send void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int ret; @@ -253,7 +262,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *send MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_reduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, - op, comm_ptr, errflag); + op, comm_ptr, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -261,13 +270,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_reduce_scatter_block(const void *send MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm, errflag); + ret = MPIDI_POSIX_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -275,13 +285,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_scan(const void *sendbuf, void *recvb MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm, errflag); + ret = MPIDI_POSIX_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_FUNC_EXIT; return ret; @@ -486,13 +497,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ineighbor_alltoallw(const void *sendb return ret; } -MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ibarrier(comm, req); + ret = MPIDI_POSIX_mpi_ibarrier(comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -500,13 +512,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibarrier(MPIR_Comm * comm, MPIR_Reque MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ibcast(buffer, count, datatype, root, comm, req); + ret = MPIDI_POSIX_mpi_ibcast(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -515,14 +527,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ibcast(void *buffer, MPI_Aint count, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, req); + recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -533,14 +546,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_ const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm, req); + displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -548,14 +561,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallgatherv(const void *sendbuf, MPI_ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_POSIX_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -564,14 +577,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iallreduce(const void *sendbuf, void MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ialltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, req); + recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -584,14 +598,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallv(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ialltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm, req); + recvcounts, rdispls, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -604,14 +618,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype recvtypes[], - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm, req); + recvcounts, rdispls, recvtypes, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -619,13 +634,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ialltoallw(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_POSIX_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -634,14 +650,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iexscan(const void *sendbuf, void *re MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_igather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, req); + recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -651,14 +668,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igatherv(const void *sendbuf, MPI_Ain MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, req); + displs, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -667,7 +685,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_igatherv(const void *sendbuf, MPI_Ain MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; @@ -675,7 +693,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sen MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, - op, comm, req); + op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -684,13 +702,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter_block(const void *sen MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + ret = + MPIDI_POSIX_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, + coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -698,14 +719,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce_scatter(const void *sendbuf, MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm_ptr, + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, req); + ret = + MPIDI_POSIX_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + req); MPIR_FUNC_EXIT; return ret; @@ -713,13 +736,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_ireduce(const void *sendbuf, void *re MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_POSIX_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_POSIX_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -728,14 +752,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscan(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, req); + recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -746,14 +771,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iscatterv(const void *sendbuf, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm_ptr, MPIR_Request ** req) + MPIR_Comm * comm_ptr, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_POSIX_mpi_iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm_ptr, req); + recvcount, recvtype, root, comm_ptr, coll_group, req); MPIR_FUNC_EXIT; return ret; diff --git a/src/mpid/ch4/shm/src/topotree.c b/src/mpid/ch4/shm/src/topotree.c index 2f71547962c..5ccc6fd62be 100644 --- a/src/mpid/ch4/shm/src/topotree.c +++ b/src/mpid/ch4/shm/src/topotree.c @@ -500,7 +500,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in shared_region_ptr[rank][depth++] = MPIR_hwtopo_get_lid(gid); gid = MPIR_hwtopo_get_ancestor(gid, topo_depth - depth - 1); } - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* STEP 3. Root has all the bind_map information, now build tree */ @@ -558,7 +558,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in 0 /*left_skewed */ , bcast_tree_type); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* Every rank copies their tree out from shared memory */ @@ -567,7 +567,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in MPIDI_SHM_print_topotree_file("BCAST", comm_ptr->context_id, rank, bcast_tree); /* Wait until shared memory is available */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* Generate the reduce tree */ /* For Reduce, package leaders are added after the package local ranks, and the per_package @@ -581,7 +581,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* each rank copy the reduce tree out */ @@ -590,7 +590,7 @@ int MPIDI_SHM_topology_tree_init(MPIR_Comm * comm_ptr, int root, int bcast_k, in if (MPIDI_SHM_TOPOTREE_DEBUG) MPIDI_SHM_print_topotree_file("REDUCE", comm_ptr->context_id, rank, reduce_tree); /* Wait for all ranks to copy out the tree */ - mpi_errno = MPIR_Barrier_impl(comm_ptr, errflag); + mpi_errno = MPIR_Barrier_impl(comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* Cleanup */ if (rank == root) { diff --git a/src/mpid/ch4/src/ch4_coll.h b/src/mpid/ch4/src/ch4_coll.h index 5b586f76c09..9df2bc60a68 100644 --- a/src/mpid/ch4/src/ch4_coll.h +++ b/src/mpid/ch4/src/ch4_coll.h @@ -100,6 +100,7 @@ */ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -113,17 +114,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Barrier_impl(comm, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Barrier_intra_composition_alpha: - mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Barrier_intra_composition_beta: - mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -137,7 +138,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * goto fn_exit; } -MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errflag) +MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -151,16 +152,16 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errfl (comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, "Barrier composition alpha cannot be applied.\n"); - mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM, mpi_errno, "Barrier composition beta cannot be applied.\n"); - mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, coll_group, errflag); break; default: - mpi_errno = MPIDI_Barrier_allcomm_composition_json(comm, errflag); + mpi_errno = MPIDI_Barrier_allcomm_composition_json(comm, coll_group, errflag); break; } @@ -169,9 +170,9 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errfl fallback: if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Barrier_impl(comm, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); else - mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, errflag); + mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -182,7 +183,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, MPIR_Errflag_t errfl MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -210,7 +211,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, } } if (cnt == NULL) { - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -218,19 +219,23 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_alpha: mpi_errno = - MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_beta: mpi_errno = - MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_gamma: mpi_errno = - MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, coll_group, + errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Bcast_intra_composition_delta: mpi_errno = - MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -245,7 +250,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, } MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -259,7 +265,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, "Bcast composition alpha cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, coll_group, + errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -268,14 +275,16 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, "Bcast composition beta cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, coll_group, + errflag); break; case 3: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM, mpi_errno, "Bcast composition gamma cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, coll_group, + errflag); break; case 4: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -284,11 +293,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, "Bcast composition delta cannot be applied.\n"); mpi_errno = - MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, coll_group, + errflag); break; default: mpi_errno = - MPIDI_Bcast_allcomm_composition_json(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_allcomm_composition_json(buffer, count, datatype, root, comm, + coll_group, errflag); break; } @@ -297,10 +308,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty fallback: if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); else mpi_errno = - MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, errflag); + MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, coll_group, + errflag); fn_exit: MPIR_FUNC_EXIT; @@ -309,7 +321,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty goto fn_exit; } -MPL_STATIC_INLINE_PREFIX void MPIDI_Allreduce_fill_multi_leads_info(MPIR_Comm * comm) +MPL_STATIC_INLINE_PREFIX void MPIDI_Allreduce_fill_multi_leads_info(MPIR_Comm * comm, + int coll_group) { int node_comm_size = 0, num_nodes; bool node_balanced = false; @@ -344,6 +357,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -364,7 +378,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -373,21 +388,21 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_alpha: mpi_errno = MPIDI_Allreduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_beta: mpi_errno = MPIDI_Allreduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_gamma: mpi_errno = MPIDI_Allreduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allreduce_intra_composition_delta: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) { - MPIDI_Allreduce_fill_multi_leads_info(comm); + MPIDI_Allreduce_fill_multi_leads_info(comm, coll_group); if (comm->node_comm) node_comm_size = MPIR_Comm_size(comm->node_comm); /* Reset number of leaders, so that (node_comm_size % num_leads) is zero. The new number of @@ -402,10 +417,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void count >= num_leads && MPIR_Op_is_commutative(op)) { mpi_errno = MPIDI_Allreduce_intra_composition_delta(sendbuf, recvbuf, count, datatype, op, - num_leads, comm, errflag); + num_leads, comm, coll_group, errflag); } else mpi_errno = - MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -421,7 +437,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int is_commutative = -1; @@ -441,7 +457,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, "Allreduce composition alpha cannot be applied.\n"); mpi_errno = MPIDI_Allreduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -449,7 +465,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, "Allreduce composition beta cannot be applied.\n"); mpi_errno = MPIDI_Allreduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); break; case 3: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -460,11 +476,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, "Allreduce composition gamma cannot be applied.\n"); mpi_errno = MPIDI_Allreduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); break; case 4: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) { - MPIDI_Allreduce_fill_multi_leads_info(comm); + MPIDI_Allreduce_fill_multi_leads_info(comm, coll_group); if (comm->node_comm) node_comm_size = MPIR_Comm_size(comm->node_comm); /* Reset number of leaders, so that (node_comm_size % num_leads) is zero. The new number of @@ -485,13 +501,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, mpi_errno = MPIDI_Allreduce_intra_composition_delta(sendbuf, recvbuf, count, datatype, op, - num_leads, comm, errflag); + num_leads, comm, coll_group, errflag); break; default: mpi_errno = MPIDI_Allreduce_allcomm_composition_json(sendbuf, recvbuf, count, datatype, op, - comm, errflag); + comm, coll_group, errflag); break; } @@ -500,11 +516,12 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, fallback: if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); else mpi_errno = MPIDI_Allreduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, comm, - errflag); + coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -513,7 +530,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, goto fn_exit; } -MPL_STATIC_INLINE_PREFIX void MPIDI_Allgather_fill_multi_leads_info(MPIR_Comm * comm) +MPL_STATIC_INLINE_PREFIX void MPIDI_Allgather_fill_multi_leads_info(MPIR_Comm * comm, + int coll_group) { int node_comm_size = 0, num_nodes; bool node_balanced = false; @@ -548,6 +566,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -579,7 +598,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void if (cnt == NULL) { mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -588,7 +607,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allgather_intra_composition_alpha: /* make sure that the algo can be run */ if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Allgather_fill_multi_leads_info(comm); + MPIDI_Allgather_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, MPIDI_COMM_ALLGATHER(comm, use_multi_leads) == 1 && data_size <= MPIR_CVAR_ALLGATHER_SHM_PER_RANK, mpi_errno, @@ -596,7 +615,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void mpi_errno = MPIDI_Allgather_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm, errflag); + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Allgather_intra_composition_beta: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -604,7 +623,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void "Allgather composition beta cannot be applied.\n"); mpi_errno = MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, + coll_group, errflag); break; default: MPIR_Assert(0); @@ -617,11 +637,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); else mpi_errno = MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -632,7 +652,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Aint type_size, data_size; @@ -650,7 +671,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco switch (MPIR_CVAR_ALLGATHER_COMPOSITION) { case 1: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Allgather_fill_multi_leads_info(comm); + MPIDI_Allgather_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, MPIDI_COMM_ALLGATHER(comm, use_multi_leads) == 1 && @@ -664,7 +685,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco mpi_errno = MPIDI_Allgather_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm, errflag); + comm, coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -672,12 +693,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco "Allgather composition beta cannot be applied.\n"); mpi_errno = MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, + errflag); break; default: mpi_errno = MPIDI_Allgather_allcomm_composition_json(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); break; } @@ -688,11 +710,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); else mpi_errno = MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); fn_exit: MPIR_FUNC_ENTER; @@ -705,7 +727,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -730,7 +752,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc if (cnt == NULL) { mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -740,7 +762,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc mpi_errno = MPIDI_Allgatherv_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm, errflag); + recvtype, comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -758,7 +780,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -782,7 +804,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun if (cnt == NULL) { MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -791,7 +813,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Scatter_intra_composition_alpha: mpi_errno = MPIDI_Scatter_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -809,7 +832,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * sendcounts, const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -834,7 +858,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * if (cnt == NULL) { MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, - root, comm, errflag); + root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -844,7 +868,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * mpi_errno = MPIDI_Scatterv_intra_composition_alpha(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, - comm, errflag); + comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -862,7 +886,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -887,7 +911,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount if (cnt == NULL) { mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -896,7 +920,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Gather_intra_composition_alpha: mpi_errno = MPIDI_Gather_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -915,7 +940,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -940,7 +965,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun if (cnt == NULL) { mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); + displs, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -950,7 +975,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun mpi_errno = MPIDI_Gatherv_intra_composition_alpha(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, - comm, errflag); + comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -965,7 +990,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun goto fn_exit; } -MPL_STATIC_INLINE_PREFIX void MPIDI_Alltoall_fill_multi_leads_info(MPIR_Comm * comm) +MPL_STATIC_INLINE_PREFIX void MPIDI_Alltoall_fill_multi_leads_info(MPIR_Comm * comm, int coll_group) { int node_comm_size = 0, num_nodes; bool node_balanced = false; @@ -1000,6 +1025,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1031,7 +1057,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void if (cnt == NULL) { mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1039,14 +1065,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void switch (cnt->id) { case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Alltoall_intra_composition_alpha: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Alltoall_fill_multi_leads_info(comm); + MPIDI_Alltoall_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && MPIDI_COMM_ALLTOALL(comm, use_multi_leads) == 1 && data_size <= MPIR_CVAR_ALLTOALL_SHM_PER_RANK, mpi_errno, "Alltoall composition alpha cannot be applied.\n"); mpi_errno = MPIDI_Alltoall_intra_composition_alpha(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, + comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Alltoall_intra_composition_beta: @@ -1055,7 +1082,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void "Alltoall composition beta cannot be applied.\n"); mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, + coll_group, errflag); break; default: @@ -1069,10 +1097,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); else mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, + errflag); fn_exit: return mpi_errno; @@ -1082,7 +1111,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, MPIR_Comm * comm, + MPI_Datatype recvtype, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1101,7 +1130,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou switch (MPIR_CVAR_ALLTOALL_COMPOSITION) { case 1: if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) - MPIDI_Alltoall_fill_multi_leads_info(comm); + MPIDI_Alltoall_fill_multi_leads_info(comm, coll_group); MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && MPIDI_COMM_ALLTOALL(comm, use_multi_leads) == 1 && @@ -1114,7 +1143,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou mpi_errno = MPIDI_Alltoall_intra_composition_alpha(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, errflag); + recvbuf, recvcount, recvtype, comm, + coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -1122,12 +1152,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou "Alltoall composition beta cannot be applied.\n"); mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, + errflag); break; default: mpi_errno = MPIDI_Alltoall_allcomm_composition_json(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); break; } @@ -1138,10 +1169,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - errflag); + coll_group, errflag); else mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, errflag); + recvcount, recvtype, comm, coll_group, + errflag); fn_exit: MPIR_FUNC_ENTER; @@ -1154,7 +1186,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Errflag_t errflag) + MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -1180,7 +1213,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint if (cnt == NULL) { mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); + rdispls, recvtype, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1190,7 +1223,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint mpi_errno = MPIDI_Alltoallv_intra_composition_alpha(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, errflag); + rdispls, recvtype, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -1210,7 +1244,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint const MPI_Datatype sendtypes[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -1236,7 +1270,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint if (cnt == NULL) { mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); + rdispls, recvtypes, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1246,7 +1280,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint mpi_errno = MPIDI_Alltoallw_intra_composition_alpha(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, errflag); + rdispls, recvtypes, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -1265,6 +1300,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1285,7 +1321,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1294,17 +1332,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_intra_composition_alpha: mpi_errno = MPIDI_Reduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_intra_composition_beta: mpi_errno = MPIDI_Reduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_intra_composition_gamma: mpi_errno = MPIDI_Reduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -1320,7 +1358,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - int root, MPIR_Comm * comm, MPIR_Errflag_t errflag) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1335,7 +1374,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, "Reduce composition alpha cannot be applied.\n"); mpi_errno = MPIDI_Reduce_intra_composition_alpha(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM @@ -1345,7 +1384,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, "Reduce composition beta cannot be applied.\n"); mpi_errno = MPIDI_Reduce_intra_composition_beta(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; case 3: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -1353,12 +1392,12 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, "Reduce composition gamma cannot be applied.\n"); mpi_errno = MPIDI_Reduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, root, - comm, errflag); + comm, coll_group, errflag); break; default: mpi_errno = MPIDI_Reduce_allcomm_composition_json(sendbuf, recvbuf, count, datatype, op, - root, comm, errflag); + root, comm, coll_group, errflag); break; } @@ -1367,11 +1406,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, fallback: if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, + errflag); else mpi_errno = MPIDI_Reduce_intra_composition_gamma(sendbuf, recvbuf, count, datatype, op, root, comm, - errflag); + coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1382,7 +1423,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1405,7 +1446,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv if (cnt == NULL) { mpi_errno = - MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, errflag); + MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1414,7 +1456,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_scatter_intra_composition_alpha: mpi_errno = MPIDI_Reduce_scatter_intra_composition_alpha(sendbuf, recvbuf, recvcounts, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -1431,7 +1474,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1455,7 +1498,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void if (cnt == NULL) { mpi_errno = MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1464,7 +1507,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Reduce_scatter_block_intra_composition_alpha: mpi_errno = MPIDI_Reduce_scatter_block_intra_composition_alpha(sendbuf, recvbuf, recvcount, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, + errflag); break; default: MPIR_Assert(0); @@ -1481,7 +1525,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -1502,7 +1546,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1511,12 +1556,12 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Scan_intra_composition_alpha: mpi_errno = MPIDI_Scan_intra_composition_alpha(sendbuf, recvbuf, count, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, errflag); break; case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Scan_intra_composition_beta: mpi_errno = MPIDI_Scan_intra_composition_beta(sendbuf, recvbuf, count, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -1533,7 +1578,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Errflag_t errflag) + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; @@ -1554,7 +1599,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, errflag);; + mpi_errno = + MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag);; MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -1563,7 +1609,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI case MPIDI_CSEL_CONTAINER_TYPE__COMPOSITION__MPIDI_Exscan_intra_composition_alpha: mpi_errno = MPIDI_Exscan_intra_composition_alpha(sendbuf, recvbuf, count, - datatype, op, comm, errflag); + datatype, op, comm, coll_group, errflag); break; default: MPIR_Assert(0); @@ -1758,26 +1804,27 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ineighbor_alltoallw(const void *sendbuf, return ret; } -MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm * comm, MPIR_Request ** req) +MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ibarrier(comm, req); + ret = MPIDI_NM_mpi_ibarrier(comm, coll_group, req); MPIR_FUNC_EXIT; return ret; } MPL_STATIC_INLINE_PREFIX int MPID_Ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm, MPIR_Request ** req) + int root, MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ibcast(buffer, count, datatype, root, comm, req); + ret = MPIDI_NM_mpi_ibcast(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1786,14 +1833,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ibcast(void *buffer, MPI_Aint count, MPI_Datat MPL_STATIC_INLINE_PREFIX int MPID_Iallgather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_iallgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1803,14 +1850,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iallgatherv(const void *sendbuf, MPI_Aint send MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_iallgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm, req); + recvcounts, displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1818,13 +1865,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iallgatherv(const void *sendbuf, MPI_Aint send MPL_STATIC_INLINE_PREFIX int MPID_Iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_NM_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1833,14 +1880,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iallreduce(const void *sendbuf, void *recvbuf, MPL_STATIC_INLINE_PREFIX int MPID_Ialltoall(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_ialltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, req); + recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1850,14 +1897,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallv(const void *sendbuf, const MPI_Aint const MPI_Aint * sdispls, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, MPI_Datatype recvtype, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_ialltoallv(sendbuf, sendcounts, sdispls, sendtype, - recvbuf, recvcounts, rdispls, recvtype, comm, req); + recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1868,14 +1915,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint const MPI_Datatype * sendtypes, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * rdispls, const MPI_Datatype * recvtypes, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, - recvbuf, recvcounts, rdispls, recvtypes, comm, req); + recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1883,13 +1930,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint MPL_STATIC_INLINE_PREFIX int MPID_Iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_NM_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1898,14 +1945,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iexscan(const void *sendbuf, void *recvbuf, MP MPL_STATIC_INLINE_PREFIX int MPID_Igather(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_igather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1915,14 +1962,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Igatherv(const void *sendbuf, MPI_Aint sendcou MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_igatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, root, comm, req); + recvcounts, displs, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1930,14 +1977,16 @@ MPL_STATIC_INLINE_PREFIX int MPID_Igatherv(const void *sendbuf, MPI_Aint sendcou MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter_block(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, - MPI_Op op, MPIR_Comm * comm, + MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm, req); + ret = + MPIDI_NM_mpi_ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm, + coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1946,13 +1995,16 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter_block(const void *sendbuf, voi MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter(const void *sendbuf, void *recvbuf, const MPI_Aint * recvcounts, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, + MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, req); + ret = + MPIDI_NM_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + req); MPIR_FUNC_EXIT; return ret; @@ -1960,13 +2012,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter(const void *sendbuf, void *rec MPL_STATIC_INLINE_PREFIX int MPID_Ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, MPIR_Request ** req) + MPIR_Comm * comm, int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, req); + ret = MPIDI_NM_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1974,13 +2026,13 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ireduce(const void *sendbuf, void *recvbuf, MP MPL_STATIC_INLINE_PREFIX int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; - ret = MPIDI_NM_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, req); + ret = MPIDI_NM_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -1989,14 +2041,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_ MPL_STATIC_INLINE_PREFIX int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_iscatter(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, root, comm, req); + recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; @@ -2006,14 +2058,14 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iscatterv(const void *sendbuf, const MPI_Aint const MPI_Aint * displs, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, - MPIR_Request ** req) + int coll_group, MPIR_Request ** req) { int ret; MPIR_FUNC_ENTER; ret = MPIDI_NM_mpi_iscatterv(sendbuf, sendcounts, displs, sendtype, - recvbuf, recvcount, recvtype, root, comm, req); + recvbuf, recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; diff --git a/src/mpid/ch4/src/ch4_coll_impl.h b/src/mpid/ch4/src/ch4_coll_impl.h index fd3e193dddf..36dd4424bce 100644 --- a/src/mpid/ch4/src/ch4_coll_impl.h +++ b/src/mpid/ch4/src/ch4_coll_impl.h @@ -162,7 +162,7 @@ static void MPIDI_Coll_calculate_size_shift(MPI_Aint count, MPI_Datatype datatyp } } -MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * comm, +MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -170,17 +170,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * c /* do the intranode barrier on all nodes */ if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } /* do the barrier across roots of all nodes */ if (comm->node_roots_comm != NULL) { - mpi_errno = MPIDI_NM_mpi_barrier(comm->node_roots_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -190,10 +190,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * c if (comm->node_comm != NULL) { int i = 0; #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(&i, 1, MPI_BYTE, 0, comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -204,12 +204,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_alpha(MPIR_Comm * c goto fn_exit; } -MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_beta(MPIR_Comm * comm, +MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_beta(MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_barrier(comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -221,6 +221,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_intra_composition_beta(MPIR_Comm * co MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -281,17 +282,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M } if (comm->node_roots_comm != NULL) { - mpi_errno = - MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -315,6 +317,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_beta(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -340,27 +343,29 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_beta(void *buffer, MP #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, MPIR_Get_intranode_rank(comm, root), - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_intranode_rank(comm, root), - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } if (comm->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (comm->node_comm != NULL && MPIR_Get_intranode_rank(comm, root) <= 0) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -384,6 +389,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_beta(void *buffer, MP MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_gamma(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -405,7 +411,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_gamma(void *buffer, M } } - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, root, comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (host_buffer != NULL && comm->rank != root) { @@ -435,6 +441,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_gamma(void *buffer, M MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, MPI_Aint count, MPI_Datatype datatype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -499,7 +506,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M if (comm->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* Node leaders copy data to GPU */ @@ -517,10 +524,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M /* intra-node Bcast */ if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(buffer, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -536,6 +545,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_alpha(const void MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -566,24 +576,24 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_alpha(const void #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(recvbuf, NULL, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIDI_NM_mpi_reduce(recvbuf, NULL, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } else { #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -597,16 +607,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_alpha(const void if (comm->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_allreduce(MPI_IN_PLACE, recvbuf, count, datatype, op, - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (comm->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(recvbuf, count, datatype, 0, comm->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif } @@ -632,6 +644,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_beta(const void * MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -657,7 +670,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_beta(const void * recvbuf = host_recvbuf; } - mpi_errno = MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (host_recvbuf != NULL) { @@ -681,6 +695,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_gamma(const void MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -706,9 +721,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_gamma(const void recvbuf = host_recvbuf; } #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIDI_SHM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, errflag); + mpi_errno = + MPIDI_NM_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -741,6 +758,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void MPI_Op op, int num_leads, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -817,9 +835,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void /* Step 0: Barrier to make sure the shm_buffer can be reused after the previous call */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -830,12 +848,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void MPIDI_SHM_mpi_reduce((char *) sendbuf + offset * extent, (char *) shm_addr + my_leader_rank * shm_size_per_lead, chunk_count, datatype, op, 0, MPIDI_COMM(comm_ptr, sub_node_comm), - errflag); + MPIR_SUBGROUP_NONE, errflag); #else mpi_errno = MPIDI_NM_mpi_reduce((char *) sendbuf + offset * extent, (char *) shm_addr + my_leader_rank * shm_size_per_lead, chunk_count, - datatype, op, 0, MPIDI_COMM(comm_ptr, sub_node_comm), errflag); + datatype, op, 0, MPIDI_COMM(comm_ptr, sub_node_comm), + MPIR_SUBGROUP_NONE, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -843,9 +862,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void * buffers. */ if (MPIDI_COMM(comm_ptr, intra_node_leads_comm) != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), + MPIR_SUBGROUP_NONE, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), errflag); + mpi_errno = MPIDI_NM_mpi_barrier(MPIDI_COMM(comm_ptr, intra_node_leads_comm), + MPIR_SUBGROUP_NONE, errflag); #endif MPIR_ERR_CHECK(mpi_errno); } @@ -892,16 +913,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_intra_composition_delta(const void extent), per_leader_count, datatype, op, MPIDI_COMM(comm_ptr, inter_node_leads_comm), - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* Step 5: Barrier to make sure non-leaders wait for leaders to finish reducing the data * from other nodes */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -941,7 +962,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_alpha(const void *se void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -978,11 +999,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_alpha(const void *se #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(intra_sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_reduce(intra_sendbuf, recvbuf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); @@ -1002,7 +1023,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_alpha(const void *se } mpi_errno = MPIDI_NM_mpi_reduce(inter_sendbuf, recvbuf, count, datatype, op, 0, - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -1026,7 +1047,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1055,10 +1076,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, 0, comm->node_comm, - errflag); + coll_group, errflag); #else - mpi_errno = - MPIDI_NM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, 0, comm->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, 0, comm->node_comm, + coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } @@ -1072,7 +1093,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen mpi_errno = MPIDI_NM_mpi_reduce(buf, NULL, count, datatype, op, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* I am on root's node. I have not participated in the earlier reduce. */ if (comm->rank != root) { @@ -1081,7 +1102,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, tmp_buf, count, datatype, op, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -1092,7 +1113,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, MPIR_Get_internode_rank(comm, root), - comm->node_roots_comm, errflag); + comm->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* set sendbuf to MPI_IN_PLACE to make final intranode reduce easy. */ @@ -1107,11 +1128,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_beta(const void *sen #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, errflag); + op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, + coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, errflag); + op, MPIR_Get_intranode_rank(comm, root), comm->node_comm, + coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } @@ -1129,12 +1152,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_gamma(const void *se void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm, + MPIR_Comm * comm, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm, errflag); + mpi_errno = + MPIDI_NM_mpi_reduce(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1153,6 +1177,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * int recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1196,9 +1221,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * /* Barrier to make sure that the shm buffer can be reused after the previous call to Alltoall */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -1226,9 +1251,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * /* Barrier to make sure each rank has copied the data to the shm buf */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -1244,7 +1269,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_alpha(const void * my_node_comm_rank * num_nodes * node_comm_size * type_size * sendcount), node_comm_size * sendcount, sendtype, recvbuf, sendcount * node_comm_size, sendtype, - MPIDI_COMM(comm_ptr, multi_leads_comm), errflag); + MPIDI_COMM(comm_ptr, multi_leads_comm), + MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1260,6 +1286,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_beta(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1273,17 +1300,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_intra_composition_beta(const void *s #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIDI_NM_mpi_alltoall(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -1302,13 +1329,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoallv_intra_composition_alpha(const void const MPI_Aint * rdispls, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_alltoallv(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, rdispls, recvtype, comm_ptr, errflag); + sendtype, recvbuf, recvcounts, rdispls, recvtype, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1328,6 +1357,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoallw_intra_composition_alpha(const void const MPI_Datatype recvtypes[], MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1335,7 +1365,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoallw_intra_composition_alpha(const void mpi_errno = MPIDI_NM_mpi_alltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm_ptr, errflag); + rdispls, recvtypes, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1351,6 +1381,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void int recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1408,9 +1439,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void /* Barrier to make sure that the shm buffer can be reused after the previous call to Allgather */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif MPIR_ERR_CHECK(mpi_errno); @@ -1424,9 +1455,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void /* Barrier to make sure all the ranks in a node_comm copied data to shm buffer */ #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #else - mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_barrier(comm_ptr->node_comm, coll_group, errflag); #endif /* Perform inter-node allgather on the multi leader comms */ @@ -1434,7 +1465,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_alpha(const void MPIDI_NM_mpi_allgather((char *) MPIDI_COMM_ALLGATHER(comm_ptr, shm_addr), sendcount * node_comm_size, sendtype, recvbuf, recvcount * node_comm_size, recvtype, - MPIDI_COMM(comm_ptr, multi_leads_comm), errflag); + MPIDI_COMM(comm_ptr, multi_leads_comm), MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1450,6 +1481,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_beta(const void * MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1463,17 +1495,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_intra_composition_beta(const void * #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIDI_NM_mpi_allgather(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm_ptr, errflag); + recvcount, recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -1491,6 +1523,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgatherv_intra_composition_alpha(const void const MPI_Aint * displs, MPI_Datatype recvtype, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1504,17 +1537,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgatherv_intra_composition_alpha(const void #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); #else mpi_errno = MPIDI_NM_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); #endif /* MPIDI_CH4_DIRECT_NETMOD */ MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIDI_NM_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, - recvcounts, displs, recvtype, comm_ptr, errflag); + recvcounts, displs, recvtype, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -1530,13 +1563,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Gather_intra_composition_alpha(const void *se void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1553,13 +1587,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Gatherv_intra_composition_alpha(const void *s const MPI_Aint * displs, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, errflag); + displs, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1575,13 +1610,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scatter_intra_composition_alpha(const void *s MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, root, comm, errflag); + recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1598,13 +1634,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scatterv_intra_composition_alpha(const void * MPI_Aint recvcount, MPI_Datatype recvtype, int root, MPIR_Comm * comm, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = MPIDI_NM_mpi_scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, - recvcount, recvtype, root, comm, errflag); + recvcount, recvtype, root, comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1620,12 +1657,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_scatter_intra_composition_alpha(const MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; mpi_errno = - MPIDI_NM_mpi_reduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag); + MPIDI_NM_mpi_reduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1643,6 +1682,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_scatter_block_intra_composition_alpha( MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { @@ -1650,7 +1690,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_scatter_block_intra_composition_alpha( mpi_errno = MPIDI_NM_mpi_reduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, - op, comm_ptr, errflag); + op, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1665,6 +1705,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; @@ -1706,12 +1747,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send * one process, just copy the raw data. */ if (comm_ptr->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = - MPIDI_SHM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = - MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } else if (sendbuf != MPI_IN_PLACE) { @@ -1743,7 +1784,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (comm_ptr->node_roots_comm != NULL) { mpi_errno = MPIDI_NM_mpi_scan(localfulldata, prefulldata, count, datatype, - op, comm_ptr->node_roots_comm, errflag); + op, comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); if (MPIR_Get_internode_rank(comm_ptr, rank) != comm_ptr->node_roots_comm->local_size - 1) { @@ -1769,10 +1810,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (comm_ptr->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = MPIDI_NM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -1780,12 +1823,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (noneed == 0) { if (comm_ptr->node_comm != NULL) { #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = - MPIDI_SHM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_SHM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #else - mpi_errno = - MPIDI_NM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, errflag); + mpi_errno = MPIDI_NM_mpi_bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); #endif /* MPIDI_CH4_DIRECT_NETMOD */ } @@ -1806,12 +1849,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_beta(const void *sendb MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIDI_NM_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -1826,11 +1870,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Exscan_intra_composition_alpha(const void *se MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm_ptr, + int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; - mpi_errno = MPIDI_NM_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, errflag); + mpi_errno = + MPIDI_NM_mpi_exscan(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpid/ch4/src/ch4_comm.c b/src/mpid/ch4/src/ch4_comm.c index 1af06962f61..9de7a1df977 100644 --- a/src/mpid/ch4/src/ch4_comm.c +++ b/src/mpid/ch4/src/ch4_comm.c @@ -622,23 +622,28 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i map_info[2] = *is_low_group; map_info[3] = pure_intracomm; mpi_errno = - MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, MPIR_ERR_NONE); + MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (!pure_intracomm) { mpi_errno = MPIR_Bcast_allcomm_auto(remote_upid_size, *remote_size, MPI_INT, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast_allcomm_auto(remote_upids, upid_recv_size, MPI_BYTE, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Bcast_allcomm_auto(*remote_gpids, *remote_size, MPI_UINT64_T, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } } else { mpi_errno = - MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, MPIR_ERR_NONE); + MPIR_Bcast_allcomm_auto(map_info, 4, MPI_INT, local_leader, local_comm, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); *remote_size = map_info[0]; upid_recv_size = map_info[1]; @@ -651,18 +656,21 @@ int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, i MPIR_CHKLMEM_MALLOC(_remote_upid_size, int *, (*remote_size) * sizeof(int), mpi_errno, "_remote_upid_size", MPL_MEM_COMM); mpi_errno = MPIR_Bcast_allcomm_auto(_remote_upid_size, *remote_size, MPI_INT, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPIR_CHKLMEM_MALLOC(_remote_upids, char *, upid_recv_size * sizeof(char), mpi_errno, "_remote_upids", MPL_MEM_COMM); mpi_errno = MPIR_Bcast_allcomm_auto(_remote_upids, upid_recv_size, MPI_BYTE, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); MPIDIU_upids_to_gpids(*remote_size, _remote_upid_size, _remote_upids, *remote_gpids); } else { mpi_errno = MPIR_Bcast_allcomm_auto(*remote_gpids, *remote_size, MPI_UINT64_T, - local_leader, local_comm, MPIR_ERR_NONE); + local_leader, local_comm, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); } } diff --git a/src/mpid/ch4/src/ch4_init.c b/src/mpid/ch4/src/ch4_init.c index 365a12b37ad..45fb824be43 100644 --- a/src/mpid/ch4/src/ch4_init.c +++ b/src/mpid/ch4/src/ch4_init.c @@ -704,7 +704,7 @@ int MPIDI_world_post_init(void) mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT, MPIDI_global.all_num_vcis, 1, MPI_INT, - MPIR_Process.comm_world, MPIR_ERR_NONE); + MPIR_Process.comm_world, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); #endif diff --git a/src/mpid/ch4/src/ch4_persist.c b/src/mpid/ch4/src/ch4_persist.c index 20ffa647d43..457dd6827cb 100644 --- a/src/mpid/ch4/src/ch4_persist.c +++ b/src/mpid/ch4/src/ch4_persist.c @@ -168,12 +168,15 @@ int MPID_Recv_init(void *buf, } int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + int root, MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, info_ptr, request); + mpi_errno = + MPIR_Bcast_init_impl(buffer, count, datatype, root, comm_ptr, coll_group, info_ptr, + request); MPIR_FUNC_EXIT; return mpi_errno; @@ -181,14 +184,16 @@ int MPID_Bcast_init(void *buffer, MPI_Aint count, MPI_Datatype datatype, int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, info_ptr, - request); + mpi_errno = + MPIR_Allreduce_init_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -196,14 +201,16 @@ int MPID_Allreduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, int root, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, - info_ptr, request); + mpi_errno = + MPIR_Reduce_init_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -212,14 +219,15 @@ int MPID_Reduce_init(const void *sendbuf, void *recvbuf, MPI_Aint count, int MPID_Alltoall_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoall_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -229,14 +237,15 @@ int MPID_Alltoallv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint sdispls[], MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallv_init_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, - recvcounts, rdispls, recvtype, comm_ptr, info_ptr, - request); + recvcounts, rdispls, recvtype, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -248,14 +257,15 @@ int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint rdispls[], const MPI_Datatype recvtypes[], - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Alltoallw_init_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, - recvcounts, rdispls, recvtypes, comm_ptr, info_ptr, - request); + recvcounts, rdispls, recvtypes, comm_ptr, coll_group, + info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -264,13 +274,14 @@ int MPID_Alltoallw_init(const void *sendbuf, const MPI_Aint sendcounts[], int MPID_Allgather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - comm_ptr, info_ptr, request); + comm_ptr, coll_group, info_ptr, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -280,13 +291,15 @@ int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs, MPI_Datatype recvtype, - MPIR_Comm * comm_ptr, MPIR_Info * info_ptr, MPIR_Request ** request) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Info * info_ptr, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Allgatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, comm_ptr, info_ptr, request); + displs, recvtype, comm_ptr, coll_group, info_ptr, + request); return mpi_errno; } @@ -294,13 +307,15 @@ int MPID_Allgatherv_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, MPI_Aint recvcount, MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, - info, request); + mpi_errno = + MPIR_Reduce_scatter_block_init_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, + coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -309,26 +324,29 @@ int MPID_Reduce_scatter_block_init(const void *sendbuf, void *recvbuf, int MPID_Reduce_scatter_init(const void *sendbuf, void *recvbuf, const MPI_Aint recvcounts[], MPI_Datatype datatype, MPI_Op op, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, + MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, - info, request); + mpi_errno = + MPIR_Reduce_scatter_init_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, + info, request); MPIR_FUNC_EXIT; return mpi_errno; } int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = + MPIR_Scan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -337,13 +355,13 @@ int MPID_Scan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gather_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -352,14 +370,14 @@ int MPID_Gather_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, const MPI_Aint recvcounts[], const MPI_Aint displs[], - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Gatherv_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -368,13 +386,13 @@ int MPID_Gatherv_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int root, - MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) + MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatter_init_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, - root, comm, info, request); + root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; @@ -383,38 +401,40 @@ int MPID_Scatter_init(const void *sendbuf, MPI_Aint sendcount, int MPID_Scatterv_init(const void *sendbuf, const MPI_Aint sendcounts[], const MPI_Aint displs[], MPI_Datatype sendtype, void *recvbuf, MPI_Aint recvcount, - MPI_Datatype recvtype, int root, MPIR_Comm * comm, + MPI_Datatype recvtype, int root, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; mpi_errno = MPIR_Scatterv_init_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, - recvtype, root, comm, info, request); + recvtype, root, comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; } -int MPID_Barrier_init(MPIR_Comm * comm, MPIR_Info * info, MPIR_Request ** request) +int MPID_Barrier_init(MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Barrier_init_impl(comm, info, request); + mpi_errno = MPIR_Barrier_init_impl(comm, coll_group, info, request); MPIR_FUNC_EXIT; return mpi_errno; } int MPID_Exscan_init(const void *sendbuf, void *recvbuf, MPI_Aint count, - MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, + MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm, int coll_group, MPIR_Info * info, MPIR_Request ** request) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - mpi_errno = MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, info, request); + mpi_errno = + MPIR_Exscan_init_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, info, + request); MPIR_FUNC_EXIT; return mpi_errno; diff --git a/src/mpid/ch4/src/ch4_spawn.c b/src/mpid/ch4/src/ch4_spawn.c index 7928fe6feec..dade6274d62 100644 --- a/src/mpid/ch4/src/ch4_spawn.c +++ b/src/mpid/ch4/src/ch4_spawn.c @@ -76,7 +76,8 @@ int MPID_Comm_spawn_multiple(int count, char *commands[], char **argvs[], const bcast_ints[0] = total_num_processes; bcast_ints[1] = spawn_error; } - mpi_errno = MPIR_Bcast(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (comm_ptr->rank != root) { total_num_processes = bcast_ints[0]; @@ -90,7 +91,8 @@ int MPID_Comm_spawn_multiple(int count, char *commands[], char **argvs[], const int should_accept = 1; if (errcodes != MPI_ERRCODES_IGNORE) { mpi_errno = - MPIR_Bcast(pmi_errcodes, total_num_processes, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + MPIR_Bcast(pmi_errcodes, total_num_processes, MPI_INT, root, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); for (int i = 0; i < total_num_processes; i++) { @@ -391,12 +393,16 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int bcast_tag_and_errno: bcast_ints[0] = tag; bcast_ints[1] = mpi_errno; - mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = bcast_ints[1]; MPIR_ERR_CHECK(mpi_errno); } else { - mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (bcast_ints[1]) { /* errno from root cannot be directly returned */ diff --git a/src/mpid/ch4/src/mpidig_win.c b/src/mpid/ch4/src/mpidig_win.c index 1438c198de0..2b3a5c0f32b 100644 --- a/src/mpid/ch4/src/mpidig_win.c +++ b/src/mpid/ch4/src/mpidig_win.c @@ -393,7 +393,7 @@ static int win_init(MPI_Aint length, int disp_unit, MPIR_Win ** win_ptr, MPIR_In no_local = true; mpi_errno = MPIR_Allreduce(&no_local, &all_no_local, 1, MPI_C_BOOL, - MPI_LAND, comm_ptr, MPIR_ERR_NONE); + MPI_LAND, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (all_no_local) MPIDI_WIN(win, winattr) |= MPIDI_WINATTR_ACCU_NO_SHM; @@ -524,7 +524,7 @@ static int win_shm_alloc_impl(MPI_Aint size, int disp_unit, MPIR_Comm * comm_ptr MPI_DATATYPE_NULL, shared_table, sizeof(MPIDIG_win_shared_info_t), MPI_BYTE, shm_comm_ptr, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_T_PVAR_TIMER_END(RMA, rma_wincreate_allgather); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -561,7 +561,7 @@ static int win_shm_alloc_impl(MPI_Aint size, int disp_unit, MPIR_Comm * comm_ptr * - user sets alloc_shared_noncontig=true, thus we can internally make * the size aligned on each process. */ mpi_errno = MPIR_Allreduce(&symheap_flag, &global_symheap_flag, 1, MPI_C_BOOL, - MPI_LAND, comm_ptr, MPIR_ERR_NONE); + MPI_LAND, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else global_symheap_flag = false; @@ -692,7 +692,7 @@ int MPIDIG_mpi_win_set_info(MPIR_Win * win, MPIR_Info * info) /* Do not update winattr except for info set at window creation. * Because it will change RMA's behavior which requires collective synchronization. */ - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: MPIR_FUNC_EXIT; return mpi_errno; @@ -857,7 +857,7 @@ int MPIDIG_mpi_win_free(MPIR_Win ** win_ptr) MPIDIG_ACCESS_EPOCH_CHECK_NONE(win, mpi_errno, return mpi_errno); MPIDIG_EXPOSURE_EPOCH_CHECK_NONE(win, mpi_errno, return mpi_errno); - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -894,7 +894,7 @@ int MPIDIG_mpi_win_create(void *base, MPI_Aint length, int disp_unit, MPIR_Info MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -962,7 +962,7 @@ int MPIDIG_mpi_win_allocate_shared(MPI_Aint size, int disp_unit, MPIR_Info * inf MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: MPIR_FUNC_EXIT; @@ -1026,7 +1026,7 @@ int MPIDIG_mpi_win_allocate(MPI_Aint size, int disp_unit, MPIR_Info * info, MPIR MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) goto fn_fail; @@ -1065,7 +1065,7 @@ int MPIDIG_mpi_win_create_dynamic(MPIR_Info * info, MPIR_Comm * comm, MPIR_Win * MPIR_ERR_CHECK(mpi_errno); #endif - mpi_errno = MPIR_Barrier(comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: MPIR_FUNC_EXIT; diff --git a/src/mpid/ch4/src/mpidig_win.h b/src/mpid/ch4/src/mpidig_win.h index 6353bf7def3..4e5936a5262 100644 --- a/src/mpid/ch4/src/mpidig_win.h +++ b/src/mpid/ch4/src/mpidig_win.h @@ -522,7 +522,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_fence(int massert, MPIR_Win * win) * the VCI lock internally. */ MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); need_unlock = 0; - mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); fn_exit: if (need_unlock) { diff --git a/src/mpid/common/bc/mpidu_bc.c b/src/mpid/common/bc/mpidu_bc.c index f61a6485cef..9c2223d2ed7 100644 --- a/src/mpid/common/bc/mpidu_bc.c +++ b/src/mpid/common/bc/mpidu_bc.c @@ -108,7 +108,7 @@ int MPIDU_bc_allgather(MPIR_Comm * allgather_comm, void *bc, int bc_len, int sam if (rank == node_root) { mpi_errno = MPIR_Allgatherv_fallback(segment, local_size * recv_bc_len, MPI_BYTE, recv_buf, recv_cnts, recv_offs, MPI_BYTE, allgather_comm, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/common/shm/mpidu_shm_alloc.c b/src/mpid/common/shm/mpidu_shm_alloc.c index 72fdaa0fa55..7720a1134cd 100644 --- a/src/mpid/common/shm/mpidu_shm_alloc.c +++ b/src/mpid/common/shm/mpidu_shm_alloc.c @@ -227,7 +227,7 @@ static int allreduce_maxloc(size_t mysz, int myloc, MPIR_Comm * comm, size_t * m mpi_errno = MPIR_Allreduce(&maxloc, &maxloc_result, 1, maxloc_type, maxloc_op->handle, comm, - MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); *maxsz_loc = maxloc_result.loc; @@ -282,21 +282,25 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int root_sync: /* broadcast the mapping result on rank 0 */ - mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*map_result_ptr != SYMSHM_SUCCESS) goto map_fail; mpi_errno = MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, - shm_comm_ptr, MPIR_ERR_NONE); + shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } else { char serialized_hnd[MPL_SHM_GHND_SZ] = { 0 }; /* receive the mapping result of rank 0 */ - mpi_errno = MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(map_result_ptr, 1, MPI_INT, 0, shm_comm_ptr, MPIR_SUBGROUP_NONE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*map_result_ptr != SYMSHM_SUCCESS) @@ -306,7 +310,7 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int /* get serialized handle from rank 0 and deserialize it */ mpi_errno = MPIR_Bcast(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, - shm_comm_ptr, MPIR_ERR_NONE); + shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpl_err = @@ -331,7 +335,7 @@ static int map_symm_shm(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg, int * return SYMSHM_OTHER_FAIL if anyone reports it (max result == 2). * Otherwise return SYMSHM_MAP_FAIL (max result == 1). */ mpi_errno = MPIR_Allreduce(map_result_ptr, &all_map_result, 1, MPI_INT, - MPI_MAX, shm_comm_ptr, MPIR_ERR_NONE); + MPI_MAX, shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (all_map_result != SYMSHM_SUCCESS) @@ -423,8 +427,9 @@ static int shm_alloc_symm_all(MPIR_Comm * comm_ptr, size_t offset, MPIDU_shm_seg map_pointer = generate_random_addr(shm_seg->segment_len); /* broadcast fixed address to the other processes in comm */ - mpi_errno = MPIR_Bcast(&map_pointer, sizeof(char *), MPI_CHAR, maxsz_loc, comm_ptr, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast(&map_pointer, sizeof(char *), MPI_CHAR, maxsz_loc, comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* optimization: make sure every process memory in the shared segment is mapped @@ -442,7 +447,7 @@ static int shm_alloc_symm_all(MPIR_Comm * comm_ptr, size_t offset, MPIDU_shm_seg /* check if any mapping failure occurs */ mpi_errno = MPIR_Allreduce(&map_result, &all_map_result, 1, MPI_INT, - MPI_MAX, comm_ptr, MPIR_ERR_NONE); + MPI_MAX, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* cleanup local shm segment if mapping failed on other process */ @@ -492,8 +497,9 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) if (shm_fail_flag) serialized_hnd = &mpl_err_hnd[0]; - mpi_errno = MPIR_Bcast_impl(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, shm_comm_ptr, - MPIR_ERR_NONE); + mpi_errno = + MPIR_Bcast_impl(serialized_hnd, MPL_SHM_GHND_SZ, MPI_BYTE, 0, shm_comm_ptr, + MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (shm_fail_flag) @@ -501,7 +507,7 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) /* ensure all other processes have mapped successfully */ mpi_errno = MPIR_Allreduce_impl(&shm_fail_flag, &any_shm_fail_flag, 1, MPI_C_BOOL, - MPI_LOR, shm_comm_ptr, MPIR_ERR_NONE); + MPI_LOR, shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* unlink shared memory region so it gets deleted when all processes exit */ @@ -516,7 +522,7 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) /* get serialized handle from rank 0 and deserialize it */ mpi_errno = MPIR_Bcast_impl(serialized_hnd, MPL_SHM_GHND_SZ, MPI_CHAR, 0, - shm_comm_ptr, MPIR_ERR_NONE); + shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* empty handler means root fails */ @@ -539,7 +545,7 @@ static int shm_alloc(MPIR_Comm * shm_comm_ptr, MPIDU_shm_seg_t * shm_seg) result_sync: mpi_errno = MPIR_Allreduce_impl(&shm_fail_flag, &any_shm_fail_flag, 1, MPI_C_BOOL, - MPI_LOR, shm_comm_ptr, MPIR_ERR_NONE); + MPI_LOR, shm_comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (any_shm_fail_flag) diff --git a/src/util/mpir_nodemap.c b/src/util/mpir_nodemap.c index a3d5fea409a..cf8eaca7efa 100644 --- a/src/util/mpir_nodemap.c +++ b/src/util/mpir_nodemap.c @@ -453,14 +453,14 @@ int MPIR_nodeid_init(void) mpi_errno = MPIR_Allgather_impl(MPI_IN_PLACE, MAX_HOSTNAME_LEN, MPI_CHAR, allhostnames, MAX_HOSTNAME_LEN, MPI_CHAR, - node_roots_comm, MPIR_ERR_NONE); + node_roots_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } MPIR_Comm *node_comm = MPIR_Process.comm_world->node_comm; if (node_comm) { mpi_errno = MPIR_Bcast_impl(allhostnames, MAX_HOSTNAME_LEN * MPIR_Process.num_nodes, - MPI_CHAR, 0, node_comm, MPIR_ERR_NONE); + MPI_CHAR, 0, node_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } From 1c40a122de35068252949cf91db3a31d87353dcd Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 16 Aug 2024 16:38:40 -0500 Subject: [PATCH 09/27] coll: add coll_group argument to MPIC/sched/TSP routines --- src/include/mpir_coll.h | 21 +++--- src/include/mpir_nbc.h | 10 +-- src/mpi/coll/helper_fns.c | 67 ++++++++++++------- .../coll/transports/gentran/gentran_types.h | 4 ++ .../coll/transports/gentran/gentran_utils.c | 12 ++-- src/mpi/coll/transports/gentran/tsp_gentran.c | 18 +++-- src/mpi/coll/transports/tsp_impl.h | 12 ++-- src/mpid/common/sched/mpidu_sched.c | 25 ++++--- src/mpid/common/sched/mpidu_sched.h | 6 +- 9 files changed, 107 insertions(+), 68 deletions(-) diff --git a/src/include/mpir_coll.h b/src/include/mpir_coll.h index 05eb02c213e..8e8cd0111cc 100644 --- a/src/include/mpir_coll.h +++ b/src/include/mpir_coll.h @@ -74,21 +74,22 @@ int MPIC_Wait(MPIR_Request * request_ptr); int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status); int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag); + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag); int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status); + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status); int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag); -int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, - int dest, int sendtag, - int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag); + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, + MPIR_Errflag_t errflag); +int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int sendtag, + int source, int recvtag, MPIR_Comm * comm_ptr, int coll_group, + MPI_Status * status, MPIR_Errflag_t errflag); int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag); -int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, - int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request); + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request, + MPIR_Errflag_t errflag); +int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request); int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status * statuses); int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype, diff --git a/src/include/mpir_nbc.h b/src/include/mpir_nbc.h index eb08995fe04..9320bf5b633 100644 --- a/src/include/mpir_nbc.h +++ b/src/include/mpir_nbc.h @@ -69,13 +69,13 @@ int MPIR_Sched_start(MPIR_Sched_t s, MPIR_Comm * comm, MPIR_Request ** req); /* send and recv take a comm ptr to enable hierarchical collectives */ int MPIR_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s); int MPIR_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, MPIR_Comm * comm, - MPIR_Sched_t s); + int coll_group, MPIR_Sched_t s); /* just like MPI_Issend, can't complete until the matching recv is posted */ int MPIR_Sched_ssend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s); int MPIR_Sched_reduce(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype, MPI_Op op, MPIR_Sched_t s); @@ -104,12 +104,12 @@ int MPIR_Sched_barrier(MPIR_Sched_t s); * is no known use case. The recv count is just an upper bound, not an exact * amount to be received, so an oversized recv is used instead of deferral. */ int MPIR_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s); /* Just like MPIR_Sched_recv except it populates the given status object with * the received count and error information, much like a normal recv. Often * useful in conjunction with MPIR_Sched_send_defer. */ int MPIR_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, - MPIR_Comm * comm, MPI_Status * status, MPIR_Sched_t s); + MPIR_Comm * comm, int coll_group, MPI_Status * status, MPIR_Sched_t s); /* buffer management, fancy reductions, etc */ int MPIR_Sched_cb(MPIR_Sched_cb_t * cb_p, void *cb_state, MPIR_Sched_t s); diff --git a/src/mpi/coll/helper_fns.c b/src/mpi/coll/helper_fns.c index 66fe0632f52..dd9adb6d444 100644 --- a/src/mpi/coll/helper_fns.c +++ b/src/mpi/coll/helper_fns.c @@ -14,39 +14,53 @@ sends/receives by setting the context offset MPIR_CONTEXT_COLL_OFFSET. */ +static int get_coll_group_rank(MPIR_Comm * comm, int coll_group, int group_rank) +{ + if (coll_group > 0) { + return comm->subgroups[coll_group].proc_table[group_rank]; + } else { + return group_rank; + } +} + #ifdef ENABLE_THREADCOMM -#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, req) \ +#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, req) \ do { \ + int rank = get_coll_group_rank(comm_ptr, coll_group, dest); \ if (comm_ptr->threadcomm) { \ - mpi_errno = MPIR_Threadcomm_isend_attr(buf, count, datatype, dest, tag, \ + mpi_errno = MPIR_Threadcomm_isend_attr(buf, count, datatype, rank, tag, \ comm_ptr->threadcomm, attr, req); \ } else { \ - mpi_errno = MPID_Isend(buf, count, datatype, dest, tag, comm_ptr, attr, req); \ + mpi_errno = MPID_Isend(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } \ } while (0) -#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, req) \ +#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, req) \ do { \ + int rank = get_coll_group_rank(comm_ptr, coll_group, source); \ if (comm_ptr->threadcomm) { \ - mpi_errno = MPIR_Threadcomm_irecv_attr(buf, count, datatype, source, tag, \ + mpi_errno = MPIR_Threadcomm_irecv_attr(buf, count, datatype, rank, tag, \ comm_ptr->threadcomm, attr, req, true); \ } else { \ - mpi_errno = MPID_Irecv(buf, count, datatype, source, tag, comm_ptr, attr, req); \ + mpi_errno = MPID_Irecv(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } \ } while (0) #else -#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, req) \ +#define DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, req) \ do { \ - mpi_errno = MPID_Isend(buf, count, datatype, dest, tag, comm_ptr, attr, req); \ + int rank = get_coll_group_rank(comm_ptr, coll_group, dest); \ + mpi_errno = MPID_Isend(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } while (0) -#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, req) \ +#define DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, req) \ do { \ - mpi_errno = MPID_Irecv(buf, count, datatype, source, tag, comm_ptr, attr, req); \ + int rank = get_coll_group_rank(comm_ptr, coll_group, source); \ + mpi_errno = MPID_Irecv(buf, count, datatype, rank, tag, comm_ptr, attr, req); \ } while (0) #endif +/* NOTE: MPIC_Probe is never used group collectives */ int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status) { int mpi_errno = MPI_SUCCESS; @@ -127,7 +141,7 @@ int MPIC_Wait(MPIR_Request * request_ptr) this is OK since there is no data that can be received corrupted. */ int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -146,7 +160,7 @@ int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, MPIR_CONTEXT_COLL_OFFSET); MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, &request_ptr); + DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, &request_ptr); MPIR_ERR_CHECK(mpi_errno); if (request_ptr) { mpi_errno = MPIC_Wait(request_ptr); @@ -168,7 +182,7 @@ int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, } int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status) + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -191,7 +205,7 @@ int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int if (status == MPI_STATUS_IGNORE) status = &mystatus; - DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, &request_ptr); + DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, &request_ptr); MPIR_ERR_CHECK(mpi_errno); if (request_ptr) { mpi_errno = MPIC_Wait(request_ptr); @@ -218,7 +232,7 @@ int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype, int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -244,7 +258,8 @@ int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype "**nomemreq"); MPIR_Status_set_procnull(&recv_req_ptr->status); } else { - DO_MPID_IRECV(recvbuf, recvcount, recvtype, source, recvtag, comm_ptr, attr, &recv_req_ptr); + DO_MPID_IRECV(recvbuf, recvcount, recvtype, source, recvtag, comm_ptr, coll_group, attr, + &recv_req_ptr); MPIR_ERR_CHECK(mpi_errno); } @@ -255,7 +270,8 @@ int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype "**nomemreq"); } else { MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(sendbuf, sendcount, sendtype, dest, sendtag, comm_ptr, attr, &send_req_ptr); + DO_MPID_ISEND(sendbuf, sendcount, sendtype, dest, sendtag, comm_ptr, coll_group, attr, + &send_req_ptr); MPIR_ERR_CHECK(mpi_errno); } @@ -297,7 +313,8 @@ int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int sendtag, int source, int recvtag, - MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; MPI_Status mystatus; @@ -336,7 +353,7 @@ int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, MPIR_ERR_CHKANDSTMT(rreq == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq"); MPIR_Status_set_procnull(&rreq->status); } else { - DO_MPID_IRECV(buf, count, datatype, source, recvtag, comm_ptr, attr, &rreq); + DO_MPID_IRECV(buf, count, datatype, source, recvtag, comm_ptr, coll_group, attr, &rreq); MPIR_ERR_CHECK(mpi_errno); } @@ -346,7 +363,8 @@ int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, MPIR_ERR_CHKANDSTMT(sreq == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq"); } else { MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(tmpbuf, actual_pack_bytes, MPI_PACKED, dest, sendtag, comm_ptr, attr, &sreq); + DO_MPID_ISEND(tmpbuf, actual_pack_bytes, MPI_PACKED, dest, sendtag, comm_ptr, coll_group, + attr, &sreq); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno != MPI_SUCCESS) { /* --BEGIN ERROR HANDLING-- */ @@ -391,7 +409,8 @@ int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, } int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_Request ** request_ptr, MPIR_Errflag_t errflag) + MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request_ptr, + MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -412,7 +431,7 @@ int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, MPIR_CONTEXT_COLL_OFFSET); MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag); - DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, attr, request_ptr); + DO_MPID_ISEND(buf, count, datatype, dest, tag, comm_ptr, coll_group, attr, request_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -425,7 +444,7 @@ int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, } int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, - int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request_ptr) + int tag, MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request_ptr) { int mpi_errno = MPI_SUCCESS; int attr = 0; @@ -446,7 +465,7 @@ int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, MPIR_CONTEXT_COLL_OFFSET); - DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, attr, request_ptr); + DO_MPID_IRECV(buf, count, datatype, source, tag, comm_ptr, coll_group, attr, request_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpi/coll/transports/gentran/gentran_types.h b/src/mpi/coll/transports/gentran/gentran_types.h index 52063db1b5f..24a8891b7c3 100644 --- a/src/mpi/coll/transports/gentran/gentran_types.h +++ b/src/mpi/coll/transports/gentran/gentran_types.h @@ -49,6 +49,7 @@ typedef struct MPII_Genutil_vtx_t { int dest; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request *req; } isend; struct { @@ -58,6 +59,7 @@ typedef struct MPII_Genutil_vtx_t { int src; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request *req; } irecv; struct { @@ -67,6 +69,7 @@ typedef struct MPII_Genutil_vtx_t { int src; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request *req; MPI_Status *status; } irecv_status; @@ -78,6 +81,7 @@ typedef struct MPII_Genutil_vtx_t { int num_dests; int tag; MPIR_Comm *comm; + int coll_group; MPIR_Request **req; int last_complete; } imcast; diff --git a/src/mpi/coll/transports/gentran/gentran_utils.c b/src/mpi/coll/transports/gentran/gentran_utils.c index 94c75c56ecf..18f27da9b56 100644 --- a/src/mpi/coll/transports/gentran/gentran_utils.c +++ b/src/mpi/coll/transports/gentran/gentran_utils.c @@ -43,8 +43,8 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.isend.count, vtxp->u.isend.dt, vtxp->u.isend.dest, - vtxp->u.isend.tag, vtxp->u.isend.comm, &vtxp->u.isend.req, - r->u.nbc.errflag); + vtxp->u.isend.tag, vtxp->u.isend.comm, vtxp->u.isend.coll_group, + &vtxp->u.isend.req, r->u.nbc.errflag); if (MPIR_Request_is_complete(vtxp->u.isend.req)) { MPIR_Request_free(vtxp->u.isend.req); @@ -75,7 +75,7 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.irecv.count, vtxp->u.irecv.dt, vtxp->u.irecv.src, vtxp->u.irecv.tag, vtxp->u.irecv.comm, - &vtxp->u.irecv.req); + vtxp->u.irecv.coll_group, &vtxp->u.irecv.req); if (MPIR_Request_is_complete(vtxp->u.irecv.req)) { MPIR_Request_free(vtxp->u.irecv.req); @@ -104,7 +104,8 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.irecv_status.count, vtxp->u.irecv_status.dt, vtxp->u.irecv_status.src, vtxp->u.irecv_status.tag, - vtxp->u.irecv_status.comm, &vtxp->u.irecv_status.req); + vtxp->u.irecv_status.comm, vtxp->u.irecv_status.coll_group, + &vtxp->u.irecv_status.req); if (MPIR_Request_is_complete(vtxp->u.irecv_status.req)) { if (vtxp->u.irecv_status.status != MPI_STATUS_IGNORE) { @@ -143,7 +144,8 @@ static int vtx_issue(int vtxid, MPII_Genutil_vtx_t * vtxp, MPII_Genutil_sched_t vtxp->u.imcast.count, vtxp->u.imcast.dt, dests[i], - vtxp->u.imcast.tag, vtxp->u.imcast.comm, &vtxp->u.imcast.req[i], + vtxp->u.imcast.tag, vtxp->u.imcast.comm, + vtxp->u.imcast.coll_group, &vtxp->u.imcast.req[i], r->u.nbc.errflag); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, diff --git a/src/mpi/coll/transports/gentran/tsp_gentran.c b/src/mpi/coll/transports/gentran/tsp_gentran.c index 76c20855847..6aa253d5239 100644 --- a/src/mpi/coll/transports/gentran/tsp_gentran.c +++ b/src/mpi/coll/transports/gentran/tsp_gentran.c @@ -187,8 +187,8 @@ int MPIR_TSP_sched_isend(const void *buf, MPI_Datatype dt, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t s, int n_in_vtcs, int *in_vtcs, - int *vtx_id) + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t s, int n_in_vtcs, + int *in_vtcs, int *vtx_id) { MPII_Genutil_sched_t *sched = s; vtx_t *vtxp; @@ -205,6 +205,7 @@ int MPIR_TSP_sched_isend(const void *buf, vtxp->u.isend.dest = dest; vtxp->u.isend.tag = tag; vtxp->u.isend.comm = comm_ptr; + vtxp->u.isend.coll_group = coll_group; /* the user may free the comm & type after initiating but before the * underlying send is actually posted, so we must add a reference here and @@ -224,8 +225,8 @@ int MPIR_TSP_sched_irecv(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t s, int n_in_vtcs, int *in_vtcs, - int *vtx_id) + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t s, int n_in_vtcs, + int *in_vtcs, int *vtx_id) { MPII_Genutil_sched_t *sched = s; vtx_t *vtxp; @@ -243,6 +244,7 @@ int MPIR_TSP_sched_irecv(void *buf, vtxp->u.irecv.src = source; vtxp->u.irecv.tag = tag; vtxp->u.irecv.comm = comm_ptr; + vtxp->u.irecv.coll_group = coll_group; MPIR_Comm_add_ref(comm_ptr); MPIR_Datatype_add_ref_if_not_builtin(dt); @@ -258,7 +260,7 @@ int MPIR_TSP_sched_irecv_status(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status, + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, int *vtx_id) { vtx_t *vtxp; @@ -277,6 +279,7 @@ int MPIR_TSP_sched_irecv_status(void *buf, vtxp->u.irecv_status.src = source; vtxp->u.irecv_status.tag = tag; vtxp->u.irecv_status.comm = comm_ptr; + vtxp->u.irecv_status.coll_group = coll_group; vtxp->u.irecv_status.status = status; MPIR_Comm_add_ref(comm_ptr); @@ -296,8 +299,8 @@ int MPIR_TSP_sched_imcast(const void *buf, int *dests, int num_dests, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t s, int n_in_vtcs, int *in_vtcs, - int *vtx_id) + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t s, int n_in_vtcs, + int *in_vtcs, int *vtx_id) { MPII_Genutil_sched_t *sched = s; vtx_t *vtxp; @@ -317,6 +320,7 @@ int MPIR_TSP_sched_imcast(const void *buf, memcpy(ut_int_array(&vtxp->u.imcast.dests), dests, num_dests * sizeof(int)); vtxp->u.imcast.tag = tag; vtxp->u.imcast.comm = comm_ptr; + vtxp->u.imcast.coll_group = coll_group; vtxp->u.imcast.req = (struct MPIR_Request **) MPL_malloc(sizeof(struct MPIR_Request *) * num_dests, MPL_MEM_COLL); diff --git a/src/mpi/coll/transports/tsp_impl.h b/src/mpi/coll/transports/tsp_impl.h index fb83f6d6a74..d18828dfc6c 100644 --- a/src/mpi/coll/transports/tsp_impl.h +++ b/src/mpi/coll/transports/tsp_impl.h @@ -49,8 +49,8 @@ int MPIR_TSP_sched_isend(const void *buf, MPI_Datatype dt, int dest, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, - int *vtx_id); + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched, + int n_in_vtcs, int *in_vtcs, int *vtx_id); /* Transport function to schedule an irecv vertex */ int MPIR_TSP_sched_irecv(void *buf, @@ -58,8 +58,8 @@ int MPIR_TSP_sched_irecv(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, - int *vtx_id); + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched, + int n_in_vtcs, int *in_vtcs, int *vtx_id); /* Transport function to schedule a irecv with status vertex */ int MPIR_TSP_sched_irecv_status(void *buf, @@ -67,7 +67,7 @@ int MPIR_TSP_sched_irecv_status(void *buf, MPI_Datatype dt, int source, int tag, - MPIR_Comm * comm_ptr, MPI_Status * status, + MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, int *vtx_id); /* Transport function to schedule an imcast vertex */ @@ -77,7 +77,7 @@ int MPIR_TSP_sched_imcast(const void *buf, int *dests, int num_dests, int tag, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPIR_TSP_sched_t sched, int n_in_vtcs, int *in_vtcs, int *vtx_id); diff --git a/src/mpid/common/sched/mpidu_sched.c b/src/mpid/common/sched/mpidu_sched.c index 891ced6b46c..e2fa385330f 100644 --- a/src/mpid/common/sched/mpidu_sched.c +++ b/src/mpid/common/sched/mpidu_sched.c @@ -226,12 +226,12 @@ static int MPIDU_Sched_start_entry(struct MPIDU_Sched *s, size_t idx, struct MPI * &send.count, but this requires patching up the pointers * during realloc of entries, so this is easier */ ret_errno = MPIC_Isend(e->u.send.buf, *e->u.send.count_p, e->u.send.datatype, - e->u.send.dest, s->tag, comm, &e->u.send.sreq, - r->u.nbc.errflag); + e->u.send.dest, s->tag, comm, e->u.send.coll_group, + &e->u.send.sreq, r->u.nbc.errflag); } else { ret_errno = MPIC_Isend(e->u.send.buf, e->u.send.count, e->u.send.datatype, - e->u.send.dest, s->tag, comm, &e->u.send.sreq, - r->u.nbc.errflag); + e->u.send.dest, s->tag, comm, e->u.send.coll_group, + &e->u.send.sreq, r->u.nbc.errflag); } /* Check if the error is actually fatal to the NBC or we can continue. */ if (unlikely(ret_errno)) { @@ -256,7 +256,8 @@ static int MPIDU_Sched_start_entry(struct MPIDU_Sched *s, size_t idx, struct MPI MPL_DBG_MSG_D(MPIR_DBG_COMM, VERBOSE, "starting RECV entry %d\n", (int) idx); comm = e->u.recv.comm; ret_errno = MPIC_Irecv(e->u.recv.buf, e->u.recv.count, e->u.recv.datatype, - e->u.recv.src, s->tag, comm, &e->u.recv.rreq); + e->u.recv.src, s->tag, comm, e->u.recv.coll_group, + &e->u.recv.rreq); /* Check if the error is actually fatal to the NBC or we can continue. */ if (unlikely(ret_errno)) { if (MPIR_ERR_NONE == r->u.nbc.errflag) { @@ -655,7 +656,7 @@ static int MPIDU_Sched_add_entry(struct MPIDU_Sched *s, int *idx, struct MPIDU_S /* do these ops need an entry handle returned? */ int MPIDU_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s) + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -674,6 +675,7 @@ int MPIDU_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int e->u.send.dest = dest; e->u.send.sreq = NULL; /* will be populated by _start_entry */ e->u.send.comm = comm; + e->u.send.coll_group = coll_group; /* the user may free the comm & type after initiating but before the * underlying send is actually posted, so we must add a reference here and @@ -711,6 +713,7 @@ int MPIDU_Sched_pt2pt_send(const void *buf, MPI_Aint count, MPI_Datatype datatyp e->u.send.dest = dest; e->u.send.sreq = NULL; /* will be populated by _start_entry */ e->u.send.comm = comm; + e->u.send.coll_group = MPIR_SUBGROUP_NONE; e->u.send.tag = tag; /* the user may free the comm & type after initiating but before the @@ -730,7 +733,7 @@ int MPIDU_Sched_pt2pt_send(const void *buf, MPI_Aint count, MPI_Datatype datatyp } int MPIDU_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype datatype, int dest, - MPIR_Comm * comm, MPIR_Sched_t s) + MPIR_Comm * comm, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -749,6 +752,7 @@ int MPIDU_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype e->u.send.dest = dest; e->u.send.sreq = NULL; /* will be populated by _start_entry */ e->u.send.comm = comm; + e->u.send.coll_group = coll_group; /* the user may free the comm & type after initiating but before the * underlying send is actually posted, so we must add a reference here and @@ -767,7 +771,7 @@ int MPIDU_Sched_send_defer(const void *buf, const MPI_Aint * count, MPI_Datatype } int MPIDU_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, - MPIR_Comm * comm, MPI_Status * status, MPIR_Sched_t s) + MPIR_Comm * comm, int coll_group, MPI_Status * status, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -785,6 +789,7 @@ int MPIDU_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, in e->u.recv.src = src; e->u.recv.rreq = NULL; /* will be populated by _start_entry */ e->u.recv.comm = comm; + e->u.recv.coll_group = coll_group; e->u.recv.status = status; status->MPI_ERROR = MPI_SUCCESS; MPIR_Comm_add_ref(comm); @@ -801,7 +806,7 @@ int MPIDU_Sched_recv_status(void *buf, MPI_Aint count, MPI_Datatype datatype, in } int MPIDU_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, MPIR_Comm * comm, - MPIR_Sched_t s) + int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; struct MPIDU_Sched_entry *e = NULL; @@ -819,6 +824,7 @@ int MPIDU_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, e->u.recv.src = src; e->u.recv.rreq = NULL; /* will be populated by _start_entry */ e->u.recv.comm = comm; + e->u.recv.coll_group = coll_group; e->u.recv.status = MPI_STATUS_IGNORE; MPIR_Comm_add_ref(comm); @@ -853,6 +859,7 @@ int MPIDU_Sched_pt2pt_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, e->u.recv.src = src; e->u.recv.rreq = NULL; /* will be populated by _start_entry */ e->u.recv.comm = comm; + e->u.send.coll_group = MPIR_SUBGROUP_NONE; e->u.recv.status = MPI_STATUS_IGNORE; e->u.recv.tag = tag; diff --git a/src/mpid/common/sched/mpidu_sched.h b/src/mpid/common/sched/mpidu_sched.h index 2454ad60c84..a61e64c6b69 100644 --- a/src/mpid/common/sched/mpidu_sched.h +++ b/src/mpid/common/sched/mpidu_sched.h @@ -43,6 +43,7 @@ struct MPIDU_Sched_send { int tag; /* only used for _PT2PT_SEND */ int dest; struct MPIR_Comm *comm; + int coll_group; struct MPIR_Request *sreq; }; @@ -53,6 +54,7 @@ struct MPIDU_Sched_recv { int tag; /* only used for _PT2PT_RECV */ int src; struct MPIR_Comm *comm; + int coll_group; struct MPIR_Request *rreq; MPI_Status *status; }; @@ -141,9 +143,9 @@ int MPIDU_Sched_reset(MPIR_Sched_t s); void *MPIDU_Sched_alloc_state(MPIR_Sched_t s, MPI_Aint size); int MPIDU_Sched_start(MPIR_Sched_t sp, struct MPIR_Comm *comm, struct MPIR_Request **req); int MPIDU_Sched_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, - struct MPIR_Comm *comm, MPIR_Sched_t s); + struct MPIR_Comm *comm, int coll_group, MPIR_Sched_t s); int MPIDU_Sched_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int src, - struct MPIR_Comm *comm, MPIR_Sched_t s); + struct MPIR_Comm *comm, int coll_group, MPIR_Sched_t s); int MPIDU_Sched_pt2pt_send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int tag, int dest, struct MPIR_Comm *comm, MPIR_Sched_t s); int MPIDU_Sched_pt2pt_recv(void *buf, MPI_Aint count, MPI_Datatype datatype, From b4104351547d457c92c522badd3b1db0df327f04 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 16 Aug 2024 17:47:10 -0500 Subject: [PATCH 10/27] continue: add coll_group in MPIC/sched/TSP routines --- .../coll/allgather/allgather_intra_brucks.c | 6 +++-- .../coll/allgather/allgather_intra_k_brucks.c | 4 +-- .../coll/allgather/allgather_intra_recexch.c | 20 +++++++------- .../allgather_intra_recursive_doubling.c | 8 +++--- src/mpi/coll/allgather/allgather_intra_ring.c | 3 ++- .../coll/allgatherv/allgatherv_intra_brucks.c | 5 ++-- .../allgatherv_intra_recursive_doubling.c | 7 ++--- .../coll/allgatherv/allgatherv_intra_ring.c | 10 +++---- .../allreduce_inter_reduce_exchange_bcast.c | 2 +- ...lreduce_intra_k_reduce_scatter_allgather.c | 27 ++++++++++--------- .../coll/allreduce/allreduce_intra_recexch.c | 20 +++++++------- .../allreduce_intra_recursive_doubling.c | 13 +++++---- ...allreduce_intra_reduce_scatter_allgather.c | 16 ++++++----- src/mpi/coll/allreduce/allreduce_intra_ring.c | 5 ++-- src/mpi/coll/allreduce/allreduce_intra_tree.c | 9 ++++--- .../coll/allreduce_group/allreduce_group.c | 23 +++++++++++----- .../alltoall_inter_pairwise_exchange.c | 2 +- src/mpi/coll/alltoall/alltoall_intra_brucks.c | 3 ++- .../coll/alltoall/alltoall_intra_k_brucks.c | 4 +-- .../coll/alltoall/alltoall_intra_pairwise.c | 2 +- ...alltoall_intra_pairwise_sendrecv_replace.c | 6 +++-- .../coll/alltoall/alltoall_intra_scattered.c | 5 ++-- .../alltoallv_inter_pairwise_exchange.c | 3 ++- ...lltoallv_intra_pairwise_sendrecv_replace.c | 4 +-- .../alltoallv/alltoallv_intra_scattered.c | 5 ++-- .../alltoallw_inter_pairwise_exchange.c | 2 +- ...lltoallw_intra_pairwise_sendrecv_replace.c | 4 +-- .../alltoallw/alltoallw_intra_scattered.c | 4 +-- .../barrier/barrier_intra_k_dissemination.c | 10 +++---- src/mpi/coll/bcast/bcast.h | 4 +-- .../bcast_inter_remote_send_local_bcast.c | 6 +++-- src/mpi/coll/bcast/bcast_intra_binomial.c | 8 +++--- .../coll/bcast/bcast_intra_pipelined_tree.c | 14 +++++----- ...tra_scatter_recursive_doubling_allgather.c | 15 +++++------ .../bcast_intra_scatter_ring_allgather.c | 5 ++-- src/mpi/coll/bcast/bcast_intra_smp.c | 4 +-- src/mpi/coll/bcast/bcast_intra_tree.c | 15 +++++++---- src/mpi/coll/bcast/bcast_utils.c | 8 +++--- .../exscan/exscan_intra_recursive_doubling.c | 2 +- src/mpi/coll/gather/gather_inter_linear.c | 9 +++---- .../gather_inter_local_gather_remote_send.c | 4 +-- src/mpi/coll/gather/gather_intra_binomial.c | 16 ++++++----- src/mpi/coll/gatherv/gatherv_allcomm_linear.c | 5 ++-- .../iallgather_intra_sched_brucks.c | 12 ++++++--- ...allgather_intra_sched_recursive_doubling.c | 10 ++++--- .../iallgather/iallgather_intra_sched_ring.c | 4 +-- .../coll/iallgather/iallgather_tsp_brucks.c | 10 ++++--- .../coll/iallgather/iallgather_tsp_recexch.c | 20 +++++++------- src/mpi/coll/iallgather/iallgather_tsp_ring.c | 9 ++++--- .../iallgatherv_intra_sched_brucks.c | 12 ++++++--- ...llgatherv_intra_sched_recursive_doubling.c | 10 ++++--- .../iallgatherv_intra_sched_ring.c | 4 +-- .../coll/iallgatherv/iallgatherv_tsp_brucks.c | 6 ++--- .../iallgatherv/iallgatherv_tsp_recexch.c | 19 ++++++------- .../coll/iallgatherv/iallgatherv_tsp_ring.c | 12 ++++----- ...allreduce_intra_sched_recursive_doubling.c | 16 ++++++----- ...uce_intra_sched_reduce_scatter_allgather.c | 20 ++++++++------ .../coll/iallreduce/iallreduce_tsp_recexch.c | 20 +++++++------- ...ecexch_reduce_scatter_recexch_allgatherv.c | 8 +++--- ...iallreduce_tsp_recursive_exchange_common.c | 8 +++--- src/mpi/coll/iallreduce/iallreduce_tsp_ring.c | 7 ++--- src/mpi/coll/iallreduce/iallreduce_tsp_tree.c | 12 +++++---- .../ialltoall_inter_sched_pairwise_exchange.c | 4 +-- .../ialltoall/ialltoall_intra_sched_brucks.c | 4 +-- .../ialltoall/ialltoall_intra_sched_inplace.c | 5 ++-- .../ialltoall_intra_sched_pairwise.c | 4 +-- .../ialltoall_intra_sched_permuted_sendrecv.c | 4 +-- src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c | 4 +-- src/mpi/coll/ialltoall/ialltoall_tsp_ring.c | 8 +++--- .../coll/ialltoall/ialltoall_tsp_scattered.c | 14 +++++----- ...ialltoallv_inter_sched_pairwise_exchange.c | 4 +-- .../ialltoallv_intra_sched_blocked.c | 6 +++-- .../ialltoallv_intra_sched_inplace.c | 5 ++-- .../coll/ialltoallv/ialltoallv_tsp_blocked.c | 8 +++--- .../coll/ialltoallv/ialltoallv_tsp_inplace.c | 4 +-- .../ialltoallv/ialltoallv_tsp_scattered.c | 16 +++++------ ...ialltoallw_inter_sched_pairwise_exchange.c | 4 +-- .../ialltoallw_intra_sched_blocked.c | 6 +++-- .../ialltoallw_intra_sched_inplace.c | 5 ++-- .../coll/ialltoallw/ialltoallw_tsp_blocked.c | 4 +-- .../coll/ialltoallw/ialltoallw_tsp_inplace.c | 6 ++--- .../ibarrier_intra_sched_recursive_doubling.c | 4 +-- .../coll/ibarrier/ibarrier_intra_tsp_dissem.c | 6 ++--- src/mpi/coll/ibcast/ibcast.h | 4 +-- src/mpi/coll/ibcast/ibcast_inter_sched_flat.c | 4 +-- .../coll/ibcast/ibcast_intra_sched_binomial.c | 9 ++++--- ...hed_scatter_recursive_doubling_allgather.c | 12 +++++---- ...bcast_intra_sched_scatter_ring_allgather.c | 6 ++--- src/mpi/coll/ibcast/ibcast_intra_sched_smp.c | 5 ++-- .../ibcast/ibcast_tsp_scatterv_allgatherv.c | 10 +++---- src/mpi/coll/ibcast/ibcast_tsp_tree.c | 9 ++++--- src/mpi/coll/ibcast/ibcast_utils.c | 8 +++--- .../iexscan_intra_sched_recursive_doubling.c | 5 ++-- .../coll/igather/igather_inter_sched_long.c | 4 +-- .../coll/igather/igather_inter_sched_short.c | 5 ++-- .../igather/igather_intra_sched_binomial.c | 18 ++++++++----- src/mpi/coll/igather/igather_tsp_tree.c | 9 ++++--- .../igatherv/igatherv_allcomm_sched_linear.c | 6 +++-- src/mpi/coll/igatherv/igatherv_tsp_linear.c | 4 +-- ...ineighbor_allgather_allcomm_sched_linear.c | 6 +++-- .../ineighbor_allgather_tsp_linear.c | 8 +++--- ...neighbor_allgatherv_allcomm_sched_linear.c | 6 +++-- .../ineighbor_allgatherv_tsp_linear.c | 8 +++--- .../ineighbor_alltoall_allcomm_sched_linear.c | 6 +++-- .../ineighbor_alltoall_tsp_linear.c | 8 +++--- ...ineighbor_alltoallv_allcomm_sched_linear.c | 6 +++-- .../ineighbor_alltoallv_tsp_linear.c | 8 +++--- ...ineighbor_alltoallw_allcomm_sched_linear.c | 8 ++++-- .../ineighbor_alltoallw_tsp_linear.c | 8 +++--- ...uce_inter_sched_local_reduce_remote_send.c | 4 +-- .../ireduce/ireduce_intra_sched_binomial.c | 9 ++++--- ...reduce_intra_sched_reduce_scatter_gather.c | 19 +++++++------ src/mpi/coll/ireduce/ireduce_tsp_tree.c | 13 ++++----- ...educe_scatter_intra_sched_noncommutative.c | 4 +-- .../ireduce_scatter_intra_sched_pairwise.c | 7 ++--- ...e_scatter_intra_sched_recursive_doubling.c | 10 ++++--- ...ce_scatter_intra_sched_recursive_halving.c | 18 ++++++++----- .../ireduce_scatter_tsp_recexch.c | 18 +++++++------ ...scatter_block_intra_sched_noncommutative.c | 4 +-- ...educe_scatter_block_intra_sched_pairwise.c | 6 ++--- ...ter_block_intra_sched_recursive_doubling.c | 10 ++++--- ...tter_block_intra_sched_recursive_halving.c | 17 +++++++----- .../ireduce_scatter_block_tsp_recexch.c | 22 ++++++++------- .../iscan_intra_sched_recursive_doubling.c | 5 ++-- src/mpi/coll/iscan/iscan_intra_sched_smp.c | 11 +++++--- .../coll/iscan/iscan_tsp_recursive_doubling.c | 8 +++--- .../iscatter/iscatter_inter_sched_linear.c | 4 +-- ...er_inter_sched_remote_send_local_scatter.c | 5 ++-- .../iscatter/iscatter_intra_sched_binomial.c | 10 ++++--- src/mpi/coll/iscatter/iscatter_tsp_tree.c | 4 +-- .../iscatterv_allcomm_sched_linear.c | 6 +++-- src/mpi/coll/iscatterv/iscatterv_tsp_linear.c | 6 ++--- .../reduce_inter_local_reduce_remote_send.c | 5 ++-- src/mpi/coll/reduce/reduce_intra_binomial.c | 10 ++++--- .../reduce_intra_reduce_scatter_gather.c | 18 ++++++++----- .../reduce_scatter_intra_noncommutative.c | 2 +- .../reduce_scatter_intra_pairwise.c | 4 +-- .../reduce_scatter_intra_recursive_doubling.c | 8 +++--- .../reduce_scatter_intra_recursive_halving.c | 17 +++++++----- ...educe_scatter_block_intra_noncommutative.c | 2 +- .../reduce_scatter_block_intra_pairwise.c | 4 +-- ...e_scatter_block_intra_recursive_doubling.c | 7 ++--- ...ce_scatter_block_intra_recursive_halving.c | 18 ++++++++----- .../coll/scan/scan_intra_recursive_doubling.c | 2 +- src/mpi/coll/scan/scan_intra_smp.c | 8 +++--- src/mpi/coll/scatter/scatter_inter_linear.c | 6 ++--- .../scatter_inter_remote_send_local_scatter.c | 4 +-- src/mpi/coll/scatter/scatter_intra_binomial.c | 10 ++++--- .../coll/scatterv/scatterv_allcomm_linear.c | 5 ++-- src/mpi/comm/comm_impl.c | 10 +++---- src/mpi/comm/comm_split.c | 3 ++- src/mpi/comm/contextid.c | 6 ++--- src/mpi/topo/dist_graph_create.c | 8 +++--- src/mpid/ch3/src/ch3u_port.c | 20 +++++++------- .../ch3/src/mpid_comm_get_all_failed_procs.c | 8 +++--- src/mpid/ch3/src/mpid_vc.c | 4 +-- .../release_gather/nb_bcast_release_gather.h | 6 +++-- .../release_gather/nb_reduce_release_gather.h | 5 ++-- .../shm/posix/release_gather/release_gather.h | 14 +++++----- src/mpid/ch4/src/ch4_coll_impl.h | 25 ++++++++--------- src/mpid/ch4/src/ch4_comm.c | 13 +++++---- 161 files changed, 764 insertions(+), 589 deletions(-) diff --git a/src/mpi/coll/allgather/allgather_intra_brucks.c b/src/mpi/coll/allgather/allgather_intra_brucks.c index c3bd63bd0aa..ac2b8a01dce 100644 --- a/src/mpi/coll/allgather/allgather_intra_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_brucks.c @@ -67,7 +67,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, MPIR_ALLGATHER_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), curr_cnt * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); curr_cnt *= 2; pof2 *= 2; @@ -84,7 +85,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, dst, MPIR_ALLGATHER_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), rem * recvcount * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allgather/allgather_intra_k_brucks.c b/src/mpi/coll/allgather/allgather_intra_k_brucks.c index 6e86b7a23cd..4fd62333c3e 100644 --- a/src/mpi/coll/allgather/allgather_intra_k_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_k_brucks.c @@ -140,7 +140,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount, /* Receive at the exact location. */ mpi_errno = MPIC_Irecv((char *) tmp_recvbuf + j * recvcount * delta * recvtype_extent, - count, recvtype, src, MPIR_ALLGATHER_TAG, comm, + count, recvtype, src, MPIR_ALLGATHER_TAG, comm, coll_group, &reqs[num_reqs++]); MPIR_ERR_CHECK(mpi_errno); @@ -152,7 +152,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount, /* Send from the start of recv till `count` amount of data. */ mpi_errno = - MPIC_Isend(tmp_recvbuf, count, recvtype, dst, MPIR_ALLGATHER_TAG, comm, + MPIC_Isend(tmp_recvbuf, count, recvtype, dst, MPIR_ALLGATHER_TAG, comm, coll_group, &reqs[num_reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/allgather/allgather_intra_recexch.c b/src/mpi/coll/allgather/allgather_intra_recexch.c index 2e4cba4ef66..bd83fed5cac 100644 --- a/src/mpi/coll/allgather/allgather_intra_recexch.c +++ b/src/mpi/coll/allgather/allgather_intra_recexch.c @@ -117,7 +117,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, buf_to_send = (void *) sendbuf; mpi_errno = MPIC_Send(buf_to_send, recvcount, recvtype, step1_sendto, MPIR_ALLGATHER_TAG, comm, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { if (step1_nrecvs) { @@ -125,7 +125,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets the data from non-participating rank */ recv_offset = step1_recvfrom[i] * recv_extent * recvcount; mpi_errno = MPIC_Irecv(((char *) recvbuf + recv_offset), recvcount, recvtype, - step1_recvfrom[i], MPIR_ALLGATHER_TAG, comm, + step1_recvfrom[i], MPIR_ALLGATHER_TAG, comm, coll_group, &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); } @@ -159,8 +159,8 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPIC_Sendrecv(((char *) recvbuf + send_offset), send_count * recvcount, recvtype, partner, MPIR_ALLGATHER_TAG, ((char *) recvbuf + recv_offset), recv_count * recvcount, - recvtype, partner, MPIR_ALLGATHER_TAG, comm, MPI_STATUS_IGNORE, - errflag); + recvtype, partner, MPIR_ALLGATHER_TAG, comm, coll_group, + MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); } } @@ -191,7 +191,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, recv_offset = offset * recv_extent * recvcount; mpi_errno = MPIC_Irecv(((char *) recvbuf + recv_offset), count * recvcount, recvtype, nbr, - MPIR_ALLGATHER_TAG, comm, &recv_reqs[num_rreq++]); + MPIR_ALLGATHER_TAG, comm, coll_group, &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); } if (recexch_type == MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_HALVING) @@ -210,7 +210,8 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPII_Recexchalgo_get_count_and_offset(rank_for_offset, j, k, nranks, &count, &offset); send_offset = offset * recv_extent * recvcount; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), count * recvcount, recvtype, - nbr, MPIR_ALLGATHER_TAG, comm, &send_reqs[num_sreq++], errflag); + nbr, MPIR_ALLGATHER_TAG, comm, coll_group, + &send_reqs[num_sreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } /* wait on prev recvs */ @@ -236,7 +237,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, send_offset = offset * recv_extent * recvcount; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), count * recvcount, - recvtype, nbr, MPIR_ALLGATHER_TAG, comm, + recvtype, nbr, MPIR_ALLGATHER_TAG, comm, coll_group, &send_reqs[num_sreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -260,13 +261,14 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, if (step1_sendto != -1) { mpi_errno = MPIC_Recv(recvbuf, recvcount * nranks, recvtype, step1_sendto, MPIR_ALLGATHER_TAG, - comm, MPI_STATUS_IGNORE); + comm, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIC_Isend(recvbuf, recvcount * nranks, recvtype, step1_recvfrom[i], - MPIR_ALLGATHER_TAG, comm, &recv_reqs[num_rreq++], errflag); + MPIR_ALLGATHER_TAG, comm, coll_group, &recv_reqs[num_rreq++], + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c index 6a4316fbf30..107b2d2ab9a 100644 --- a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c +++ b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c @@ -82,7 +82,7 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, ((char *) recvbuf + recv_offset), (comm_size - dst_tree_root) * recvcount, recvtype, dst, - MPIR_ALLGATHER_TAG, comm_ptr, &status, errflag); + MPIR_ALLGATHER_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Get_count_impl(&status, recvtype, &last_recv_cnt); curr_cnt += last_recv_cnt; @@ -136,7 +136,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, && (dst >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Send(((char *) recvbuf + offset), last_recv_cnt, - recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, errflag); + recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } /* recv only if this proc. doesn't have data and sender @@ -146,7 +147,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, (rank >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Recv(((char *) recvbuf + offset), (comm_size - (my_tree_root + mask)) * recvcount, - recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, &status); + recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, coll_group, + &status); MPIR_ERR_CHECK(mpi_errno); /* nprocs_completed is also equal to the * no. of processes whose data we don't have */ diff --git a/src/mpi/coll/allgather/allgather_intra_ring.c b/src/mpi/coll/allgather/allgather_intra_ring.c index 88df120075d..a9715f502f6 100644 --- a/src/mpi/coll/allgather/allgather_intra_ring.c +++ b/src/mpi/coll/allgather/allgather_intra_ring.c @@ -64,7 +64,8 @@ int MPIR_Allgather_intra_ring(const void *sendbuf, ((char *) recvbuf + jnext * recvcount * recvtype_extent), recvcount, recvtype, left, - MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLGATHER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); j = jnext; jnext = (comm_size + jnext - 1) % comm_size; diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c index 1769221ac5d..780ae714f25 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c @@ -76,7 +76,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, MPIR_ALLGATHERV_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), (total_count - curr_cnt) * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHERV_TAG, comm_ptr, &status, errflag); + src, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { recv_cnt = 0; @@ -103,7 +103,8 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, dst, MPIR_ALLGATHERV_TAG, ((char *) tmp_buf + curr_cnt * recvtype_sz), (total_count - curr_cnt) * recvtype_sz, MPI_BYTE, - src, MPIR_ALLGATHERV_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c index 390b81679e1..21df0baf8ef 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c @@ -113,7 +113,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, MPIR_ALLGATHERV_TAG, ((char *) tmp_buf + recv_offset * recvtype_sz), (total_count - recv_offset) * recvtype_sz, MPI_BYTE, dst, - MPIR_ALLGATHERV_TAG, comm_ptr, &status, errflag); + MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { last_recv_cnt = 0; @@ -176,7 +176,8 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Send(((char *) tmp_buf + offset * recvtype_sz), last_recv_cnt * recvtype_sz, - MPI_BYTE, dst, MPIR_ALLGATHERV_TAG, comm_ptr, errflag); + MPI_BYTE, dst, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* last_recv_cnt was set in the previous * receive. that's the amount of data to be @@ -194,7 +195,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Recv(((char *) tmp_buf + offset * recvtype_sz), (total_count - offset) * recvtype_sz, MPI_BYTE, - dst, MPIR_ALLGATHERV_TAG, comm_ptr, &status); + dst, MPIR_ALLGATHERV_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { last_recv_cnt = 0; diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c index 10fee8a9cb6..3daafdee90b 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c @@ -109,19 +109,19 @@ int MPIR_Allgatherv_intra_ring(const void *sendbuf, /* Don't do anything. This case is possible if two * consecutive processes contribute 0 bytes each. */ } else if (!sendnow) { /* If there's no data to send, just do a recv call */ - mpi_errno = - MPIC_Recv(rbuf, recvnow, recvtype, left, MPIR_ALLGATHERV_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(rbuf, recvnow, recvtype, left, MPIR_ALLGATHERV_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); torecv -= recvnow; } else if (!recvnow) { /* If there's no data to receive, just do a send call */ - mpi_errno = - MPIC_Send(sbuf, sendnow, recvtype, right, MPIR_ALLGATHERV_TAG, comm_ptr, errflag); + mpi_errno = MPIC_Send(sbuf, sendnow, recvtype, right, MPIR_ALLGATHERV_TAG, + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); tosend -= sendnow; } else { /* There's data to be sent and received */ mpi_errno = MPIC_Sendrecv(sbuf, sendnow, recvtype, right, MPIR_ALLGATHERV_TAG, rbuf, recvnow, recvtype, left, MPIR_ALLGATHERV_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); tosend -= sendnow; torecv -= recvnow; diff --git a/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c b/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c index ae4839a3470..cb594f415e1 100644 --- a/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c +++ b/src/mpi/coll/allreduce/allreduce_inter_reduce_exchange_bcast.c @@ -48,7 +48,7 @@ int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbu if (comm_ptr->rank == 0) { mpi_errno = MPIC_Sendrecv(tmp_buf, count, datatype, 0, MPIR_REDUCE_TAG, recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, - comm_ptr, MPI_STATUS_IGNORE, errflag); + comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c index df552c6d85e..81451aba812 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c @@ -107,13 +107,14 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, if (!in_step2) { /* even */ /* non-participating rank sends the data to a participating rank */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, errflag); + datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* odd */ for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets data from non-partcipating ranks */ mpi_errno = MPIC_Recv(tmp_recvbuf, count, datatype, step1_recvfrom[i], MPIR_ALLREDUCE_TAG, comm, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* Do reduction of reduced data */ mpi_errno = MPIR_Reduce_local(tmp_recvbuf, recvbuf, count, datatype, op); @@ -163,8 +164,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, send_cnt += cnts[offset + x]; mpi_errno = MPIC_Isend((char *) recvbuf + send_offset, send_cnt, - datatype, dst, MPIR_ALLREDUCE_TAG, comm, &recv_reqs[num_rreq++], - errflag); + datatype, dst, MPIR_ALLREDUCE_TAG, comm, coll_group, + &recv_reqs[num_rreq++], errflag); MPIR_ERR_CHECK(mpi_errno); rank_for_offset = MPII_Recexchalgo_reverse_digits_step2(rank, nranks, k); @@ -177,7 +178,7 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, recv_cnt += cnts[offset + x]; mpi_errno = MPIC_Irecv((char *) tmp_recvbuf + recv_offset, recv_cnt, datatype, - dst, MPIR_ALLREDUCE_TAG, comm, &recv_reqs[num_rreq++]); + dst, MPIR_ALLREDUCE_TAG, comm, coll_group, &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Waitall(num_rreq, recv_reqs, MPI_STATUSES_IGNORE); @@ -210,7 +211,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, for (x = 0; x < current_cnt; x++) recv_count += cnts[offset + x]; mpi_errno = MPIC_Irecv(((char *) recvbuf + recv_offset), recv_count, datatype, - nbr, MPIR_ALLREDUCE_TAG, comm, &recv_reqs[num_rreq++]); + nbr, MPIR_ALLREDUCE_TAG, comm, coll_group, + &recv_reqs[num_rreq++]); MPIR_ERR_CHECK(mpi_errno); } recv_phase--; @@ -226,8 +228,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, for (x = 0; x < current_cnt; x++) send_count += cnts[offset + x]; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), send_count, datatype, - nbr, MPIR_ALLREDUCE_TAG, comm, &send_reqs[num_sreq++], - errflag); + nbr, MPIR_ALLREDUCE_TAG, comm, coll_group, + &send_reqs[num_sreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } /* wait on prev recvs */ @@ -250,7 +252,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, send_count += cnts[offset + x]; mpi_errno = MPIC_Isend(((char *) recvbuf + send_offset), send_count, datatype, nbr, - MPIR_ALLREDUCE_TAG, comm, &send_reqs[num_sreq++], errflag); + MPIR_ALLREDUCE_TAG, comm, coll_group, &send_reqs[num_sreq++], + errflag); MPIR_ERR_CHECK(mpi_errno); } /* wait on prev recvs */ @@ -268,15 +271,15 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, /* Step 3: This is reverse of Step 1. Rans that participated in Step 2 * send the data to non-partcipating rans */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ - mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, - MPI_STATUS_IGNORE); + mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, + comm, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } else { if (step1_nrecvs > 0) { for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIC_Isend(recvbuf, count, datatype, step1_recvfrom[i], MPIR_ALLREDUCE_TAG, - comm, &send_reqs[i], errflag); + comm, coll_group, &send_reqs[i], errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allreduce/allreduce_intra_recexch.c b/src/mpi/coll/allreduce/allreduce_intra_recexch.c index adf6988a40f..72d3320055b 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recexch.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recexch.c @@ -154,14 +154,16 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, if (!in_step2) { /* even */ /* non-participating rank sends the data to a participating rank */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, errflag); + datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* odd */ if (step1_nrecvs) { for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets data from non-partcipating ranks */ mpi_errno = MPIC_Irecv(nbr_buffer[i], count, datatype, step1_recvfrom[i], - MPIR_ALLREDUCE_TAG, comm, &recv_reqs[recv_nreq++]); + MPIR_ALLREDUCE_TAG, comm, coll_group, + &recv_reqs[recv_nreq++]); MPIR_ERR_CHECK(mpi_errno); } mpi_errno = MPIC_Waitall(recv_nreq, recv_reqs, MPI_STATUSES_IGNORE); @@ -187,7 +189,7 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, nbr = step2_nbrs[phase + j][i]; mpi_errno = MPIC_Irecv(nbr_buffer[buf++], count, datatype, nbr, MPIR_ALLREDUCE_TAG, - comm, &recv_reqs[recv_nreq++]); + comm, coll_group, &recv_reqs[recv_nreq++]); MPIR_ERR_CHECK(mpi_errno); } } @@ -196,8 +198,8 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, /* send data to all the neighbors */ for (i = 0; i < k - 1; i++) { nbr = step2_nbrs[phase][i]; - mpi_errno = MPIC_Isend(recvbuf, count, datatype, nbr, MPIR_ALLREDUCE_TAG, comm, - &send_reqs[send_nreq++], errflag); + mpi_errno = MPIC_Isend(recvbuf, count, datatype, nbr, MPIR_ALLREDUCE_TAG, + comm, coll_group, &send_reqs[send_nreq++], errflag); MPIR_ERR_CHECK(mpi_errno); if (rank > nbr) { } @@ -227,7 +229,7 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, mpi_errno = MPIC_Isend(recvbuf, count, datatype, nbr, MPIR_ALLREDUCE_TAG, comm, - &send_reqs[send_nreq++], errflag); + coll_group, &send_reqs[send_nreq++], errflag); MPIR_ERR_CHECK(mpi_errno); } @@ -251,14 +253,14 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, /* Step 3: This is reverse of Step 1. Rans that participated in Step 2 * send the data to non-partcipating rans */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ - mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, comm, - MPI_STATUS_IGNORE); + mpi_errno = MPIC_Recv(recvbuf, count, datatype, step1_sendto, MPIR_ALLREDUCE_TAG, + comm, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIC_Isend(recvbuf, count, datatype, step1_recvfrom[i], MPIR_ALLREDUCE_TAG, - comm, &send_reqs[i], errflag); + comm, coll_group, &send_reqs[i], errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c index a87d75cd553..18486831078 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c @@ -66,7 +66,8 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -76,7 +77,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, } else { /* odd */ mpi_errno = MPIC_Recv(tmp_buf, count, datatype, rank - 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -112,7 +113,8 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype, dst, MPIR_ALLREDUCE_TAG, tmp_buf, count, datatype, dst, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -140,11 +142,12 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2) /* odd */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); else /* even */ mpi_errno = MPIC_Recv(recvbuf, count, datatype, rank + 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c index 5d7c3320bd4..9bdd3bec10b 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c @@ -86,7 +86,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank + 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -96,7 +97,7 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, } else { /* odd */ mpi_errno = MPIC_Recv(tmp_buf, count, datatype, rank - 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -176,7 +177,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -235,7 +237,8 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, (char *) recvbuf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (newrank > newdst) @@ -250,11 +253,12 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2) /* odd */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + datatype, rank - 1, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + errflag); else /* even */ mpi_errno = MPIC_Recv(recvbuf, count, datatype, rank + 1, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/allreduce/allreduce_intra_ring.c b/src/mpi/coll/allreduce/allreduce_intra_ring.c index f05b393060c..3791abfc262 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_ring.c +++ b/src/mpi/coll/allreduce/allreduce_intra_ring.c @@ -78,11 +78,12 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count mpi_errno = MPIR_Sched_next_tag(comm, &tag); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIC_Irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, &reqs[0]); + mpi_errno = MPIC_Irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, + comm, coll_group, &reqs[0]); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Isend((char *) recvbuf + displs[send_rank] * extent, cnts[send_rank], - datatype, dst, tag, comm, &reqs[1], errflag); + datatype, dst, tag, comm, coll_group, &reqs[1], errflag); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Waitall(2, reqs, MPI_STATUSES_IGNORE); diff --git a/src/mpi/coll/allreduce/allreduce_intra_tree.c b/src/mpi/coll/allreduce/allreduce_intra_tree.c index bb09fcf09eb..91293b75904 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_tree.c +++ b/src/mpi/coll/allreduce/allreduce_intra_tree.c @@ -150,7 +150,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, mpi_errno = MPIC_Recv(recv_address, msgsize, datatype, child, MPIR_ALLREDUCE_TAG, comm_ptr, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); /* for communication errors, just record the error but continue */ MPIR_ERR_CHECK(mpi_errno); @@ -172,14 +172,14 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, if (rank != root) { /* send data to the parent */ mpi_errno = MPIC_Isend(reduce_address, msgsize, datatype, my_tree.parent, MPIR_ALLREDUCE_TAG, - comm_ptr, &reqs[num_reqs++], errflag); + comm_ptr, coll_group, &reqs[num_reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); } if (my_tree.parent != -1) { mpi_errno = MPIC_Recv(reduce_address, msgsize, datatype, my_tree.parent, MPIR_ALLREDUCE_TAG, comm_ptr, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } if (num_children) { @@ -189,7 +189,8 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, MPIR_Assert(child != 0); mpi_errno = MPIC_Isend(reduce_address, msgsize, datatype, child, - MPIR_ALLREDUCE_TAG, comm_ptr, &reqs[num_reqs++], errflag); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, &reqs[num_reqs++], + errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/allreduce_group/allreduce_group.c b/src/mpi/coll/allreduce_group/allreduce_group.c index 68a959f17f7..0fa313ededa 100644 --- a/src/mpi/coll/allreduce_group/allreduce_group.c +++ b/src/mpi/coll/allreduce_group/allreduce_group.c @@ -67,7 +67,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, if (group_rank < 2 * rem) { if (group_rank % 2 == 0) { /* even */ to_comm_rank(cdst, group_rank + 1); - mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag); + mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, + MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -76,7 +77,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, newrank = -1; } else { /* odd */ to_comm_rank(csrc, group_rank - 1); - mpi_errno = MPIC_Recv(tmp_buf, count, datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE); + mpi_errno = MPIC_Recv(tmp_buf, count, datatype, csrc, tag, comm_ptr, + MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -116,7 +118,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, mpi_errno = MPIC_Sendrecv(recvbuf, count, datatype, cdst, tag, tmp_buf, count, datatype, cdst, - tag, comm_ptr, MPI_STATUS_IGNORE, errflag); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (!mpi_errno) { /* tmp_buf contains data received in this step. @@ -197,7 +200,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, cdst, - tag, comm_ptr, MPI_STATUS_IGNORE, errflag); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -257,7 +261,8 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, (char *) recvbuf + disps[recv_idx] * extent, recv_cnt, datatype, cdst, - tag, comm_ptr, MPI_STATUS_IGNORE, errflag); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); if (newrank > newdst) @@ -274,10 +279,14 @@ int MPII_Allreduce_group_intra(void *sendbuf, void *recvbuf, MPI_Aint count, if (group_rank < 2 * rem) { if (group_rank % 2) { /* odd */ to_comm_rank(cdst, group_rank - 1); - mpi_errno = MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, errflag); + mpi_errno = + MPIC_Send(recvbuf, count, datatype, cdst, tag, comm_ptr, MPIR_SUBGROUP_NONE, + errflag); } else { /* even */ to_comm_rank(csrc, group_rank + 1); - mpi_errno = MPIC_Recv(recvbuf, count, datatype, csrc, tag, comm_ptr, MPI_STATUS_IGNORE); + mpi_errno = + MPIC_Recv(recvbuf, count, datatype, csrc, tag, comm_ptr, MPIR_SUBGROUP_NONE, + MPI_STATUS_IGNORE); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c b/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c index 1f6c59b0eed..0af0ce90fb1 100644 --- a/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoall/alltoall_inter_pairwise_exchange.c @@ -57,7 +57,7 @@ int MPIR_Alltoall_inter_pairwise_exchange(const void *sendbuf, MPI_Aint sendcoun mpi_errno = MPIC_Sendrecv(sendaddr, sendcount, sendtype, dst, MPIR_ALLTOALL_TAG, recvaddr, recvcount, recvtype, src, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_intra_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_brucks.c index e26df86f4fa..41b98808dfa 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_brucks.c @@ -107,7 +107,8 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf, mpi_errno = MPIC_Sendrecv(tmp_buf, newtype_sz, MPI_BYTE, dst, MPIR_ALLTOALL_TAG, recvbuf, 1, newtype, - src, MPIR_ALLTOALL_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_ALLTOALL_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&newtype); diff --git a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c index 33ae4224dc4..dc43a99196e 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c @@ -251,12 +251,12 @@ int MPIR_Alltoall_intra_k_brucks(const void *sendbuf, mpi_errno = MPIC_Irecv(tmp_rbuf[j - 1], packsize, MPI_BYTE, src, MPIR_ALLTOALL_TAG, comm, - &reqs[num_reqs++]); + coll_group, &reqs[num_reqs++]); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Isend(tmp_sbuf[j - 1], packsize, MPI_BYTE, dst, MPIR_ALLTOALL_TAG, comm, - &reqs[num_reqs++], errflag); + coll_group, &reqs[num_reqs++], errflag); if (mpi_errno) { MPIR_ERR_POP(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c index 829b3a9f1f3..c9bfb158286 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c @@ -75,7 +75,7 @@ int MPIR_Alltoall_intra_pairwise(const void *sendbuf, ((char *) recvbuf + src * recvcount * recvtype_extent), recvcount, recvtype, src, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c index 2746fe79771..e267b7af94d 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c @@ -56,14 +56,16 @@ int MPIR_Alltoall_intra_pairwise_sendrecv_replace(const void *sendbuf, mpi_errno = MPIC_Sendrecv_replace(((char *) recvbuf + j * recvcount * recvtype_extent), recvcount, recvtype, j, MPIR_ALLTOALL_TAG, j, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == j) { /* same as above with i/j args reversed */ mpi_errno = MPIC_Sendrecv_replace(((char *) recvbuf + i * recvcount * recvtype_extent), recvcount, recvtype, i, MPIR_ALLTOALL_TAG, i, - MPIR_ALLTOALL_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/alltoall/alltoall_intra_scattered.c b/src/mpi/coll/alltoall/alltoall_intra_scattered.c index f4437bbe9ca..a95031f3461 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_scattered.c +++ b/src/mpi/coll/alltoall/alltoall_intra_scattered.c @@ -72,7 +72,7 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, mpi_errno = MPIC_Irecv((char *) recvbuf + dst * recvcount * recvtype_extent, recvcount, recvtype, dst, - MPIR_ALLTOALL_TAG, comm_ptr, &reqarray[i]); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &reqarray[i]); MPIR_ERR_CHECK(mpi_errno); } @@ -81,7 +81,8 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, mpi_errno = MPIC_Isend((char *) sendbuf + dst * sendcount * sendtype_extent, sendcount, sendtype, dst, - MPIR_ALLTOALL_TAG, comm_ptr, &reqarray[i + ss], errflag); + MPIR_ALLTOALL_TAG, comm_ptr, coll_group, &reqarray[i + ss], + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c b/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c index abc3b6eb39b..341ff3153f0 100644 --- a/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoallv/alltoallv_inter_pairwise_exchange.c @@ -66,7 +66,8 @@ int MPIR_Alltoallv_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint * mpi_errno = MPIC_Sendrecv(sendaddr, sendcount, sendtype, dst, MPIR_ALLTOALLV_TAG, recvaddr, recvcount, - recvtype, src, MPIR_ALLTOALLV_TAG, comm_ptr, &status, errflag); + recvtype, src, MPIR_ALLTOALLV_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c index 6835bddc3cb..7f93eeaedb5 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c @@ -59,7 +59,7 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[j], recvtype, j, MPIR_ALLTOALLV_TAG, j, MPIR_ALLTOALLV_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == j) { @@ -68,7 +68,7 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[i], recvtype, i, MPIR_ALLTOALLV_TAG, i, MPIR_ALLTOALLV_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c index 6fba549f28a..c33ec8e4816 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c @@ -71,7 +71,8 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou if (type_size) { mpi_errno = MPIC_Irecv((char *) recvbuf + rdispls[dst] * recv_extent, recvcounts[dst], recvtype, dst, - MPIR_ALLTOALLV_TAG, comm_ptr, &reqarray[req_cnt]); + MPIR_ALLTOALLV_TAG, comm_ptr, coll_group, + &reqarray[req_cnt]); MPIR_ERR_CHECK(mpi_errno); req_cnt++; } @@ -86,7 +87,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou if (type_size) { mpi_errno = MPIC_Isend((char *) sendbuf + sdispls[dst] * send_extent, sendcounts[dst], sendtype, dst, - MPIR_ALLTOALLV_TAG, comm_ptr, + MPIR_ALLTOALLV_TAG, comm_ptr, coll_group, &reqarray[req_cnt], errflag); MPIR_ERR_CHECK(mpi_errno); req_cnt++; diff --git a/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c b/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c index fafa9502e00..f7d6d6ba967 100644 --- a/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c +++ b/src/mpi/coll/alltoallw/alltoallw_inter_pairwise_exchange.c @@ -67,7 +67,7 @@ int MPIR_Alltoallw_inter_pairwise_exchange(const void *sendbuf, const MPI_Aint s mpi_errno = MPIC_Sendrecv(sendaddr, sendcount, sendtype, dst, MPIR_ALLTOALLW_TAG, recvaddr, recvcount, recvtype, src, - MPIR_ALLTOALLW_TAG, comm_ptr, &status, errflag); + MPIR_ALLTOALLW_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } fn_exit: diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c index bd67b157b71..3e367430969 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c @@ -56,7 +56,7 @@ int MPIR_Alltoallw_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[j], recvtypes[j], j, MPIR_ALLTOALLW_TAG, j, MPIR_ALLTOALLW_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == j) { /* same as above with i/j args reversed */ @@ -64,7 +64,7 @@ int MPIR_Alltoallw_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP recvcounts[i], recvtypes[i], i, MPIR_ALLTOALLW_TAG, i, MPIR_ALLTOALLW_TAG, - comm_ptr, &status, errflag); + comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c index 348c5528af0..cc6d6960063 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c @@ -68,7 +68,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount if (type_size) { mpi_errno = MPIC_Irecv((char *) recvbuf + rdispls[dst], recvcounts[dst], recvtypes[dst], dst, - MPIR_ALLTOALLW_TAG, comm_ptr, + MPIR_ALLTOALLW_TAG, comm_ptr, coll_group, &reqarray[outstanding_requests]); MPIR_ERR_CHECK(mpi_errno); @@ -84,7 +84,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount if (type_size) { mpi_errno = MPIC_Isend((char *) sendbuf + sdispls[dst], sendcounts[dst], sendtypes[dst], dst, - MPIR_ALLTOALLW_TAG, comm_ptr, + MPIR_ALLTOALLW_TAG, comm_ptr, coll_group, &reqarray[outstanding_requests], errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c index 849806e5658..6ae5ba6c00f 100644 --- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c +++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c @@ -28,7 +28,8 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_ src = (rank - mask + size) % size; mpi_errno = MPIC_Sendrecv(NULL, 0, MPI_BYTE, dst, MPIR_BARRIER_TAG, NULL, 0, MPI_BYTE, - src, MPIR_BARRIER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + src, MPIR_BARRIER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); mask <<= 1; } @@ -98,7 +99,7 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, /* recv from (k-1) nbrs */ mpi_errno = - MPIC_Irecv(NULL, 0, MPI_BYTE, from, MPIR_BARRIER_TAG, comm, + MPIC_Irecv(NULL, 0, MPI_BYTE, from, MPIR_BARRIER_TAG, comm, coll_group, &recv_reqs[(j - 1) + ((k - 1) * (i & 1))]); MPIR_ERR_CHECK(mpi_errno); /* wait on recvs from prev phase */ @@ -108,9 +109,8 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = - MPIC_Isend(NULL, 0, MPI_BYTE, to, MPIR_BARRIER_TAG, comm, &send_reqs[j - 1], - errflag); + mpi_errno = MPIC_Isend(NULL, 0, MPI_BYTE, to, MPIR_BARRIER_TAG, comm, coll_group, + &send_reqs[j - 1], errflag); MPIR_ERR_CHECK(mpi_errno); } mpi_errno = MPIC_Waitall(k - 1, send_reqs, MPI_STATUSES_IGNORE); diff --git a/src/mpi/coll/bcast/bcast.h b/src/mpi/coll/bcast/bcast.h index 23d15fc1325..5e2d5b3194e 100644 --- a/src/mpi/coll/bcast/bcast.h +++ b/src/mpi/coll/bcast/bcast.h @@ -9,7 +9,7 @@ #include "mpiimpl.h" int MPII_Scatter_for_bcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, - int root, MPIR_Comm * comm_ptr, MPI_Aint nbytes, void *tmp_buf, - int is_contig, MPIR_Errflag_t errflag); + int root, MPIR_Comm * comm_ptr, int coll_group, MPI_Aint nbytes, + void *tmp_buf, int is_contig, MPIR_Errflag_t errflag); #endif /* BCAST_H_INCLUDED */ diff --git a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c index 4bace6c2a72..087718375b8 100644 --- a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c +++ b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c @@ -29,7 +29,8 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, mpi_errno = MPI_SUCCESS; } else if (root == MPI_ROOT) { /* root sends to rank 0 on remote group and returns */ - mpi_errno = MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, comm_ptr, errflag); + mpi_errno = + MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. rank 0 on remote group receives from root */ @@ -37,7 +38,8 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, rank = comm_ptr->rank; if (rank == 0) { - mpi_errno = MPIC_Recv(buffer, count, datatype, root, MPIR_BCAST_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(buffer, count, datatype, root, MPIR_BCAST_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_intra_binomial.c b/src/mpi/coll/bcast/bcast_intra_binomial.c index 0040b6a98ce..e291927703f 100644 --- a/src/mpi/coll/bcast/bcast_intra_binomial.c +++ b/src/mpi/coll/bcast/bcast_intra_binomial.c @@ -92,10 +92,10 @@ int MPIR_Bcast_intra_binomial(void *buffer, src += comm_size; if (!is_contig) mpi_errno = MPIC_Recv(tmp_buf, nbytes, MPI_BYTE, src, - MPIR_BCAST_TAG, comm_ptr, status_p); + MPIR_BCAST_TAG, comm_ptr, coll_group, status_p); else mpi_errno = MPIC_Recv(buffer, count, datatype, src, - MPIR_BCAST_TAG, comm_ptr, status_p); + MPIR_BCAST_TAG, comm_ptr, coll_group, status_p); MPIR_ERR_CHECK(mpi_errno); #ifdef HAVE_ERROR_CHECKING /* check that we received as much as we expected */ @@ -129,10 +129,10 @@ int MPIR_Bcast_intra_binomial(void *buffer, dst -= comm_size; if (!is_contig) mpi_errno = MPIC_Send(tmp_buf, nbytes, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); else mpi_errno = MPIC_Send(buffer, count, datatype, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } mask >>= 1; diff --git a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c index 3ef926f3630..e09dc797bc0 100644 --- a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c @@ -118,7 +118,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (src != -1) { /* post receive from parent */ mpi_errno = MPIC_Irecv((char *) sendbuf + offset, msgsize, MPI_BYTE, - src, MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++]); + src, MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++]); MPIR_ERR_CHECK(mpi_errno); } offset += msgsize; @@ -131,7 +131,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (src != -1) { mpi_errno = MPIC_Irecv((char *) sendbuf + offset, msgsize, MPI_BYTE, - src, MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++]); + src, MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++]); MPIR_ERR_CHECK(mpi_errno); } offset += msgsize; @@ -170,7 +170,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (src != -1) { mpi_errno = MPIC_Recv((char *) sendbuf + offset, msgsize, MPI_BYTE, - src, MPIR_BCAST_TAG, comm_ptr, &status); + src, MPIR_BCAST_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Get_count_impl(&status, MPI_BYTE, &recvd_size); MPIR_ERR_CHKANDJUMP2(recvd_size != nbytes, mpi_errno, MPI_ERR_OTHER, @@ -191,11 +191,11 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (!is_nb) { mpi_errno = MPIC_Send((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); } else { mpi_errno = MPIC_Isend((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], errflag); } MPIR_ERR_CHECK(mpi_errno); @@ -207,11 +207,11 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, dst = *p; if (!is_nb) { mpi_errno = MPIC_Send((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); } else { mpi_errno = MPIC_Isend((char *) sendbuf + offset, msgsize, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], errflag); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c index 43abef6d1be..53c968276bc 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c @@ -78,7 +78,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint scatter_size; scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ - mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, + mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, coll_group, nbytes, tmp_buf, is_contig, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -119,12 +119,10 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, curr_size, MPI_BYTE, dst, MPIR_BCAST_TAG, ((char *) tmp_buf + recv_offset), (nbytes - recv_offset < 0 ? 0 : nbytes - recv_offset), - MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, &status, errflag); + MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); - if (mpi_errno) { - recv_size = 0; - } else - MPIR_Get_count_impl(&status, MPI_BYTE, &recv_size); + MPIR_Get_count_impl(&status, MPI_BYTE, &recv_size); curr_size += recv_size; } @@ -184,7 +182,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, * fflush(stdout); */ mpi_errno = MPIC_Send(((char *) tmp_buf + offset), recv_size, MPI_BYTE, dst, - MPIR_BCAST_TAG, comm_ptr, errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, errflag); /* recv_size was set in the previous * receive. that's the amount of data to be * sent now. */ @@ -199,7 +197,8 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, * relative_rank, dst); */ mpi_errno = MPIC_Recv(((char *) tmp_buf + offset), nbytes - offset < 0 ? 0 : nbytes - offset, - MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, &status); + MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, + &status); /* nprocs_completed is also equal to the no. of processes * whose data we don't have */ MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c index ad44f2f03dd..bc64139c6c9 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c @@ -70,7 +70,7 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, scatter_size = (nbytes + comm_size - 1) / comm_size; /* ceiling division */ - mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, + mpi_errno = MPII_Scatter_for_bcast(buffer, count, datatype, root, comm_ptr, coll_group, nbytes, tmp_buf, is_contig, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -104,7 +104,8 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, mpi_errno = MPIC_Sendrecv((char *) tmp_buf + right_disp, right_count, MPI_BYTE, right, MPIR_BCAST_TAG, (char *) tmp_buf + left_disp, left_count, - MPI_BYTE, left, MPIR_BCAST_TAG, comm_ptr, &status, errflag); + MPI_BYTE, left, MPIR_BCAST_TAG, comm_ptr, coll_group, &status, + errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Get_count_impl(&status, MPI_BYTE, &recvd_size); curr_size += recvd_size; diff --git a/src/mpi/coll/bcast/bcast_intra_smp.c b/src/mpi/coll/bcast/bcast_intra_smp.c index 402e2dbc860..5881213d638 100644 --- a/src/mpi/coll/bcast/bcast_intra_smp.c +++ b/src/mpi/coll/bcast/bcast_intra_smp.c @@ -40,12 +40,12 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) and is on our node (!-1) */ if (root == comm_ptr->rank) { mpi_errno = MPIC_Send(buffer, count, datatype, 0, - MPIR_BCAST_TAG, comm_ptr->node_comm, errflag); + MPIR_BCAST_TAG, comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (0 == comm_ptr->node_comm->rank) { mpi_errno = MPIC_Recv(buffer, count, datatype, MPIR_Get_intranode_rank(comm_ptr, root), - MPIR_BCAST_TAG, comm_ptr->node_comm, status_p); + MPIR_BCAST_TAG, comm_ptr->node_comm, coll_group, status_p); MPIR_ERR_CHECK(mpi_errno); #ifdef HAVE_ERROR_CHECKING /* check that we received as much as we expected */ diff --git a/src/mpi/coll/bcast/bcast_intra_tree.c b/src/mpi/coll/bcast/bcast_intra_tree.c index 560232ea825..c7a27a02db2 100644 --- a/src/mpi/coll/bcast/bcast_intra_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_tree.c @@ -129,7 +129,8 @@ int MPIR_Bcast_intra_tree(void *buffer, if ((parent != -1 && tree_type != MPIR_TREE_TYPE_KARY) || (!is_root && tree_type == MPIR_TREE_TYPE_KARY)) { src = parent; - mpi_errno = MPIC_Recv(send_buf, count, dtype, src, MPIR_BCAST_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(send_buf, count, dtype, src, MPIR_BCAST_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); /* check that we received as much as we expected */ MPIR_Get_count_impl(&status, MPI_BYTE, &recvd_size); @@ -147,10 +148,12 @@ int MPIR_Bcast_intra_tree(void *buffer, if (!is_nb) { mpi_errno = - MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, errflag); + MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, + errflag); } else { mpi_errno = MPIC_Isend(send_buf, count, dtype, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], + errflag); } MPIR_ERR_CHECK(mpi_errno); } @@ -161,10 +164,12 @@ int MPIR_Bcast_intra_tree(void *buffer, if (!is_nb) { mpi_errno = - MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, errflag); + MPIC_Send(send_buf, count, dtype, dst, MPIR_BCAST_TAG, comm_ptr, coll_group, + errflag); } else { mpi_errno = MPIC_Isend(send_buf, count, dtype, dst, - MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], errflag); + MPIR_BCAST_TAG, comm_ptr, coll_group, &reqs[num_req++], + errflag); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/bcast/bcast_utils.c b/src/mpi/coll/bcast/bcast_utils.c index 4d5900385da..1105b434481 100644 --- a/src/mpi/coll/bcast/bcast_utils.c +++ b/src/mpi/coll/bcast/bcast_utils.c @@ -20,7 +20,7 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), MPI_Aint count ATTRIBUTE((unused)), MPI_Datatype datatype ATTRIBUTE((unused)), int root, - MPIR_Comm * comm_ptr, + MPIR_Comm * comm_ptr, int coll_group, MPI_Aint nbytes, void *tmp_buf, int is_contig, MPIR_Errflag_t errflag) { MPI_Status status; @@ -66,7 +66,8 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), } else { mpi_errno = MPIC_Recv(((char *) tmp_buf + relative_rank * scatter_size), - recv_size, MPI_BYTE, src, MPIR_BCAST_TAG, comm_ptr, &status); + recv_size, MPI_BYTE, src, MPIR_BCAST_TAG, comm_ptr, + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); /* query actual size of data received */ MPIR_Get_count_impl(&status, MPI_BYTE, &curr_size); @@ -93,7 +94,8 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), dst -= comm_size; mpi_errno = MPIC_Send(((char *) tmp_buf + scatter_size * (relative_rank + mask)), - send_size, MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, errflag); + send_size, MPI_BYTE, dst, MPIR_BCAST_TAG, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); curr_size -= send_size; diff --git a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c index 3d10c8d08cf..62dbb6502c4 100644 --- a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c +++ b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c @@ -93,7 +93,7 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(partial_scan, count, datatype, dst, MPIR_EXSCAN_TAG, tmp_buf, count, datatype, dst, - MPIR_EXSCAN_TAG, comm_ptr, &status, errflag); + MPIR_EXSCAN_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank > dst) { diff --git a/src/mpi/coll/gather/gather_inter_linear.c b/src/mpi/coll/gather/gather_inter_linear.c index a1f7dbdec37..2b6b76b6dbb 100644 --- a/src/mpi/coll/gather/gather_inter_linear.c +++ b/src/mpi/coll/gather/gather_inter_linear.c @@ -33,14 +33,13 @@ int MPIR_Gather_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Dataty MPIR_Datatype_get_extent_macro(recvtype, extent); for (i = 0; i < remote_size; i++) { - mpi_errno = - MPIC_Recv(((char *) recvbuf + recvcount * i * extent), recvcount, recvtype, i, - MPIR_GATHER_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(((char *) recvbuf + recvcount * i * extent), recvcount, recvtype, + i, MPIR_GATHER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = - MPIC_Send(sendbuf, sendcount, sendtype, root, MPIR_GATHER_TAG, comm_ptr, errflag); + mpi_errno = MPIC_Send(sendbuf, sendcount, sendtype, root, MPIR_GATHER_TAG, + comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c b/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c index 22a3e71aa9d..6f35bb27dde 100644 --- a/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c +++ b/src/mpi/coll/gather/gather_inter_local_gather_remote_send.c @@ -36,7 +36,7 @@ int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, MPI_Aint sen /* root receives data from rank 0 on remote group */ mpi_errno = MPIC_Recv(recvbuf, recvcount * remote_size, recvtype, 0, MPIR_GATHER_TAG, comm_ptr, - &status); + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. Rank 0 allocates temporary buffer, does @@ -73,7 +73,7 @@ int MPIR_Gather_inter_local_gather_remote_send(const void *sendbuf, MPI_Aint sen if (rank == 0) { mpi_errno = MPIC_Send(tmp_buf, sendcount * local_size * sendtype_sz, MPI_BYTE, - root, MPIR_GATHER_TAG, comm_ptr, errflag); + root, MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/gather/gather_intra_binomial.c b/src/mpi/coll/gather/gather_intra_binomial.c index 97b1dfd1d88..5dce21e04df 100644 --- a/src/mpi/coll/gather/gather_intra_binomial.c +++ b/src/mpi/coll/gather/gather_intra_binomial.c @@ -137,13 +137,15 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data (((rank + mask) % comm_size) * (MPI_Aint) recvcount * extent)), (MPI_Aint) recvblks * recvcount, - recvtype, src, MPIR_GATHER_TAG, comm_ptr, &status); + recvtype, src, MPIR_GATHER_TAG, comm_ptr, coll_group, + &status); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { /* small transfer size case. cast ok */ MPIR_Assert(recvblks * nbytes == (int) (recvblks * nbytes)); mpi_errno = MPIC_Recv(tmp_buf, (int) (recvblks * nbytes), - MPI_BYTE, src, MPIR_GATHER_TAG, comm_ptr, &status); + MPI_BYTE, src, MPIR_GATHER_TAG, comm_ptr, coll_group, + &status); MPIR_ERR_CHECK(mpi_errno); copy_offset = rank + mask; copy_blks = recvblks; @@ -163,7 +165,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Recv(recvbuf, 1, tmp_type, src, - MPIR_GATHER_TAG, comm_ptr, &status); + MPIR_GATHER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&tmp_type); @@ -184,7 +186,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data offset = (mask - 1) * nbytes; mpi_errno = MPIC_Recv(((char *) tmp_buf + offset), recvblks * nbytes, MPI_BYTE, src, - MPIR_GATHER_TAG, comm_ptr, &status); + MPIR_GATHER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); curr_cnt += (recvblks * nbytes); } @@ -196,11 +198,11 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data if (!tmp_buf_size) { /* leaf nodes send directly from sendbuf */ mpi_errno = MPIC_Send(sendbuf, sendcount, sendtype, dst, - MPIR_GATHER_TAG, comm_ptr, errflag); + MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { mpi_errno = MPIC_Send(tmp_buf, curr_cnt, MPI_BYTE, dst, - MPIR_GATHER_TAG, comm_ptr, errflag); + MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else { MPI_Aint blocks[2]; @@ -225,7 +227,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Send(MPI_BOTTOM, 1, tmp_type, dst, - MPIR_GATHER_TAG, comm_ptr, errflag); + MPIR_GATHER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); MPIR_Type_free_impl(&tmp_type); if (types[1] != MPI_BYTE) diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c index bf571243595..906f1ed6bb2 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c @@ -61,7 +61,8 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, } else { mpi_errno = MPIC_Irecv(((char *) recvbuf + displs[i] * extent), recvcounts[i], recvtype, i, - MPIR_GATHERV_TAG, comm_ptr, &reqarray[reqs++]); + MPIR_GATHERV_TAG, comm_ptr, coll_group, + &reqarray[reqs++]); MPIR_ERR_CHECK(mpi_errno); } } @@ -74,7 +75,7 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (sendcount) { mpi_errno = MPIC_Send(sendbuf, sendcount, sendtype, root, - MPIR_GATHERV_TAG, comm_ptr, errflag); + MPIR_GATHERV_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c index 65fb949e5b3..24d6d3afa9b 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c @@ -57,11 +57,13 @@ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, src = (rank + pof2) % comm_size; dst = (rank - pof2 + comm_size) % comm_size; - mpi_errno = MPIR_Sched_send(tmp_buf, curr_cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, curr_cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); /* logically sendrecv, so no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + curr_cnt * recvtype_sz), - curr_cnt * recvtype_sz, MPI_BYTE, src, comm_ptr, s); + curr_cnt * recvtype_sz, MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -77,11 +79,13 @@ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, dst = (rank - pof2 + comm_size) % comm_size; mpi_errno = - MPIR_Sched_send(tmp_buf, rem * recvcount * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + MPIR_Sched_send(tmp_buf, rem * recvcount * recvtype_sz, MPI_BYTE, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* logically sendrecv, so no barrier here */ mpi_errno = MPIR_Sched_recv((char *) tmp_buf + curr_cnt * recvtype_sz, - rem * recvcount * recvtype_sz, MPI_BYTE, src, comm_ptr, s); + rem * recvcount * recvtype_sz, MPI_BYTE, src, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c index c035d04939b..ef41fe4658b 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c @@ -105,12 +105,13 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint if (dst < comm_size) { mpi_errno = MPIR_Sched_send_defer(((char *) recvbuf + send_offset), - &ss->curr_count, recvtype, dst, comm_ptr, s); + &ss->curr_count, recvtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); /* send-recv, no sched barrier here */ mpi_errno = MPIR_Sched_recv_status(((char *) recvbuf + recv_offset), ((comm_size - dst_tree_root) * recvcount), - recvtype, dst, comm_ptr, &ss->status, s); + recvtype, dst, comm_ptr, coll_group, &ss->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -169,7 +170,7 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint * sent now. */ mpi_errno = MPIR_Sched_send_defer(((char *) recvbuf + offset), &ss->last_recv_count, - recvtype, dst, comm_ptr, s); + recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -183,7 +184,8 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint mpi_errno = MPIR_Sched_recv_status(((char *) recvbuf + offset), ((comm_size - (my_tree_root + mask)) * recvcount), - recvtype, dst, comm_ptr, &ss->status, s); + recvtype, dst, comm_ptr, coll_group, + &ss->status, s); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&get_count, ss, s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c index 16f148f3187..48c8c2c6e05 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c @@ -53,11 +53,11 @@ int MPIR_Iallgather_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, MP jnext = left; for (i = 1; i < comm_size; i++) { mpi_errno = MPIR_Sched_send(((char *) recvbuf + j * recvcount * recvtype_extent), - recvcount, recvtype, right, comm_ptr, s); + recvcount, recvtype, right, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* concurrent, no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) recvbuf + jnext * recvcount * recvtype_extent), - recvcount, recvtype, left, comm_ptr, s); + recvcount, recvtype, left, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c index 7743e2bb7ca..750bcf41b64 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c @@ -117,18 +117,20 @@ MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* Receive at the exact location. */ mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_recvbuf + j * recvcount * delta * recvtype_extent, - count, recvtype, src, tag, comm, sched, 0, NULL, &vtx_id); + count, recvtype, src, tag, comm, coll_group, sched, 0, NULL, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); recv_id[i_recv++] = vtx_id; /* Send from the start of recv till `count` amount of data. */ if (i == 0) mpi_errno = - MPIR_TSP_sched_isend(tmp_recvbuf, count, recvtype, dst, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(tmp_recvbuf, count, recvtype, dst, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); else mpi_errno = MPIR_TSP_sched_isend(tmp_recvbuf, count, recvtype, dst, tag, - comm, sched, n_invtcs, recv_id, &vtx_id); + comm, coll_group, sched, n_invtcs, recv_id, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); } n_invtcs += (k - 1); diff --git a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c index 9b6d213fb73..1a6092c18ff 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c @@ -33,7 +33,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(int rank, int n /* send my data to partner */ mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), count * recvcount, recvtype, - partner, tag, comm, sched, 0, NULL, &vtx_id); + partner, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); @@ -46,7 +46,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_data_exchange(int rank, int n /* recv data from my partner */ mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), count * recvcount, recvtype, - partner, tag, comm, sched, 0, NULL, &vtx_id); + partner, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -83,15 +83,15 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step1(int step1_sendto, int * else buf_to_send = (void *) sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, recvcount, recvtype, step1_sendto, tag, comm, sched, - 0, NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, recvcount, recvtype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets the data from non-participating rank */ MPI_Aint recv_offset = step1_recvfrom[i] * recv_extent * recvcount; mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recvcount, recvtype, - step1_recvfrom[i], tag, comm, sched, n_invtcs, invtx, - &vtx_id); + step1_recvfrom[i], tag, comm, coll_group, sched, + n_invtcs, invtx, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -140,7 +140,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step2(int step1_sendto, int s MPI_Aint send_offset = offset * recv_extent * recvcount; mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), count * recvcount, recvtype, - nbr, tag, comm, sched, nrecvs, recv_id, &vtx_id); + nbr, tag, comm, coll_group, sched, nrecvs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, @@ -158,7 +158,7 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step2(int step1_sendto, int s MPI_Aint recv_offset = offset * recv_extent * recvcount; mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), count * recvcount, recvtype, - nbr, tag, comm, sched, 0, NULL, &vtx_id); + nbr, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); @@ -202,13 +202,13 @@ static int MPIR_TSP_Iallgather_sched_intra_recexch_step3(int step1_sendto, int * if (step1_sendto != -1) { mpi_errno = MPIR_TSP_sched_irecv(recvbuf, recvcount * nranks, recvtype, step1_sendto, tag, comm, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIR_TSP_sched_isend(recvbuf, recvcount * nranks, recvtype, step1_recvfrom[i], - tag, comm, sched, nrecvs, recv_id, &vtx_id); + tag, comm, coll_group, sched, nrecvs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/iallgather/iallgather_tsp_ring.c b/src/mpi/coll/iallgather/iallgather_tsp_ring.c index 9798e63aae7..b29f64afb81 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_ring.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_ring.c @@ -90,14 +90,16 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount nvtcs = 1; vtcs[0] = dtcopy_id[0]; mpi_errno = MPIR_TSP_sched_isend((char *) sbuf, recvcount, recvtype, - dst, tag, comm, sched, nvtcs, vtcs, &send_id[0]); + dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id[0]); nvtcs = 0; } else { nvtcs = 2; vtcs[0] = recv_id[(i - 1) % 3]; vtcs[1] = send_id[(i - 1) % 3]; mpi_errno = MPIR_TSP_sched_isend((char *) sbuf, recvcount, recvtype, - dst, tag, comm, sched, nvtcs, vtcs, &send_id[i % 3]); + dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id[i % 3]); if (i == 1) { nvtcs = 2; vtcs[0] = send_id[0]; @@ -112,7 +114,8 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_TSP_sched_irecv((char *) rbuf, recvcount, recvtype, - src, tag, comm, sched, nvtcs, vtcs, &recv_id[i % 3]); + src, tag, comm, coll_group, sched, nvtcs, vtcs, + &recv_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c index dd764574004..d3d0cb073af 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c @@ -70,11 +70,14 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, incoming_count += recvcounts[(src + i) % comm_size]; } - mpi_errno = MPIR_Sched_send(tmp_buf, curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + curr_count * recvtype_sz), - incoming_count * recvtype_sz, MPI_BYTE, src, comm_ptr, s); + incoming_count * recvtype_sz, MPI_BYTE, src, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -93,12 +96,13 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, for (i = 0; i < rem; i++) cnt += recvcounts[(rank + i) % comm_size]; - mpi_errno = MPIR_Sched_send(tmp_buf, cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, cnt * recvtype_sz, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + curr_count * recvtype_sz), (total_count - curr_count) * recvtype_sz, MPI_BYTE, - src, comm_ptr, s); + src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c index 6fc7a618799..7a03d90d534 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c @@ -109,11 +109,13 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain incoming_count += recvcounts[j]; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + send_offset * recvtype_sz), - curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + curr_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + recv_offset * recvtype_sz), - incoming_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, s); + incoming_count * recvtype_sz, MPI_BYTE, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -178,7 +180,7 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain * sent now. */ mpi_errno = MPIR_Sched_send(((char *) tmp_buf + offset), incoming_count * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -200,7 +202,7 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + offset * recvtype_sz), incoming_count * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); curr_count += incoming_count; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c index 74e366d9d60..3572924050a 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c @@ -78,12 +78,12 @@ int MPIR_Iallgatherv_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, /* Communicate */ if (recvnow) { /* If there's no data to send, just do a recv call */ - mpi_errno = MPIR_Sched_recv(rbuf, recvnow, recvtype, left, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(rbuf, recvnow, recvtype, left, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); torecv -= recvnow; } if (sendnow) { /* If there's no data to receive, just do a send call */ - mpi_errno = MPIR_Sched_send(sbuf, sendnow, recvtype, right, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sbuf, sendnow, recvtype, right, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); tosend -= sendnow; } diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c index 3cbc1785426..baf490396a7 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c @@ -218,8 +218,8 @@ MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* Recv at the exact location */ mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_recvbuf + recv_index[idx] * recvtype_extent, - r_counts[i][j - 1], recvtype, src, tag, comm, sched, 0, NULL, - &vtx_id); + r_counts[i][j - 1], recvtype, src, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); recv_id[idx] = vtx_id; @@ -228,7 +228,7 @@ MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* Send from the start of recv till the count amount of data */ mpi_errno = MPIR_TSP_sched_isend(tmp_recvbuf, s_counts[i][j - 1], recvtype, dst, tag, comm, - sched, n_invtcs, recv_id, &vtx_id); + coll_group, sched, n_invtcs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } n_invtcs += (k - 1); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c index 12c7e0e0130..fbfba53e3d1 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c @@ -40,7 +40,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(int rank, int /* send my data to partner */ mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), send_count, recvtype, partner, - tag, comm, sched, 0, NULL, &vtx_id); + tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); /* calculate offset and count of the data to be received from the partner */ @@ -54,7 +54,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_data_exchange(int rank, int recv_offset, recv_count)); /* recv data from my partner */ mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recv_count, recvtype, - partner, tag, comm, sched, 0, NULL, &vtx_id); + partner, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -92,7 +92,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(int step1_sendto, int buf_to_send = (void *) sendbuf; mpi_errno = MPIR_TSP_sched_isend(buf_to_send, recvcounts[rank], recvtype, step1_sendto, tag, comm, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets the data from non-participating rank */ @@ -100,7 +100,7 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step1(int step1_sendto, int mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recvcounts[step1_recvfrom[i]], recvtype, step1_recvfrom[i], - tag, comm, sched, n_invtcs, invtx, &vtx_id); + tag, comm, coll_group, sched, n_invtcs, invtx, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -152,7 +152,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n for (x = 0; x < count; x++) send_count += recvcounts[offset + x]; mpi_errno = MPIR_TSP_sched_isend(((char *) recvbuf + send_offset), send_count, recvtype, - nbr, tag, comm, sched, nrecvs, recv_id, &vtx_id); + nbr, tag, comm, coll_group, sched, nrecvs, recv_id, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, @@ -173,7 +174,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n recv_count += recvcounts[offset + x]; mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + recv_offset), recv_count, recvtype, - nbr, tag, comm, sched, 0, NULL, &vtx_id); + nbr, tag, comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); recv_id[j * (k - 1) + i] = vtx_id; @@ -221,14 +222,14 @@ static int MPIR_TSP_Iallgatherv_sched_intra_recexch_step3(int step1_sendto, int if (step1_sendto != -1) { mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, total_count, recvtype, step1_sendto, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, total_count, recvtype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { mpi_errno = MPIR_TSP_sched_isend(recvbuf, total_count, recvtype, step1_recvfrom[i], - tag, comm, sched, nrecvs, recv_id, &vtx_id); + tag, comm, coll_group, sched, nrecvs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c index 129c1f67d94..f6252210917 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c @@ -94,8 +94,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun vtcs[0] = dtcopy_id[0]; mpi_errno = - MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, sched, - nvtcs, vtcs, &send_id[i % 3]); + MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); nvtcs = 0; } else { @@ -104,8 +104,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun vtcs[1] = send_id[(i - 1) % 3]; mpi_errno = - MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, sched, - nvtcs, vtcs, &send_id[i % 3]); + MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); if (i == 1) { @@ -121,8 +121,8 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun } mpi_errno = - MPIR_TSP_sched_irecv(rbuf, recvcounts[recv_rank], recvtype, src, tag, comm, sched, - nvtcs, vtcs, &recv_id[i % 3]); + MPIR_TSP_sched_irecv(rbuf, recvcounts[recv_rank], recvtype, src, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); /* Copy to correct position in recvbuf */ mpi_errno = diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c index 36278689541..9ea8c1da8dd 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c @@ -51,7 +51,8 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -60,7 +61,8 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -86,10 +88,10 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re /* Send the most current data, which is in recvbuf. Recv * into tmp_buf */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -121,10 +123,12 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re * (rank-1), the ranks who didn't participate above. */ if (rank < 2 * rem) { if (rank % 2) { /* odd */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* even */ - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c index 44cc9175f1a..627ccb2440a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c @@ -57,7 +57,8 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -66,7 +67,8 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -133,11 +135,11 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -192,11 +194,11 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo } mpi_errno = MPIR_Sched_recv(((char *) recvbuf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -212,10 +214,12 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo * (rank-1), the ranks who didn't participate above. */ if (rank < 2 * rem) { if (rank % 2) { /* odd */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* even */ - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c index c498685986a..ab88975ffa4 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c @@ -152,8 +152,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, nbr = step2_nbrs[phase][i]; mpi_errno = - MPIR_TSP_sched_isend(tmp_buf, count, datatype, nbr, tag, comm, sched, nvtcs, vtcs, - &send_id[i]); + MPIR_TSP_sched_isend(tmp_buf, count, datatype, nbr, tag, comm, coll_group, sched, + nvtcs, vtcs, &send_id[i]); MPIR_ERR_CHECK(mpi_errno); if (rank > nbr) { myidx = i + 1; @@ -168,8 +168,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, vtcs[nvtcs++] = (counter == 0) ? reduce_id[k - 2] : reduce_id[counter - 1]; } mpi_errno = - MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, sched, nvtcs, - vtcs, &recv_id[buf]); + MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[buf]); MPIR_ERR_CHECK(mpi_errno); if (count != 0) { @@ -196,8 +196,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, vtcs[nvtcs++] = (counter == 0) ? reduce_id[k - 2] : reduce_id[counter - 1]; } mpi_errno = - MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, sched, nvtcs, - vtcs, &recv_id[buf]); + MPIR_TSP_sched_irecv(nbr_buffer[buf], count, datatype, nbr, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[buf]); MPIR_ERR_CHECK(mpi_errno); @@ -233,8 +233,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, * send the data to non-partcipating ranks */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { @@ -253,8 +253,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, vtcs[0] = reduce_id[k - 2]; } mpi_errno = - MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, sched, - nvtcs, vtcs, &vtx_id); + MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, + coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c index 5053de12eed..67cfde36fb1 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c @@ -139,14 +139,14 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co * send the data to non-partcipating ranks */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, sched, 1, - &sink_id, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, count, datatype, step1_sendto, tag, comm, coll_group, + sched, 1, &sink_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { for (i = 0; i < step1_nrecvs; i++) { mpi_errno = - MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, sched, - nvtcs, recv_id, &vtx_id); + MPIR_TSP_sched_isend(recvbuf, count, datatype, step1_recvfrom[i], tag, comm, + coll_group, sched, nvtcs, recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c index 26902fd80b1..9e6632673fe 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recursive_exchange_common.c @@ -62,8 +62,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, else buf_to_send = sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, count, datatype, step1_sendto, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, count, datatype, step1_sendto, tag, comm, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { /* Step 2 participating rank */ step1_recvbuf = *step1_recvbuf_ = @@ -89,8 +89,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_step1(const void *sendbuf, reduce_id[i - 1])); } mpi_errno = MPIR_TSP_sched_irecv(step1_recvbuf[i], count, datatype, - step1_recvfrom[i], tag, comm, sched, nvtcs, vtcs, - &recv_id[i]); + step1_recvfrom[i], tag, comm, coll_group, sched, nvtcs, + vtcs, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); if (count != 0) { /* Reduce only if data is present */ /* setup reduce dependencies */ diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c index f8fec2f41f8..8cee22ff4f0 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c @@ -88,8 +88,8 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI nvtcs = (i == 0) ? 0 : 1; vtcs = (i == 0) ? 0 : reduce_id[(i - 1) % 2]; mpi_errno = - MPIR_TSP_sched_irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, sched, nvtcs, - &vtcs, &recv_id); + MPIR_TSP_sched_irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, coll_group, + sched, nvtcs, &vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); mpi_errno = @@ -101,7 +101,8 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI mpi_errno = MPIR_TSP_sched_isend((char *) recvbuf + displs[send_rank] * extent, cnts[send_rank], - datatype, dst, tag, comm, sched, nvtcs, &vtcs, &vtx_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, &vtcs, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); } MPIR_CHKLMEM_MALLOC(reduce_id, int *, 2 * sizeof(int), mpi_errno, "reduce_id", MPL_MEM_COLL); diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c index 0a363224777..b4f51312f1f 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c @@ -139,8 +139,9 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI nvtcs = 1; } - mpi_errno = MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, - sched, nvtcs, vtcs, &recv_id[i]); + mpi_errno = + MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); /* Setup dependencies for reduction. Reduction depends on the corresponding recv to complete */ @@ -187,7 +188,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI if (rank != root) { mpi_errno = MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, my_tree.parent, tag, comm, - sched, nvtcs, vtcs, &vtx_id); + coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -201,7 +202,8 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI if (my_tree.parent != -1) { mpi_errno = MPIR_TSP_sched_irecv(reduce_address, msgsize, datatype, - my_tree.parent, tag, comm, sched, 1, &sink_id, &bcast_recv_id); + my_tree.parent, tag, comm, coll_group, sched, 1, &sink_id, + &bcast_recv_id); MPIR_ERR_CHECK(mpi_errno); } @@ -211,7 +213,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI vtcs[0] = bcast_recv_id; mpi_errno = MPIR_TSP_sched_imcast(reduce_address, msgsize, datatype, ut_int_array(my_tree.children), num_children, tag, - comm, sched, nvtcs, vtcs, &vtx_id); + comm, coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c index 27ed068ce8a..271e0721b11 100644 --- a/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoall/ialltoall_inter_sched_pairwise_exchange.c @@ -54,9 +54,9 @@ int MPIR_Ialltoall_inter_sched_pairwise_exchange(const void *sendbuf, MPI_Aint s sendaddr = (char *) sendbuf + dst * sendcount * sendtype_extent; } - mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c index 3338aa5cd75..40ec46c3629 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c @@ -108,9 +108,9 @@ int MPIR_Ialltoall_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPIR_SCHED_BARRIER(s); /* now send and recv in parallel */ - mpi_errno = MPIR_Sched_send(tmp_buf, newtype_size, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_buf, newtype_size, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvbuf, 1, newtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, 1, newtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c index abccd14055d..fba7405785c 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c @@ -61,10 +61,11 @@ int MPIR_Ialltoall_intra_sched_inplace(const void *sendbuf, MPI_Aint sendcount, MPIR_SCHED_BARRIER(s); /* now simultaneously send from tmp_buf and recv to recvbuf */ - mpi_errno = MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, peer, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) recvbuf + peer * recvcount * recvtype_extent), - recvcount, recvtype, peer, comm_ptr, s); + recvcount, recvtype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c index c6380474e80..4e2c0df6667 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c @@ -63,10 +63,10 @@ int MPIR_Ialltoall_intra_sched_pairwise(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = MPIR_Sched_send(((char *) sendbuf + dst * sendcount * sendtype_extent), - sendcount, sendtype, dst, comm_ptr, s); + sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) recvbuf + src * recvcount * recvtype_extent), - recvcount, recvtype, src, comm_ptr, s); + recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c index a2aa1153099..a772b406270 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c @@ -45,14 +45,14 @@ int MPIR_Ialltoall_intra_sched_permuted_sendrecv(const void *sendbuf, MPI_Aint s for (i = 0; i < ss; i++) { dst = (rank + i + ii) % comm_size; mpi_errno = MPIR_Sched_recv(((char *) recvbuf + dst * recvcount * recvtype_extent), - recvcount, recvtype, dst, comm_ptr, s); + recvcount, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < ss; i++) { dst = (rank - i - ii + comm_size) % comm_size; mpi_errno = MPIR_Sched_send(((char *) sendbuf + dst * sendcount * sendtype_extent), - sendcount, sendtype, dst, comm_ptr, s); + sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c index c86987531f2..c6e09c66963 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c @@ -287,7 +287,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_TSP_sched_isend(tmp_sbuf[i][j - 1], packsize, MPI_BYTE, dst, tag, - comm, sched, 1, &packids[j - 1], &sendids[j - 1]); + comm, coll_group, sched, 1, &packids[j - 1], &sendids[j - 1]); MPIR_ERR_CHECK(mpi_errno); if (i != 0 && buffer_per_phase == 0) { /* this dependency holds only when we don't have dedicated recv buffer per phase */ @@ -296,7 +296,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = MPIR_TSP_sched_irecv(tmp_rbuf[i][j - 1], packsize, MPI_BYTE, - src, tag, comm, sched, recv_ninvtcs, recv_invtcs, + src, tag, comm, coll_group, sched, recv_ninvtcs, recv_invtcs, &recvids[j - 1]); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c index 7a1b8515c45..cfede050026 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c @@ -131,8 +131,8 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = - MPIR_TSP_sched_isend((char *) sbuf, size * recvcount, recvtype, dst, tag, comm, sched, - nvtcs, vtcs, &send_id[i % 3]); + MPIR_TSP_sched_isend((char *) sbuf, size * recvcount, recvtype, dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); /* schedule recv */ if (i == 0) @@ -149,8 +149,8 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, } mpi_errno = - MPIR_TSP_sched_irecv((char *) rbuf, size * recvcount, recvtype, src, tag, comm, sched, - nvtcs, vtcs, &recv_id[i % 3]); + MPIR_TSP_sched_irecv((char *) rbuf, size * recvcount, recvtype, src, tag, comm, + coll_group, sched, nvtcs, vtcs, &recv_id[i % 3]); MPIR_ERR_CHECK(mpi_errno); /* destination offset of the copy */ diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c index 86371daebed..7efd13aed44 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c @@ -110,13 +110,15 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendc src = (rank + i) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + src * recvcount * recvtype_extent, - recvcount, recvtype, src, tag, comm, sched, 0, NULL, &recv_id[i]); + recvcount, recvtype, src, tag, comm, coll_group, sched, 0, NULL, + &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) data_buf + dst * sendcount * sendtype_extent, - sendcount, sendtype, dst, tag, comm, sched, 0, NULL, &send_id[i]); + sendcount, sendtype, dst, tag, comm, coll_group, sched, 0, NULL, + &send_id[i]); MPIR_ERR_CHECK(mpi_errno); } @@ -137,15 +139,15 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendc src = (rank + i + j) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + src * recvcount * recvtype_extent, - recvcount, recvtype, src, tag, comm, sched, 1, &invtcs, - &recv_id[(i + j) % bblock]); + recvcount, recvtype, src, tag, comm, coll_group, sched, 1, + &invtcs, &recv_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i - j + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) data_buf + dst * sendcount * sendtype_extent, - sendcount, sendtype, dst, tag, comm, sched, 1, &invtcs, - &send_id[(i + j) % bblock]); + sendcount, sendtype, dst, tag, comm, coll_group, sched, 1, + &invtcs, &send_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c index 48f41bbb411..6c933b56841 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_inter_sched_pairwise_exchange.c @@ -67,9 +67,9 @@ int MPIR_Ialltoallv_inter_sched_pairwise_exchange(const void *sendbuf, const MPI if (recvcount * recvtype_size == 0) src = MPI_PROC_NULL; - mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c index ce1b6706bf6..fb42149f818 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c @@ -46,7 +46,8 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send dst = (rank + i + ii) % comm_size; if (recvcounts[dst] && recvtype_size) { mpi_errno = MPIR_Sched_recv((char *) recvbuf + rdispls[dst] * recv_extent, - recvcounts[dst], recvtype, dst, comm_ptr, s); + recvcounts[dst], recvtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } @@ -55,7 +56,8 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send dst = (rank - i - ii + comm_size) % comm_size; if (sendcounts[dst] && sendtype_size) { mpi_errno = MPIR_Sched_send((char *) sendbuf + sdispls[dst] * send_extent, - sendcounts[dst], sendtype, dst, comm_ptr, s); + sendcounts[dst], sendtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c index 4c51b670d06..d08ddfa2be4 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c @@ -58,10 +58,11 @@ int MPIR_Ialltoallv_intra_sched_inplace(const void *sendbuf, const MPI_Aint send dst = i; mpi_errno = MPIR_Sched_send(((char *) recvbuf + rdispls[dst] * recvtype_extent), - recvcounts[dst], recvtype, dst, comm_ptr, s); + recvcounts[dst], recvtype, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(tmp_buf, recvcounts[dst] * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c index 159fc019602..0f0d1cba5ac 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c @@ -56,8 +56,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint dst = (rank + j + i) % nranks; if (recvcounts[dst] && recvtype_size) { mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[dst] * recv_extent, - recvcounts[dst], recvtype, dst, tag, comm, sched, - 0, NULL, &vtx_id); + recvcounts[dst], recvtype, dst, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -66,8 +66,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint dst = (rank - j - i + nranks) % nranks; if (sendcounts[dst] && sendtype_size) { mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * send_extent, - sendcounts[dst], sendtype, dst, tag, comm, sched, - 0, NULL, &vtx_id); + sendcounts[dst], sendtype, dst, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c index 7fc0e5ccdfa..afd3ba5108e 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c @@ -52,12 +52,12 @@ int MPIR_TSP_Ialltoallv_sched_intra_inplace(const void *sendbuf, const MPI_Aint vtcs[0] = dtcopy_id; mpi_errno = MPIR_TSP_sched_isend((char *) recvbuf + rdispls[dst] * recv_extent, - recvcounts[dst], recvtype, dst, tag, comm, + recvcounts[dst], recvtype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, &send_id); MPIR_ERR_CHECK(mpi_errno); mpi_errno = - MPIR_TSP_sched_irecv(tmp_buf, recvcounts[dst], recvtype, dst, tag, comm, + MPIR_TSP_sched_irecv(tmp_buf, recvcounts[dst], recvtype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c index d5339566593..62c3cf55e54 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c @@ -63,15 +63,15 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain src = (rank + i) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent, - recvcounts[src], recvtype, src, tag, comm, sched, 0, NULL, - &recv_id[i]); + recvcounts[src], recvtype, src, tag, comm, coll_group, sched, 0, + NULL, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent, - sendcounts[dst], sendtype, dst, tag, comm, sched, 0, NULL, - &send_id[i]); + sendcounts[dst], sendtype, dst, tag, comm, coll_group, sched, 0, + NULL, &send_id[i]); MPIR_ERR_CHECK(mpi_errno); } @@ -93,15 +93,15 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain src = (rank + i + j) % size; mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent, - recvcounts[src], recvtype, src, tag, comm, sched, 1, &invtcs, - &recv_id[(i + j) % bblock]); + recvcounts[src], recvtype, src, tag, comm, coll_group, sched, + 1, &invtcs, &recv_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); dst = (rank - i - j + size) % size; mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent, - sendcounts[dst], sendtype, dst, tag, comm, sched, 1, &invtcs, - &send_id[(i + j) % bblock]); + sendcounts[dst], sendtype, dst, tag, comm, coll_group, sched, + 1, &invtcs, &send_id[(i + j) % bblock]); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c b/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c index 1b2b35cd414..9ff9cfb9868 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_inter_sched_pairwise_exchange.c @@ -60,10 +60,10 @@ int MPIR_Ialltoallw_inter_sched_pairwise_exchange(const void *sendbuf, const MPI sendtype = sendtypes[dst]; } - mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendaddr, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvaddr, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c index 13745bf24f9..e203f119915 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c @@ -53,7 +53,8 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Datatype_get_size_macro(recvtypes[dst], type_size); if (type_size) { mpi_errno = MPIR_Sched_recv((char *) recvbuf + rdispls[dst], - recvcounts[dst], recvtypes[dst], dst, comm_ptr, s); + recvcounts[dst], recvtypes[dst], dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } @@ -66,7 +67,8 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Datatype_get_size_macro(sendtypes[dst], type_size); if (type_size) { mpi_errno = MPIR_Sched_send((char *) sendbuf + sdispls[dst], - sendcounts[dst], sendtypes[dst], dst, comm_ptr, s); + sendcounts[dst], sendtypes[dst], dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c index eab0e1cb524..1f865665ad9 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c @@ -67,10 +67,11 @@ int MPIR_Ialltoallw_intra_sched_inplace(const void *sendbuf, const MPI_Aint send MPIR_Datatype_get_size_macro(recvtypes[dst], recvtype_sz); mpi_errno = MPIR_Sched_send(((char *) recvbuf + rdispls[dst]), - recvcounts[dst], recvtypes[dst], dst, comm_ptr, s); + recvcounts[dst], recvtypes[dst], dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(tmp_buf, recvcounts[dst] * recvtype_sz, MPI_BYTE, - dst, comm_ptr, s); + dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c index 04b46f2ff57..9eaf904d2bf 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c @@ -48,7 +48,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint if (recvtype_size) { mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[dst], recvcounts[dst], recvtypes[dst], dst, tag, - comm, sched, 0, NULL, &vtx_id); + comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } @@ -61,7 +61,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint if (sendtype_size) { mpi_errno = MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst], sendcounts[dst], sendtypes[dst], dst, tag, - comm, sched, 0, NULL, &vtx_id); + comm, coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c index 22b1dde0e22..e3984434c71 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c @@ -62,12 +62,12 @@ int MPIR_TSP_Ialltoallw_sched_intra_inplace(const void *sendbuf, const MPI_Aint adj_tmp_buf = (void *) ((char *) tmp_buf - true_lb); mpi_errno = MPIR_TSP_sched_isend((char *) recvbuf + rdispls[dst], - recvcounts[dst], recvtypes[dst], dst, tag, comm, sched, - nvtcs, vtcs, &send_id); + recvcounts[dst], recvtypes[dst], dst, tag, comm, + coll_group, sched, nvtcs, vtcs, &send_id); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_TSP_sched_irecv(adj_tmp_buf, recvcounts[dst], recvtypes[dst], dst, tag, comm, - sched, nvtcs, vtcs, &recv_id); + coll_group, sched, nvtcs, vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs = 2; diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c index 8ed47527124..eafee4f7a5f 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c @@ -34,10 +34,10 @@ int MPIR_Ibarrier_intra_sched_recursive_doubling(MPIR_Comm * comm_ptr, int coll_ dst = (rank + mask) % size; src = (rank - mask + size) % size; - mpi_errno = MPIR_Sched_send(NULL, 0, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(NULL, 0, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(NULL, 0, MPI_BYTE, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(NULL, 0, MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c index 33145105d84..2dd5ceb3668 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c @@ -51,15 +51,15 @@ int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int coll_gro MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "dissem barrier - scheduling recv from %d\n", from)); mpi_errno = - MPIR_TSP_sched_irecv(NULL, 0, MPI_BYTE, from, tag, comm, sched, 0, NULL, + MPIR_TSP_sched_irecv(NULL, 0, MPI_BYTE, from, tag, comm, coll_group, sched, 0, NULL, &recv_ids[i * (k - 1) + j - 1]); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "dissem barrier - scheduling send to %d\n", to)); mpi_errno = - MPIR_TSP_sched_isend(NULL, 0, MPI_BYTE, to, tag, comm, sched, i * (k - 1), recv_ids, - &vtx_id); + MPIR_TSP_sched_isend(NULL, 0, MPI_BYTE, to, tag, comm, coll_group, sched, + i * (k - 1), recv_ids, &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, diff --git a/src/mpi/coll/ibcast/ibcast.h b/src/mpi/coll/ibcast/ibcast.h index cc21d5645ff..e13e3f3f1c2 100644 --- a/src/mpi/coll/ibcast/ibcast.h +++ b/src/mpi/coll/ibcast/ibcast.h @@ -20,7 +20,7 @@ int MPII_Ibcast_sched_test_length(MPIR_Comm * comm, int tag, void *state); int MPII_Ibcast_sched_test_curr_length(MPIR_Comm * comm, int tag, void *state); int MPII_Ibcast_sched_init_length(MPIR_Comm * comm, int tag, void *state); int MPII_Ibcast_sched_add_length(MPIR_Comm * comm, int tag, void *state); -int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, MPI_Aint nbytes, - MPIR_Sched_t s); +int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, int coll_group, + MPI_Aint nbytes, MPIR_Sched_t s); #endif /* IBCAST_H_INCLUDED */ diff --git a/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c b/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c index dc79d1f5282..2af50df7ded 100644 --- a/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c +++ b/src/mpi/coll/ibcast/ibcast_inter_sched_flat.c @@ -21,12 +21,12 @@ int MPIR_Ibcast_inter_sched_flat(void *buffer, MPI_Aint count, MPI_Datatype data mpi_errno = MPI_SUCCESS; } else if (root == MPI_ROOT) { /* root sends to rank 0 on remote group and returns */ - mpi_errno = MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. rank 0 on remote group receives from root */ if (comm_ptr->rank == 0) { - mpi_errno = MPIR_Sched_recv(buffer, count, datatype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(buffer, count, datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c index 00943d9d824..32c8ade7e6f 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c @@ -92,10 +92,10 @@ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype src += comm_size; if (!is_contig) mpi_errno = MPIR_Sched_recv_status(tmp_buf, nbytes, MPI_BYTE, src, - comm_ptr, &ibcast_state->status, s); + comm_ptr, coll_group, &ibcast_state->status, s); else mpi_errno = MPIR_Sched_recv_status(buffer, count, datatype, src, - comm_ptr, &ibcast_state->status, s); + comm_ptr, coll_group, &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -125,9 +125,10 @@ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype if (dst >= comm_size) dst -= comm_size; if (!is_contig) - mpi_errno = MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, nbytes, MPI_BYTE, dst, comm_ptr, coll_group, s); else - mpi_errno = MPIR_Sched_send(buffer, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(buffer, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* NOTE: This is departure from MPIR_Bcast_intra_binomial. A true analog diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c index 9f26d944e2d..e088a7d0932 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c @@ -110,7 +110,7 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M } - mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, nbytes, s); + mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, coll_group, nbytes, s); MPIR_ERR_CHECK(mpi_errno); MPI_Aint scatter_size, curr_size, incoming_count; @@ -162,12 +162,13 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M incoming_count = 0; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + send_offset), - curr_size, MPI_BYTE, dst, comm_ptr, s); + curr_size, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier */ mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + recv_offset), incoming_count, - MPI_BYTE, dst, comm_ptr, &ibcast_state->status, s); + MPI_BYTE, dst, comm_ptr, coll_group, + &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_add_length, ibcast_state, s); @@ -228,7 +229,8 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M * receive. that's the amount of data to be * sent now. */ mpi_errno = MPIR_Sched_send(((char *) tmp_buf + offset), - incoming_count, MPI_BYTE, dst, comm_ptr, s); + incoming_count, MPI_BYTE, dst, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -248,7 +250,7 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M * whose data we don't have */ mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + offset), incoming_count, MPI_BYTE, dst, comm_ptr, - &ibcast_state->status, s); + coll_group, &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&MPII_Ibcast_sched_add_length, ibcast_state, s); diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c index b3f5103ce0c..3d78dff1851 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c @@ -79,7 +79,7 @@ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, } } - mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, nbytes, s); + mpi_errno = MPII_Iscatter_for_bcast_sched(tmp_buf, root, comm_ptr, coll_group, nbytes, s); MPIR_ERR_CHECK(mpi_errno); MPI_Aint scatter_size, curr_size; @@ -120,11 +120,11 @@ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, right_disp = rel_j * scatter_size; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + right_disp), - right_count, MPI_BYTE, right, comm_ptr, s); + right_count, MPI_BYTE, right, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_recv_status(((char *) tmp_buf + left_disp), - left_count, MPI_BYTE, left, comm_ptr, + left_count, MPI_BYTE, left, comm_ptr, coll_group, &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c index 810e12a1dbf..028004f4863 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c @@ -48,13 +48,14 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat /* send to intranode-rank 0 on the root's node */ if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) *//* and is on our node (!-1) */ if (root == comm_ptr->rank) { - mpi_errno = MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr->node_comm, s); + mpi_errno = + MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr->node_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else if (0 == comm_ptr->node_comm->rank) { mpi_errno = MPIR_Sched_recv_status(buffer, count, datatype, MPIR_Get_intranode_rank(comm_ptr, root), comm_ptr->node_comm, - &ibcast_state->status, s); + coll_group, &ibcast_state->status, s); MPIR_ERR_CHECK(mpi_errno); #ifdef HAVE_ERROR_CHECKING MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c index 58c9eba80d5..7324a96100e 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c @@ -150,14 +150,14 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count ibcast_state->n_bytes = recv_size; mpi_errno = MPIR_TSP_sched_irecv_status((char *) tmp_buf + displs[rank], recv_size, MPI_BYTE, - my_tree.parent, tag, comm, &ibcast_state->status, sched, 0, - NULL, &recv_id); + my_tree.parent, tag, comm, coll_group, + &ibcast_state->status, sched, 0, NULL, &recv_id); MPIR_TSP_sched_cb(&MPII_Ibcast_sched_test_length, ibcast_state, sched, 1, &recv_id, &vtx_id); #else mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_buf + displs[rank], recv_size, MPI_BYTE, - my_tree.parent, tag, comm, sched, 0, NULL, &recv_id); + my_tree.parent, tag, comm, coll_group, sched, 0, NULL, &recv_id); #endif MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "rank:%d posts recv", rank)); @@ -174,8 +174,8 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count mpi_errno = MPIR_TSP_sched_isend((char *) tmp_buf + displs[child], child_subtree_size[i], MPI_BYTE, - child, tag, comm, sched, num_send_dependencies, &recv_id, - &vtx_id); + child, tag, comm, coll_group, sched, num_send_dependencies, + &recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ibcast/ibcast_tsp_tree.c b/src/mpi/coll/ibcast/ibcast_tsp_tree.c index b56b3b76cdf..44e76e4caef 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_tree.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_tree.c @@ -70,7 +70,7 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype #ifdef HAVE_ERROR_CHECKING mpi_errno = MPIR_TSP_sched_irecv_status((char *) buffer + offset * extent, msgsize, - datatype, my_tree.parent, tag, comm, + datatype, my_tree.parent, tag, comm, coll_group, &ibcast_state->status, sched, 0, NULL, &recv_id); MPIR_ERR_CHECK(mpi_errno); MPIR_TSP_sched_cb(&MPII_Ibcast_sched_test_length, ibcast_state, sched, 1, &recv_id, @@ -78,7 +78,8 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype #else mpi_errno = MPIR_TSP_sched_irecv((char *) buffer + offset * extent, msgsize, datatype, - my_tree.parent, tag, comm, sched, 0, NULL, &recv_id); + my_tree.parent, tag, comm, coll_group, sched, 0, NULL, + &recv_id); MPIR_ERR_CHECK(mpi_errno); #endif } @@ -87,8 +88,8 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype /* Multicast data to the children */ mpi_errno = MPIR_TSP_sched_imcast((char *) buffer + offset * extent, msgsize, datatype, ut_int_array(my_tree.children), num_children, tag, - comm, sched, (my_tree.parent != -1) ? 1 : 0, &recv_id, - &vtx_id); + comm, coll_group, sched, + (my_tree.parent != -1) ? 1 : 0, &recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } offset += msgsize; diff --git a/src/mpi/coll/ibcast/ibcast_utils.c b/src/mpi/coll/ibcast/ibcast_utils.c index 9bfaf925b3a..a600a638cc1 100644 --- a/src/mpi/coll/ibcast/ibcast_utils.c +++ b/src/mpi/coll/ibcast/ibcast_utils.c @@ -68,8 +68,8 @@ int MPII_Ibcast_sched_add_length(MPIR_Comm * comm, int tag, void *state) /* This is a binomial scatter operation, but it does *not* take * typical scatter arguments. At the moment this function always * scatters a buffer of nbytes starting at tmp_buf address. */ -int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, MPI_Aint nbytes, - MPIR_Sched_t s) +int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, int coll_group, + MPI_Aint nbytes, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; int rank, comm_size, src, dst; @@ -110,7 +110,7 @@ int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, if (recv_size > 0) { mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + relative_rank * scatter_size), - recv_size, MPI_BYTE, src, comm_ptr, s); + recv_size, MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -135,7 +135,7 @@ int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, dst -= comm_size; mpi_errno = MPIR_Sched_send(((char *) tmp_buf + scatter_size * (relative_rank + mask)), - send_size, MPI_BYTE, dst, comm_ptr, s); + send_size, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); curr_size -= send_size; diff --git a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c index 10ff8466e68..ac42822de90 100644 --- a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c @@ -90,10 +90,11 @@ int MPIR_Iexscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvb dst = rank ^ mask; if (dst < comm_size) { /* Send partial_scan to dst. Recv into tmp_buf */ - mpi_errno = MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/igather/igather_inter_sched_long.c b/src/mpi/coll/igather/igather_inter_sched_long.c index fe38bc5694e..6da72eb5823 100644 --- a/src/mpi/coll/igather/igather_inter_sched_long.c +++ b/src/mpi/coll/igather/igather_inter_sched_long.c @@ -29,11 +29,11 @@ int MPIR_Igather_inter_sched_long(const void *sendbuf, MPI_Aint sendcount, MPI_D for (i = 0; i < remote_size; i++) { mpi_errno = MPIR_Sched_recv(((char *) recvbuf + recvcount * i * extent), - recvcount, recvtype, i, comm_ptr, s); + recvcount, recvtype, i, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/igather/igather_inter_sched_short.c b/src/mpi/coll/igather/igather_inter_sched_short.c index e2d684b827e..f6d9b55f412 100644 --- a/src/mpi/coll/igather/igather_inter_sched_short.c +++ b/src/mpi/coll/igather/igather_inter_sched_short.c @@ -30,7 +30,8 @@ int MPIR_Igather_inter_sched_short(const void *sendbuf, MPI_Aint sendcount, MPI_ mpi_errno = MPI_SUCCESS; } else if (root == MPI_ROOT) { /* root receives data from rank 0 on remote group */ - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount * remote_size, recvtype, 0, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount * remote_size, recvtype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. Rank 0 allocates temporary buffer, does @@ -66,7 +67,7 @@ int MPIR_Igather_inter_sched_short(const void *sendbuf, MPI_Aint sendcount, MPI_ if (rank == 0) { mpi_errno = MPIR_Sched_send(tmp_buf, sendcount * local_size * sendtype_sz, MPI_BYTE, - root, comm_ptr, s); + root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/igather/igather_intra_sched_binomial.c b/src/mpi/coll/igather/igather_intra_sched_binomial.c index d0a77292f85..42b0f6d8550 100644 --- a/src/mpi/coll/igather/igather_intra_sched_binomial.c +++ b/src/mpi/coll/igather/igather_intra_sched_binomial.c @@ -127,14 +127,15 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, char *rp = (char *) recvbuf + (((rank + mask) % comm_size) * recvcount * extent); mpi_errno = - MPIR_Sched_recv(rp, (recvblks * recvcount), recvtype, src, comm_ptr, s); + MPIR_Sched_recv(rp, (recvblks * recvcount), recvtype, src, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { mpi_errno = MPIR_Sched_recv(tmp_buf, (recvblks * nbytes), MPI_BYTE, src, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -153,7 +154,8 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_Type_commit_impl(&tmp_type); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(recvbuf, 1, tmp_type, src, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, 1, tmp_type, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -176,7 +178,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, offset = (mask - 1) * nbytes; mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + offset), (recvblks * nbytes), - MPI_BYTE, src, comm_ptr, s); + MPI_BYTE, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -189,12 +191,14 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, if (!tmp_buf_size) { /* leaf nodes send directly from sendbuf */ - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); } else if (nbytes < MPIR_CVAR_GATHER_VSMALL_MSG_SIZE) { - mpi_errno = MPIR_Sched_send(tmp_buf, curr_cnt, MPI_BYTE, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_buf, curr_cnt, MPI_BYTE, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -219,7 +223,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_Type_commit_impl(&tmp_type); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_send(MPI_BOTTOM, 1, tmp_type, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(MPI_BOTTOM, 1, tmp_type, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/igather/igather_tsp_tree.c b/src/mpi/coll/igather/igather_tsp_tree.c index da395125e68..b08d2021384 100644 --- a/src/mpi/coll/igather/igather_tsp_tree.c +++ b/src/mpi/coll/igather/igather_tsp_tree.c @@ -135,8 +135,8 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* Leaf nodes send to parent */ if (num_children == 0) { mpi_errno = - MPIR_TSP_sched_isend(tmp_buf, sendcount, sendtype, my_tree.parent, tag, comm, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(tmp_buf, sendcount, sendtype, my_tree.parent, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "rank:%d posts recv\n", rank)); } else { @@ -160,13 +160,14 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_buf + child_data_offset[i] * recvtype_extent, child_subtree_size[i] * recvcount, recvtype, child, tag, comm, - sched, num_dependencies, &dtcopy_id, &recv_id[i]); + coll_group, sched, num_dependencies, &dtcopy_id, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); } if (my_tree.parent != -1) { mpi_errno = MPIR_TSP_sched_isend(tmp_buf, recv_size, recvtype, my_tree.parent, - tag, comm, sched, num_children, recv_id, &vtx_id); + tag, comm, coll_group, sched, num_children, recv_id, + &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c index b0154a27b28..47148214fa6 100644 --- a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c +++ b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c @@ -45,7 +45,8 @@ int MPIR_Igatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint sendcount, } } else { mpi_errno = MPIR_Sched_recv(((char *) recvbuf + displs[i] * extent), - recvcounts[i], recvtype, i, comm_ptr, s); + recvcounts[i], recvtype, i, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } @@ -53,7 +54,8 @@ int MPIR_Igatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint sendcount, } else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (sendcount) { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/igatherv/igatherv_tsp_linear.c b/src/mpi/coll/igatherv/igatherv_tsp_linear.c index cd4e6e6b6ac..ffa878af186 100644 --- a/src/mpi/coll/igatherv/igatherv_tsp_linear.c +++ b/src/mpi/coll/igatherv/igatherv_tsp_linear.c @@ -58,7 +58,7 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou } else { mpi_errno = MPIR_TSP_sched_irecv(((char *) recvbuf + displs[i] * extent), recvcounts[i], recvtype, i, tag, comm_ptr, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); } MPIR_ERR_CHECK(mpi_errno); } @@ -67,7 +67,7 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (sendcount) { mpi_errno = MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, root, tag, comm_ptr, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c index c4037e5b4e2..6d4a0dd84b4 100644 --- a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_allcomm_sched_linear.c @@ -36,13 +36,15 @@ int MPIR_Ineighbor_allgather_allcomm_sched_linear(const void *sendbuf, MPI_Aint MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c index 237275ce321..6f3218c2c62 100644 --- a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c +++ b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c @@ -43,16 +43,16 @@ int MPIR_TSP_Ineighbor_allgather_sched_allcomm_linear(const void *sendbuf, MPI_A for (k = 0; k < outdegree; ++k) { mpi_errno = - MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c index b4c51feb741..d61720abedb 100644 --- a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_allcomm_sched_linear.c @@ -37,13 +37,15 @@ int MPIR_Ineighbor_allgatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { - mpi_errno = MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount, sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + displs[l] * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c index f33f1cf7693..c8a3d70867e 100644 --- a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c +++ b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c @@ -44,16 +44,16 @@ int MPIR_TSP_Ineighbor_allgatherv_sched_allcomm_linear(const void *sendbuf, MPI_ for (k = 0; k < outdegree; ++k) { mpi_errno = - MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sendbuf, sendcount, sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (l = 0; l < indegree; ++l) { char *rb = ((char *) recvbuf) + displs[l] * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c index c593a5fe200..0c1a6333baf 100644 --- a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_allcomm_sched_linear.c @@ -38,7 +38,8 @@ int MPIR_Ineighbor_alltoall_allcomm_sched_linear(const void *sendbuf, MPI_Aint s for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + k * sendcount * sendtype_extent; - mpi_errno = MPIR_Sched_send(sb, sendcount, sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sb, sendcount, sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } @@ -57,7 +58,8 @@ int MPIR_Ineighbor_alltoall_allcomm_sched_linear(const void *sendbuf, MPI_Aint s */ for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcount, recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c index 3131e163f0d..74e3caa9ca6 100644 --- a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c @@ -46,8 +46,8 @@ int MPIR_TSP_Ineighbor_alltoall_sched_allcomm_linear(const void *sendbuf, MPI_Ai for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + k * sendcount * sendtype_extent; mpi_errno = - MPIR_TSP_sched_isend(sb, sendcount, sendtype, dsts[k], tag, comm_ptr, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_isend(sb, sendcount, sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -57,8 +57,8 @@ int MPIR_TSP_Ineighbor_alltoall_sched_allcomm_linear(const void *sendbuf, MPI_Ai for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + l * recvcount * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, sched, 0, NULL, - &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcount, recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c index 95713065499..27f200a3933 100644 --- a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_allcomm_sched_linear.c @@ -39,7 +39,8 @@ int MPIR_Ineighbor_alltoallv_allcomm_sched_linear(const void *sendbuf, const MPI for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + sdispls[k] * sendtype_extent; - mpi_errno = MPIR_Sched_send(sb, sendcounts[k], sendtype, dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sb, sendcounts[k], sendtype, dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } @@ -48,7 +49,8 @@ int MPIR_Ineighbor_alltoallv_allcomm_sched_linear(const void *sendbuf, const MPI * ref. ineighbor_alltoall_allcomm_sched_linear.c */ for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + rdispls[l] * recvtype_extent; - mpi_errno = MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcounts[l], recvtype, srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c index cdf14c9a6d7..6ed728dc8a5 100644 --- a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c @@ -49,8 +49,8 @@ int MPIR_TSP_Ineighbor_alltoallv_sched_allcomm_linear(const void *sendbuf, for (k = 0; k < outdegree; ++k) { char *sb = ((char *) sendbuf) + sdispls[k] * sendtype_extent; mpi_errno = - MPIR_TSP_sched_isend(sb, sendcounts[k], sendtype, dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sb, sendcounts[k], sendtype, dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -60,8 +60,8 @@ int MPIR_TSP_Ineighbor_alltoallv_sched_allcomm_linear(const void *sendbuf, for (l = indegree - 1; l >= 0; l--) { char *rb = ((char *) recvbuf) + rdispls[l] * recvtype_extent; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtype, srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c index dd9f143bcf2..5ccde6b4a7f 100644 --- a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c +++ b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_allcomm_sched_linear.c @@ -39,7 +39,9 @@ int MPIR_Ineighbor_alltoallw_allcomm_sched_linear(const void *sendbuf, const MPI char *sb; sb = ((char *) sendbuf) + sdispls[k]; - mpi_errno = MPIR_Sched_send(sb, sendcounts[k], sendtypes[k], dsts[k], comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sb, sendcounts[k], sendtypes[k], dsts[k], comm_ptr, MPIR_SUBGROUP_NONE, + s); MPIR_ERR_CHECK(mpi_errno); } @@ -50,7 +52,9 @@ int MPIR_Ineighbor_alltoallw_allcomm_sched_linear(const void *sendbuf, const MPI char *rb; rb = ((char *) recvbuf) + rdispls[l]; - mpi_errno = MPIR_Sched_recv(rb, recvcounts[l], recvtypes[l], srcs[l], comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(rb, recvcounts[l], recvtypes[l], srcs[l], comm_ptr, MPIR_SUBGROUP_NONE, + s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c index ee5b5b3872e..bc6501a51ba 100644 --- a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c @@ -47,8 +47,8 @@ int MPIR_TSP_Ineighbor_alltoallw_sched_allcomm_linear(const void *sendbuf, sb = ((char *) sendbuf) + sdispls[k]; mpi_errno = - MPIR_TSP_sched_isend(sb, sendcounts[k], sendtypes[k], dsts[k], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_isend(sb, sendcounts[k], sendtypes[k], dsts[k], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -60,8 +60,8 @@ int MPIR_TSP_Ineighbor_alltoallw_sched_allcomm_linear(const void *sendbuf, rb = ((char *) recvbuf) + rdispls[l]; mpi_errno = - MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtypes[l], srcs[l], tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(rb, recvcounts[l], recvtypes[l], srcs[l], tag, comm_ptr, + MPIR_SUBGROUP_NONE, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c b/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c index c0a6711a555..46af4fbb22f 100644 --- a/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c +++ b/src/mpi/coll/ireduce/ireduce_inter_sched_local_reduce_remote_send.c @@ -30,7 +30,7 @@ int MPIR_Ireduce_inter_sched_local_reduce_remote_send(const void *sendbuf, void if (root == MPI_ROOT) { /* root receives data from rank 0 on remote group */ - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -63,7 +63,7 @@ int MPIR_Ireduce_inter_sched_local_reduce_remote_send(const void *sendbuf, void MPIR_ERR_CHECK(mpi_errno); if (rank == 0) { - mpi_errno = MPIR_Sched_send(tmp_buf, count, datatype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_buf, count, datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c index 31193f2d00b..3ab93c5c664 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c @@ -93,7 +93,8 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai source = (relrank | mask); if (source < comm_size) { source = (source + lroot) % comm_size; - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, source, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, source, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -119,7 +120,7 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai /* I've received all that I'm going to. Send my result to * my parent */ source = ((relrank & (~mask)) + lroot) % comm_size; - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, source, comm_ptr, s); + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, source, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); @@ -131,12 +132,12 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai if (!is_commutative && (root != 0)) { if (rank == 0) { - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); } else if (rank == root) { - mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, count, datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_barrier(s); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c index 9e997ae2cb5..2fc41a8bd9f 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c @@ -105,7 +105,8 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re if (rank < 2 * rem) { if (rank % 2 != 0) { /* odd */ - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, count, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -114,7 +115,8 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re * doubling */ newrank = -1; } else { /* even */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_buf, count, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -183,11 +185,11 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ mpi_errno = MPIR_Sched_recv(((char *) tmp_buf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -232,7 +234,7 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re for (i = 1; i < pof2; i++) disps[i] = disps[i - 1] + cnts[i - 1]; - mpi_errno = MPIR_Sched_recv(recvbuf, cnts[0], datatype, 0, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, cnts[0], datatype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -240,7 +242,8 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re send_idx = 0; last_idx = 2; } else if (newrank == 0) { /* send */ - mpi_errno = MPIR_Sched_send(recvbuf, cnts[0], datatype, root, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(recvbuf, cnts[0], datatype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); newrank = -1; @@ -305,14 +308,14 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re /* send and exit */ /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[send_idx] * extent), - send_cnt, datatype, dst, comm_ptr, s); + send_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); break; } else { /* recv and continue */ mpi_errno = MPIR_Sched_recv(((char *) recvbuf + disps[recv_idx] * extent), - recv_cnt, datatype, dst, comm_ptr, s); + recv_cnt, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce/ireduce_tsp_tree.c b/src/mpi/coll/ireduce/ireduce_tsp_tree.c index a2bb817bdc0..f92c54ddfb7 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_tree.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_tree.c @@ -213,8 +213,9 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai nvtcs = 1; } - mpi_errno = MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, - sched, nvtcs, vtcs, &recv_id[i]); + mpi_errno = + MPIR_TSP_sched_irecv(recv_address, msgsize, datatype, child, tag, comm, coll_group, + sched, nvtcs, vtcs, &recv_id[i]); MPIR_ERR_CHECK(mpi_errno); /* Setup dependencies for reduction. Reduction depends on the corresponding recv to complete */ @@ -260,7 +261,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai if (!is_tree_root) { mpi_errno = MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, my_tree.parent, tag, comm, - sched, nvtcs, vtcs, &vtx_id); + coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } @@ -268,12 +269,12 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai if (tree_root != root) { if (is_tree_root) { /* tree_root sends data to root */ mpi_errno = - MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, root, tag, comm, sched, - nvtcs, vtcs, &vtx_id); + MPIR_TSP_sched_isend(reduce_address, msgsize, datatype, root, tag, comm, + coll_group, sched, nvtcs, vtcs, &vtx_id); } else if (is_root) { /* root receives data from tree_root */ mpi_errno = MPIR_TSP_sched_irecv((char *) recvbuf + offset * extent, msgsize, datatype, - tree_root, tag, comm, sched, 0, NULL, &vtx_id); + tree_root, tag, comm, coll_group, sched, 0, NULL, &vtx_id); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c index 2ecf91f004e..3f0ef676834 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c @@ -99,10 +99,10 @@ int MPIR_Ireduce_scatter_intra_sched_noncommutative(const void *sendbuf, void *r } mpi_errno = MPIR_Sched_send((outgoing_data + send_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv((incoming_data + recv_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c index b9686961f5f..41b5dbdec9b 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c @@ -73,14 +73,15 @@ int MPIR_Ireduce_scatter_intra_sched_pairwise(const void *sendbuf, void *recvbuf * needs from src into tmp_recvbuf */ if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Sched_send(((char *) sendbuf + disps[dst] * extent), - recvcounts[dst], datatype, dst, comm_ptr, s); + recvcounts[dst], datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[dst] * extent), - recvcounts[dst], datatype, dst, comm_ptr, s); + recvcounts[dst], datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, recvcounts[rank], datatype, src, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, recvcounts[rank], datatype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c index aed3af7f144..c65bf91d165 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c @@ -141,9 +141,9 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi * received in tmp_recvbuf and then accumulated into * tmp_results. accumulation is done later below. */ - mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; @@ -184,7 +184,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi if ((dst > rank) && (rank < tree_root + nprocs_completed) && (dst >= tree_root + nprocs_completed)) { /* send the current result */ - mpi_errno = MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -193,7 +194,8 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi else if ((dst < rank) && (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) { - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c index 93da7604868..20460dc2fc1 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c @@ -108,7 +108,9 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -117,7 +119,9 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -191,10 +195,10 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void int recv_dst = (recv_cnt ? dst : MPI_PROC_NULL); mpi_errno = MPIR_Sched_send(((char *) tmp_results + newdisps[send_idx] * extent), - send_cnt, datatype, send_dst, comm_ptr, s); + send_cnt, datatype, send_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) tmp_recvbuf + newdisps[recv_idx] * extent), - recv_cnt, datatype, recv_dst, comm_ptr, s); + recv_cnt, datatype, recv_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -232,14 +236,16 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void if (rank % 2) { /* odd */ if (recvcounts[rank - 1]) { mpi_errno = MPIR_Sched_send(((char *) tmp_results + disps[rank - 1] * extent), - recvcounts[rank - 1], datatype, rank - 1, comm_ptr, s); + recvcounts[rank - 1], datatype, rank - 1, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } } else { /* even */ if (recvcounts[rank]) { mpi_errno = - MPIR_Sched_recv(recvbuf, recvcounts[rank], datatype, rank + 1, comm_ptr, s); + MPIR_Sched_recv(recvbuf, recvcounts[rank], datatype, rank + 1, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c index 304462eede5..b49136672d2 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c @@ -86,7 +86,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * send_offset, send_cnt)); mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + send_offset, send_cnt, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &send_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id); MPIR_ERR_CHECK(mpi_errno); rank_for_offset = @@ -103,7 +104,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void * recv_offset, recv_cnt)); mpi_errno = MPIR_TSP_sched_irecv((char *) tmp_recvbuf + recv_offset, recv_cnt, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &recv_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &recv_id); MPIR_ERR_CHECK(mpi_errno); @@ -216,8 +218,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv else buf_to_send = (void *) sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, sched, - 0, NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { /* Step 2 participating rank */ @@ -226,8 +228,8 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv nvtcs = 1; vtcs[0] = (i == 0) ? dtcopy_id : reduce_id; mpi_errno = MPIR_TSP_sched_irecv(tmp_recvbuf, total_count, datatype, - step1_recvfrom[i], tag, comm, sched, nvtcs, vtcs, - &recv_id); + step1_recvfrom[i], tag, comm, coll_group, sched, nvtcs, + vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs++; vtcs[1] = recv_id; @@ -266,7 +268,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = MPIR_TSP_sched_irecv(recvbuf, recvcounts[rank], datatype, step1_sendto, tag, comm, - sched, 1, &sink_id, &vtx_id); + coll_group, sched, 1, &sink_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { @@ -274,7 +276,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv /* vtcs will be assigned to last reduce_id in step2 function */ mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + displs[step1_recvfrom[i]] * extent, recvcounts[step1_recvfrom[i]], datatype, step1_recvfrom[i], - tag, comm, sched, nvtcs, vtcs, &vtx_id); + tag, comm, coll_group, sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c index afd6c96691e..fef8400f794 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c @@ -84,10 +84,10 @@ int MPIR_Ireduce_scatter_block_intra_sched_noncommutative(const void *sendbuf, v } mpi_errno = MPIR_Sched_send((outgoing_data + send_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv((incoming_data + recv_offset * true_extent), - size, datatype, peer, comm_ptr, s); + size, datatype, peer, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c index 7c24f889773..339268ddbb5 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c @@ -63,14 +63,14 @@ int MPIR_Ireduce_scatter_block_intra_sched_pairwise(const void *sendbuf, void *r * needs from src into tmp_recvbuf */ if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Sched_send(((char *) sendbuf + disps[dst] * extent), - recvcount, datatype, dst, comm_ptr, s); + recvcount, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIR_Sched_send(((char *) recvbuf + disps[dst] * extent), - recvcount, datatype, dst, comm_ptr, s); + recvcount, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, recvcount, datatype, src, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_recvbuf, recvcount, datatype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c index 0786592616b..0e7785ef242 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c @@ -124,9 +124,9 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu /* tmp_results contains data to be sent in each step. Data is * received in tmp_recvbuf and then accumulated into * tmp_results. accumulation is done later below. */ - mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_send(tmp_results, 1, sendtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; @@ -167,7 +167,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu if ((dst > rank) && (rank < tree_root + nprocs_completed) && (dst >= tree_root + nprocs_completed)) { /* send the current result */ - mpi_errno = MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -176,7 +177,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu else if ((dst < rank) && (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) { - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, 1, recvtype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); received = 1; diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c index 2ce35875de5..b6673cd19de 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c @@ -77,7 +77,9 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ - mpi_errno = MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(tmp_results, total_count, datatype, rank + 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -86,7 +88,9 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf * doubling */ newrank = -1; } else { /* odd */ - mpi_errno = MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(tmp_recvbuf, total_count, datatype, rank - 1, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -159,10 +163,10 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf int recv_dst = (recv_cnt ? dst : MPI_PROC_NULL); mpi_errno = MPIR_Sched_send(((char *) tmp_results + newdisps[send_idx] * extent), - send_cnt, datatype, send_dst, comm_ptr, s); + send_cnt, datatype, send_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_recv(((char *) tmp_recvbuf + newdisps[recv_idx] * extent), - recv_cnt, datatype, recv_dst, comm_ptr, s); + recv_cnt, datatype, recv_dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -197,11 +201,12 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf if (rank < 2 * rem) { if (rank % 2) { /* odd */ mpi_errno = MPIR_Sched_send(((char *) tmp_results + disps[rank - 1] * extent), - recvcount, datatype, rank - 1, comm_ptr, s); + recvcount, datatype, rank - 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { /* even */ - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, datatype, rank + 1, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount, datatype, rank + 1, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c index c5cbf8c2a83..14a25188b0b 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c @@ -75,8 +75,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void else buf_to_send = (void *) sendbuf; mpi_errno = - MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, sched, - 0, NULL, &vtx_id); + MPIR_TSP_sched_isend(buf_to_send, total_count, datatype, step1_sendto, tag, comm, + coll_group, sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } else { /* Step 2 participating rank */ for (i = 0; i < step1_nrecvs; i++) { /* participating rank gets data from non-partcipating ranks */ @@ -84,8 +84,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void nvtcs = 1; vtcs[0] = (i == 0) ? dtcopy_id : reduce_id; mpi_errno = MPIR_TSP_sched_irecv(tmp_recvbuf, total_count, datatype, - step1_recvfrom[i], tag, comm, sched, nvtcs, vtcs, - &recv_id); + step1_recvfrom[i], tag, comm, coll_group, sched, nvtcs, + vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs++; vtcs[1] = recv_id; @@ -118,7 +118,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + send_offset, send_cnt * recvcount, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &send_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &send_id); MPIR_ERR_CHECK(mpi_errno); MPII_Recexchalgo_get_count_and_offset(rank, phase, k, nranks, &recv_cnt, &offset); @@ -126,7 +127,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void mpi_errno = MPIR_TSP_sched_irecv(tmp_recvbuf, recv_cnt * recvcount, - datatype, dst, tag, comm, sched, nvtcs, vtcs, &recv_id); + datatype, dst, tag, comm, coll_group, sched, nvtcs, vtcs, + &recv_id); MPIR_ERR_CHECK(mpi_errno); nvtcs = 2; @@ -155,8 +157,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void * send the data to non-partcipating ranks */ if (step1_sendto != -1) { /* I am a Step 2 non-participating rank */ mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, recvcount, datatype, step1_sendto, tag, comm, sched, 1, - &step1_id, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, recvcount, datatype, step1_sendto, tag, comm, coll_group, + sched, 1, &step1_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } for (i = 0; i < step1_nrecvs; i++) { @@ -164,8 +166,8 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void vtcs[0] = reduce_id; mpi_errno = MPIR_TSP_sched_isend((char *) tmp_results + recvcount * step1_recvfrom[i] * extent, - recvcount, datatype, step1_recvfrom[i], tag, comm, sched, nvtcs, - vtcs, &vtx_id); + recvcount, datatype, step1_recvfrom[i], tag, comm, coll_group, + sched, nvtcs, vtcs, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c index 885f1b88a5a..dc742296782 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c @@ -56,10 +56,11 @@ int MPIR_Iscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf dst = rank ^ mask; if (dst < comm_size) { /* Send partial_scan to dst. Recv into tmp_buf */ - mpi_errno = MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(partial_scan, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); /* sendrecv, no barrier here */ - mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); diff --git a/src/mpi/coll/iscan/iscan_intra_sched_smp.c b/src/mpi/coll/iscan/iscan_intra_sched_smp.c index 79b71954349..9e3b4ed0b86 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_smp.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_smp.c @@ -74,12 +74,12 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * reduced data of rank 1,2,3. */ if (roots_comm != NULL && node_comm != NULL) { mpi_errno = MPIR_Sched_recv(localfulldata, count, datatype, - (node_comm->local_size - 1), node_comm, s); + (node_comm->local_size - 1), node_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else if (roots_comm == NULL && node_comm != NULL && node_comm->rank == node_comm->local_size - 1) { - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, 0, node_comm, s); + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, 0, node_comm, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else if (roots_comm != NULL) { @@ -103,12 +103,15 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun if (roots_rank != roots_comm->local_size - 1) { mpi_errno = - MPIR_Sched_send(prefulldata, count, datatype, (roots_rank + 1), roots_comm, s); + MPIR_Sched_send(prefulldata, count, datatype, (roots_rank + 1), roots_comm, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } if (roots_rank != 0) { - mpi_errno = MPIR_Sched_recv(tempbuf, count, datatype, (roots_rank - 1), roots_comm, s); + mpi_errno = + MPIR_Sched_recv(tempbuf, count, datatype, (roots_rank - 1), roots_comm, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c index eae2bed1fcf..eba2405a9aa 100644 --- a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c @@ -75,8 +75,8 @@ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *rec nvtcs = 1; vtcs[0] = (loop_count == 0) ? dtcopy_id : reduce_id; mpi_errno = - MPIR_TSP_sched_isend(partial_scan, count, datatype, dst, tag, comm, sched, nvtcs, - vtcs, &send_id); + MPIR_TSP_sched_isend(partial_scan, count, datatype, dst, tag, comm, coll_group, + sched, nvtcs, vtcs, &send_id); MPIR_ERR_CHECK(mpi_errno); @@ -85,8 +85,8 @@ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *rec vtcs[1] = recv_reduce; } mpi_errno = - MPIR_TSP_sched_irecv(tmp_buf, count, datatype, dst, tag, comm, sched, nvtcs, vtcs, - &recv_id); + MPIR_TSP_sched_irecv(tmp_buf, count, datatype, dst, tag, comm, coll_group, sched, + nvtcs, vtcs, &recv_id); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c b/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c index 32f6c4b920b..0c9a97581b0 100644 --- a/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c +++ b/src/mpi/coll/iscatter/iscatter_inter_sched_linear.c @@ -34,12 +34,12 @@ int MPIR_Iscatter_inter_sched_linear(const void *sendbuf, MPI_Aint sendcount, MP for (i = 0; i < remote_size; i++) { mpi_errno = MPIR_Sched_send(((char *) sendbuf + sendcount * i * extent), sendcount, sendtype, i, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); } else { - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, s); + mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c b/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c index b12e2db5367..7b381fc5fee 100644 --- a/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c +++ b/src/mpi/coll/iscatter/iscatter_inter_sched_remote_send_local_scatter.c @@ -34,7 +34,8 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI if (root == MPI_ROOT) { /* root sends all data to rank 0 on remote group and returns */ - mpi_errno = MPIR_Sched_send(sendbuf, sendcount * remote_size, sendtype, 0, comm_ptr, s); + mpi_errno = + MPIR_Sched_send(sendbuf, sendcount * remote_size, sendtype, 0, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); goto fn_exit; @@ -53,7 +54,7 @@ int MPIR_Iscatter_inter_sched_remote_send_local_scatter(const void *sendbuf, MPI mpi_errno = MPIR_Sched_recv(tmp_buf, recvcount * local_size * recvtype_sz, MPI_BYTE, - root, comm_ptr, s); + root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { diff --git a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c index 0bd9bb1820d..e6d0866ed09 100644 --- a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c +++ b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c @@ -158,7 +158,8 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, * they don't have to forward data to anyone. Others * receive data into a temporary buffer. */ if (relative_rank % 2) { - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, src, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount, recvtype, src, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { @@ -167,7 +168,7 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, * some cases. query amount of data actually received */ mpi_errno = MPIR_Sched_recv_status(tmp_buf, tmp_buf_size, MPI_BYTE, src, comm_ptr, - &ss->status, s); + coll_group, &ss->status, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); mpi_errno = MPIR_Sched_cb(&get_count, ss, s); @@ -205,7 +206,8 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, /* mask is also the size of this process's subtree */ mpi_errno = MPIR_Sched_send_defer(((char *) sendbuf + extent * sendcount * mask), - &ss->send_subtree_count, sendtype, dst, comm_ptr, s); + &ss->send_subtree_count, sendtype, dst, comm_ptr, + coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { @@ -218,7 +220,7 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, /* mask is also the size of this process's subtree */ mpi_errno = MPIR_Sched_send_defer(((char *) tmp_buf + ss->nbytes * mask), &ss->send_subtree_count, MPI_BYTE, dst, - comm_ptr, s); + comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscatter/iscatter_tsp_tree.c b/src/mpi/coll/iscatter/iscatter_tsp_tree.c index a3955317444..e063be91d87 100644 --- a/src/mpi/coll/iscatter/iscatter_tsp_tree.c +++ b/src/mpi/coll/iscatter/iscatter_tsp_tree.c @@ -148,7 +148,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* receive data from the parent */ if (my_tree.parent != -1) { mpi_errno = MPIR_TSP_sched_irecv(tmp_buf, recv_size, recvtype, my_tree.parent, - tag, comm, sched, 0, NULL, &recv_id); + tag, comm, coll_group, sched, 0, NULL, &recv_id); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "rank:%d posts recv", rank)); } @@ -158,7 +158,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, int child = *(int *) utarray_eltptr(my_tree.children, i); mpi_errno = MPIR_TSP_sched_isend((char *) tmp_buf + child_data_offset[i] * sendtype_extent, child_subtree_size[i] * sendcount, sendtype, - child, tag, comm, sched, num_send_dependencies, + child, tag, comm, coll_group, sched, num_send_dependencies, (lrank == 0) ? dtcopy_id : &recv_id, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c index 3ed60e0f3b1..32cfe4f9320 100644 --- a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c @@ -49,7 +49,8 @@ int MPIR_Iscatterv_allcomm_sched_linear(const void *sendbuf, const MPI_Aint send } } else { mpi_errno = MPIR_Sched_send(((char *) sendbuf + displs[i] * extent), - sendcounts[i], sendtype, i, comm_ptr, s); + sendcounts[i], sendtype, i, comm_ptr, coll_group, + s); MPIR_ERR_CHECK(mpi_errno); } } @@ -59,7 +60,8 @@ int MPIR_Iscatterv_allcomm_sched_linear(const void *sendbuf, const MPI_Aint send else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (recvcount) { - mpi_errno = MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, s); + mpi_errno = + MPIR_Sched_recv(recvbuf, recvcount, recvtype, root, comm_ptr, coll_group, s); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c index 8313d41eb78..883dcef4d1b 100644 --- a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c @@ -57,7 +57,7 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint } else { mpi_errno = MPIR_TSP_sched_isend(((char *) sendbuf + displs[i] * extent), sendcounts[i], sendtype, i, tag, comm_ptr, - sched, 0, NULL, &vtx_id); + coll_group, sched, 0, NULL, &vtx_id); } } MPIR_ERR_CHECK(mpi_errno); @@ -68,8 +68,8 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (recvcount) { mpi_errno = - MPIR_TSP_sched_irecv(recvbuf, recvcount, recvtype, root, tag, comm_ptr, sched, 0, - NULL, &vtx_id); + MPIR_TSP_sched_irecv(recvbuf, recvcount, recvtype, root, tag, comm_ptr, coll_group, + sched, 0, NULL, &vtx_id); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c b/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c index 8b90a7ebd05..7b111d611c2 100644 --- a/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c +++ b/src/mpi/coll/reduce/reduce_inter_local_reduce_remote_send.c @@ -36,7 +36,8 @@ int MPIR_Reduce_inter_local_reduce_remote_send(const void *sendbuf, if (root == MPI_ROOT) { /* root receives data from rank 0 on remote group */ - mpi_errno = MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { /* remote group. Rank 0 allocates temporary buffer, does @@ -70,7 +71,7 @@ int MPIR_Reduce_inter_local_reduce_remote_send(const void *sendbuf, if (rank == 0) { mpi_errno = MPIC_Send(tmp_buf, count, datatype, root, - MPIR_REDUCE_TAG, comm_ptr, errflag); + MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/reduce/reduce_intra_binomial.c b/src/mpi/coll/reduce/reduce_intra_binomial.c index ff6cf3182fb..48748373a54 100644 --- a/src/mpi/coll/reduce/reduce_intra_binomial.c +++ b/src/mpi/coll/reduce/reduce_intra_binomial.c @@ -97,7 +97,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, if (source < comm_size) { source = (source + lroot) % comm_size; mpi_errno = MPIC_Recv(tmp_buf, count, datatype, source, - MPIR_REDUCE_TAG, comm_ptr, &status); + MPIR_REDUCE_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); /* The sender is above us, so the received buffer must be @@ -118,7 +118,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, * my parent */ source = ((relrank & (~mask)) + lroot) % comm_size; mpi_errno = MPIC_Send(recvbuf, count, datatype, - source, MPIR_REDUCE_TAG, comm_ptr, errflag); + source, MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); break; } @@ -128,9 +128,11 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, if (!is_commutative && (root != 0)) { if (rank == 0) { mpi_errno = MPIC_Send(recvbuf, count, datatype, root, - MPIR_REDUCE_TAG, comm_ptr, errflag); + MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); } else if (rank == root) { - mpi_errno = MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, &status); + mpi_errno = + MPIC_Recv(recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, coll_group, + &status); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c index ea6813457ff..452564d3d1e 100644 --- a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c +++ b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c @@ -106,7 +106,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, if (rank < 2 * rem) { if (rank % 2 != 0) { /* odd */ mpi_errno = MPIC_Send(recvbuf, count, - datatype, rank - 1, MPIR_REDUCE_TAG, comm_ptr, errflag); + datatype, rank - 1, MPIR_REDUCE_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -115,7 +116,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, newrank = -1; } else { /* even */ mpi_errno = MPIC_Recv(tmp_buf, count, - datatype, rank + 1, MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + datatype, rank + 1, MPIR_REDUCE_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. */ @@ -190,7 +192,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag); + MPIR_REDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, + errflag); MPIR_ERR_CHECK(mpi_errno); /* tmp_buf contains data received in this step. @@ -237,14 +240,14 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, disps[i] = disps[i - 1] + cnts[i - 1]; mpi_errno = MPIC_Recv(recvbuf, cnts[0], datatype, - 0, MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + 0, MPIR_REDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); newrank = 0; send_idx = 0; last_idx = 2; } else if (newrank == 0) { /* send */ mpi_errno = MPIC_Send(recvbuf, cnts[0], datatype, - root, MPIR_REDUCE_TAG, comm_ptr, errflag); + root, MPIR_REDUCE_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); newrank = -1; } @@ -310,7 +313,8 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, /* Send data from recvbuf. Recv into tmp_buf */ mpi_errno = MPIC_Send((char *) recvbuf + disps[send_idx] * extent, - send_cnt, datatype, dst, MPIR_REDUCE_TAG, comm_ptr, errflag); + send_cnt, datatype, dst, MPIR_REDUCE_TAG, comm_ptr, + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); break; } else { @@ -320,7 +324,7 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, mpi_errno = MPIC_Recv((char *) recvbuf + disps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c index 042e7dfc9f0..40c6230304e 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c @@ -101,7 +101,7 @@ int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf, size, datatype, peer, MPIR_REDUCE_SCATTER_TAG, incoming_data + recv_offset * true_extent, size, datatype, peer, MPIR_REDUCE_SCATTER_TAG, - comm_ptr, MPI_STATUS_IGNORE, errflag); + comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); /* always perform the reduction at recv_offset, the data at send_offset * is now our peer's responsibility */ diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c index 1ad6e16422f..bedaa05e980 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c @@ -82,14 +82,14 @@ int MPIR_Reduce_scatter_intra_pairwise(const void *sendbuf, void *recvbuf, recvcounts[dst], datatype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, recvcounts[rank], datatype, src, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else mpi_errno = MPIC_Sendrecv(((char *) recvbuf + disps[dst] * extent), recvcounts[dst], datatype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, recvcounts[rank], datatype, src, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c index e72c6384cc1..807588fef84 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c @@ -148,7 +148,7 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv mpi_errno = MPIC_Sendrecv(tmp_results, 1, sendtype, dst, MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf, 1, recvtype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); received = 1; MPIR_ERR_CHECK(mpi_errno); @@ -190,7 +190,8 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv && (dst >= tree_root + nprocs_completed)) { /* send the current result */ mpi_errno = MPIC_Send(tmp_recvbuf, 1, recvtype, - dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } /* recv only if this proc. doesn't have data and sender @@ -199,7 +200,8 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Recv(tmp_recvbuf, 1, recvtype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); received = 1; MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c index 0009e54b5f0..578e73e7448 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c @@ -113,7 +113,8 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb if (rank < 2 * rem) { if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(tmp_results, total_count, - datatype, rank + 1, MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + datatype, rank + 1, MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -123,7 +124,7 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb } else { /* odd */ mpi_errno = MPIC_Recv(tmp_recvbuf, total_count, datatype, rank - 1, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -199,18 +200,19 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb (char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else if ((send_cnt == 0) && (recv_cnt != 0)) mpi_errno = MPIC_Recv((char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); else if ((recv_cnt == 0) && (send_cnt != 0)) mpi_errno = MPIC_Send((char *) tmp_results + newdisps[send_idx] * extent, send_cnt, datatype, - dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); @@ -250,14 +252,15 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb mpi_errno = MPIC_Send((char *) tmp_results + disps[rank - 1] * extent, recvcounts[rank - 1], datatype, rank - 1, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, errflag); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { /* even */ if (recvcounts[rank]) { mpi_errno = MPIC_Recv(recvbuf, recvcounts[rank], datatype, rank + 1, - MPIR_REDUCE_SCATTER_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c index 08319ceb12e..db1a88edc70 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c @@ -100,7 +100,7 @@ int MPIR_Reduce_scatter_block_intra_noncommutative(const void *sendbuf, size, datatype, peer, MPIR_REDUCE_SCATTER_BLOCK_TAG, incoming_data + recv_offset * true_extent, size, datatype, peer, MPIR_REDUCE_SCATTER_BLOCK_TAG, - comm_ptr, MPI_STATUS_IGNORE, errflag); + comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); /* always perform the reduction at recv_offset, the data at send_offset * is now our peer's responsibility */ diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c index 81d2ef3467b..7cbb663b0ac 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c @@ -84,14 +84,14 @@ int MPIR_Reduce_scatter_block_intra_pairwise(const void *sendbuf, recvcount, datatype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf, recvcount, datatype, src, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else mpi_errno = MPIC_Sendrecv(((char *) recvbuf + disps[dst] * extent), recvcount, datatype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf, recvcount, datatype, src, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c index 0f237c26b40..b8ed30050c7 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c @@ -146,7 +146,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(tmp_results, 1, sendtype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, tmp_recvbuf, 1, recvtype, dst, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); received = 1; MPIR_ERR_CHECK(mpi_errno); @@ -188,7 +188,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, && (dst >= tree_root + nprocs_completed)) { /* send the current result */ mpi_errno = MPIC_Send(tmp_recvbuf, 1, recvtype, - dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); } /* recv only if this proc. doesn't have data and sender @@ -198,7 +199,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, (rank >= tree_root + nprocs_completed)) { mpi_errno = MPIC_Recv(tmp_recvbuf, 1, recvtype, dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, - comm_ptr, MPI_STATUS_IGNORE); + comm_ptr, coll_group, MPI_STATUS_IGNORE); received = 1; MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c index 29572e9be2f..eedae204057 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c @@ -115,7 +115,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, if (rank % 2 == 0) { /* even */ mpi_errno = MPIC_Send(tmp_results, total_count, datatype, rank + 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* temporarily set the rank to -1 so that this @@ -125,7 +125,8 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, } else { /* odd */ mpi_errno = MPIC_Recv(tmp_recvbuf, total_count, datatype, rank - 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); /* do the reduction on received data. since the @@ -203,18 +204,20 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, (char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE, errflag); else if ((send_cnt == 0) && (recv_cnt != 0)) mpi_errno = MPIC_Recv((char *) tmp_recvbuf + newdisps[recv_idx] * extent, recv_cnt, datatype, dst, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); else if ((recv_cnt == 0) && (send_cnt != 0)) mpi_errno = MPIC_Send((char *) tmp_results + newdisps[send_idx] * extent, send_cnt, datatype, - dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + dst, MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + errflag); MPIR_ERR_CHECK(mpi_errno); @@ -249,11 +252,12 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, mpi_errno = MPIC_Send((char *) tmp_results + disps[rank - 1] * extent, recvcount, datatype, rank - 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, errflag); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, errflag); } else { /* even */ mpi_errno = MPIC_Recv(recvbuf, recvcount, datatype, rank + 1, - MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_REDUCE_SCATTER_BLOCK_TAG, comm_ptr, coll_group, + MPI_STATUS_IGNORE); } MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scan/scan_intra_recursive_doubling.c b/src/mpi/coll/scan/scan_intra_recursive_doubling.c index 6d9a2d40658..a89650b9083 100644 --- a/src/mpi/coll/scan/scan_intra_recursive_doubling.c +++ b/src/mpi/coll/scan/scan_intra_recursive_doubling.c @@ -97,7 +97,7 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf, mpi_errno = MPIC_Sendrecv(partial_scan, count, datatype, dst, MPIR_SCAN_TAG, tmp_buf, count, datatype, dst, - MPIR_SCAN_TAG, comm_ptr, &status, errflag); + MPIR_SCAN_TAG, comm_ptr, coll_group, &status, errflag); MPIR_ERR_CHECK(mpi_errno); if (rank > dst) { diff --git a/src/mpi/coll/scan/scan_intra_smp.c b/src/mpi/coll/scan/scan_intra_smp.c index 0414f3dfc67..efd5e09ac68 100644 --- a/src/mpi/coll/scan/scan_intra_smp.c +++ b/src/mpi/coll/scan/scan_intra_smp.c @@ -58,13 +58,13 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, if (comm_ptr->node_roots_comm != NULL && comm_ptr->node_comm != NULL) { mpi_errno = MPIC_Recv(localfulldata, count, datatype, comm_ptr->node_comm->local_size - 1, MPIR_SCAN_TAG, - comm_ptr->node_comm, &status); + comm_ptr->node_comm, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else if (comm_ptr->node_roots_comm == NULL && comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, rank) == comm_ptr->node_comm->local_size - 1) { mpi_errno = MPIC_Send(recvbuf, count, datatype, - 0, MPIR_SCAN_TAG, comm_ptr->node_comm, errflag); + 0, MPIR_SCAN_TAG, comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (comm_ptr->node_roots_comm != NULL) { localfulldata = recvbuf; @@ -82,13 +82,13 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, if (MPIR_Get_internode_rank(comm_ptr, rank) != comm_ptr->node_roots_comm->local_size - 1) { mpi_errno = MPIC_Send(prefulldata, count, datatype, MPIR_Get_internode_rank(comm_ptr, rank) + 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, errflag); + MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (MPIR_Get_internode_rank(comm_ptr, rank) != 0) { mpi_errno = MPIC_Recv(tempbuf, count, datatype, MPIR_Get_internode_rank(comm_ptr, rank) - 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, &status); + MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, &status); noneed = 0; MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scatter/scatter_inter_linear.c b/src/mpi/coll/scatter/scatter_inter_linear.c index 9b5f5ec5800..0fcd47a06a5 100644 --- a/src/mpi/coll/scatter/scatter_inter_linear.c +++ b/src/mpi/coll/scatter/scatter_inter_linear.c @@ -33,12 +33,12 @@ int MPIR_Scatter_inter_linear(const void *sendbuf, MPI_Aint sendcount, MPI_Datat for (i = 0; i < remote_size; i++) { mpi_errno = MPIC_Send(((char *) sendbuf + sendcount * i * extent), sendcount, sendtype, i, - MPIR_SCATTER_TAG, comm_ptr, errflag); + MPIR_SCATTER_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { - mpi_errno = - MPIC_Recv(recvbuf, recvcount, recvtype, root, MPIR_SCATTER_TAG, comm_ptr, &status); + mpi_errno = MPIC_Recv(recvbuf, recvcount, recvtype, root, MPIR_SCATTER_TAG, + comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c b/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c index d02865b1117..def8d1ade0f 100644 --- a/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c +++ b/src/mpi/coll/scatter/scatter_inter_remote_send_local_scatter.c @@ -36,7 +36,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint s /* root sends all data to rank 0 on remote group and returns */ mpi_errno = MPIC_Send(sendbuf, sendcount * remote_size, sendtype, 0, MPIR_SCATTER_TAG, comm_ptr, - errflag); + coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } else { @@ -54,7 +54,7 @@ int MPIR_Scatter_inter_remote_send_local_scatter(const void *sendbuf, MPI_Aint s "tmp_buf", MPL_MEM_BUFFER); mpi_errno = MPIC_Recv(tmp_buf, recvcount * local_size * recvtype_sz, MPI_BYTE, - root, MPIR_SCATTER_TAG, comm_ptr, &status); + root, MPIR_SCATTER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { /* silience -Wmaybe-uninitialized due to MPIR_Scatter by non-zero ranks */ diff --git a/src/mpi/coll/scatter/scatter_intra_binomial.c b/src/mpi/coll/scatter/scatter_intra_binomial.c index cd6a5fe13d6..e662cee1570 100644 --- a/src/mpi/coll/scatter/scatter_intra_binomial.c +++ b/src/mpi/coll/scatter/scatter_intra_binomial.c @@ -116,11 +116,11 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat * receive data into a temporary buffer. */ if (relative_rank % 2) { mpi_errno = MPIC_Recv(recvbuf, recvcount, recvtype, - src, MPIR_SCATTER_TAG, comm_ptr, &status); + src, MPIR_SCATTER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else { mpi_errno = MPIC_Recv(tmp_buf, tmp_buf_size, MPI_BYTE, src, - MPIR_SCATTER_TAG, comm_ptr, &status); + MPIR_SCATTER_TAG, comm_ptr, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); if (mpi_errno) { curr_cnt = 0; @@ -152,14 +152,16 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat mpi_errno = MPIC_Send(((char *) sendbuf + extent * sendcount * mask), send_subtree_cnt, - sendtype, dst, MPIR_SCATTER_TAG, comm_ptr, errflag); + sendtype, dst, MPIR_SCATTER_TAG, comm_ptr, coll_group, + errflag); } else { /* non-zero root and others */ send_subtree_cnt = curr_cnt - nbytes * mask; /* mask is also the size of this process's subtree */ mpi_errno = MPIC_Send(((char *) tmp_buf + nbytes * mask), send_subtree_cnt, - MPI_BYTE, dst, MPIR_SCATTER_TAG, comm_ptr, errflag); + MPI_BYTE, dst, MPIR_SCATTER_TAG, comm_ptr, coll_group, + errflag); } MPIR_ERR_CHECK(mpi_errno); curr_cnt -= send_subtree_cnt; diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c index c34aa40b4a5..9c0e5cc2bd4 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c @@ -57,7 +57,8 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount } else { mpi_errno = MPIC_Isend(((char *) sendbuf + displs[i] * extent), sendcounts[i], sendtype, i, - MPIR_SCATTERV_TAG, comm_ptr, &reqarray[reqs++], errflag); + MPIR_SCATTERV_TAG, comm_ptr, coll_group, + &reqarray[reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); } } @@ -70,7 +71,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount else if (root != MPI_PROC_NULL) { /* non-root nodes, and in the intercomm. case, non-root nodes on remote side */ if (recvcount) { mpi_errno = MPIC_Recv(recvbuf, recvcount, recvtype, root, - MPIR_SCATTERV_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_SCATTERV_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpi/comm/comm_impl.c b/src/mpi/comm/comm_impl.c index 0e76fc19109..6094f82b4d2 100644 --- a/src/mpi/comm/comm_impl.c +++ b/src/mpi/comm/comm_impl.c @@ -471,8 +471,8 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co info[1] = group_ptr->size; mpi_errno = MPIC_Sendrecv(info, 2, MPI_INT, 0, 0, - rinfo, 2, MPI_INT, 0, 0, comm_ptr, MPI_STATUS_IGNORE, - MPIR_ERR_NONE); + rinfo, 2, MPI_INT, 0, 0, comm_ptr, MPIR_SUBGROUP_NONE, + MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (*newcomm_ptr != NULL) { (*newcomm_ptr)->context_id = rinfo[0]; @@ -486,7 +486,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co /* Populate and exchange the ranks */ mpi_errno = MPIC_Sendrecv(mapping, group_ptr->size, MPI_INT, 0, 0, remote_mapping, remote_size, MPI_INT, 0, 0, - comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Broadcast to the other members of the local group */ @@ -1033,7 +1033,7 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader, mpi_errno = MPIC_Sendrecv(&recvcontext_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, remote_leader, tag, &remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, remote_leader, tag, - peer_comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); final_context_id = remote_context_id; @@ -1206,7 +1206,7 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i /* This routine allows use to use the collective communication * context rather than the point-to-point context. */ mpi_errno = MPIC_Sendrecv(&local_high, 1, MPI_INT, 0, 0, - &remote_high, 1, MPI_INT, 0, 0, comm_ptr, + &remote_high, 1, MPI_INT, 0, 0, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/comm/comm_split.c b/src/mpi/comm/comm_split.c index c929598767b..f8685836054 100644 --- a/src/mpi/comm/comm_split.c +++ b/src/mpi/comm/comm_split.c @@ -217,7 +217,8 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** if (comm_ptr->rank == 0) { mpi_errno = MPIC_Sendrecv(&new_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, 0, &remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, - 0, 0, comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + 0, 0, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Bcast(&remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, local_comm_ptr, diff --git a/src/mpi/comm/contextid.c b/src/mpi/comm/contextid.c index 25faa6f4fba..d1d2833a435 100644 --- a/src/mpi/comm/contextid.c +++ b/src/mpi/comm/contextid.c @@ -646,11 +646,11 @@ static int sched_cb_gcn_bcast(MPIR_Comm * comm, int tag, void *state) if (st->comm_ptr_inter->rank == 0) { mpi_errno = MPIR_Sched_recv(st->ctx1, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, st->comm_ptr_inter, - st->s); + MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_send(st->ctx0, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, st->comm_ptr_inter, - st->s); + MPIR_SUBGROUP_NONE, st->s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(st->s); } @@ -1058,7 +1058,7 @@ int MPIR_Get_intercomm_contextid(MPIR_Comm * comm_ptr, MPIR_Context_id_t * conte if (comm_ptr->rank == 0) { mpi_errno = MPIC_Sendrecv(&mycontext_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, tag, &remote_context_id, 1, MPIR_CONTEXT_ID_T_DATATYPE, 0, tag, - comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/topo/dist_graph_create.c b/src/mpi/topo/dist_graph_create.c index e42e61c65c2..07b7270fdee 100644 --- a/src/mpi/topo/dist_graph_create.c +++ b/src/mpi/topo/dist_graph_create.c @@ -151,14 +151,14 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, /* send edges where i is a destination to process i */ mpi_errno = MPIC_Isend(&rin[i][0], rin_sizes[i], MPI_INT, i, MPIR_TOPO_A_TAG, comm_ptr, - &reqs[idx++], MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, &reqs[idx++], MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } if (rout_sizes[i]) { /* send edges where i is a source to process i */ mpi_errno = MPIC_Isend(&rout[i][0], rout_sizes[i], MPI_INT, i, MPIR_TOPO_B_TAG, comm_ptr, - &reqs[idx++], MPIR_ERR_NONE); + MPIR_SUBGROUP_NONE, &reqs[idx++], MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } } @@ -204,7 +204,7 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, MPIR_ERR_CHKANDJUMP(!buf, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = MPIC_Recv(buf, count, MPI_INT, MPI_ANY_SOURCE, MPIR_TOPO_A_TAG, - comm_ptr, MPI_STATUS_IGNORE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); for (int j = 0; j < count / 2; ++j) { @@ -237,7 +237,7 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr, MPIR_ERR_CHKANDJUMP(!buf, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = MPIC_Recv(buf, count, MPI_INT, MPI_ANY_SOURCE, MPIR_TOPO_B_TAG, - comm_ptr, MPI_STATUS_IGNORE); + comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); for (int j = 0; j < count / 2; ++j) { diff --git a/src/mpid/ch3/src/ch3u_port.c b/src/mpid/ch3/src/ch3u_port.c index 9ca1fc5bfde..390eedfd0f3 100644 --- a/src/mpid/ch3/src/ch3u_port.c +++ b/src/mpid/ch3/src/ch3u_port.c @@ -646,7 +646,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, send_ints[0], send_ints[1], send_ints[2])); mpi_errno = MPIC_Sendrecv(send_ints, 3, MPI_INT, 0, sendtag++, recv_ints, 3, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); if (mpi_errno != MPI_SUCCESS) { /* this is a no_port error because we may fail to connect @@ -689,7 +689,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, mpi_errno = MPIC_Sendrecv(local_translation, local_comm_size * 2, MPI_INT, 0, sendtag++, remote_translation, remote_comm_size * 2, - MPI_INT, 0, recvtag++, tmp_comm, + MPI_INT, 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -740,7 +740,7 @@ int MPIDI_Comm_connect(const char *port_name, MPIR_Info *info, int root, MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"sync with peer"); mpi_errno = MPIC_Sendrecv(&i, 0, MPI_INT, 0, sendtag++, &j, 0, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -928,7 +928,7 @@ static int ReceivePGAndDistribute( MPIR_Comm *tmp_comm, MPIR_Comm *comm_ptr, if (rank == root) { /* First, receive the pg description from the partner */ mpi_errno = MPIC_Recv(&j, 1, MPI_INT, 0, recvtag++, - tmp_comm, MPI_STATUS_IGNORE); + tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); *recvtag_p = recvtag; MPIR_ERR_CHECK(mpi_errno); pg_str = (char*)MPL_malloc(j, MPL_MEM_DYNAMIC); @@ -936,7 +936,7 @@ static int ReceivePGAndDistribute( MPIR_Comm *tmp_comm, MPIR_Comm *comm_ptr, MPIR_ERR_POP(mpi_errno); } mpi_errno = MPIC_Recv(pg_str, j, MPI_CHAR, 0, recvtag++, - tmp_comm, MPI_STATUS_IGNORE); + tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); *recvtag_p = recvtag; MPIR_ERR_CHECK(mpi_errno); } @@ -1083,13 +1083,13 @@ static int SendPGtoPeerAndFree( MPIR_Comm *tmp_comm, int *sendtag_p, pg_iter = pg_list; i = pg_iter->lenStr; /*printf("connect:sending 1 int: %d\n", i);fflush(stdout);*/ - mpi_errno = MPIC_Send(&i, 1, MPI_INT, 0, sendtag++, tmp_comm, MPIR_ERR_NONE); + mpi_errno = MPIC_Send(&i, 1, MPI_INT, 0, sendtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); *sendtag_p = sendtag; MPIR_ERR_CHECK(mpi_errno); /* printf("connect:sending string length %d\n", i);fflush(stdout); */ mpi_errno = MPIC_Send(pg_iter->str, i, MPI_CHAR, 0, sendtag++, - tmp_comm, MPIR_ERR_NONE); + tmp_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); *sendtag_p = sendtag; MPIR_ERR_CHECK(mpi_errno); @@ -1182,7 +1182,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, /*printf("accept:sending 3 ints, %d, %d, %d, and receiving 2 ints\n", send_ints[0], send_ints[1], send_ints[2]);fflush(stdout);*/ mpi_errno = MPIC_Sendrecv(send_ints, 3, MPI_INT, 0, sendtag++, recv_ints, 3, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } @@ -1221,7 +1221,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, mpi_errno = MPIC_Sendrecv(local_translation, local_comm_size * 2, MPI_INT, 0, sendtag++, remote_translation, remote_comm_size * 2, - MPI_INT, 0, recvtag++, tmp_comm, + MPI_INT, 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); @@ -1271,7 +1271,7 @@ int MPIDI_Comm_accept(const char *port_name, MPIR_Info *info, int root, MPL_DBG_MSG(MPIDI_CH3_DBG_CONNECT,VERBOSE,"sync with peer"); mpi_errno = MPIC_Sendrecv(&i, 0, MPI_INT, 0, sendtag++, &j, 0, MPI_INT, - 0, recvtag++, tmp_comm, + 0, recvtag++, tmp_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c b/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c index 0c1ed9377c2..588512a9d23 100644 --- a/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c +++ b/src/mpid/ch3/src/mpid_comm_get_all_failed_procs.c @@ -107,7 +107,7 @@ int MPID_Comm_get_all_failed_procs(MPIR_Comm *comm_ptr, MPIR_Group **failed_grou for (i = 1; i < comm_ptr->local_size; i++) { /* Get everyone's list of failed processes to aggregate */ ret_errno = MPIC_Recv(remote_bitarray, bitarray_size, MPI_INT, - i, tag, comm_ptr, MPI_STATUS_IGNORE); + i, tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); if (ret_errno) continue; /* Combine the received bitarray with my own */ @@ -121,7 +121,7 @@ int MPID_Comm_get_all_failed_procs(MPIR_Comm *comm_ptr, MPIR_Group **failed_grou for (i = 1; i < comm_ptr->local_size; i++) { /* Send the list to each rank to be processed locally */ ret_errno = MPIC_Send(bitarray, bitarray_size, MPI_INT, i, - tag, comm_ptr, MPIR_ERR_NONE); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); if (ret_errno) continue; } @@ -130,11 +130,11 @@ int MPID_Comm_get_all_failed_procs(MPIR_Comm *comm_ptr, MPIR_Group **failed_grou } else { /* Send my bitarray to rank 0 */ mpi_errno = MPIC_Send(bitarray, bitarray_size, MPI_INT, 0, - tag, comm_ptr, MPIR_ERR_NONE); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); /* Get the resulting bitarray back from rank 0 */ mpi_errno = MPIC_Recv(remote_bitarray, bitarray_size, MPI_INT, 0, - tag, comm_ptr, MPI_STATUS_IGNORE); + tag, comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); /* Convert the bitarray into a group */ *failed_group = bitarray_to_group(comm_ptr, remote_bitarray); diff --git a/src/mpid/ch3/src/mpid_vc.c b/src/mpid/ch3/src/mpid_vc.c index cf64f4063f3..3fd867ab269 100644 --- a/src/mpid/ch3/src/mpid_vc.c +++ b/src/mpid/ch3/src/mpid_vc.c @@ -492,7 +492,7 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, remote_leader, cts_tag, remote_size, 1, MPI_INT, remote_leader, cts_tag, - peer_comm_ptr, MPI_STATUS_IGNORE, MPIR_ERR_NONE ); + peer_comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); MPL_DBG_MSG_FMT(MPIDI_CH3_DBG_OTHER,VERBOSE,(MPL_DBG_FDEST, "local size = %d, remote size = %d", local_size, @@ -511,7 +511,7 @@ int MPID_Intercomm_exchange_map(MPIR_Comm *local_comm_ptr, int local_leader, mpi_errno = MPIC_Sendrecv( local_gpids, local_size*sizeof(MPIDI_Gpid), MPI_BYTE, remote_leader, cts_tag, remote_gpids, (*remote_size)*sizeof(MPIDI_Gpid), MPI_BYTE, - remote_leader, cts_tag, peer_comm_ptr, + remote_leader, cts_tag, peer_comm_ptr, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, MPIR_ERR_NONE ); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h b/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h index f11907a4d5b..b8e9197ce73 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h @@ -94,12 +94,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_NB_RG_root_datacopy_completion(void *v, /* Root sends data to rank 0 */ if (rank == root) { MPIC_Isend(per_call_data->local_buf, per_call_data->count, per_call_data->datatype, - 0, per_call_data->tag, comm_ptr, &(per_call_data->sreq), MPIR_ERR_NONE); + 0, per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, + &(per_call_data->sreq), MPIR_ERR_NONE); *done = 1; } else if (rank == 0) { MPIC_Irecv(MPIDI_POSIX_RELEASE_GATHER_NB_IBCAST_DATA_ADDR(segment), per_call_data->count, per_call_data->datatype, per_call_data->root, - per_call_data->tag, comm_ptr, &(per_call_data->rreq)); + per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, + &(per_call_data->rreq)); *done = 1; } } else { diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h b/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h index 62f67a56039..a3b7144f9f6 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h @@ -248,11 +248,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_NB_RG_reduce_start_sendrecv_completion( if (root != 0) { if (rank == root) { MPIC_Irecv(per_call_data->recv_buf, per_call_data->count, per_call_data->datatype, - 0, per_call_data->tag, comm_ptr, &(per_call_data->rreq)); + 0, per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, &(per_call_data->rreq)); } else if (rank == 0) { MPIC_Isend(MPIDI_POSIX_RELEASE_GATHER_NB_REDUCE_DATA_ADDR(rank, segment), per_call_data->count, per_call_data->datatype, per_call_data->root, - per_call_data->tag, comm_ptr, &(per_call_data->sreq), MPIR_ERR_NONE); + per_call_data->tag, comm_ptr, MPIR_SUBGROUP_NONE, &(per_call_data->sreq), + MPIR_ERR_NONE); } } diff --git a/src/mpid/ch4/shm/posix/release_gather/release_gather.h b/src/mpid/ch4/shm/posix/release_gather/release_gather.h index ac966cb9772..cf3ea4eaee4 100644 --- a/src/mpid/ch4/shm/posix/release_gather/release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/release_gather.h @@ -105,8 +105,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_release(void *local_ if (root != 0) { /* Root sends data to rank 0 */ if (rank == root) { - mpi_errno = - MPIC_Send(local_buf, count, datatype, 0, MPIR_BCAST_TAG, comm_ptr, errflag); + mpi_errno = MPIC_Send(local_buf, count, datatype, 0, MPIR_BCAST_TAG, + comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (rank == 0) { #ifdef HAVE_ERROR_CHECKING @@ -118,8 +118,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_release(void *local_ MPI_Aint recv_bytes; mpi_errno = MPIC_Recv((char *) bcast_data_addr + 2 * MPIDU_SHM_CACHE_LINE_LEN, count, - datatype, root, MPIR_BCAST_TAG, comm_ptr, &status); - MPIR_ERR_CHECK(mpi_errno); + datatype, root, MPIR_BCAST_TAG, comm_ptr, MPIR_SUBGROUP_NONE, + &status); MPIR_Get_count_impl(&status, MPI_BYTE, &recv_bytes); MPIR_Typerep_copy(bcast_data_addr, &recv_bytes, sizeof(int), MPIR_TYPEREP_FLAG_NONE); @@ -137,7 +137,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_release(void *local_ /* When error checking is disabled, MPI_STATUS_IGNORE is used */ mpi_errno = MPIC_Recv(bcast_data_addr, count, datatype, root, MPIR_BCAST_TAG, comm_ptr, - MPI_STATUS_IGNORE); + MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); #endif } @@ -373,13 +373,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_release_gather_gather(const void *i if (rank == root) { mpi_errno = MPIC_Recv(outbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm_ptr, - MPI_STATUS_IGNORE); + MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } else if (rank == 0) { MPIR_ERR_CHKANDJUMP(!reduce_data_addr, mpi_errno, MPI_ERR_OTHER, "**nomem"); mpi_errno = MPIC_Send((void *) reduce_data_addr, count, datatype, root, MPIR_REDUCE_TAG, - comm_ptr, errflag); + comm_ptr, MPIR_SUBGROUP_NONE, errflag); MPIR_ERR_CHECK(mpi_errno); } } diff --git a/src/mpid/ch4/src/ch4_coll_impl.h b/src/mpid/ch4/src/ch4_coll_impl.h index 36dd4424bce..2e4531d4a76 100644 --- a/src/mpid/ch4/src/ch4_coll_impl.h +++ b/src/mpid/ch4/src/ch4_coll_impl.h @@ -241,7 +241,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M /* root sends message to local leader (node_comm rank 0) */ if (comm->rank == root) { mpi_errno = MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* local leader receives message from root */ @@ -249,12 +249,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_alpha(void *buffer, M #ifndef HAVE_ERROR_CHECKING mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - &status); + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Datatype_get_size_macro(datatype, type_size); @@ -461,7 +461,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M /* root sends message to local leader (node_comm rank 0) */ if (comm->rank == root) { mpi_errno = MPIC_Send(buffer, count, datatype, 0, MPIR_BCAST_TAG, - comm->node_comm, errflag); + comm->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* local leader receives message from root */ @@ -469,12 +469,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_intra_composition_delta(void *buffer, M #ifndef HAVE_ERROR_CHECKING mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - MPI_STATUS_IGNORE); + coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); #else mpi_errno = MPIC_Recv(buffer, count, datatype, intra_root, MPIR_BCAST_TAG, comm->node_comm, - &status); + coll_group, &status); MPIR_ERR_CHECK(mpi_errno); MPIR_Datatype_get_size_macro(datatype, type_size); @@ -1030,9 +1030,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_intra_composition_alpha(const void *se /* Send data to root via point-to-point message if root is not rank 0 in comm */ if (root != 0) { if (comm->rank == 0) { - MPIC_Send(recvbuf, count, datatype, root, MPIR_REDUCE_TAG, comm, errflag); + MPIC_Send(recvbuf, count, datatype, root, MPIR_REDUCE_TAG, comm, coll_group, errflag); } else if (comm->rank == root) { - MPIC_Recv(ori_recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm, MPI_STATUS_IGNORE); + MPIC_Recv(ori_recvbuf, count, datatype, 0, MPIR_REDUCE_TAG, comm, coll_group, + MPI_STATUS_IGNORE); } } @@ -1766,13 +1767,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (comm_ptr->node_roots_comm != NULL && comm_ptr->node_comm != NULL) { mpi_errno = MPIC_Recv(localfulldata, count, datatype, comm_ptr->node_comm->local_size - 1, MPIR_SCAN_TAG, - comm_ptr->node_comm, &status); + comm_ptr->node_comm, coll_group, &status); MPIR_ERR_CHECK(mpi_errno); } else if (comm_ptr->node_roots_comm == NULL && comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, rank) == comm_ptr->node_comm->local_size - 1) { mpi_errno = MPIC_Send(recvbuf, count, datatype, - 0, MPIR_SCAN_TAG, comm_ptr->node_comm, errflag); + 0, MPIR_SCAN_TAG, comm_ptr->node_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (comm_ptr->node_roots_comm != NULL) { localfulldata = recvbuf; @@ -1790,13 +1791,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Scan_intra_composition_alpha(const void *send if (MPIR_Get_internode_rank(comm_ptr, rank) != comm_ptr->node_roots_comm->local_size - 1) { mpi_errno = MPIC_Send(prefulldata, count, datatype, MPIR_Get_internode_rank(comm_ptr, rank) + 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, errflag); + MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); } if (MPIR_Get_internode_rank(comm_ptr, rank) != 0) { mpi_errno = MPIC_Recv(tempbuf, count, datatype, MPIR_Get_internode_rank(comm_ptr, rank) - 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, &status); + MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, &status); noneed = 0; MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpid/ch4/src/ch4_comm.c b/src/mpid/ch4/src/ch4_comm.c index 9de7a1df977..a925231244f 100644 --- a/src/mpid/ch4/src/ch4_comm.c +++ b/src/mpid/ch4/src/ch4_comm.c @@ -450,8 +450,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C mpi_errno = MPIC_Sendrecv(&local_size_send, 1, MPI_INT, remote_leader, cts_tag, &remote_size_recv, 1, MPI_INT, - remote_leader, cts_tag, peer_comm, MPI_STATUS_IGNORE, - MPIR_ERR_NONE); + remote_leader, cts_tag, peer_comm, MPIR_SUBGROUP_NONE, + MPI_STATUS_IGNORE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); if (remote_size_recv & MPIDI_DYNPROC_MASK) @@ -488,7 +488,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C remote_leader, cts_tag, remote_upid_size, *remote_size, MPI_INT, remote_leader, cts_tag, - peer_comm, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); upid_send_size = 0; for (i = 0; i < local_size; i++) @@ -502,7 +503,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C remote_leader, cts_tag, remote_upids, upid_recv_size, MPI_BYTE, remote_leader, cts_tag, - peer_comm, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* Stage 1.2 convert remote UPID to GPID and get GPID for local group */ @@ -513,7 +515,8 @@ int MPID_Intercomm_exchange_map(MPIR_Comm * local_comm, int local_leader, MPIR_C remote_leader, cts_tag, *remote_gpids, *remote_size, MPI_UINT64_T, remote_leader, cts_tag, - peer_comm, MPI_STATUS_IGNORE, MPIR_ERR_NONE); + peer_comm, MPIR_SUBGROUP_NONE, MPI_STATUS_IGNORE, + MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } /* Stage 1.3 check if local/remote groups are disjoint */ From ca5917db976a39a2103dc48f1dbc68fc9beaf17e Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 19 Aug 2024 14:55:16 -0500 Subject: [PATCH 11/27] ch4: fallback to mpir if coll_group is non-zero Assuming the device layer collectives are not able to handle non-trivial coll_group, always fallback when coll_group != MPIR_SUBGROUP_NONE, for now. Also normalize the code style to use the fallback label. We should always fallback to mpir impl routines rather than the netmod routines (composition_beta). The composition_beta may fallback in the future when netmod coll become fancy, resulting in deadloop. --- src/mpid/ch4/src/ch4_coll.h | 399 ++++++++++++++++++++++++++---------- 1 file changed, 288 insertions(+), 111 deletions(-) diff --git a/src/mpid/ch4/src/ch4_coll.h b/src/mpid/ch4/src/ch4_coll.h index 9df2bc60a68..15678832485 100644 --- a/src/mpid/ch4/src/ch4_coll.h +++ b/src/mpid/ch4/src/ch4_coll.h @@ -114,9 +114,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -131,6 +129,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -144,6 +146,9 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, int coll_group, MPIR MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } switch (MPIR_CVAR_BARRIER_COMPOSITION) { case 1: @@ -169,10 +174,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, int coll_group, MPIR goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); - else - mpi_errno = MPIDI_Barrier_intra_composition_beta(comm, coll_group, errflag); + mpi_errno = MPIR_Barrier_impl(comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -211,9 +213,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, } } if (cnt == NULL) { - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -242,6 +242,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -257,6 +261,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + switch (MPIR_CVAR_BCAST_COMPOSITION) { case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, @@ -307,12 +315,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); - else - mpi_errno = - MPIDI_Bcast_intra_composition_gamma(buffer, count, datatype, root, comm, coll_group, - errflag); + mpi_errno = MPIR_Bcast_impl(buffer, count, datatype, root, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -378,10 +381,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -428,6 +428,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Allreduce_impl(sendbuf, recvbuf, count, datatype, op, + comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -445,6 +450,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + is_commutative = MPIR_Op_is_commutative(op); switch (MPIR_CVAR_ALLREDUCE_COMPOSITION) { @@ -596,11 +605,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -634,14 +639,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - coll_group, errflag); - else - mpi_errno = - MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, coll_group, errflag); + mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -660,6 +659,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + if (sendbuf != MPI_IN_PLACE) { MPIR_Datatype_get_size_macro(sendtype, type_size); data_size = sendcount * type_size; @@ -707,14 +710,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgather(const void *sendbuf, MPI_Aint sendco goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - coll_group, errflag); - else - mpi_errno = - MPIDI_Allgather_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, recvcount, - recvtype, comm, coll_group, errflag); + mpi_errno = MPIR_Allgather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_ENTER; @@ -732,6 +729,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLGATHERV, .comm_ptr = comm, @@ -750,11 +751,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, - recvtype, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -769,7 +766,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + fallback: + mpi_errno = MPIR_Allgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, + displs, recvtype, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; return mpi_errno; @@ -785,6 +786,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCATTER, .comm_ptr = comm, @@ -803,10 +808,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, - coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -821,6 +823,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Scatter_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -838,6 +845,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCATTERV, .comm_ptr = comm, @@ -857,10 +868,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, - root, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -875,6 +883,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Scatterv_impl(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, + recvtype, root, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -891,6 +904,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__GATHER, .comm_ptr = comm, @@ -909,11 +926,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, - coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -928,6 +941,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Gather_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -945,6 +963,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__GATHERV, .comm_ptr = comm, @@ -964,10 +986,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, - displs, recvtype, root, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -982,6 +1001,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Gatherv_impl(sendbuf, sendcount, sendtype, recvbuf, recvcounts, + displs, recvtype, root, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1056,10 +1080,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1094,14 +1115,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - coll_group, errflag); - else - mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, coll_group, - errflag); + mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: return mpi_errno; @@ -1119,6 +1134,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + if (sendbuf != MPI_IN_PLACE) { MPIR_Datatype_get_size_macro(sendtype, type_size); data_size = sendcount * type_size; @@ -1166,14 +1185,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoall(const void *sendbuf, MPI_Aint sendcou goto fn_exit; fallback: - if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) - mpi_errno = - MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, - coll_group, errflag); - else - mpi_errno = MPIDI_Alltoall_intra_composition_beta(sendbuf, sendcount, sendtype, recvbuf, - recvcount, recvtype, comm, coll_group, - errflag); + mpi_errno = MPIR_Alltoall_impl(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_ENTER; @@ -1192,6 +1205,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALLV, .comm_ptr = comm, @@ -1211,11 +1228,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, - sendtype, recvbuf, recvcounts, - rdispls, recvtype, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1231,6 +1244,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Alltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, + rdispls, recvtype, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1249,6 +1267,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALLW, .comm_ptr = comm, @@ -1268,11 +1290,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, - sendtypes, recvbuf, recvcounts, - rdispls, recvtypes, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1288,6 +1306,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Alltoallw_impl(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, + rdispls, recvtypes, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1365,6 +1388,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + switch (MPIR_CVAR_REDUCE_COMPOSITION) { case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM @@ -1429,6 +1456,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE_SCATTER, .comm_ptr = comm, @@ -1445,11 +1476,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, - errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1464,6 +1491,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Reduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1480,6 +1512,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE_SCATTER_BLOCK, .comm_ptr = comm, @@ -1496,11 +1532,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, comm, - coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1515,6 +1547,11 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Reduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, + comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1530,6 +1567,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCAN, .comm_ptr = comm, @@ -1546,10 +1587,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1568,6 +1606,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Scan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1583,6 +1625,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI int mpi_errno = MPI_SUCCESS; const MPIDI_Csel_container_s *cnt = NULL; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__EXSCAN, .comm_ptr = comm, @@ -1599,10 +1645,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); if (cnt == NULL) { - mpi_errno = - MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag);; - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; + goto fallback; } switch (cnt->id) { @@ -1616,6 +1659,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI } MPIR_ERR_CHECK(mpi_errno); + goto fn_exit; + + fallback: + mpi_errno = MPIR_Exscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, errflag); fn_exit: MPIR_FUNC_EXIT; @@ -1810,10 +1857,17 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ibarrier(MPIR_Comm * comm, int coll_group, MPI MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ibarrier(comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ibarrier_impl(comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ibcast(void *buffer, MPI_Aint count, MPI_Datatype datatype, @@ -1824,10 +1878,17 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ibcast(void *buffer, MPI_Aint count, MPI_Datat MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ibcast(buffer, count, datatype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ibcast_impl(buffer, count, datatype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iallgather(const void *sendbuf, MPI_Aint sendcount, @@ -1839,11 +1900,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iallgather(const void *sendbuf, MPI_Aint sendc MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iallgather_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iallgatherv(const void *sendbuf, MPI_Aint sendcount, @@ -1856,11 +1925,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iallgatherv(const void *sendbuf, MPI_Aint send MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iallgatherv_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcounts, displs, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iallreduce(const void *sendbuf, void *recvbuf, MPI_Aint count, @@ -1871,10 +1948,17 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iallreduce(const void *sendbuf, void *recvbuf, MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iallreduce(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iallreduce_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ialltoall(const void *sendbuf, MPI_Aint sendcount, @@ -1886,11 +1970,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ialltoall(const void *sendbuf, MPI_Aint sendco MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ialltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ialltoall_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallv(const void *sendbuf, const MPI_Aint * sendcounts, @@ -1903,11 +1995,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallv(const void *sendbuf, const MPI_Aint MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ialltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ialltoallv_impl(sendbuf, sendcounts, sdispls, sendtype, + recvbuf, recvcounts, rdispls, recvtype, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint * sendcounts, @@ -1921,11 +2021,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ialltoallw(const void *sendbuf, const MPI_Aint MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ialltoallw(sendbuf, sendcounts, sdispls, sendtypes, + recvbuf, recvcounts, rdispls, recvtypes, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iexscan(const void *sendbuf, void *recvbuf, MPI_Aint count, @@ -1936,10 +2044,17 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iexscan(const void *sendbuf, void *recvbuf, MP MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iexscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Igather(const void *sendbuf, MPI_Aint sendcount, @@ -1951,11 +2066,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Igather(const void *sendbuf, MPI_Aint sendcoun MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_igather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Igather_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Igatherv(const void *sendbuf, MPI_Aint sendcount, @@ -1968,11 +2091,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Igatherv(const void *sendbuf, MPI_Aint sendcou MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Igatherv(sendbuf, sendcount, sendtype, recvbuf, + recvcounts, displs, recvtype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter_block(const void *sendbuf, void *recvbuf, @@ -1984,12 +2115,20 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter_block(const void *sendbuf, voi MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ireduce_scatter_block_impl(sendbuf, recvbuf, recvcount, datatype, op, + comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter(const void *sendbuf, void *recvbuf, @@ -2002,12 +2141,20 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ireduce_scatter(const void *sendbuf, void *rec MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ireduce_scatter_impl(sendbuf, recvbuf, recvcounts, datatype, op, + comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Ireduce(const void *sendbuf, void *recvbuf, MPI_Aint count, @@ -2018,10 +2165,17 @@ MPL_STATIC_INLINE_PREFIX int MPID_Ireduce(const void *sendbuf, void *recvbuf, MP MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Ireduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_Aint count, @@ -2032,10 +2186,17 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iscan(const void *sendbuf, void *recvbuf, MPI_ MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iscan(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iscan_impl(sendbuf, recvbuf, count, datatype, op, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcount, @@ -2047,11 +2208,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iscatter(const void *sendbuf, MPI_Aint sendcou MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iscatter_impl(sendbuf, sendcount, sendtype, recvbuf, + recvcount, recvtype, root, comm, coll_group, req); } MPL_STATIC_INLINE_PREFIX int MPID_Iscatterv(const void *sendbuf, const MPI_Aint * sendcounts, @@ -2064,11 +2233,19 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iscatterv(const void *sendbuf, const MPI_Aint MPIR_FUNC_ENTER; + if (coll_group != MPIR_SUBGROUP_NONE) { + goto fallback; + } + ret = MPIDI_NM_mpi_iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, coll_group, req); MPIR_FUNC_EXIT; return ret; + + fallback: + return MPIR_Iscatterv_impl(sendbuf, sendcounts, displs, sendtype, + recvbuf, recvcount, recvtype, root, comm, coll_group, req); } #endif /* CH4_COLL_H_INCLUDED */ From a683ec6fe8c553baab9b9e98b18231fd1562c8cc Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 18 Aug 2024 10:50:50 -0500 Subject: [PATCH 12/27] coll: add coll_group to csel signature Make csel coll_group aware. --- maint/gen_coll.py | 4 ++++ src/include/mpir_csel.h | 1 + src/mpi/coll/allreduce/allreduce_intra_tree.c | 1 + src/mpi/coll/bcast/bcast_intra_tree.c | 1 + src/mpi/coll/iallreduce/iallreduce_tsp_auto.c | 1 + src/mpi/coll/ibarrier/ibarrier_tsp_auto.c | 1 + src/mpi/coll/ireduce/ireduce_tsp_auto.c | 1 + src/mpi/coll/ireduce/ireduce_tsp_tree.c | 1 + src/mpid/ch4/shm/posix/posix_coll.h | 1 + src/mpid/ch4/src/ch4_coll.h | 17 +++++++++++++++++ 10 files changed, 29 insertions(+) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 6327480f41f..1c15408af10 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -94,6 +94,8 @@ def dump_allcomm_auto_blocking(name): dump_open("MPIR_Csel_coll_sig_s coll_sig = {") G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME) G.out.append(".comm_ptr = comm_ptr,") + if not re.match(r'i?neighbor_', func_name, re.IGNORECASE): + G.out.append(".coll_group = coll_group,") for p in func['parameters']: if not re.match(r'comm$', p['name']): G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name'])) @@ -169,6 +171,8 @@ def dump_allcomm_sched_auto(name): dump_open("MPIR_Csel_coll_sig_s coll_sig = {") G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME) G.out.append(".comm_ptr = comm_ptr,") + if not re.match(r'i?neighbor_', func_name, re.IGNORECASE): + G.out.append(".coll_group = coll_group,") for p in func['parameters']: if not re.match(r'comm$', p['name']): G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name'])) diff --git a/src/include/mpir_csel.h b/src/include/mpir_csel.h index 07f061e98ef..c25193bb2fe 100644 --- a/src/include/mpir_csel.h +++ b/src/include/mpir_csel.h @@ -60,6 +60,7 @@ typedef enum { typedef struct { MPIR_Csel_coll_type_e coll_type; MPIR_Comm *comm_ptr; + int coll_group; union { struct { diff --git a/src/mpi/coll/allreduce/allreduce_intra_tree.c b/src/mpi/coll/allreduce/allreduce_intra_tree.c index 91293b75904..73cc8851911 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_tree.c +++ b/src/mpi/coll/allreduce/allreduce_intra_tree.c @@ -73,6 +73,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLREDUCE, .comm_ptr = comm_ptr, + .coll_group = coll_group, .u.allreduce.sendbuf = sendbuf, .u.allreduce.recvbuf = recvbuf, .u.allreduce.count = count, diff --git a/src/mpi/coll/bcast/bcast_intra_tree.c b/src/mpi/coll/bcast/bcast_intra_tree.c index c7a27a02db2..f9abecdf6ea 100644 --- a/src/mpi/coll/bcast/bcast_intra_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_tree.c @@ -82,6 +82,7 @@ int MPIR_Bcast_intra_tree(void *buffer, MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BCAST, .comm_ptr = comm_ptr, + .coll_group = coll_group, .u.bcast.buffer = buffer, .u.bcast.count = count, .u.bcast.datatype = datatype, diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c b/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c index 54129fdeaca..47de5324fe6 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_auto.c @@ -22,6 +22,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IALLREDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.iallreduce.sendbuf = sendbuf, .u.iallreduce.recvbuf = recvbuf, diff --git a/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c b/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c index bcb3a295d3a..d9a685a2c35 100644 --- a/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c +++ b/src/mpi/coll/ibarrier/ibarrier_tsp_auto.c @@ -13,6 +13,7 @@ int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int coll_group, MPI MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IBARRIER, .comm_ptr = comm, + .coll_group = coll_group, }; MPII_Csel_container_s *cnt; void *recvbuf = NULL; diff --git a/src/mpi/coll/ireduce/ireduce_tsp_auto.c b/src/mpi/coll/ireduce/ireduce_tsp_auto.c index 3d558be3dab..13d7aae7181 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_auto.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_auto.c @@ -48,6 +48,7 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IREDUCE, .comm_ptr = comm_ptr, + .coll_group = coll_group, .u.ireduce.sendbuf = sendbuf, .u.ireduce.recvbuf = recvbuf, diff --git a/src/mpi/coll/ireduce/ireduce_tsp_tree.c b/src/mpi/coll/ireduce/ireduce_tsp_tree.c index f92c54ddfb7..47f9c0abd19 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_tree.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_tree.c @@ -76,6 +76,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__IREDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.ireduce.sendbuf = sendbuf, .u.ireduce.recvbuf = recvbuf, .u.ireduce.count = count, diff --git a/src/mpid/ch4/shm/posix/posix_coll.h b/src/mpid/ch4/shm/posix/posix_coll.h index 219508f1c08..ec681802fc0 100644 --- a/src/mpid/ch4/shm/posix/posix_coll.h +++ b/src/mpid/ch4/shm/posix/posix_coll.h @@ -155,6 +155,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_barrier(MPIR_Comm * comm, int coll_ MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BARRIER, .comm_ptr = comm, + .coll_group = coll_group, }; MPIDI_POSIX_csel_container_s *cnt; diff --git a/src/mpid/ch4/src/ch4_coll.h b/src/mpid/ch4/src/ch4_coll.h index 15678832485..68d017a7e0d 100644 --- a/src/mpid/ch4/src/ch4_coll.h +++ b/src/mpid/ch4/src/ch4_coll.h @@ -109,6 +109,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Barrier_allcomm_composition_json(MPIR_Comm * MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BARRIER, .comm_ptr = comm, + .coll_group = coll_group, }; cnt = MPIR_Csel_search(MPIDI_COMM(comm, csel_comm), coll_sig); @@ -192,6 +193,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Bcast_allcomm_composition_json(void *buffer, MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__BCAST, .comm_ptr = comm, + .coll_group = coll_group, .u.bcast.buffer = buffer, .u.bcast.count = count, .u.bcast.datatype = datatype, @@ -370,6 +372,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allreduce_allcomm_composition_json(const void MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLREDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.allreduce.sendbuf = sendbuf, .u.allreduce.recvbuf = recvbuf, @@ -593,6 +596,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Allgather_allcomm_composition_json(const void MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLGATHER, .comm_ptr = comm, + .coll_group = coll_group, .u.allgather.sendbuf = sendbuf, .u.allgather.sendcount = sendcount, @@ -736,6 +740,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allgatherv(const void *sendbuf, MPI_Aint sendc MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLGATHERV, .comm_ptr = comm, + .coll_group = coll_group, .u.allgatherv.sendbuf = sendbuf, .u.allgatherv.sendcount = sendcount, @@ -793,6 +798,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatter(const void *sendbuf, MPI_Aint sendcoun MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCATTER, .comm_ptr = comm, + .coll_group = coll_group, .u.scatter.sendbuf = sendbuf, .u.scatter.sendcount = sendcount, @@ -852,6 +858,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scatterv(const void *sendbuf, const MPI_Aint * MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCATTERV, .comm_ptr = comm, + .coll_group = coll_group, .u.scatterv.sendbuf = sendbuf, .u.scatterv.sendcounts = sendcounts, @@ -911,6 +918,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gather(const void *sendbuf, MPI_Aint sendcount MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__GATHER, .comm_ptr = comm, + .coll_group = coll_group, .u.gather.sendbuf = sendbuf, .u.gather.sendcount = sendcount, @@ -970,6 +978,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Gatherv(const void *sendbuf, MPI_Aint sendcoun MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__GATHERV, .comm_ptr = comm, + .coll_group = coll_group, .u.gatherv.sendbuf = sendbuf, .u.gatherv.sendcount = sendcount, @@ -1068,6 +1077,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Alltoall_allcomm_composition_json(const void MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALL, .comm_ptr = comm, + .coll_group = coll_group, .u.alltoall.sendbuf = sendbuf, .u.alltoall.sendcount = sendcount, @@ -1212,6 +1222,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallv(const void *sendbuf, const MPI_Aint MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALLV, .comm_ptr = comm, + .coll_group = coll_group, .u.alltoallv.sendbuf = sendbuf, .u.alltoallv.sendcounts = sendcounts, @@ -1274,6 +1285,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Alltoallw(const void *sendbuf, const MPI_Aint MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__ALLTOALLW, .comm_ptr = comm, + .coll_group = coll_group, .u.alltoallw.sendbuf = sendbuf, .u.alltoallw.sendcounts = sendcounts, @@ -1332,6 +1344,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_Reduce_allcomm_composition_json(const void *s MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE, .comm_ptr = comm, + .coll_group = coll_group, .u.reduce.sendbuf = sendbuf, .u.reduce.recvbuf = recvbuf, @@ -1463,6 +1476,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter(const void *sendbuf, void *recv MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE_SCATTER, .comm_ptr = comm, + .coll_group = coll_group, .u.reduce_scatter.sendbuf = sendbuf, .u.reduce_scatter.recvbuf = recvbuf, @@ -1519,6 +1533,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce_scatter_block(const void *sendbuf, void MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__REDUCE_SCATTER_BLOCK, .comm_ptr = comm, + .coll_group = coll_group, .u.reduce_scatter_block.sendbuf = sendbuf, .u.reduce_scatter_block.recvbuf = recvbuf, @@ -1574,6 +1589,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Scan(const void *sendbuf, void *recvbuf, MPI_A MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__SCAN, .comm_ptr = comm, + .coll_group = coll_group, .u.scan.sendbuf = sendbuf, .u.scan.recvbuf = recvbuf, @@ -1632,6 +1648,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Exscan(const void *sendbuf, void *recvbuf, MPI MPIR_Csel_coll_sig_s coll_sig = { .coll_type = MPIR_CSEL_COLL_TYPE__EXSCAN, .comm_ptr = comm, + .coll_group = coll_group, .u.exscan.sendbuf = sendbuf, .u.exscan.recvbuf = recvbuf, From 287aeb4dd0cc672225a10ed3197fffad7ba27ee1 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 17 Aug 2024 23:20:59 -0500 Subject: [PATCH 13/27] coll: threadcomm coll to use MPIR_SUBGROUP_THREADCOMM Use coll_group=MPIR_SUBGROUP_THREADCOMM for threadcomm collectives. This allows compositional collectives under threadcomm. --- src/include/mpir_threadcomm.h | 22 ---------- .../coll/allgather/allgather_intra_brucks.c | 2 +- .../coll/allgatherv/allgatherv_intra_brucks.c | 2 +- .../allreduce_intra_recursive_doubling.c | 2 +- src/mpi/coll/alltoall/alltoall_intra_brucks.c | 2 +- .../alltoallv/alltoallv_intra_scattered.c | 2 +- .../alltoallw/alltoallw_intra_scattered.c | 2 +- .../barrier/barrier_intra_k_dissemination.c | 2 +- src/mpi/coll/bcast/bcast_intra_binomial.c | 2 +- .../exscan/exscan_intra_recursive_doubling.c | 2 +- src/mpi/coll/gather/gather_intra_binomial.c | 2 +- src/mpi/coll/gatherv/gatherv_allcomm_linear.c | 2 +- src/mpi/coll/reduce/reduce_intra_binomial.c | 2 +- .../reduce_scatter_intra_recursive_halving.c | 2 +- ...ce_scatter_block_intra_recursive_halving.c | 2 +- .../coll/scan/scan_intra_recursive_doubling.c | 2 +- src/mpi/coll/scatter/scatter_intra_binomial.c | 2 +- .../coll/scatterv/scatterv_allcomm_linear.c | 2 +- src/mpi/threadcomm/threadcomm_coll_impl.c | 40 ++++++++++--------- 19 files changed, 38 insertions(+), 58 deletions(-) diff --git a/src/include/mpir_threadcomm.h b/src/include/mpir_threadcomm.h index cda298f1f9e..90aefcf0a9b 100644 --- a/src/include/mpir_threadcomm.h +++ b/src/include/mpir_threadcomm.h @@ -110,28 +110,6 @@ MPL_STATIC_INLINE_PREFIX #endif /* ENABLE_THREADCOMM */ } -#ifdef ENABLE_THREADCOMM -#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ - MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \ - if (threadcomm) { \ - int intracomm_size = (comm)->local_size; \ - size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \ - rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \ - } else { \ - rank_ = (comm)->rank; \ - size_ = (comm)->local_size; \ - } \ - } while (0) - -#else -#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \ - MPIR_Assert((comm)->threadcomm == NULL); \ - rank_ = (comm)->rank; \ - size_ = (comm)->local_size; \ - } while (0) - -#endif - #ifdef ENABLE_THREADCOMM typedef struct MPIR_threadcomm_tls_t { MPIR_Threadcomm *threadcomm; diff --git a/src/mpi/coll/allgather/allgather_intra_brucks.c b/src/mpi/coll/allgather/allgather_intra_brucks.c index ac2b8a01dce..21d7b1655db 100644 --- a/src/mpi/coll/allgather/allgather_intra_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_brucks.c @@ -34,7 +34,7 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf, if (((sendcount == 0) && (sendbuf != MPI_IN_PLACE)) || (recvcount == 0)) goto fn_exit; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); MPIR_Datatype_get_size_macro(recvtype, recvtype_sz); diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c index 780ae714f25..b838fb4eb2f 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_brucks.c @@ -34,7 +34,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); total_count = 0; for (i = 0; i < comm_size; i++) diff --git a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c index 18486831078..85f555466a9 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c @@ -32,7 +32,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf, MPI_Aint true_extent, true_lb, extent; void *tmp_buf; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/alltoall/alltoall_intra_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_brucks.c index 41b98808dfa..ecbaef9990c 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_brucks.c @@ -37,7 +37,7 @@ int MPIR_Alltoall_intra_brucks(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(6); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf != MPI_IN_PLACE); diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c index c33ec8e4816..2c0a34728fe 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_scattered.c @@ -37,7 +37,7 @@ int MPIR_Alltoallv_intra_scattered(const void *sendbuf, const MPI_Aint * sendcou MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent of recv type, but send type is only valid if (sendbuf!=MPI_IN_PLACE) */ MPIR_Datatype_get_extent_macro(recvtype, recv_extent); diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c index cc6d6960063..936fcf9e040 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_scattered.c @@ -35,7 +35,7 @@ int MPIR_Alltoallw_intra_scattered(const void *sendbuf, const MPI_Aint sendcount MPI_Aint type_size; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* When MPI_IN_PLACE, we use pair-wise sendrecv_replace in order to conserve memory usage, diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c index 6ae5ba6c00f..3d1f884ebf7 100644 --- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c +++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c @@ -20,7 +20,7 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_ { int size, rank, src, dst, mask, mpi_errno = MPI_SUCCESS; - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); mask = 0x1; while (mask < size) { diff --git a/src/mpi/coll/bcast/bcast_intra_binomial.c b/src/mpi/coll/bcast/bcast_intra_binomial.c index e291927703f..e154ce259c6 100644 --- a/src/mpi/coll/bcast/bcast_intra_binomial.c +++ b/src/mpi/coll/bcast/bcast_intra_binomial.c @@ -33,7 +33,7 @@ int MPIR_Bcast_intra_binomial(void *buffer, void *tmp_buf = NULL; MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (HANDLE_IS_BUILTIN(datatype)) is_contig = 1; diff --git a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c index 62dbb6502c4..5f57f2a91b6 100644 --- a/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c +++ b/src/mpi/coll/exscan/exscan_intra_recursive_doubling.c @@ -59,7 +59,7 @@ int MPIR_Exscan_intra_recursive_doubling(const void *sendbuf, void *partial_scan, *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/gather/gather_intra_binomial.c b/src/mpi/coll/gather/gather_intra_binomial.c index 5dce21e04df..876b4529e64 100644 --- a/src/mpi/coll/gather/gather_intra_binomial.c +++ b/src/mpi/coll/gather/gather_intra_binomial.c @@ -57,7 +57,7 @@ int MPIR_Gather_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Data MPIR_CHKLMEM_DECL(1); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Use binomial tree algorithm. */ diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c index 906f1ed6bb2..123be4ed548 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c @@ -33,7 +33,7 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || diff --git a/src/mpi/coll/reduce/reduce_intra_binomial.c b/src/mpi/coll/reduce/reduce_intra_binomial.c index 48748373a54..93f6e233199 100644 --- a/src/mpi/coll/reduce/reduce_intra_binomial.c +++ b/src/mpi/coll/reduce/reduce_intra_binomial.c @@ -24,7 +24,7 @@ int MPIR_Reduce_intra_binomial(const void *sendbuf, void *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Create a temporary buffer */ diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c index 578e73e7448..4af12cc5135 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_halving.c @@ -50,7 +50,7 @@ int MPIR_Reduce_scatter_intra_recursive_halving(const void *sendbuf, void *recvb int pof2, old_i, newrank; MPIR_CHKLMEM_DECL(5); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING { diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c index eedae204057..85b10866140 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_halving.c @@ -53,7 +53,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_halving(const void *sendbuf, int pof2, old_i, newrank; MPIR_CHKLMEM_DECL(5); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING { diff --git a/src/mpi/coll/scan/scan_intra_recursive_doubling.c b/src/mpi/coll/scan/scan_intra_recursive_doubling.c index a89650b9083..31109b14d89 100644 --- a/src/mpi/coll/scan/scan_intra_recursive_doubling.c +++ b/src/mpi/coll/scan/scan_intra_recursive_doubling.c @@ -55,7 +55,7 @@ int MPIR_Scan_intra_recursive_doubling(const void *sendbuf, void *partial_scan, *tmp_buf; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/scatter/scatter_intra_binomial.c b/src/mpi/coll/scatter/scatter_intra_binomial.c index e662cee1570..1db7390e44a 100644 --- a/src/mpi/coll/scatter/scatter_intra_binomial.c +++ b/src/mpi/coll/scatter/scatter_intra_binomial.c @@ -41,7 +41,7 @@ int MPIR_Scatter_intra_binomial(const void *sendbuf, MPI_Aint sendcount, MPI_Dat int mpi_errno = MPI_SUCCESS; MPIR_CHKLMEM_DECL(4); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (rank == root) MPIR_Datatype_get_extent_macro(sendtype, extent); diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c index 9c0e5cc2bd4..da5be5a0b1f 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c @@ -29,7 +29,7 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || diff --git a/src/mpi/threadcomm/threadcomm_coll_impl.c b/src/mpi/threadcomm/threadcomm_coll_impl.c index fde81702361..529b415d17d 100644 --- a/src/mpi/threadcomm/threadcomm_coll_impl.c +++ b/src/mpi/threadcomm/threadcomm_coll_impl.c @@ -34,7 +34,7 @@ int MPIR_Threadcomm_barrier_impl(MPIR_Comm * comm) if (comm->local_size == 1) { thread_barrier(comm->threadcomm); } else { - mpi_errno = MPIR_Barrier_intra_dissemination(comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + mpi_errno = MPIR_Barrier_intra_dissemination(comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); } return mpi_errno; @@ -46,7 +46,7 @@ int MPIR_Threadcomm_bcast_impl(void *buffer, MPI_Aint count, MPI_Datatype dataty int mpi_errno = MPI_SUCCESS; mpi_errno = - MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm, MPIR_SUBGROUP_NONE, + MPIR_Bcast_intra_binomial(buffer, count, datatype, root, comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; @@ -60,7 +60,7 @@ int MPIR_Threadcomm_gather_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Dat mpi_errno = MPIR_Gather_intra_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, - MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -74,7 +74,7 @@ int MPIR_Threadcomm_gatherv_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Da mpi_errno = MPIR_Gatherv_allcomm_linear(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -87,7 +87,7 @@ int MPIR_Threadcomm_scatter_impl(const void *sendbuf, MPI_Aint sendcount, MPI_Da mpi_errno = MPIR_Scatter_intra_binomial(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -101,7 +101,7 @@ int MPIR_Threadcomm_scatterv_impl(const void *sendbuf, const MPI_Aint * sendcoun mpi_errno = MPIR_Scatterv_allcomm_linear(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -113,8 +113,8 @@ int MPIR_Threadcomm_allgather_impl(const void *sendbuf, MPI_Aint sendcount, MPI_ int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allgather_intra_brucks(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, MPIR_SUBGROUP_NONE, - MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, comm, + MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -128,7 +128,7 @@ int MPIR_Threadcomm_allgatherv_impl(const void *sendbuf, MPI_Aint sendcount, MPI mpi_errno = MPIR_Allgatherv_intra_brucks(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -141,8 +141,8 @@ int MPIR_Threadcomm_alltoall_impl(const void *sendbuf, MPI_Aint sendcount, MPI_D MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoall_intra_brucks(sendbuf, sendcount, sendtype, - recvbuf, recvcount, recvtype, comm, MPIR_SUBGROUP_NONE, - MPIR_ERR_NONE); + recvbuf, recvcount, recvtype, comm, + MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -158,7 +158,7 @@ int MPIR_Threadcomm_alltoallv_impl(const void *sendbuf, const MPI_Aint * sendcou MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoallv_intra_scattered(sendbuf, sendcounts, sdispls, sendtype, recvbuf, recvcounts, rdispls, recvtype, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -174,7 +174,7 @@ int MPIR_Threadcomm_alltoallw_impl(const void *sendbuf, const MPI_Aint * sendcou MPIR_Assert(sendbuf != MPI_IN_PLACE); mpi_errno = MPIR_Alltoallw_intra_scattered(sendbuf, sendcounts, sdispls, sendtypes, recvbuf, recvcounts, rdispls, recvtypes, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -186,7 +186,8 @@ int MPIR_Threadcomm_allreduce_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Allreduce_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, + MPIR_ERR_NONE); return mpi_errno; } @@ -198,7 +199,7 @@ int MPIR_Threadcomm_reduce_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Reduce_intra_binomial(sendbuf, recvbuf, count, datatype, op, root, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -211,7 +212,8 @@ int MPIR_Threadcomm_reduce_scatter_impl(const void *sendbuf, void *recvbuf, MPIR_Assert(MPIR_Op_is_commutative(op)); mpi_errno = MPIR_Reduce_scatter_intra_recursive_halving(sendbuf, recvbuf, recvcounts, - datatype, op, comm, MPIR_SUBGROUP_NONE, + datatype, op, comm, + MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; @@ -226,7 +228,7 @@ int MPIR_Threadcomm_reduce_scatter_block_impl(const void *sendbuf, void *recvbuf MPIR_Assert(MPIR_Op_is_commutative(op)); mpi_errno = MPIR_Reduce_scatter_block_intra_recursive_halving(sendbuf, recvbuf, recvcount, datatype, op, - comm, MPIR_SUBGROUP_NONE, + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; @@ -238,7 +240,7 @@ int MPIR_Threadcomm_scan_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Scan_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } @@ -249,7 +251,7 @@ int MPIR_Threadcomm_exscan_impl(const void *sendbuf, void *recvbuf, int mpi_errno = MPI_SUCCESS; mpi_errno = MPIR_Exscan_intra_recursive_doubling(sendbuf, recvbuf, count, datatype, op, - comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + comm, MPIR_SUBGROUP_THREADCOMM, MPIR_ERR_NONE); return mpi_errno; } From f7f6ae150b4c9b15aac8c54c8945575367c540ae Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 18 Aug 2024 00:09:04 -0500 Subject: [PATCH 14/27] coll: check coll_group in MPIR_Comm_is_parent_comm We call MPIR_Comm_is_parent_comm to prevent recursively entering compositional algorithms such as the _smp algorithms. Check coll_group as well as we will switch to use subgroup rather than subcomms. Also check num_external directly for trivial comm. Subcomms and comm->hierarchy_kind will be removed in the future. --- maint/gen_coll.py | 2 +- src/include/mpir_comm.h | 2 +- src/mpi/coll/barrier/barrier_intra_smp.c | 2 +- src/mpi/coll/bcast/bcast_intra_smp.c | 2 +- .../iallreduce/iallreduce_intra_sched_smp.c | 2 +- src/mpi/coll/ibcast/ibcast_intra_sched_smp.c | 2 +- .../coll/ireduce/ireduce_intra_sched_smp.c | 2 +- src/mpi/coll/mpir_coll_sched_auto.c | 8 +++---- src/mpi/comm/commutil.c | 10 +++++---- src/mpid/ch3/src/ch3u_recvq.c | 4 ++-- src/mpid/ch4/src/ch4_coll.h | 21 +++++++------------ 11 files changed, 26 insertions(+), 31 deletions(-) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 1c15408af10..19bfd18dbd9 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -564,7 +564,7 @@ def dump_fallback(algo): elif a== "builtin-op": cond_list.append("HANDLE_IS_BUILTIN(op)") elif a == "parent-comm": - cond_list.append("MPIR_Comm_is_parent_comm(comm_ptr)") + cond_list.append("MPIR_Comm_is_parent_comm(comm_ptr, coll_group)") elif a == "node-consecutive": cond_list.append("MPII_Comm_is_node_consecutive(comm_ptr)") elif a == "displs-ordered": diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 945f887eb39..40d533237b0 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -420,7 +420,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co int MPIR_Comm_create_subcomms(MPIR_Comm * comm); int MPIR_Comm_commit(MPIR_Comm *); -int MPIR_Comm_is_parent_comm(MPIR_Comm *); +int MPIR_Comm_is_parent_comm(MPIR_Comm * comm, int coll_group); /* peer intercomm is an internal 1-to-1 intercomm used for connecting dynamic processes */ int MPIR_peer_intercomm_create(MPIR_Context_id_t context_id, MPIR_Context_id_t recvcontext_id, diff --git a/src/mpi/coll/barrier/barrier_intra_smp.c b/src/mpi/coll/barrier/barrier_intra_smp.c index 36711f8560e..bafb2d34600 100644 --- a/src/mpi/coll/barrier/barrier_intra_smp.c +++ b/src/mpi/coll/barrier/barrier_intra_smp.c @@ -9,7 +9,7 @@ int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t { int mpi_errno = MPI_SUCCESS; - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); /* do the intranode barrier on all nodes */ if (comm_ptr->node_comm != NULL) { diff --git a/src/mpi/coll/bcast/bcast_intra_smp.c b/src/mpi/coll/bcast/bcast_intra_smp.c index 5881213d638..04deb237d2f 100644 --- a/src/mpi/coll/bcast/bcast_intra_smp.c +++ b/src/mpi/coll/bcast/bcast_intra_smp.c @@ -25,7 +25,7 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in #endif #ifdef HAVE_ERROR_CHECKING - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); #endif MPIR_Datatype_get_size_macro(datatype, type_size); diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c index e0acb2954ff..ad363bc8e54 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c @@ -15,7 +15,7 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint MPIR_Comm *nc; MPIR_Comm *nrc; - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); nc = comm_ptr->node_comm; nrc = comm_ptr->node_roots_comm; diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c index 028004f4863..33b7dd5fa64 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c @@ -35,7 +35,7 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat struct MPII_Ibcast_state *ibcast_state; #ifdef HAVE_ERROR_CHECKING - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); #endif ibcast_state = MPIR_Sched_alloc_state(s, sizeof(struct MPII_Ibcast_state)); MPIR_ERR_CHKANDJUMP(!ibcast_state, mpi_errno, MPI_ERR_OTHER, "**nomem"); diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c index adb3e10868d..13915416320 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c @@ -16,7 +16,7 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co MPIR_Comm *nc; MPIR_Comm *nrc; - MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr)); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); nc = comm_ptr->node_comm; diff --git a/src/mpi/coll/mpir_coll_sched_auto.c b/src/mpi/coll/mpir_coll_sched_auto.c index 0060a10d38e..2caad70c96c 100644 --- a/src/mpi/coll/mpir_coll_sched_auto.c +++ b/src/mpi/coll/mpir_coll_sched_auto.c @@ -41,7 +41,7 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT) { + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group)) { mpi_errno = MPIR_Ibcast_intra_sched_smp(buffer, count, datatype, root, comm_ptr, coll_group, s); if (mpi_errno) @@ -544,7 +544,7 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT && MPIR_Op_is_commutative(op)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group) && MPIR_Op_is_commutative(op)) { mpi_errno = MPIR_Ireduce_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, root, comm_ptr, coll_group, s); if (mpi_errno) @@ -601,7 +601,7 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT && MPIR_Op_is_commutative(op)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group) && MPIR_Op_is_commutative(op)) { mpi_errno = MPIR_Iallreduce_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, s); @@ -824,7 +824,7 @@ int MPIR_Iscan_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint cou { int mpi_errno = MPI_SUCCESS; - if (comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT) { + if (MPIR_Comm_is_parent_comm(comm_ptr, coll_group)) { mpi_errno = MPIR_Iscan_intra_sched_smp(sendbuf, recvbuf, count, datatype, op, comm_ptr, coll_group, s); diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index ac6d22b4164..b8c91b5d3ae 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -815,9 +815,11 @@ int MPIR_Comm_commit(MPIR_Comm * comm) /* Returns true if the given communicator is aware of node topology information, false otherwise. Such information could be used to implement more efficient collective communication, for example. */ -int MPIR_Comm_is_parent_comm(MPIR_Comm * comm) +int MPIR_Comm_is_parent_comm(MPIR_Comm * comm, int coll_group) { - return (comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT); + return (coll_group == MPIR_SUBGROUP_NONE && + comm->num_external > 1 && comm->num_external != comm->remote_size && + comm->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__PARENT); } /* Returns true if the communicator is node-aware and processes in all the nodes @@ -828,7 +830,7 @@ int MPII_Comm_is_node_consecutive(MPIR_Comm * comm) int i = 0, curr_nodeidx = 0; int *internode_table = comm->internode_table; - if (!MPIR_Comm_is_parent_comm(comm)) + if (!MPIR_Comm_is_parent_comm(comm, MPIR_SUBGROUP_NONE)) return 0; for (; i < comm->local_size; i++) { @@ -1311,7 +1313,7 @@ int MPII_Comm_is_node_balanced(MPIR_Comm * comm, int *num_nodes, bool * node_bal MPIR_CHKPMEM_DECL(1); - if (!MPIR_Comm_is_parent_comm(comm)) { + if (!MPIR_Comm_is_parent_comm(comm, MPIR_SUBGROUP_NONE)) { *node_balanced = false; goto fn_exit; } diff --git a/src/mpid/ch3/src/ch3u_recvq.c b/src/mpid/ch3/src/ch3u_recvq.c index d30938ae36a..10742ae4e98 100644 --- a/src/mpid/ch3/src/ch3u_recvq.c +++ b/src/mpid/ch3/src/ch3u_recvq.c @@ -921,7 +921,7 @@ int MPIDI_CH3U_Clean_recvq(MPIR_Comm *comm_ptr) } } - if (MPIR_Comm_is_parent_comm(comm_ptr)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, MPIR_SUBGROUP_NONE)) { /* node_comm pt2pt */ match.parts.context_id = comm_ptr->recvcontext_id + MPIR_CONTEXT_INTRANODE_OFFSET; @@ -1014,7 +1014,7 @@ int MPIDI_CH3U_Clean_recvq(MPIR_Comm *comm_ptr) } } - if (MPIR_Comm_is_parent_comm(comm_ptr)) { + if (MPIR_Comm_is_parent_comm(comm_ptr, MPIR_SUBGROUP_NONE)) { /* node_comm coll */ match.parts.context_id = comm_ptr->recvcontext_id + MPIR_CONTEXT_INTRANODE_OFFSET; diff --git a/src/mpid/ch4/src/ch4_coll.h b/src/mpid/ch4/src/ch4_coll.h index 68d017a7e0d..ae068abadbc 100644 --- a/src/mpid/ch4/src/ch4_coll.h +++ b/src/mpid/ch4/src/ch4_coll.h @@ -155,8 +155,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Barrier(MPIR_Comm * comm, int coll_group, MPIR case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Barrier composition alpha cannot be applied.\n"); mpi_errno = MPIDI_Barrier_intra_composition_alpha(comm, coll_group, errflag); break; @@ -271,8 +270,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Bcast composition alpha cannot be applied.\n"); mpi_errno = MPIDI_Bcast_intra_composition_alpha(buffer, count, datatype, root, comm, coll_group, @@ -281,8 +279,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Bcast composition beta cannot be applied.\n"); mpi_errno = MPIDI_Bcast_intra_composition_beta(buffer, count, datatype, root, comm, coll_group, @@ -299,8 +296,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Bcast(void *buffer, MPI_Aint count, MPI_Dataty case 4: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT), mpi_errno, + MPIR_Comm_is_parent_comm(comm, coll_group), mpi_errno, "Bcast composition delta cannot be applied.\n"); mpi_errno = MPIDI_Bcast_intra_composition_delta(buffer, count, datatype, root, comm, coll_group, @@ -463,8 +459,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Allreduce(const void *sendbuf, void *recvbuf, case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) && - (comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT) && + MPIR_Comm_is_parent_comm(comm, coll_group) && is_commutative, mpi_errno, "Allreduce composition alpha cannot be applied.\n"); mpi_errno = @@ -1408,8 +1403,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, switch (MPIR_CVAR_REDUCE_COMPOSITION) { case 1: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM - && comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT && + && MPIR_Comm_is_parent_comm(comm, coll_group) && MPIR_Op_is_commutative(op), mpi_errno, "Reduce composition alpha cannot be applied.\n"); mpi_errno = @@ -1418,8 +1412,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Reduce(const void *sendbuf, void *recvbuf, break; case 2: MPII_COLLECTIVE_FALLBACK_CHECK(comm->rank, comm->comm_kind == MPIR_COMM_KIND__INTRACOMM - && comm->hierarchy_kind == - MPIR_COMM_HIERARCHY_KIND__PARENT && + && MPIR_Comm_is_parent_comm(comm, coll_group) && MPIR_Op_is_commutative(op), mpi_errno, "Reduce composition beta cannot be applied.\n"); mpi_errno = From 75a0e6732acb0bd7dbe491c2ad7ea9fe26be32de Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 18 Aug 2024 10:44:51 -0500 Subject: [PATCH 15/27] coll: make non-compositional algorithm coll_group aware Use MPIR_COLL_RANK_SIZE if the algorithm is topology neutral. Use MPIR_COLL_RANK_SIZE_NO_GROUP if the algorithm is topology dependent. It adds an assertion on coll_group == MPIR_SUBGROUPS_NONE since coll_group may alter the topology assumptions. Intercomm does not work with non-zero coll_group. --- maint/gen_coll.py | 4 +- .../coll/allgather/allgather_intra_k_brucks.c | 5 ++- .../coll/allgather/allgather_intra_recexch.c | 3 +- .../allgather_intra_recursive_doubling.c | 3 +- src/mpi/coll/allgather/allgather_intra_ring.c | 3 +- .../allgatherv_intra_recursive_doubling.c | 3 +- .../coll/allgatherv/allgatherv_intra_ring.c | 3 +- ...lreduce_intra_k_reduce_scatter_allgather.c | 4 +- .../coll/allreduce/allreduce_intra_recexch.c | 4 +- ...allreduce_intra_reduce_scatter_allgather.c | 5 +-- src/mpi/coll/allreduce/allreduce_intra_ring.c | 3 +- src/mpi/coll/allreduce/allreduce_intra_tree.c | 3 +- .../coll/alltoall/alltoall_intra_k_brucks.c | 3 +- .../coll/alltoall/alltoall_intra_pairwise.c | 3 +- ...alltoall_intra_pairwise_sendrecv_replace.c | 3 +- .../coll/alltoall/alltoall_intra_scattered.c | 3 +- ...lltoallv_intra_pairwise_sendrecv_replace.c | 3 +- ...lltoallw_intra_pairwise_sendrecv_replace.c | 3 +- .../barrier/barrier_intra_k_dissemination.c | 5 +-- .../bcast_inter_remote_send_local_bcast.c | 2 +- .../coll/bcast/bcast_intra_pipelined_tree.c | 3 +- ...tra_scatter_recursive_doubling_allgather.c | 3 +- .../bcast_intra_scatter_ring_allgather.c | 3 +- src/mpi/coll/bcast/bcast_intra_tree.c | 6 +-- src/mpi/coll/bcast/bcast_utils.c | 4 +- src/mpi/coll/gatherv/gatherv_allcomm_linear.c | 11 ++++-- .../iallgather_intra_sched_brucks.c | 3 +- ...allgather_intra_sched_recursive_doubling.c | 3 +- .../iallgather/iallgather_intra_sched_ring.c | 3 +- .../coll/iallgather/iallgather_tsp_brucks.c | 5 ++- .../coll/iallgather/iallgather_tsp_recexch.c | 3 +- src/mpi/coll/iallgather/iallgather_tsp_ring.c | 5 ++- .../iallgatherv_intra_sched_brucks.c | 3 +- ...llgatherv_intra_sched_recursive_doubling.c | 3 +- .../iallgatherv_intra_sched_ring.c | 4 +- .../coll/iallgatherv/iallgatherv_tsp_brucks.c | 3 +- .../iallgatherv/iallgatherv_tsp_recexch.c | 3 +- .../coll/iallgatherv/iallgatherv_tsp_ring.c | 3 +- .../iallreduce/iallreduce_intra_sched_naive.c | 6 ++- ...allreduce_intra_sched_recursive_doubling.c | 5 +-- ...uce_intra_sched_reduce_scatter_allgather.c | 5 +-- .../coll/iallreduce/iallreduce_tsp_recexch.c | 3 +- ...ecexch_reduce_scatter_recexch_allgatherv.c | 3 +- src/mpi/coll/iallreduce/iallreduce_tsp_ring.c | 3 +- src/mpi/coll/iallreduce/iallreduce_tsp_tree.c | 3 +- .../ialltoall/ialltoall_intra_sched_brucks.c | 3 +- .../ialltoall/ialltoall_intra_sched_inplace.c | 4 +- .../ialltoall_intra_sched_pairwise.c | 3 +- .../ialltoall_intra_sched_permuted_sendrecv.c | 3 +- src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c | 3 +- src/mpi/coll/ialltoall/ialltoall_tsp_ring.c | 5 ++- .../coll/ialltoall/ialltoall_tsp_scattered.c | 3 +- .../ialltoallv_intra_sched_blocked.c | 3 +- .../ialltoallv_intra_sched_inplace.c | 3 +- .../coll/ialltoallv/ialltoallv_tsp_blocked.c | 3 +- .../coll/ialltoallv/ialltoallv_tsp_inplace.c | 3 +- .../ialltoallv/ialltoallv_tsp_scattered.c | 4 +- .../ialltoallw_intra_sched_blocked.c | 3 +- .../ialltoallw_intra_sched_inplace.c | 3 +- .../coll/ialltoallw/ialltoallw_tsp_blocked.c | 3 +- .../coll/ialltoallw/ialltoallw_tsp_inplace.c | 3 +- .../ibarrier_intra_sched_recursive_doubling.c | 3 +- .../coll/ibarrier/ibarrier_intra_tsp_dissem.c | 3 +- .../coll/ibcast/ibcast_intra_sched_binomial.c | 3 +- ...hed_scatter_recursive_doubling_allgather.c | 4 +- ...bcast_intra_sched_scatter_ring_allgather.c | 3 +- src/mpi/coll/ibcast/ibcast_tsp_auto.c | 6 ++- .../ibcast/ibcast_tsp_scatterv_allgatherv.c | 9 ++--- src/mpi/coll/ibcast/ibcast_tsp_tree.c | 3 +- src/mpi/coll/ibcast/ibcast_utils.c | 4 +- .../iexscan_intra_sched_recursive_doubling.c | 3 +- .../igather/igather_intra_sched_binomial.c | 3 +- src/mpi/coll/igather/igather_tsp_tree.c | 3 +- .../igatherv/igatherv_allcomm_sched_linear.c | 12 +++--- src/mpi/coll/igatherv/igatherv_tsp_linear.c | 12 +++--- .../ireduce/ireduce_intra_sched_binomial.c | 3 +- ...reduce_intra_sched_reduce_scatter_gather.c | 5 +-- src/mpi/coll/ireduce/ireduce_tsp_auto.c | 8 +++- src/mpi/coll/ireduce/ireduce_tsp_tree.c | 3 +- ...educe_scatter_intra_sched_noncommutative.c | 5 ++- .../ireduce_scatter_intra_sched_pairwise.c | 3 +- ...e_scatter_intra_sched_recursive_doubling.c | 3 +- ...ce_scatter_intra_sched_recursive_halving.c | 5 +-- .../ireduce_scatter_tsp_recexch.c | 3 +- ...scatter_block_intra_sched_noncommutative.c | 5 ++- ...educe_scatter_block_intra_sched_pairwise.c | 3 +- ...ter_block_intra_sched_recursive_doubling.c | 3 +- ...tter_block_intra_sched_recursive_halving.c | 3 +- .../ireduce_scatter_block_tsp_recexch.c | 3 +- .../iscan_intra_sched_recursive_doubling.c | 3 +- .../coll/iscan/iscan_tsp_recursive_doubling.c | 3 +- .../iscatter/iscatter_intra_sched_binomial.c | 3 +- src/mpi/coll/iscatter/iscatter_tsp_tree.c | 3 +- .../iscatterv_allcomm_sched_linear.c | 12 +++--- src/mpi/coll/iscatterv/iscatterv_tsp_linear.c | 12 +++--- src/mpi/coll/mpir_coll_sched_auto.c | 39 ++++++++++++------- .../reduce_intra_reduce_scatter_gather.c | 5 +-- .../reduce_scatter_intra_noncommutative.c | 2 + .../reduce_scatter_intra_pairwise.c | 3 +- .../reduce_scatter_intra_recursive_doubling.c | 3 +- ...educe_scatter_block_intra_noncommutative.c | 2 + .../reduce_scatter_block_intra_pairwise.c | 3 +- ...e_scatter_block_intra_recursive_doubling.c | 3 +- .../coll/scatterv/scatterv_allcomm_linear.c | 11 ++++-- src/mpi/coll/src/csel.c | 13 +++++-- 105 files changed, 220 insertions(+), 245 deletions(-) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 19bfd18dbd9..83323a951da 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -556,9 +556,9 @@ def dump_fallback(algo): elif a == "noinplace": cond_list.append("sendbuf != MPI_IN_PLACE") elif a == "power-of-two": - cond_list.append("MPL_is_pof2(comm_ptr->local_size)") + cond_list.append("MPL_is_pof2(MPIR_Coll_size(comm_ptr, coll_group))") elif a == "size-ge-pof2": - cond_list.append("count >= MPL_pof2(comm_ptr->local_size)") + cond_list.append("count >= MPL_pof2(MPIR_Coll_size(comm_ptr, coll_group))") elif a == "commutative": cond_list.append("MPIR_Op_is_commutative(op)") elif a== "builtin-op": diff --git a/src/mpi/coll/allgather/allgather_intra_k_brucks.c b/src/mpi/coll/allgather/allgather_intra_k_brucks.c index 4fd62333c3e..cf138bf5e96 100644 --- a/src/mpi/coll/allgather/allgather_intra_k_brucks.c +++ b/src/mpi/coll/allgather/allgather_intra_k_brucks.c @@ -30,8 +30,9 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount, int nphases = 0; int src, dst, p_of_k = 0; /* Largest power of k that is smaller than 'size' */ - int rank = MPIR_Comm_rank(comm); - int size = MPIR_Comm_size(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); int max = size - 1; MPIR_Request **reqs; diff --git a/src/mpi/coll/allgather/allgather_intra_recexch.c b/src/mpi/coll/allgather/allgather_intra_recexch.c index bd83fed5cac..5f8fe672e44 100644 --- a/src/mpi/coll/allgather/allgather_intra_recexch.c +++ b/src/mpi/coll/allgather/allgather_intra_recexch.c @@ -37,8 +37,7 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPIR_Request **recv_reqs = NULL, **send_reqs = NULL; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); diff --git a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c index 107b2d2ab9a..be80c8774b5 100644 --- a/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c +++ b/src/mpi/coll/allgather/allgather_intra_recursive_doubling.c @@ -35,8 +35,7 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf, MPI_Status status; int mask, dst_tree_root, my_tree_root, nprocs_completed, k, tmp_mask, tree_root; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. diff --git a/src/mpi/coll/allgather/allgather_intra_ring.c b/src/mpi/coll/allgather/allgather_intra_ring.c index a9715f502f6..d763d70942e 100644 --- a/src/mpi/coll/allgather/allgather_intra_ring.c +++ b/src/mpi/coll/allgather/allgather_intra_ring.c @@ -34,8 +34,7 @@ int MPIR_Allgather_intra_ring(const void *sendbuf, int j, i; int left, right, jnext; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c index 21df0baf8ef..0603924a960 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c @@ -39,8 +39,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf, MPI_Aint position, send_offset, recv_offset, offset; MPIR_CHKLMEM_DECL(1); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. diff --git a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c index 3daafdee90b..18c55b5172d 100644 --- a/src/mpi/coll/allgatherv/allgatherv_intra_ring.c +++ b/src/mpi/coll/allgatherv/allgatherv_intra_ring.c @@ -37,8 +37,7 @@ int MPIR_Allgatherv_intra_ring(const void *sendbuf, MPI_Aint recvtype_extent; MPI_Aint total_count; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); total_count = 0; for (i = 0; i < comm_size; i++) diff --git a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c index 81451aba812..842e0914005 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c @@ -37,8 +37,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, MPIR_Assert(k > 1); - rank = comm->rank; - nranks = comm->local_size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); + MPIR_Assert(MPIR_Op_is_commutative(op)); /* need to allocate temporary buffer to store incoming data */ diff --git a/src/mpi/coll/allreduce/allreduce_intra_recexch.c b/src/mpi/coll/allreduce/allreduce_intra_recexch.c index 72d3320055b..9d54f4524c0 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recexch.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recexch.c @@ -34,8 +34,8 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, MPIR_Request **send_reqs = NULL, **recv_reqs = NULL; int send_nreq = 0, recv_nreq = 0, total_phases = 0; - rank = comm->rank; - nranks = comm->local_size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); + is_commutative = MPIR_Op_is_commutative(op); bool is_float; diff --git a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c index 9bdd3bec10b..c072067b870 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_reduce_scatter_allgather.c @@ -53,8 +53,7 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, MPI_Aint true_extent, true_lb, extent; void *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* need to allocate temporary buffer to store incoming data */ MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -73,7 +72,7 @@ int MPIR_Allreduce_intra_reduce_scatter_allgather(const void *sendbuf, } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/allreduce/allreduce_intra_ring.c b/src/mpi/coll/allreduce/allreduce_intra_ring.c index 3791abfc262..fadd14d6b98 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_ring.c +++ b/src/mpi/coll/allreduce/allreduce_intra_ring.c @@ -25,8 +25,7 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count MPIR_Request *reqs[2]; /* one send and one recv per transfer */ is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); diff --git a/src/mpi/coll/allreduce/allreduce_intra_tree.c b/src/mpi/coll/allreduce/allreduce_intra_tree.c index 73cc8851911..226e40d6bfc 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_tree.c +++ b/src/mpi/coll/allreduce/allreduce_intra_tree.c @@ -37,8 +37,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, MPIR_Request **reqs; int num_reqs = 0; - comm_size = MPIR_Comm_size(comm_ptr); - rank = MPIR_Comm_rank(comm_ptr); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); diff --git a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c index dc43a99196e..cbe766b9008 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c +++ b/src/mpi/coll/alltoall/alltoall_intra_k_brucks.c @@ -134,8 +134,7 @@ int MPIR_Alltoall_intra_k_brucks(const void *sendbuf, is_inplace = (sendbuf == MPI_IN_PLACE); - rank = MPIR_Comm_rank(comm); - size = MPIR_Comm_size(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); nphases = 0; max = size - 1; diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c index c9bfb158286..736f71c6750 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise.c @@ -35,8 +35,7 @@ int MPIR_Alltoall_intra_pairwise(const void *sendbuf, int mpi_errno = MPI_SUCCESS, src, dst, rank; MPI_Status status; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf != MPI_IN_PLACE); diff --git a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c index e267b7af94d..9c83d5af935 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoall/alltoall_intra_pairwise_sendrecv_replace.c @@ -33,8 +33,7 @@ int MPIR_Alltoall_intra_pairwise_sendrecv_replace(const void *sendbuf, int mpi_errno = MPI_SUCCESS, rank; MPI_Status status; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent of send and recv types */ MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); diff --git a/src/mpi/coll/alltoall/alltoall_intra_scattered.c b/src/mpi/coll/alltoall/alltoall_intra_scattered.c index a95031f3461..ee31a35400a 100644 --- a/src/mpi/coll/alltoall/alltoall_intra_scattered.c +++ b/src/mpi/coll/alltoall/alltoall_intra_scattered.c @@ -42,8 +42,7 @@ int MPIR_Alltoall_intra_scattered(const void *sendbuf, MPI_Status *starray; MPIR_CHKLMEM_DECL(6); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf != MPI_IN_PLACE); diff --git a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c index 7f93eeaedb5..bbe2ba5a1cc 100644 --- a/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallv/alltoallv_intra_pairwise_sendrecv_replace.c @@ -31,8 +31,7 @@ int MPIR_Alltoallv_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP MPI_Status status; int rank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent of recv type, but send type is only valid if (sendbuf!=MPI_IN_PLACE) */ MPIR_Datatype_get_extent_macro(recvtype, recv_extent); diff --git a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c index 3e367430969..2d3423f2aa4 100644 --- a/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c +++ b/src/mpi/coll/alltoallw/alltoallw_intra_pairwise_sendrecv_replace.c @@ -31,8 +31,7 @@ int MPIR_Alltoallw_intra_pairwise_sendrecv_replace(const void *sendbuf, const MP MPI_Status status; int rank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(sendbuf == MPI_IN_PLACE); diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c index 3d1f884ebf7..5e71475c466 100644 --- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c +++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c @@ -20,7 +20,7 @@ int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_ { int size, rank, src, dst, mask, mpi_errno = MPI_SUCCESS; - MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, size); mask = 0x1; while (mask < size) { @@ -54,8 +54,7 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, MPIR_Request *sreqs[MAX_RADIX], *rreqs[MAX_RADIX * 2]; MPIR_Request **send_reqs = NULL, **recv_reqs = NULL; - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); if (nranks == 1) goto fn_exit; diff --git a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c index 087718375b8..266c48eb02a 100644 --- a/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c +++ b/src/mpi/coll/bcast/bcast_inter_remote_send_local_bcast.c @@ -22,7 +22,7 @@ int MPIR_Bcast_inter_remote_send_local_bcast(void *buffer, MPIR_Comm *newcomm_ptr = NULL; MPIR_FUNC_ENTER; - + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); if (root == MPI_PROC_NULL) { /* local processes other than root do nothing */ diff --git a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c index e09dc797bc0..98498d7ca77 100644 --- a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c @@ -31,8 +31,7 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, MPIR_Treealgo_tree_t my_tree; MPIR_CHKLMEM_DECL(3); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* If there is only one process, return */ if (comm_size == 1) diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c index 53c968276bc..4867712378d 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c @@ -45,8 +45,7 @@ int MPIR_Bcast_intra_scatter_recursive_doubling_allgather(void *buffer, MPI_Aint true_extent, true_lb; void *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; if (HANDLE_IS_BUILTIN(datatype)) diff --git a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c index bc64139c6c9..4c4641c347a 100644 --- a/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c +++ b/src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c @@ -39,8 +39,7 @@ int MPIR_Bcast_intra_scatter_ring_allgather(void *buffer, MPI_Aint true_extent, true_lb; MPIR_CHKLMEM_DECL(1); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (HANDLE_IS_BUILTIN(datatype)) is_contig = 1; diff --git a/src/mpi/coll/bcast/bcast_intra_tree.c b/src/mpi/coll/bcast/bcast_intra_tree.c index f9abecdf6ea..eaf48370014 100644 --- a/src/mpi/coll/bcast/bcast_intra_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_tree.c @@ -15,7 +15,7 @@ int MPIR_Bcast_intra_tree(void *buffer, int root, MPIR_Comm * comm_ptr, int coll_group, int tree_type, int branching_factor, int is_nb, MPIR_Errflag_t errflag) { - int rank, comm_size, src, dst, *p, j, k, lrank = -1, is_contig; + int rank, comm_size, src, dst, *p, j, k, is_contig; int parent = -1, num_children = 0, num_req = 0, is_root = 0; int mpi_errno = MPI_SUCCESS; MPI_Aint nbytes = 0, type_size, actual_packed_unpacked_bytes, recvd_size; @@ -29,8 +29,7 @@ int MPIR_Bcast_intra_tree(void *buffer, MPIR_Treealgo_tree_t my_tree; MPIR_CHKLMEM_DECL(3); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* If there is only one process, return */ if (comm_size == 1) @@ -64,6 +63,7 @@ int MPIR_Bcast_intra_tree(void *buffer, dtype = MPI_BYTE; } + int lrank = 0; if (tree_type == MPIR_TREE_TYPE_KARY) { if (rank == root) is_root = 1; diff --git a/src/mpi/coll/bcast/bcast_utils.c b/src/mpi/coll/bcast/bcast_utils.c index 1105b434481..8b614ae6df3 100644 --- a/src/mpi/coll/bcast/bcast_utils.c +++ b/src/mpi/coll/bcast/bcast_utils.c @@ -30,8 +30,8 @@ int MPII_Scatter_for_bcast(void *buffer ATTRIBUTE((unused)), MPI_Aint scatter_size, recv_size = 0; MPI_Aint curr_size, send_size; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; /* use long message algorithm: binomial tree scatter followed by an allgather */ diff --git a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c index 123be4ed548..62a3fda6ef2 100644 --- a/src/mpi/coll/gatherv/gatherv_allcomm_linear.c +++ b/src/mpi/coll/gatherv/gatherv_allcomm_linear.c @@ -33,14 +33,17 @@ int MPIR_Gatherv_allcomm_linear(const void *sendbuf, MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) - comm_size = comm_ptr->remote_size; - MPIR_Datatype_get_extent_macro(recvtype, extent); MPIR_CHKLMEM_MALLOC(reqarray, MPIR_Request **, comm_size * sizeof(MPIR_Request *), diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c index 24d6d3afa9b..44a5019b115 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_brucks.c @@ -25,8 +25,7 @@ int MPIR_Iallgather_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPI_Aint recvtype_extent, recvtype_sz; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); /* allocate a temporary buffer of the same size as recvbuf. */ diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c index ef41fe4658b..42bcb719edd 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_recursive_doubling.c @@ -57,8 +57,7 @@ int MPIR_Iallgather_intra_sched_recursive_doubling(const void *sendbuf, MPI_Aint int dst_tree_root, my_tree_root, tree_root; MPI_Aint recvtype_extent; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. diff --git a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c index 48c8c2c6e05..da715eb1c1a 100644 --- a/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c +++ b/src/mpi/coll/iallgather/iallgather_intra_sched_ring.c @@ -30,8 +30,7 @@ int MPIR_Iallgather_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, MP int i, j, jnext, left, right; MPI_Aint recvtype_extent; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); diff --git a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c index 750bcf41b64..d30e85ceca8 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c @@ -21,8 +21,9 @@ MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, int src, dst, p_of_k = 0; /* Largest power of k that is (strictly) smaller than 'size' */ MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; - int rank = MPIR_Comm_rank(comm); - int size = MPIR_Comm_size(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); int max = size - 1; int vtx_id; diff --git a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c index 1a6092c18ff..6663ba9c96c 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c @@ -259,8 +259,7 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco MPIR_ERR_CHECK(mpi_errno); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); diff --git a/src/mpi/coll/iallgather/iallgather_tsp_ring.c b/src/mpi/coll/iallgather/iallgather_tsp_ring.c index b29f64afb81..700f972a083 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_ring.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_ring.c @@ -16,8 +16,9 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount /* Temporary buffers to execute the ring algorithm */ void *buf1, *buf2, *data_buf, *rbuf, *sbuf; - int size = MPIR_Comm_size(comm); - int rank = MPIR_Comm_rank(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); int tag; int vtx_id; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c index d3d0cb073af..e978058c7d7 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_brucks.c @@ -17,8 +17,7 @@ int MPIR_Iallgatherv_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, int dst, pof2, src, rem; void *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); MPIR_Datatype_get_size_macro(recvtype, recvtype_sz); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c index 7a03d90d534..da0f616f048 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_recursive_doubling.c @@ -18,8 +18,7 @@ int MPIR_Iallgatherv_intra_sched_recursive_doubling(const void *sendbuf, MPI_Ain MPI_Aint recvtype_extent, recvtype_sz, position, offset; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); #ifdef HAVE_ERROR_CHECKING /* Currently this algorithm can only handle power-of-2 comm_size. diff --git a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c index 3572924050a..5eefd3de298 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_intra_sched_ring.c @@ -19,8 +19,8 @@ int MPIR_Iallgatherv_intra_sched_ring(const void *sendbuf, MPI_Aint sendcount, char *sbuf = NULL; char *rbuf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); total_count = 0; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c index baf490396a7..4eb3f04579d 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c @@ -68,8 +68,7 @@ MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, MPIR_ERR_CHECK(mpi_errno); is_inplace = (sendbuf == MPI_IN_PLACE); - rank = MPIR_Comm_rank(comm); - size = MPIR_Comm_size(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); max = size - 1; if (is_inplace) { diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c index fbfba53e3d1..a22694e0c6a 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c @@ -265,8 +265,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c index f6252210917..bc5cdd96978 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c @@ -26,8 +26,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); /* find out the buffer which has the send data and point data_buf to it */ if (is_inplace) { diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c index da7d0186c76..eb5544c70bf 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_naive.c @@ -11,9 +11,11 @@ int MPIR_Iallreduce_intra_sched_naive(const void *sendbuf, void *recvbuf, MPI_Ai int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int rank; + int rank, comm_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + if (comm_size == 1) + goto fn_exit; if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) { mpi_errno = diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c index 9ea8c1da8dd..bdb79e3fd0a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_recursive_doubling.c @@ -16,8 +16,7 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re MPI_Aint true_lb, true_extent, extent; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); @@ -39,7 +38,7 @@ int MPIR_Iallreduce_intra_sched_recursive_doubling(const void *sendbuf, void *re } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c index 627ccb2440a..803961e4bb4 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_reduce_scatter_allgather.c @@ -24,8 +24,7 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo MPIR_Assert(HANDLE_IS_BUILTIN(op)); #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* need to allocate temporary buffer to store incoming data */ MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -45,7 +44,7 @@ int MPIR_Iallreduce_intra_sched_reduce_scatter_allgather(const void *sendbuf, vo } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c index ab88975ffa4..85646e3d637 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c @@ -39,8 +39,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c index 67cfde36fb1..284ca9e5b19 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c @@ -45,8 +45,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co MPIR_FUNC_ENTER; is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c index 8cee22ff4f0..a461635964c 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c @@ -30,8 +30,7 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI MPIR_CHKLMEM_DECL(4); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c index b4f51312f1f..4c5d3d79b50 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c @@ -38,8 +38,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c index 40ec46c3629..25584afd8aa 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_brucks.c @@ -38,8 +38,7 @@ int MPIR_Ialltoall_intra_sched_brucks(const void *sendbuf, MPI_Aint sendcount, MPIR_Assert(sendbuf != MPI_IN_PLACE); /* we do not handle in-place */ #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent); MPIR_Datatype_get_size_macro(recvtype, recvtype_sz); diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c index fba7405785c..dbeab81ec5d 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_inplace.c @@ -34,8 +34,8 @@ int MPIR_Ialltoall_intra_sched_inplace(const void *sendbuf, MPI_Aint sendcount, MPIR_Assert(sendbuf == MPI_IN_PLACE); #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Datatype_get_size_macro(recvtype, recvtype_size); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); nbytes = recvtype_size * recvcount; diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c index 4e2c0df6667..54580b0588a 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_pairwise.c @@ -37,8 +37,7 @@ int MPIR_Ialltoall_intra_sched_pairwise(const void *sendbuf, MPI_Aint sendcount, MPIR_Assert(sendbuf != MPI_IN_PLACE); /* we do not handle in-place */ #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); diff --git a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c index a772b406270..23c909a4bcc 100644 --- a/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c +++ b/src/mpi/coll/ialltoall/ialltoall_intra_sched_permuted_sendrecv.c @@ -29,8 +29,7 @@ int MPIR_Ialltoall_intra_sched_permuted_sendrecv(const void *sendbuf, MPI_Aint s MPIR_Assert(sendbuf != MPI_IN_PLACE); /* we do not handle in-place */ #endif - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent); MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c index c6e09c66963..9ab18d690e9 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c @@ -159,8 +159,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, is_inplace = (sendbuf == MPI_IN_PLACE); - rank = MPIR_Comm_rank(comm); - size = MPIR_Comm_size(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); max = size - 1; diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c index cfede050026..989769feb15 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c @@ -46,8 +46,9 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, void *buf1, *buf2, *data_buf, *sbuf, *rbuf; int tag, vtx_id; - int size = MPIR_Comm_size(comm); - int rank = MPIR_Comm_rank(comm); + int size, rank; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + int is_inplace = (sendbuf == MPI_IN_PLACE); MPI_Aint recvtype_lb, recvtype_extent; diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c index 7efd13aed44..1fbf7d2b383 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c @@ -61,8 +61,7 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendc mpi_errno = MPIR_Sched_next_tag(comm, &tag); MPIR_ERR_CHECK(mpi_errno); - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); is_inplace = (sendbuf == MPI_IN_PLACE); /* vtcs is twice the batch size to store both send and recv ids */ diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c index fb42149f818..5cdbf428f09 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_blocked.c @@ -22,8 +22,7 @@ int MPIR_Ialltoallv_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Assert(sendbuf != MPI_IN_PLACE); #endif /* HAVE_ERROR_CHECKING */ - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent and size of recvtype, don't look at sendtype for MPI_IN_PLACE */ MPIR_Datatype_get_extent_macro(recvtype, recv_extent); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c index d08ddfa2be4..3b8444bf5be 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_intra_sched_inplace.c @@ -18,8 +18,7 @@ int MPIR_Ialltoallv_intra_sched_inplace(const void *sendbuf, const MPI_Aint send MPI_Aint recvtype_extent, recvtype_sz; int dst, rank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Get extent and size of recvtype, don't look at sendtype for MPI_IN_PLACE */ MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c index 0f0d1cba5ac..2e3b4d31829 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c @@ -31,8 +31,7 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint mpi_errno = MPIR_Sched_next_tag(comm, &tag); MPIR_ERR_CHECK(mpi_errno); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c index afd3ba5108e..80feead1da9 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c @@ -30,8 +30,7 @@ int MPIR_TSP_Ialltoallv_sched_intra_inplace(const void *sendbuf, const MPI_Aint mpi_errno = MPIR_Sched_next_tag(comm, &tag); MPIR_ERR_CHECK(mpi_errno); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(recvtype, recv_extent); MPIR_Type_get_true_extent_impl(recvtype, &recv_lb, &true_extent); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c index 62c3cf55e54..57ec5f136d9 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c @@ -25,8 +25,8 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain MPIR_Assert(!(sendbuf == MPI_IN_PLACE)); - int size = MPIR_Comm_size(comm); - int rank = MPIR_Comm_rank(comm); + int rank, size; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); MPI_Aint recvtype_lb, recvtype_extent; MPI_Aint sendtype_lb, sendtype_extent; diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c index e203f119915..fe99d505cc2 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_blocked.c @@ -34,8 +34,7 @@ int MPIR_Ialltoallw_intra_sched_blocked(const void *sendbuf, const MPI_Aint send MPIR_Assert(sendbuf != MPI_IN_PLACE); #endif /* HAVE_ERROR_CHECKING */ - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); bblock = MPIR_CVAR_ALLTOALL_THROTTLE; if (bblock == 0) diff --git a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c index 1f865665ad9..7b48e76247e 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_intra_sched_inplace.c @@ -29,8 +29,7 @@ int MPIR_Ialltoallw_intra_sched_inplace(const void *sendbuf, const MPI_Aint send MPI_Aint recvtype_sz; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* The regular MPI_Alltoallw handles MPI_IN_PLACE using pairwise * sendrecv_replace calls. We don't have a sendrecv_replace, so just diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c index 9eaf904d2bf..9675c9ffe3b 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c @@ -25,8 +25,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint MPIR_Assert(sendbuf != MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); if (bblock == 0) bblock = nranks; diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c index e3984434c71..7af7c386b4b 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c @@ -27,8 +27,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_inplace(const void *sendbuf, const MPI_Aint MPIR_Assert(sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c index eafee4f7a5f..8fba9c192d5 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_sched_recursive_doubling.c @@ -26,8 +26,7 @@ int MPIR_Ibarrier_intra_sched_recursive_doubling(MPIR_Comm * comm_ptr, int coll_ MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, size); mask = 0x1; while (mask < size) { diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c index 2dd5ceb3668..e59e162ab65 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c @@ -20,8 +20,7 @@ int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int coll_gro MPIR_FUNC_ENTER; - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); mpi_errno = MPIR_Sched_next_tag(comm, &tag); if (mpi_errno) diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c index 32c8ade7e6f..0b4d309bf27 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_binomial.c @@ -25,8 +25,7 @@ int MPIR_Ibcast_intra_sched_binomial(void *buffer, MPI_Aint count, MPI_Datatype struct MPII_Ibcast_state *ibcast_state; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_is_contig(datatype, &is_contig); MPIR_Datatype_get_size_macro(datatype, type_size); diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c index e088a7d0932..3ece3d2a562 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_recursive_doubling_allgather.c @@ -61,8 +61,8 @@ int MPIR_Ibcast_intra_sched_scatter_recursive_doubling_allgather(void *buffer, M void *tmp_buf; struct MPII_Ibcast_state *ibcast_state; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; #ifdef HAVE_ERROR_CHECKING diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c index 3d78dff1851..3dc2d9bb6bf 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_scatter_ring_allgather.c @@ -38,8 +38,7 @@ int MPIR_Ibcast_intra_sched_scatter_ring_allgather(void *buffer, MPI_Aint count, struct MPII_Ibcast_state *ibcast_state; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); if (HANDLE_IS_BUILTIN(datatype)) is_contig = 1; diff --git a/src/mpi/coll/ibcast/ibcast_tsp_auto.c b/src/mpi/coll/ibcast/ibcast_tsp_auto.c index 4128a82372f..edc9f4aad85 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_auto.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_auto.c @@ -17,14 +17,16 @@ static int MPIR_Ibcast_sched_intra_tsp_flat_auto(void *buffer, MPI_Aint count, MPIR_TSP_sched_t sched) { int mpi_errno = MPI_SUCCESS; - int comm_size; + int comm_size, rank; MPI_Aint type_size, nbytes; int tree_type = MPIR_TREE_TYPE_KNOMIAL_1; int radix = 2, scatterv_k = 2, allgatherv_k = 2, block_size = 0; MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - comm_size = comm_ptr->local_size; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warning */ + MPIR_Datatype_get_size_macro(datatype, type_size); nbytes = type_size * count; diff --git a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c index 7324a96100e..5bece20ed41 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c @@ -37,14 +37,13 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count MPIR_FUNC_ENTER; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); + lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ + MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "Scheduling scatter followed by recursive exchange allgather based broadcast on %d ranks, root=%d\n", - MPIR_Comm_size(comm), root)); - - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); - lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ + size, root)); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); diff --git a/src/mpi/coll/ibcast/ibcast_tsp_tree.c b/src/mpi/coll/ibcast/ibcast_tsp_tree.c index 44e76e4caef..836ffbed01c 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_tree.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_tree.c @@ -29,8 +29,7 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); MPIR_Datatype_get_size_macro(datatype, type_size); MPIR_Datatype_get_extent_macro(datatype, extent); diff --git a/src/mpi/coll/ibcast/ibcast_utils.c b/src/mpi/coll/ibcast/ibcast_utils.c index a600a638cc1..09d035ac87a 100644 --- a/src/mpi/coll/ibcast/ibcast_utils.c +++ b/src/mpi/coll/ibcast/ibcast_utils.c @@ -76,8 +76,8 @@ int MPII_Iscatter_for_bcast_sched(void *tmp_buf, int root, MPIR_Comm * comm_ptr, int relative_rank, mask; MPI_Aint scatter_size, curr_size, recv_size, send_size; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + relative_rank = (rank >= root) ? rank - root : rank - root + comm_size; /* The scatter algorithm divides the buffer into nprocs pieces and diff --git a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c index ac42822de90..799ad718131 100644 --- a/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iexscan/iexscan_intra_sched_recursive_doubling.c @@ -59,8 +59,7 @@ int MPIR_Iexscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvb MPI_Aint true_extent, true_lb, extent; void *partial_scan, *tmp_buf; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/igather/igather_intra_sched_binomial.c b/src/mpi/coll/igather/igather_intra_sched_binomial.c index 42b0f6d8550..b0a0e15c957 100644 --- a/src/mpi/coll/igather/igather_intra_sched_binomial.c +++ b/src/mpi/coll/igather/igather_intra_sched_binomial.c @@ -42,8 +42,7 @@ int MPIR_Igather_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, int copy_offset = 0, copy_blks = 0; MPI_Datatype types[2], tmp_type; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); diff --git a/src/mpi/coll/igather/igather_tsp_tree.c b/src/mpi/coll/igather/igather_tsp_tree.c index b08d2021384..7799da9bdc2 100644 --- a/src/mpi/coll/igather/igather_tsp_tree.c +++ b/src/mpi/coll/igather/igather_tsp_tree.c @@ -32,8 +32,7 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ if (rank == root) diff --git a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c index 47148214fa6..9660596ca43 100644 --- a/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c +++ b/src/mpi/coll/igatherv/igatherv_allcomm_sched_linear.c @@ -22,15 +22,17 @@ int MPIR_Igatherv_allcomm_sched_linear(const void *sendbuf, MPI_Aint sendcount, int comm_size, rank; MPI_Aint extent; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; MPIR_Datatype_get_extent_macro(recvtype, extent); diff --git a/src/mpi/coll/igatherv/igatherv_tsp_linear.c b/src/mpi/coll/igatherv/igatherv_tsp_linear.c index ffa878af186..324cebfb441 100644 --- a/src/mpi/coll/igatherv/igatherv_tsp_linear.c +++ b/src/mpi/coll/igatherv/igatherv_tsp_linear.c @@ -30,7 +30,12 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou int tag; MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); MPIR_ERR_CHECK(mpi_errno); @@ -38,11 +43,6 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou /* If rank == root, then I recv lots, otherwise I send */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; - MPIR_Datatype_get_extent_macro(recvtype, extent); for (i = 0; i < comm_size; i++) { diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c index 3ab93c5c664..d803296b1be 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_binomial.c @@ -17,8 +17,7 @@ int MPIR_Ireduce_intra_sched_binomial(const void *sendbuf, void *recvbuf, MPI_Ai MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Create a temporary buffer */ diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c index 2fc41a8bd9f..501a45bd077 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_reduce_scatter_gather.c @@ -45,8 +45,7 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re MPI_Aint true_lb, true_extent, extent; MPIR_CHKLMEM_DECL(2); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* NOTE: this algorithm is currently only correct for commutative operations */ is_commutative = MPIR_Op_is_commutative(op); @@ -65,7 +64,7 @@ int MPIR_Ireduce_intra_sched_reduce_scatter_gather(const void *sendbuf, void *re tmp_buf = (void *) ((char *) tmp_buf - true_lb); /* get nearest power-of-two less than or equal to comm_size */ - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(HANDLE_IS_BUILTIN(op)); diff --git a/src/mpi/coll/ireduce/ireduce_tsp_auto.c b/src/mpi/coll/ireduce/ireduce_tsp_auto.c index 13d7aae7181..f86966ccaf6 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_auto.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_auto.c @@ -61,10 +61,16 @@ int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MP MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); + int rank, comm_size; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + if (comm_size == 1) { + goto fn_exit; + } + switch (MPIR_CVAR_IREDUCE_INTRA_ALGORITHM) { case MPIR_CVAR_IREDUCE_INTRA_ALGORITHM_tsp_tree: /*Only knomial_1 tree supports non-commutative operations */ - MPII_COLLECTIVE_FALLBACK_CHECK(comm_ptr->rank, MPIR_Op_is_commutative(op) || + MPII_COLLECTIVE_FALLBACK_CHECK(rank, MPIR_Op_is_commutative(op) || MPIR_Ireduce_tree_type == MPIR_TREE_TYPE_KNOMIAL_1, mpi_errno, "Ireduce gentran_tree cannot be applied.\n"); mpi_errno = diff --git a/src/mpi/coll/ireduce/ireduce_tsp_tree.c b/src/mpi/coll/ireduce/ireduce_tsp_tree.c index 47f9c0abd19..664dae78341 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_tree.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_tree.c @@ -42,8 +42,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai MPIR_FUNC_ENTER; - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); is_root = (rank == root); /* main algorithm */ diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c index 3f0ef676834..5d23863a6e6 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_noncommutative.c @@ -27,8 +27,7 @@ int MPIR_Ireduce_scatter_intra_sched_noncommutative(const void *sendbuf, void *r MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size = comm_ptr->local_size; - int rank = comm_ptr->rank; + int comm_size, rank; int log2_comm_size; int i, k; MPI_Aint true_extent, true_lb; @@ -37,6 +36,8 @@ int MPIR_Ireduce_scatter_intra_sched_noncommutative(const void *sendbuf, void *r void *tmp_buf1; void *result_ptr; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c index 41b5dbdec9b..71a2756c623 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_pairwise.c @@ -24,8 +24,7 @@ int MPIR_Ireduce_scatter_intra_sched_pairwise(const void *sendbuf, void *recvbuf void *tmp_recvbuf; int src, dst; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c index c65bf91d165..2f1dc4de1db 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_doubling.c @@ -32,8 +32,7 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_doubling(const void *sendbuf, voi MPI_Datatype sendtype, recvtype; int nprocs_completed, tmp_mask, tree_root, is_commutative; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c index 20460dc2fc1..aa5b1c2ad09 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_intra_sched_recursive_halving.c @@ -48,8 +48,7 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void int rem, newdst, send_idx, recv_idx, last_idx; int pof2, old_i, newrank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); @@ -96,7 +95,7 @@ int MPIR_Ireduce_scatter_intra_sched_recursive_halving(const void *sendbuf, void MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(comm_size); rem = comm_size - pof2; diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c index b49136672d2..f71afdf58cf 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c @@ -162,8 +162,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv mpi_errno = MPIR_Sched_next_tag(comm, &tag); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c index fef8400f794..080e957245d 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_noncommutative.c @@ -15,8 +15,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_noncommutative(const void *sendbuf, v int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size = comm_ptr->local_size; - int rank = comm_ptr->rank; + int comm_size, rank; int log2_comm_size; int i, k; MPI_Aint true_extent, true_lb; @@ -25,6 +24,8 @@ int MPIR_Ireduce_scatter_block_intra_sched_noncommutative(const void *sendbuf, v void *tmp_buf1; void *result_ptr; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c index 339268ddbb5..0f03670c6a5 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_pairwise.c @@ -20,8 +20,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_pairwise(const void *sendbuf, void *r int src, dst; MPI_Aint total_count; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c index 0e7785ef242..3554a64d6a5 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_doubling.c @@ -23,8 +23,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_doubling(const void *sendbu MPI_Datatype sendtype, recvtype; int nprocs_completed, tmp_mask, tree_root, is_commutative; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c index b6673cd19de..34c3b68db34 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_intra_sched_recursive_halving.c @@ -22,8 +22,7 @@ int MPIR_Ireduce_scatter_block_intra_sched_recursive_halving(const void *sendbuf int rem, newdst, send_idx, recv_idx, last_idx, send_cnt, recv_cnt; int pof2, old_i, newrank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c index 14a25188b0b..5e4107de0b4 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c @@ -35,8 +35,7 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void mpi_errno = MPIR_Sched_next_tag(comm, &tag); is_inplace = (sendbuf == MPI_IN_PLACE); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent); diff --git a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c index dc742296782..be5eb268dae 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_recursive_doubling.c @@ -16,8 +16,7 @@ int MPIR_Iscan_intra_sched_recursive_doubling(const void *sendbuf, void *recvbuf void *partial_scan = NULL; void *tmp_buf = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c index eba2405a9aa..2307607c8e1 100644 --- a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c @@ -31,8 +31,7 @@ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *rec mpi_errno = MPIR_Sched_next_tag(comm, &tag); MPIR_ERR_CHECK(mpi_errno); - nranks = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c index e6d0866ed09..5f17deed25c 100644 --- a/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c +++ b/src/mpi/coll/iscatter/iscatter_intra_sched_binomial.c @@ -77,8 +77,7 @@ int MPIR_Iscatter_intra_sched_binomial(const void *sendbuf, MPI_Aint sendcount, void *tmp_buf = NULL; struct shared_state *ss = NULL; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); ss = MPIR_Sched_alloc_state(s, sizeof(struct shared_state)); MPIR_ERR_CHKANDJUMP(!ss, mpi_errno, MPI_ERR_OTHER, "**nomem"); diff --git a/src/mpi/coll/iscatter/iscatter_tsp_tree.c b/src/mpi/coll/iscatter/iscatter_tsp_tree.c index e063be91d87..858f988050d 100644 --- a/src/mpi/coll/iscatter/iscatter_tsp_tree.c +++ b/src/mpi/coll/iscatter/iscatter_tsp_tree.c @@ -34,8 +34,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, MPIR_Errflag_t errflag ATTRIBUTE((unused)) = MPIR_ERR_NONE; MPIR_CHKLMEM_DECL(2); - size = MPIR_Comm_size(comm); - rank = MPIR_Comm_rank(comm); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); lrank = (rank - root + size) % size; /* logical rank when root is non-zero */ if (rank == root) diff --git a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c index 32cfe4f9320..4f19b399152 100644 --- a/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_allcomm_sched_linear.c @@ -26,15 +26,17 @@ int MPIR_Iscatterv_allcomm_sched_linear(const void *sendbuf, const MPI_Aint send MPI_Aint extent; int i; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; MPIR_Datatype_get_extent_macro(sendtype, extent); diff --git a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c index 883dcef4d1b..1ede379982d 100644 --- a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c @@ -22,7 +22,13 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint MPIR_FUNC_ENTER; - rank = comm_ptr->rank; + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ @@ -32,10 +38,6 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) - comm_size = comm_ptr->local_size; - else - comm_size = comm_ptr->remote_size; MPIR_Datatype_get_extent_macro(sendtype, extent); /* We need a check to ensure extent will fit in a diff --git a/src/mpi/coll/mpir_coll_sched_auto.c b/src/mpi/coll/mpir_coll_sched_auto.c index 2caad70c96c..8ad4e037144 100644 --- a/src/mpi/coll/mpir_coll_sched_auto.c +++ b/src/mpi/coll/mpir_coll_sched_auto.c @@ -36,7 +36,6 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size; MPI_Aint type_size, nbytes; MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); @@ -50,7 +49,10 @@ int MPIR_Ibcast_intra_sched_auto(void *buffer, MPI_Aint count, MPI_Datatype data goto fn_exit; } - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + MPIR_Datatype_get_size_macro(datatype, type_size); nbytes = type_size * count; @@ -290,10 +292,11 @@ int MPIR_Iallgather_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MP MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size; MPI_Aint recvtype_size, tot_bytes; - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ MPIR_Datatype_get_size_macro(recvtype, recvtype_size); tot_bytes = (MPI_Aint) recvcount *comm_size * recvtype_size; @@ -342,10 +345,13 @@ int MPIR_Iallgatherv_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int i, comm_size; + int i; MPI_Aint total_count, recvtype_size; - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + MPIR_Datatype_get_size_macro(recvtype, recvtype_size); total_count = 0; @@ -407,10 +413,11 @@ int MPIR_Ialltoall_intra_sched_auto(const void *sendbuf, MPI_Aint sendcount, MPI MPIR_Comm * comm_ptr, int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int comm_size; MPI_Aint nbytes, sendtype_size; - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ MPIR_Datatype_get_size_macro(sendtype, sendtype_size); nbytes = sendtype_size * sendcount; @@ -556,7 +563,7 @@ int MPIR_Ireduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Aint c MPIR_Datatype_get_size_macro(datatype, type_size); /* get nearest power-of-two less than or equal to number of ranks in the communicator */ - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(MPIR_Coll_size(comm_ptr, coll_group)); if ((count * type_size > MPIR_CVAR_REDUCE_SHORT_MSG_SIZE) && (HANDLE_IS_BUILTIN(op)) && (count >= pof2)) { @@ -615,7 +622,7 @@ int MPIR_Iallreduce_intra_sched_auto(const void *sendbuf, void *recvbuf, MPI_Ain MPIR_Datatype_get_size_macro(datatype, type_size); /* get nearest power-of-two less than or equal to number of ranks in the communicator */ - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(MPIR_Coll_size(comm_ptr, coll_group)); /* If op is user-defined or count is less than pof2, use * recursive doubling algorithm. Otherwise do a reduce-scatter @@ -669,11 +676,13 @@ int MPIR_Ireduce_scatter_intra_sched_auto(const void *sendbuf, void *recvbuf, int i; int is_commutative; MPI_Aint total_count, type_size, nbytes; - int comm_size; is_commutative = MPIR_Op_is_commutative(op); - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + total_count = 0; for (i = 0; i < comm_size; i++) { total_count += recvcounts[i]; @@ -752,11 +761,13 @@ int MPIR_Ireduce_scatter_block_intra_sched_auto(const void *sendbuf, void *recvb int mpi_errno = MPI_SUCCESS; int is_commutative; MPI_Aint total_count, type_size, nbytes; - int comm_size; is_commutative = MPIR_Op_is_commutative(op); - comm_size = comm_ptr->local_size; + int comm_size, rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ + total_count = recvcount * comm_size; if (total_count == 0) { goto fn_exit; diff --git a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c index 452564d3d1e..a5a2b1c4f98 100644 --- a/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c +++ b/src/mpi/coll/reduce/reduce_intra_reduce_scatter_gather.c @@ -50,8 +50,7 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, MPIR_CHKLMEM_DECL(4); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); /* Create a temporary buffer */ @@ -78,7 +77,7 @@ int MPIR_Reduce_intra_reduce_scatter_gather(const void *sendbuf, } /* get nearest power-of-two less than or equal to comm_size */ - pof2 = MPL_pof2(comm_ptr->local_size); + pof2 = MPL_pof2(comm_size); #ifdef HAVE_ERROR_CHECKING MPIR_Assert(HANDLE_IS_BUILTIN(op)); diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c index 40c6230304e..f2f169d5c0e 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_noncommutative.c @@ -37,6 +37,8 @@ int MPIR_Reduce_scatter_intra_noncommutative(const void *sendbuf, void *recvbuf, void *result_ptr; MPIR_CHKLMEM_DECL(3); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c index bedaa05e980..a3367dd6705 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_pairwise.c @@ -26,8 +26,7 @@ int MPIR_Reduce_scatter_intra_pairwise(const void *sendbuf, void *recvbuf, int src, dst; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c index 807588fef84..972d1a13fbe 100644 --- a/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter/reduce_scatter_intra_recursive_doubling.c @@ -34,8 +34,7 @@ int MPIR_Reduce_scatter_intra_recursive_doubling(const void *sendbuf, void *recv int nprocs_completed, tmp_mask, tree_root, is_commutative; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c index db1a88edc70..6451bcb4cee 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_noncommutative.c @@ -40,6 +40,8 @@ int MPIR_Reduce_scatter_block_intra_noncommutative(const void *sendbuf, void *result_ptr; MPIR_CHKLMEM_DECL(3); + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); #ifdef HAVE_ERROR_CHECKING diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c index 7cbb663b0ac..132d16218bc 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_pairwise.c @@ -35,8 +35,7 @@ int MPIR_Reduce_scatter_block_intra_pairwise(const void *sendbuf, int src, dst; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c index b8ed30050c7..82d2d20c692 100644 --- a/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c +++ b/src/mpi/coll/reduce_scatter_block/reduce_scatter_block_intra_recursive_doubling.c @@ -39,8 +39,7 @@ int MPIR_Reduce_scatter_block_intra_recursive_doubling(const void *sendbuf, int nprocs_completed, tmp_mask, tree_root, is_commutative; MPIR_CHKLMEM_DECL(5); - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); MPIR_Datatype_get_extent_macro(datatype, extent); MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); diff --git a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c index da5be5a0b1f..223027e6128 100644 --- a/src/mpi/coll/scatterv/scatterv_allcomm_linear.c +++ b/src/mpi/coll/scatterv/scatterv_allcomm_linear.c @@ -29,14 +29,17 @@ int MPIR_Scatterv_allcomm_linear(const void *sendbuf, const MPI_Aint * sendcount MPI_Status *starray; MPIR_CHKLMEM_DECL(2); - MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); + } else { + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + rank = comm_ptr->rank; + comm_size = comm_ptr->remote_size; + } /* If I'm the root, then scatter */ if (((comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) && (root == rank)) || ((comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) && (root == MPI_ROOT))) { - if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) - comm_size = comm_ptr->remote_size; - MPIR_Datatype_get_extent_macro(sendtype, extent); MPIR_CHKLMEM_MALLOC(reqarray, MPIR_Request **, comm_size * sizeof(MPIR_Request *), diff --git a/src/mpi/coll/src/csel.c b/src/mpi/coll/src/csel.c index e0bb9dbf4ce..4422c78dea0 100644 --- a/src/mpi/coll/src/csel.c +++ b/src/mpi/coll/src/csel.c @@ -924,7 +924,10 @@ static inline MPI_Aint get_count(MPIR_Csel_coll_sig_s coll_info) { MPI_Aint count = 0; int i = 0; - int comm_size = coll_info.comm_ptr->local_size; + + int comm_size, rank; + MPIR_COLL_RANK_SIZE(coll_info.comm_ptr, coll_info.coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ switch (coll_info.coll_type) { case MPIR_CSEL_COLL_TYPE__BCAST: @@ -981,7 +984,10 @@ static inline MPI_Aint get_count(MPIR_Csel_coll_sig_s coll_info) static inline MPI_Aint get_total_msgsize(MPIR_Csel_coll_sig_s coll_info) { MPI_Aint total_bytes = 0, i = 0, count = 0, typesize = 0; - int comm_size = coll_info.comm_ptr->local_size; + + int comm_size, rank; + MPIR_COLL_RANK_SIZE(coll_info.comm_ptr, coll_info.coll_group, rank, comm_size); + (void) rank; /* silence unused variable warnings */ switch (coll_info.coll_type) { case MPIR_CSEL_COLL_TYPE__ALLREDUCE: @@ -1285,7 +1291,8 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COUNT_LT_POW2: - if (get_count(coll_info) < MPL_pof2(coll_info.comm_ptr->local_size)) + if (get_count(coll_info) < + MPL_pof2(MPIR_Coll_size(coll_info.comm_ptr, coll_info.coll_group))) node = node->success; else node = node->failure; From bdd4532467fb70a9b9b92823689e15505efeb693 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 18 Aug 2024 10:22:19 -0500 Subject: [PATCH 16/27] coll: modify bcast_intra_smp to use subgroups Replace the usage of subcomms with subgroups. --- src/mpi/coll/bcast/bcast_intra_smp.c | 67 ++++++++++++++++------------ 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/src/mpi/coll/bcast/bcast_intra_smp.c b/src/mpi/coll/bcast/bcast_intra_smp.c index 04deb237d2f..04d3ec02fdb 100644 --- a/src/mpi/coll/bcast/bcast_intra_smp.c +++ b/src/mpi/coll/bcast/bcast_intra_smp.c @@ -27,6 +27,21 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in #ifdef HAVE_ERROR_CHECKING MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); #endif + int comm_size = comm_ptr->local_size; + + int node_group = 0, node_roots_group = 0; + int local_rank, local_size, local_root, inter_root = -1; + + node_group = MPIR_SUBGROUP_NODE; +#define NODEGROUP(field) comm_ptr->subgroups[node_group].field + + local_rank = NODEGROUP(rank); + local_size = NODEGROUP(size); + local_root = MPIR_Get_intranode_rank(comm_ptr, root); + if (local_rank == 0) { + node_roots_group = MPIR_SUBGROUP_NODE_CROSS; + inter_root = MPIR_Get_internode_rank(comm_ptr, root); + } MPIR_Datatype_get_size_macro(datatype, type_size); @@ -34,18 +49,17 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in if (nbytes == 0) goto fn_exit; /* nothing to do */ - if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || - (comm_ptr->local_size < MPIR_CVAR_BCAST_MIN_PROCS)) { + if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS)) { /* send to intranode-rank 0 on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) and is on our node (!-1) */ - if (root == comm_ptr->rank) { + if (local_root > 0) { /* is not the node root (0) and is on our node (!-1) */ + if (local_rank == local_root) { mpi_errno = MPIC_Send(buffer, count, datatype, 0, - MPIR_BCAST_TAG, comm_ptr->node_comm, coll_group, errflag); + MPIR_BCAST_TAG, comm_ptr, node_group, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (0 == comm_ptr->node_comm->rank) { - mpi_errno = - MPIC_Recv(buffer, count, datatype, MPIR_Get_intranode_rank(comm_ptr, root), - MPIR_BCAST_TAG, comm_ptr->node_comm, coll_group, status_p); + mpi_errno = MPIC_Recv(buffer, count, datatype, + MPIR_Get_intranode_rank(comm_ptr, root), + MPIR_BCAST_TAG, comm_ptr, node_group, status_p); MPIR_ERR_CHECK(mpi_errno); #ifdef HAVE_ERROR_CHECKING /* check that we received as much as we expected */ @@ -59,54 +73,49 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in } - /* perform the internode broadcast */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, - MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, coll_group, errflag); + /* local roots perform the internode broadcast */ + if (local_rank == 0) { + mpi_errno = MPIR_Bcast(buffer, count, datatype, inter_root, + comm_ptr, node_roots_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* perform the intranode broadcast on all except for the root's node */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, - coll_group, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr, node_group, errflag); MPIR_ERR_CHECK(mpi_errno); } - } else { /* (nbytes > MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_ptr->size >= MPIR_CVAR_BCAST_MIN_PROCS) */ + } else { /* (nbytes > MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_size >= MPIR_CVAR_BCAST_MIN_PROCS) */ /* supposedly... * smp+doubling good for pof2 * reg+ring better for non-pof2 */ - if (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(comm_ptr->local_size)) { + if (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(comm_size)) { /* medium-sized msg and pof2 np */ /* perform the intranode broadcast on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) and is on our node (!-1) */ + if (local_size > 1 && local_root > 0) { /* is not the node root (0) and is on our node (!-1) */ /* FIXME binomial may not be the best algorithm for on-node * bcast. We need a more comprehensive system for selecting the * right algorithms here. */ - mpi_errno = MPIR_Bcast(buffer, count, datatype, - MPIR_Get_intranode_rank(comm_ptr, root), - comm_ptr->node_comm, coll_group, errflag); + mpi_errno = MPIR_Bcast(buffer, count, datatype, local_root, + comm_ptr, node_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* perform the internode broadcast */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, - MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, coll_group, errflag); + if (local_rank == 0) { + mpi_errno = MPIR_Bcast(buffer, count, datatype, inter_root, + comm_ptr, node_roots_group, errflag); MPIR_ERR_CHECK(mpi_errno); } /* perform the intranode broadcast on all except for the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) <= 0) { /* 0 if root was local root too, -1 if different node than root */ + if (local_size > 1 && local_root <= 0) { /* 0 if root was local root too, -1 if different node than root */ /* FIXME binomial may not be the best algorithm for on-node * bcast. We need a more comprehensive system for selecting the * right algorithms here. */ - mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr->node_comm, - coll_group, errflag); + mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr, node_group, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { /* large msg or non-pof2 */ From 066e586682687185b1f5670a4b60c83c36008bec Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 18 Aug 2024 10:40:02 -0500 Subject: [PATCH 17/27] coll: avoid extra intra bcast in bcast_intra_smp When root is not local rank 0, instead of adding a extra intra-node send/recv or bcast, construct an inter group that includes the root process. --- src/mpi/coll/bcast/bcast_intra_smp.c | 128 ++++++++++----------------- 1 file changed, 49 insertions(+), 79 deletions(-) diff --git a/src/mpi/coll/bcast/bcast_intra_smp.c b/src/mpi/coll/bcast/bcast_intra_smp.c index 04d3ec02fdb..d875772bf56 100644 --- a/src/mpi/coll/bcast/bcast_intra_smp.c +++ b/src/mpi/coll/bcast/bcast_intra_smp.c @@ -5,6 +5,34 @@ #include "mpiimpl.h" +/* TODO: move this to commutil.c */ +static void MPIR_Comm_construct_internode_roots_group(MPIR_Comm * comm, int root, + int *group_p, int *root_rank_p) +{ + int inter_size = comm->num_external; + int inter_rank = comm->internode_table[comm->rank]; + + int inter_group, *proc_table; + MPIR_COMM_PUSH_SUBGROUP(comm, inter_size, inter_rank, inter_group, proc_table); + + for (int i = 0; i < inter_size; i++) { + proc_table[i] = -1; + } + for (int i = 0; i < comm->local_size; i++) { + int r = comm->internode_table[i]; + if (proc_table[r] == -1) { + proc_table[r] = i; + } + } + int inter_root_rank = comm->internode_table[root]; + proc_table[inter_root_rank] = root; + + comm->subgroups[inter_group].proc_table = proc_table; + + *group_p = inter_group; + *root_rank_p = inter_root_rank; +} + /* FIXME This function uses some heuristsics based off of some testing on a * cluster at Argonne. We need a better system for detrmining and controlling * the cutoff points for these algorithms. If I've done this right, you should @@ -15,14 +43,6 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in { int mpi_errno = MPI_SUCCESS; MPI_Aint type_size, nbytes = 0; - MPI_Status *status_p; -#ifdef HAVE_ERROR_CHECKING - MPI_Status status; - status_p = &status; - MPI_Aint recvd_size; -#else - status_p = MPI_STATUS_IGNORE; -#endif #ifdef HAVE_ERROR_CHECKING MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); @@ -38,9 +58,13 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in local_rank = NODEGROUP(rank); local_size = NODEGROUP(size); local_root = MPIR_Get_intranode_rank(comm_ptr, root); - if (local_rank == 0) { - node_roots_group = MPIR_SUBGROUP_NODE_CROSS; - inter_root = MPIR_Get_internode_rank(comm_ptr, root); + if (local_root < 0) { + /* non-root node use local rank 0 as local root */ + local_root = 0; + } + if (local_rank == local_root) { + MPIR_Comm_construct_internode_roots_group(comm_ptr, root, &node_roots_group, &inter_root); + MPIR_Assert(node_roots_group > 0); } MPIR_Datatype_get_size_macro(datatype, type_size); @@ -49,87 +73,33 @@ int MPIR_Bcast_intra_smp(void *buffer, MPI_Aint count, MPI_Datatype datatype, in if (nbytes == 0) goto fn_exit; /* nothing to do */ - if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS)) { - /* send to intranode-rank 0 on the root's node */ - if (local_root > 0) { /* is not the node root (0) and is on our node (!-1) */ - if (local_rank == local_root) { - mpi_errno = MPIC_Send(buffer, count, datatype, 0, - MPIR_BCAST_TAG, comm_ptr, node_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - } else if (0 == comm_ptr->node_comm->rank) { - mpi_errno = MPIC_Recv(buffer, count, datatype, - MPIR_Get_intranode_rank(comm_ptr, root), - MPIR_BCAST_TAG, comm_ptr, node_group, status_p); - MPIR_ERR_CHECK(mpi_errno); -#ifdef HAVE_ERROR_CHECKING - /* check that we received as much as we expected */ - MPIR_Get_count_impl(status_p, MPI_BYTE, &recvd_size); - MPIR_ERR_CHKANDJUMP2(recvd_size != nbytes, mpi_errno, MPI_ERR_OTHER, - "**collective_size_mismatch", - "**collective_size_mismatch %d %d", - (int) recvd_size, (int) nbytes); -#endif - } - - } - + if ((nbytes < MPIR_CVAR_BCAST_SHORT_MSG_SIZE) || (comm_size < MPIR_CVAR_BCAST_MIN_PROCS) || + (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(comm_size))) { /* local roots perform the internode broadcast */ - if (local_rank == 0) { + if (local_rank == local_root) { mpi_errno = MPIR_Bcast(buffer, count, datatype, inter_root, comm_ptr, node_roots_group, errflag); MPIR_ERR_CHECK(mpi_errno); } - /* perform the intranode broadcast on all except for the root's node */ + /* perform the intranode broadcast */ if (local_size > 1) { mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr, node_group, errflag); MPIR_ERR_CHECK(mpi_errno); } - } else { /* (nbytes > MPIR_CVAR_BCAST_SHORT_MSG_SIZE) && (comm_size >= MPIR_CVAR_BCAST_MIN_PROCS) */ - - /* supposedly... - * smp+doubling good for pof2 - * reg+ring better for non-pof2 */ - if (nbytes < MPIR_CVAR_BCAST_LONG_MSG_SIZE && MPL_is_pof2(comm_size)) { - /* medium-sized msg and pof2 np */ - - /* perform the intranode broadcast on the root's node */ - if (local_size > 1 && local_root > 0) { /* is not the node root (0) and is on our node (!-1) */ - /* FIXME binomial may not be the best algorithm for on-node - * bcast. We need a more comprehensive system for selecting the - * right algorithms here. */ - mpi_errno = MPIR_Bcast(buffer, count, datatype, local_root, - comm_ptr, node_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - } - - /* perform the internode broadcast */ - if (local_rank == 0) { - mpi_errno = MPIR_Bcast(buffer, count, datatype, inter_root, - comm_ptr, node_roots_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - } - - /* perform the intranode broadcast on all except for the root's node */ - if (local_size > 1 && local_root <= 0) { /* 0 if root was local root too, -1 if different node than root */ - /* FIXME binomial may not be the best algorithm for on-node - * bcast. We need a more comprehensive system for selecting the - * right algorithms here. */ - mpi_errno = MPIR_Bcast(buffer, count, datatype, 0, comm_ptr, node_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - } - } else { /* large msg or non-pof2 */ - - /* FIXME It would be good to have an SMP-aware version of this - * algorithm that (at least approximately) minimized internode - * communication. */ - mpi_errno = MPIR_Bcast_intra_scatter_ring_allgather(buffer, count, datatype, root, - comm_ptr, coll_group, errflag); - MPIR_ERR_CHECK(mpi_errno); - } + } else { + /* FIXME It would be good to have an SMP-aware version of this + * algorithm that (at least approximately) minimized internode + * communication. */ + mpi_errno = MPIR_Bcast_intra_scatter_ring_allgather(buffer, count, datatype, root, comm_ptr, + coll_group, errflag); + MPIR_ERR_CHECK(mpi_errno); } fn_exit: + if (node_roots_group) { + MPIR_COMM_POP_SUBGROUP(comm_ptr); + } return mpi_errno; fn_fail: goto fn_exit; From e3749209f5444bfa794fdbc273f1401fc4ada1fb Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 18 Aug 2024 23:48:37 -0500 Subject: [PATCH 18/27] coll: modify smp algorithms to use MPIR_Subgroup --- src/mpi/coll/allreduce/allreduce_intra_smp.c | 30 ++++---- src/mpi/coll/barrier/barrier_intra_smp.c | 14 ++-- .../iallreduce/iallreduce_intra_sched_smp.c | 34 ++++----- src/mpi/coll/ibcast/ibcast_intra_sched_smp.c | 34 ++++----- .../coll/ireduce/ireduce_intra_sched_smp.c | 46 ++++++------ src/mpi/coll/iscan/iscan_intra_sched_smp.c | 70 ++++++++----------- src/mpi/coll/reduce/reduce_intra_smp.c | 38 +++++----- src/mpi/coll/scan/scan_intra_smp.c | 56 ++++++++------- 8 files changed, 156 insertions(+), 166 deletions(-) diff --git a/src/mpi/coll/allreduce/allreduce_intra_smp.c b/src/mpi/coll/allreduce/allreduce_intra_smp.c index 4ad6920dcbf..08818cf1900 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_smp.c +++ b/src/mpi/coll/allreduce/allreduce_intra_smp.c @@ -11,24 +11,26 @@ int MPIR_Allreduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, { int mpi_errno = MPI_SUCCESS; + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + /* on each node, do a reduce to the local root */ - if (comm_ptr->node_comm != NULL) { + if (local_size > 1) { /* take care of the MPI_IN_PLACE case. For reduce, * MPI_IN_PLACE is specified only on the root; * for allreduce it is specified on all processes. */ - if ((sendbuf == MPI_IN_PLACE) && (comm_ptr->node_comm->rank != 0)) { + if ((sendbuf == MPI_IN_PLACE) && (local_rank != 0)) { /* IN_PLACE and not root of reduce. Data supplied to this * allreduce is in recvbuf. Pass that as the sendbuf to reduce. */ - mpi_errno = - MPIR_Reduce(recvbuf, NULL, count, datatype, op, 0, comm_ptr->node_comm, - coll_group, errflag); + mpi_errno = MPIR_Reduce(recvbuf, NULL, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } else { - mpi_errno = - MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr->node_comm, - coll_group, errflag); + mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } } else { @@ -40,17 +42,15 @@ int MPIR_Allreduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, } /* now do an IN_PLACE allreduce among the local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = - MPIR_Allreduce(MPI_IN_PLACE, recvbuf, count, datatype, op, comm_ptr->node_roots_comm, - coll_group, errflag); + if (local_rank == 0) { + mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } /* now broadcast the result among local processes */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, comm_ptr->node_comm, - coll_group, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } goto fn_exit; diff --git a/src/mpi/coll/barrier/barrier_intra_smp.c b/src/mpi/coll/barrier/barrier_intra_smp.c index bafb2d34600..cef13c91059 100644 --- a/src/mpi/coll/barrier/barrier_intra_smp.c +++ b/src/mpi/coll/barrier/barrier_intra_smp.c @@ -10,25 +10,27 @@ int MPIR_Barrier_intra_smp(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t int mpi_errno = MPI_SUCCESS; MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; /* do the intranode barrier on all nodes */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Barrier(comm_ptr->node_comm, coll_group, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Barrier(comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } /* do the barrier across roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Barrier(comm_ptr->node_roots_comm, coll_group, errflag); + if (local_rank == 0) { + mpi_errno = MPIR_Barrier(comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } /* release the local processes on each node with a 1-byte * broadcast (0-byte broadcast just returns without doing * anything) */ - if (comm_ptr->node_comm != NULL) { + if (local_size > 1) { int i = 0; - mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, 0, comm_ptr->node_comm, coll_group, errflag); + mpi_errno = MPIR_Bcast(&i, 1, MPI_BYTE, 0, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c index ad363bc8e54..31d4767108a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c +++ b/src/mpi/coll/iallreduce/iallreduce_intra_sched_smp.c @@ -12,13 +12,10 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint { int mpi_errno = MPI_SUCCESS; int is_commutative; - MPIR_Comm *nc; - MPIR_Comm *nrc; MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); - - nc = comm_ptr->node_comm; - nrc = comm_ptr->node_roots_comm; + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; is_commutative = MPIR_Op_is_commutative(op); @@ -33,21 +30,19 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint } /* on each node, do a reduce to the local root */ - if (nc != NULL) { + if (local_size > 1) { /* take care of the MPI_IN_PLACE case. For reduce, * MPI_IN_PLACE is specified only on the root; * for allreduce it is specified on all processes. */ - if ((sendbuf == MPI_IN_PLACE) && (comm_ptr->node_comm->rank != 0)) { + if ((sendbuf == MPI_IN_PLACE) && (local_rank != 0)) { /* IN_PLACE and not root of reduce. Data supplied to this * allreduce is in recvbuf. Pass that as the sendbuf to reduce. */ - mpi_errno = - MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, nc, - MPIR_SUBGROUP_NONE, s); + mpi_errno = MPIR_Ireduce_intra_sched_auto(recvbuf, NULL, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); } else { - mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, nc, - MPIR_SUBGROUP_NONE, s); + mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); } MPIR_SCHED_BARRIER(s); @@ -61,18 +56,17 @@ int MPIR_Iallreduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint } /* now do an IN_PLACE allreduce among the local roots of all nodes */ - if (nrc != NULL) { - mpi_errno = - MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, recvbuf, count, datatype, op, nrc, - MPIR_SUBGROUP_NONE, s); + if (local_rank == 0) { + mpi_errno = MPIR_Iallreduce_intra_sched_auto(MPI_IN_PLACE, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } /* now broadcast the result among local processes */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = - MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, nc, MPIR_SUBGROUP_NONE, s); + if (local_size > 1) { + mpi_errno = MPIR_Ibcast_intra_sched_auto(recvbuf, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c index 33b7dd5fa64..70ae42fb5ab 100644 --- a/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c +++ b/src/mpi/coll/ibcast/ibcast_intra_sched_smp.c @@ -37,6 +37,10 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat #ifdef HAVE_ERROR_CHECKING MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); #endif + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int local_root = MPIR_Get_intranode_rank(comm_ptr, root); + ibcast_state = MPIR_Sched_alloc_state(s, sizeof(struct MPII_Ibcast_state)); MPIR_ERR_CHKANDJUMP(!ibcast_state, mpi_errno, MPI_ERR_OTHER, "**nomem"); @@ -46,16 +50,15 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat /* TODO insert packing here */ /* send to intranode-rank 0 on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) > 0) { /* is not the node root (0) *//* and is on our node (!-1) */ - if (root == comm_ptr->rank) { + if (local_size > 1 && local_root > 0) { /* is not the node root (0) *//* and is on our node (!-1) */ + if (local_rank == local_root) { mpi_errno = - MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr->node_comm, coll_group, s); + MPIR_Sched_send(buffer, count, datatype, 0, comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); - } else if (0 == comm_ptr->node_comm->rank) { - mpi_errno = - MPIR_Sched_recv_status(buffer, count, datatype, - MPIR_Get_intranode_rank(comm_ptr, root), comm_ptr->node_comm, - coll_group, &ibcast_state->status, s); + } else if (local_rank == 0) { + mpi_errno = MPIR_Sched_recv_status(buffer, count, datatype, local_root, + comm_ptr, MPIR_SUBGROUP_NODE, &ibcast_state->status, + s); MPIR_ERR_CHECK(mpi_errno); #ifdef HAVE_ERROR_CHECKING MPIR_SCHED_BARRIER(s); @@ -67,20 +70,19 @@ int MPIR_Ibcast_intra_sched_smp(void *buffer, MPI_Aint count, MPI_Datatype datat } /* perform the internode broadcast */ - if (comm_ptr->node_roots_comm != NULL) { - mpi_errno = MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, - MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, coll_group, s); + if (local_rank == 0) { + int inter_root = MPIR_Get_internode_rank(comm_ptr, root); + mpi_errno = MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); /* don't allow the local ops for the intranode phase to start until this has completed */ MPIR_SCHED_BARRIER(s); } /* perform the intranode broadcast on all except for the root's node */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = - MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, 0, comm_ptr->node_comm, - coll_group, s); + if (local_size > 1) { + mpi_errno = MPIR_Ibcast_intra_sched_auto(buffer, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c index 13915416320..1e00f6ea3f8 100644 --- a/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c +++ b/src/mpi/coll/ireduce/ireduce_intra_sched_smp.c @@ -13,14 +13,12 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co int is_commutative; MPI_Aint true_lb, true_extent, extent; void *tmp_buf = NULL; - MPIR_Comm *nc; - MPIR_Comm *nrc; MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM); - - nc = comm_ptr->node_comm; - nrc = comm_ptr->node_roots_comm; + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int local_root = MPIR_Get_intranode_rank(comm_ptr, root); /* is the op commutative? We do SMP optimizations only if it is. */ is_commutative = MPIR_Op_is_commutative(op); @@ -33,7 +31,7 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co } /* Create a temporary buffer on local roots of all nodes */ - if (nrc != NULL) { + if (local_rank == 0) { MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -44,34 +42,32 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co } /* do the intranode reduce on all nodes other than the root's node */ - if (nc != NULL && MPIR_Get_intranode_rank(comm_ptr, root) == -1) { - mpi_errno = - MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, nc, - MPIR_SUBGROUP_NONE, s); + if (local_size > 1 && local_root == -1) { + mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } /* do the internode reduce to the root's node */ - if (nrc != NULL) { - if (nrc->rank != MPIR_Get_internode_rank(comm_ptr, root)) { + if (local_rank == 0) { + int inter_root = MPIR_Get_internode_rank(comm_ptr, root); + if (local_root < 0) { /* I am not on root's node. Use tmp_buf if we * participated in the first reduce, otherwise use sendbuf */ - const void *buf = (nc == NULL ? sendbuf : tmp_buf); - mpi_errno = MPIR_Ireduce_intra_sched_auto(buf, NULL, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - nrc, MPIR_SUBGROUP_NONE, s); + const void *buf = (local_size > 1 ? tmp_buf : sendbuf); + mpi_errno = MPIR_Ireduce_intra_sched_auto(buf, NULL, count, datatype, op, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else { /* I am on root's node. I have not participated in the earlier reduce. */ - if (comm_ptr->rank != root) { + if (local_rank != local_root) { /* I am not the root though. I don't have a valid recvbuf. * Use tmp_buf as recvbuf. */ mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, tmp_buf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, - root), nrc, - MPIR_SUBGROUP_NONE, s); + op, inter_root, comm_ptr, + MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -81,9 +77,8 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co /* I am the root. in_place is automatically handled. */ mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, - root), nrc, - MPIR_SUBGROUP_NONE, s); + op, inter_root, comm_ptr, + MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); @@ -94,10 +89,9 @@ int MPIR_Ireduce_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint co } /* do the intranode reduce on the root's node */ - if (nc != NULL && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { + if (local_size > 1 && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { mpi_errno = MPIR_Ireduce_intra_sched_auto(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm_ptr, root), nc, - MPIR_SUBGROUP_NONE, s); + op, local_root, comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/iscan/iscan_intra_sched_smp.c b/src/mpi/coll/iscan/iscan_intra_sched_smp.c index 9e3b4ed0b86..20e5bbdc445 100644 --- a/src/mpi/coll/iscan/iscan_intra_sched_smp.c +++ b/src/mpi/coll/iscan/iscan_intra_sched_smp.c @@ -11,14 +11,16 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun int coll_group, MPIR_Sched_t s) { int mpi_errno = MPI_SUCCESS; - int rank = comm_ptr->rank; - MPIR_Comm *node_comm; - MPIR_Comm *roots_comm; MPI_Aint true_extent, true_lb, extent; void *tempbuf = NULL; void *prefulldata = NULL; void *localfulldata = NULL; + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int inter_rank = MPIR_Get_internode_rank(comm_ptr, comm_ptr->rank); + /* In order to use the SMP-aware algorithm, the "op" can be * either commutative or non-commutative, but we require a * communicator in which all the nodes contain processes with @@ -31,9 +33,6 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun comm_ptr, coll_group, s); } - node_comm = comm_ptr->node_comm; - roots_comm = comm_ptr->node_roots_comm; - MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -42,12 +41,12 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun tempbuf = (void *) ((char *) tempbuf - true_lb); /* Create prefulldata and localfulldata on local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { prefulldata = MPIR_Sched_alloc_state(s, count * (MPL_MAX(extent, true_extent))); MPIR_ERR_CHKANDJUMP(!prefulldata, mpi_errno, MPI_ERR_OTHER, "**nomem"); prefulldata = (void *) ((char *) prefulldata - true_lb); - if (node_comm != NULL) { + if (local_size > 1) { localfulldata = MPIR_Sched_alloc_state(s, count * (MPL_MAX(extent, true_extent))); MPIR_ERR_CHKANDJUMP(!localfulldata, mpi_errno, MPI_ERR_OTHER, "**nomem"); localfulldata = (void *) ((char *) localfulldata - true_lb); @@ -56,10 +55,9 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun /* perform intranode scan to get temporary result in recvbuf. if there is only * one process, just copy the raw data. */ - if (node_comm != NULL) { - mpi_errno = - MPIR_Iscan_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, node_comm, - coll_group, s); + if (local_size > 1) { + mpi_errno = MPIR_Iscan_intra_sched_auto(sendbuf, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } else if (sendbuf != MPI_IN_PLACE) { @@ -72,17 +70,16 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * contains the reduce result of the whole node. Name it as * localfulldata. For example, localfulldata from node 1 contains * reduced data of rank 1,2,3. */ - if (roots_comm != NULL && node_comm != NULL) { - mpi_errno = MPIR_Sched_recv(localfulldata, count, datatype, - (node_comm->local_size - 1), node_comm, coll_group, s); + if (local_rank == 0 && local_size > 1) { + mpi_errno = MPIR_Sched_recv(localfulldata, count, datatype, (local_size - 1), + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - } else if (roots_comm == NULL && node_comm != NULL && - node_comm->rank == node_comm->local_size - 1) { - mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, 0, node_comm, coll_group, s); + } else if (local_rank != 0 && local_size > 1 && local_rank == local_size - 1) { + mpi_errno = MPIR_Sched_send(recvbuf, count, datatype, 0, comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - } else if (roots_comm != NULL) { + } else if (local_rank == 0) { localfulldata = recvbuf; } @@ -90,28 +87,23 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * prefulldata on rank 4 contains reduce result of ranks * 1,2,3,4,5,6. it will be sent to rank 7 which is the * main process of node 3. */ - if (roots_comm != NULL) { - /* FIXME just use roots_comm->rank instead */ - int roots_rank = MPIR_Get_internode_rank(comm_ptr, rank); - MPIR_Assert(roots_rank == roots_comm->rank); - - mpi_errno = - MPIR_Iscan_intra_sched_auto(localfulldata, prefulldata, count, datatype, op, roots_comm, - coll_group, s); + if (local_rank == 0) { + int inter_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE_CROSS].size; + + mpi_errno = MPIR_Iscan_intra_sched_auto(localfulldata, prefulldata, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); - if (roots_rank != roots_comm->local_size - 1) { - mpi_errno = - MPIR_Sched_send(prefulldata, count, datatype, (roots_rank + 1), roots_comm, - coll_group, s); + if (inter_rank != inter_size - 1) { + mpi_errno = MPIR_Sched_send(prefulldata, count, datatype, (inter_rank + 1), + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } - if (roots_rank != 0) { - mpi_errno = - MPIR_Sched_recv(tempbuf, count, datatype, (roots_rank - 1), roots_comm, coll_group, - s); + if (inter_rank != 0) { + mpi_errno = MPIR_Sched_recv(tempbuf, count, datatype, (inter_rank - 1), + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } @@ -123,13 +115,13 @@ int MPIR_Iscan_intra_sched_smp(const void *sendbuf, void *recvbuf, MPI_Aint coun * then we should broadcast this result in the local node, and * reduce it with recvbuf to get final result if necessary. */ - if (MPIR_Get_internode_rank(comm_ptr, rank) != 0) { + if (inter_rank != 0) { /* we aren't on "node 0", so our node leader (possibly us) received * "prefulldata" from another leader into "tempbuf" */ - if (node_comm != NULL) { - mpi_errno = - MPIR_Ibcast_intra_sched_auto(tempbuf, count, datatype, 0, node_comm, coll_group, s); + if (local_size > 1) { + mpi_errno = MPIR_Ibcast_intra_sched_auto(tempbuf, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, s); MPIR_ERR_CHECK(mpi_errno); MPIR_SCHED_BARRIER(s); } diff --git a/src/mpi/coll/reduce/reduce_intra_smp.c b/src/mpi/coll/reduce/reduce_intra_smp.c index e932611afc1..a22fc8ed956 100644 --- a/src/mpi/coll/reduce/reduce_intra_smp.c +++ b/src/mpi/coll/reduce/reduce_intra_smp.c @@ -19,11 +19,17 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, int is_commutative; is_commutative = MPIR_Op_is_commutative(op); MPIR_Assertp(is_commutative); + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); } #endif /* HAVE_ERROR_CHECKING */ + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + int local_root = MPIR_Get_intranode_rank(comm_ptr, root); + int inter_root = MPIR_Get_internode_rank(comm_ptr, root); + /* Create a temporary buffer on local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -35,30 +41,29 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, } /* do the intranode reduce on all nodes other than the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) == -1) { - mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, - op, 0, comm_ptr->node_comm, coll_group, errflag); + if (local_size > 1 && local_root == -1) { + mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } /* do the internode reduce to the root's node */ - if (comm_ptr->node_roots_comm != NULL) { - if (comm_ptr->node_roots_comm->rank != MPIR_Get_internode_rank(comm_ptr, root)) { + if (local_rank == 0) { + if (local_root == -1) { /* I am not on root's node. Use tmp_buf if we * participated in the first reduce, otherwise use sendbuf */ - const void *buf = (comm_ptr->node_comm == NULL ? sendbuf : tmp_buf); + const void *buf = (local_size > 1 ? tmp_buf : sendbuf); mpi_errno = MPIR_Reduce(buf, NULL, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, coll_group, errflag); + op, inter_root, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } else { /* I am on root's node. I have not participated in the earlier reduce. */ - if (comm_ptr->rank != root) { + if (local_root != 0) { /* I am not the root though. I don't have a valid recvbuf. * Use tmp_buf as recvbuf. */ mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, coll_group, errflag); + op, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); /* point sendbuf at tmp_buf to make final intranode reduce easy */ @@ -67,8 +72,8 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, /* I am the root. in_place is automatically handled. */ mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_internode_rank(comm_ptr, root), - comm_ptr->node_roots_comm, coll_group, errflag); + op, inter_root, + comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); /* set sendbuf to MPI_IN_PLACE to make final intranode reduce easy. */ @@ -79,10 +84,9 @@ int MPIR_Reduce_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, } /* do the intranode reduce on the root's node */ - if (comm_ptr->node_comm != NULL && MPIR_Get_intranode_rank(comm_ptr, root) != -1) { + if (local_size > 1 && local_root != -1) { mpi_errno = MPIR_Reduce(sendbuf, recvbuf, count, datatype, - op, MPIR_Get_intranode_rank(comm_ptr, root), - comm_ptr->node_comm, coll_group, errflag); + op, local_root, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/coll/scan/scan_intra_smp.c b/src/mpi/coll/scan/scan_intra_smp.c index efd5e09ac68..1b798979181 100644 --- a/src/mpi/coll/scan/scan_intra_smp.c +++ b/src/mpi/coll/scan/scan_intra_smp.c @@ -12,13 +12,16 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, { int mpi_errno = MPI_SUCCESS; MPIR_CHKLMEM_DECL(3); - int rank = comm_ptr->rank; MPI_Status status; void *tempbuf = NULL, *localfulldata = NULL, *prefulldata = NULL; MPI_Aint true_lb, true_extent, extent; int noneed = 1; /* noneed=1 means no need to bcast tempbuf and * reduce tempbuf & recvbuf */ + MPIR_Assert(MPIR_Comm_is_parent_comm(comm_ptr, coll_group)); + int local_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE].size; + MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); @@ -28,12 +31,12 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, tempbuf = (void *) ((char *) tempbuf - true_lb); /* Create prefulldata and localfulldata on local roots of all nodes */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { MPIR_CHKLMEM_MALLOC(prefulldata, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno, "prefulldata for scan", MPL_MEM_BUFFER); prefulldata = (void *) ((char *) prefulldata - true_lb); - if (comm_ptr->node_comm != NULL) { + if (local_size > 1) { MPIR_CHKLMEM_MALLOC(localfulldata, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno, "localfulldata for scan", MPL_MEM_BUFFER); localfulldata = (void *) ((char *) localfulldata - true_lb); @@ -42,9 +45,9 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, /* perform intranode scan to get temporary result in recvbuf. if there is only * one process, just copy the raw data. */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Scan(sendbuf, recvbuf, count, datatype, op, comm_ptr->node_comm, - coll_group, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Scan(sendbuf, recvbuf, count, datatype, op, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } else if (sendbuf != MPI_IN_PLACE) { mpi_errno = MPIR_Localcopy(sendbuf, count, datatype, recvbuf, count, datatype); @@ -55,18 +58,15 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * contains the reduce result of the whole node. Name it as * localfulldata. For example, localfulldata from node 1 contains * reduced data of rank 1,2,3. */ - if (comm_ptr->node_roots_comm != NULL && comm_ptr->node_comm != NULL) { + if (local_rank == 0 && local_size > 1) { mpi_errno = MPIC_Recv(localfulldata, count, datatype, - comm_ptr->node_comm->local_size - 1, MPIR_SCAN_TAG, - comm_ptr->node_comm, coll_group, &status); + local_size - 1, MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE, &status); MPIR_ERR_CHECK(mpi_errno); - } else if (comm_ptr->node_roots_comm == NULL && - comm_ptr->node_comm != NULL && - MPIR_Get_intranode_rank(comm_ptr, rank) == comm_ptr->node_comm->local_size - 1) { + } else if (local_rank > 0 && local_size > 1 && local_rank == local_size - 1) { mpi_errno = MPIC_Send(recvbuf, count, datatype, - 0, MPIR_SCAN_TAG, comm_ptr->node_comm, coll_group, errflag); + 0, MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); - } else if (comm_ptr->node_roots_comm != NULL) { + } else if (local_rank == 0) { localfulldata = recvbuf; } @@ -74,21 +74,23 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * prefulldata on rank 4 contains reduce result of ranks * 1,2,3,4,5,6. it will be sent to rank 7 which is the * main process of node 3. */ - if (comm_ptr->node_roots_comm != NULL) { + if (local_rank == 0) { mpi_errno = MPIR_Scan(localfulldata, prefulldata, count, datatype, - op, comm_ptr->node_roots_comm, coll_group, errflag); + op, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); - if (MPIR_Get_internode_rank(comm_ptr, rank) != comm_ptr->node_roots_comm->local_size - 1) { + int inter_rank = comm_ptr->subgroups[MPIR_SUBGROUP_NODE_CROSS].rank; + int inter_size = comm_ptr->subgroups[MPIR_SUBGROUP_NODE_CROSS].size; + if (inter_rank != inter_size - 1) { mpi_errno = MPIC_Send(prefulldata, count, datatype, - MPIR_Get_internode_rank(comm_ptr, rank) + 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, errflag); + inter_rank + 1, + MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, errflag); MPIR_ERR_CHECK(mpi_errno); } - if (MPIR_Get_internode_rank(comm_ptr, rank) != 0) { + if (inter_rank != 0) { mpi_errno = MPIC_Recv(tempbuf, count, datatype, - MPIR_Get_internode_rank(comm_ptr, rank) - 1, - MPIR_SCAN_TAG, comm_ptr->node_roots_comm, coll_group, &status); + inter_rank - 1, + MPIR_SCAN_TAG, comm_ptr, MPIR_SUBGROUP_NODE_CROSS, &status); noneed = 0; MPIR_ERR_CHECK(mpi_errno); } @@ -100,15 +102,15 @@ int MPIR_Scan_intra_smp(const void *sendbuf, void *recvbuf, MPI_Aint count, * then we should broadcast this result in the local node, and * reduce it with recvbuf to get final result if necessary. */ - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(&noneed, 1, MPI_INT, 0, comm_ptr->node_comm, coll_group, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Bcast(&noneed, 1, MPI_INT, 0, comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } if (noneed == 0) { - if (comm_ptr->node_comm != NULL) { - mpi_errno = MPIR_Bcast(tempbuf, count, datatype, 0, comm_ptr->node_comm, - coll_group, errflag); + if (local_size > 1) { + mpi_errno = MPIR_Bcast(tempbuf, count, datatype, 0, + comm_ptr, MPIR_SUBGROUP_NODE, errflag); MPIR_ERR_CHECK(mpi_errno); } From f98f2e745878e7b14233a5daa966c0374b767949 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 20 Aug 2024 10:59:07 -0500 Subject: [PATCH 19/27] mpir: replace subcomm usage with subgroups Directly use information from MPIR_Process rather than from nodecomm in MPIR_Process. One step toward removing subcomms. --- src/mpi/comm/comm_split_type_nbhd.c | 7 +------ src/mpi/init/init_async.c | 9 +++------ src/util/mpir_nodemap.c | 21 ++++++++++----------- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/mpi/comm/comm_split_type_nbhd.c b/src/mpi/comm/comm_split_type_nbhd.c index 07c0e6d288c..9fbaa730ca0 100644 --- a/src/mpi/comm/comm_split_type_nbhd.c +++ b/src/mpi/comm/comm_split_type_nbhd.c @@ -474,12 +474,7 @@ static int network_split_by_min_memsize(MPIR_Comm * comm_ptr, int key, long min_ if (min_mem_size == 0 || topo_type == MPIR_NETTOPO_TYPE__INVALID) { *newcomm_ptr = NULL; } else { - int num_ranks_node; - if (MPIR_Process.comm_world->node_comm != NULL) { - num_ranks_node = MPIR_Comm_size(MPIR_Process.comm_world->node_comm); - } else { - num_ranks_node = 1; - } + int num_ranks_node = MPIR_Process.local_size; memory_per_process = total_memory_size / num_ranks_node; mpi_errno = network_split_by_minsize(comm_ptr, key, min_mem_size / memory_per_process, newcomm_ptr); diff --git a/src/mpi/init/init_async.c b/src/mpi/init/init_async.c index a7f4e1f4807..2e8c82de94a 100644 --- a/src/mpi/init/init_async.c +++ b/src/mpi/init/init_async.c @@ -179,17 +179,14 @@ static int get_thread_affinity(bool * apply_affinity, int **p_thread_affinity, i } global_rank = MPIR_Process.rank; - local_rank = - (MPIR_Process.comm_world->node_comm) ? MPIR_Process.comm_world->node_comm->rank : 0; + local_rank = MPIR_Process.local_rank; if (have_cliques) { - /* If local cliques > 1, using local_size from node_comm will have conflict on thread idx. + /* If local cliques > 1, using local_size will have conflict on thread idx. * In multiple nodes case, this would cost extra memory for allocating thread affinity on every * node, but it is okay to solve progress thread oversubscription. */ local_size = MPIR_Process.comm_world->local_size; } else { - local_size = - (MPIR_Process.comm_world->node_comm) ? MPIR_Process.comm_world-> - node_comm->local_size : 1; + local_size = MPIR_Process.local_size; } async_threads_per_node = local_size; diff --git a/src/util/mpir_nodemap.c b/src/util/mpir_nodemap.c index cf8eaca7efa..e8f6d1ee6ba 100644 --- a/src/util/mpir_nodemap.c +++ b/src/util/mpir_nodemap.c @@ -436,14 +436,14 @@ int MPIR_nodeid_init(void) utarray_resize(MPIR_Process.node_hostnames, MPIR_Process.num_nodes, MPL_MEM_OTHER); char *allhostnames = (char *) utarray_eltptr(MPIR_Process.node_hostnames, 0); - if (MPIR_Process.local_rank == 0) { - MPIR_Comm *node_roots_comm = MPIR_Process.comm_world->node_roots_comm; - if (node_roots_comm == NULL) { - /* num_external == comm->remote_size */ - node_roots_comm = MPIR_Process.comm_world; - } + MPIR_Comm *world_comm = MPIR_Process.comm_world; + int local_rank = world_comm->subgroups[MPIR_SUBGROUP_NODE].rank; + int local_size = world_comm->subgroups[MPIR_SUBGROUP_NODE].size; + + if (local_rank == 0) { + int inter_rank = world_comm->subgroups[MPIR_SUBGROUP_NODE_CROSS].rank; - char *my_hostname = allhostnames + MAX_HOSTNAME_LEN * node_roots_comm->rank; + char *my_hostname = allhostnames + MAX_HOSTNAME_LEN * inter_rank; int ret = gethostname(my_hostname, MAX_HOSTNAME_LEN); char strerrbuf[MPIR_STRERROR_BUF_SIZE] ATTRIBUTE((unused)); MPIR_ERR_CHKANDJUMP2(ret == -1, mpi_errno, MPI_ERR_OTHER, @@ -453,14 +453,13 @@ int MPIR_nodeid_init(void) mpi_errno = MPIR_Allgather_impl(MPI_IN_PLACE, MAX_HOSTNAME_LEN, MPI_CHAR, allhostnames, MAX_HOSTNAME_LEN, MPI_CHAR, - node_roots_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + world_comm, MPIR_SUBGROUP_NODE_CROSS, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } - MPIR_Comm *node_comm = MPIR_Process.comm_world->node_comm; - if (node_comm) { + if (local_size > 1) { mpi_errno = MPIR_Bcast_impl(allhostnames, MAX_HOSTNAME_LEN * MPIR_Process.num_nodes, - MPI_CHAR, 0, node_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE); + MPI_CHAR, 0, world_comm, MPIR_SUBGROUP_NODE, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); } From 1a16de908c4f34ca7b7829a7fb0bb8329fd9469d Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Wed, 21 Aug 2024 23:51:02 -0500 Subject: [PATCH 20/27] coll/csel: omit prunning on communicator size Now that we may run collectives on subgroups, we can't pre-prune the csel trees based on communicator size or topology since that may change for subgroups. I don't think the performance from the tree pruning is significant -- it only saves a couple levels of tree decendence. But if we later decide the efficiency from pruning is important, we can easily prune the trees at subgroup level and save the pruned trees to the MPIR_Group structure. --- src/mpi/coll/src/csel.c | 61 ++--------------------------------------- 1 file changed, 3 insertions(+), 58 deletions(-) diff --git a/src/mpi/coll/src/csel.c b/src/mpi/coll/src/csel.c index 4422c78dea0..284bdada995 100644 --- a/src/mpi/coll/src/csel.c +++ b/src/mpi/coll/src/csel.c @@ -615,63 +615,6 @@ static csel_node_s *prune_tree(csel_node_s * root, MPIR_Comm * comm_ptr) node = node->failure; break; - case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_LE: - if (comm_ptr->local_size <= node->u.comm_size_le.val) - node = node->success; - else - node = node->failure; - break; - - case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_LT: - if (comm_ptr->local_size < node->u.comm_size_lt.val) - node = node->success; - else - node = node->failure; - break; - - case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_NODE_COMM_SIZE: - /* comm_size equal to node_comm_size just mean the size inter-node is 1 */ - if (comm_ptr->num_external == 1) - node = node->success; - else - node = node->failure; - break; - - case CSEL_NODE_TYPE__OPERATOR__COMM_SIZE_POW2: - if (comm_ptr->local_size & (comm_ptr->local_size - 1)) - node = node->failure; - else - node = node->success; - break; - - case CSEL_NODE_TYPE__OPERATOR__COMM_HIERARCHY: - if (comm_ptr->hierarchy_kind == node->u.comm_hierarchy.val) - node = node->success; - else - node = node->failure; - break; - - case CSEL_NODE_TYPE__OPERATOR__IS_NODE_CONSECUTIVE: - if (MPII_Comm_is_node_consecutive(comm_ptr) == node->u.is_node_consecutive.val) - node = node->success; - else - node = node->failure; - break; - - case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LE: - if (comm_ptr->local_size <= node->u.comm_avg_ppn_le.val * comm_ptr->num_external) - node = node->success; - else - node = node->failure; - break; - - case CSEL_NODE_TYPE__OPERATOR__COMM_AVG_PPN_LT: - if (comm_ptr->local_size < node->u.comm_avg_ppn_le.val * comm_ptr->num_external) - node = node->success; - else - node = node->failure; - break; - case CSEL_NODE_TYPE__OPERATOR__ANY: node = node->success; break; @@ -1188,6 +1131,7 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) csel_s *csel = (csel_s *) csel_; csel_node_s *node = NULL; MPIR_Comm *comm_ptr = coll_info.comm_ptr; + int coll_group = coll_info.coll_group; MPIR_Assert(csel_); @@ -1349,7 +1293,8 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) break; case CSEL_NODE_TYPE__OPERATOR__COMM_HIERARCHY: - if (coll_info.comm_ptr->hierarchy_kind == node->u.comm_hierarchy.val) + if (node->u.comm_hierarchy.val == MPIR_COMM_HIERARCHY_KIND__PARENT && + MPIR_Comm_is_parent_comm(comm_ptr, coll_group)) node = node->success; else node = node->failure; From 718d868eec68af0bb2f86136087425d1d22dcfb9 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 22 Aug 2024 17:33:36 -0500 Subject: [PATCH 21/27] coll: refactor caching tree in the comm struct Use a single "cached_tree" rather than 3 different fields for each tree type. --- src/include/mpir_comm.h | 15 +-- .../coll/algorithms/recexchalgo/recexchalgo.c | 34 +---- src/mpi/coll/algorithms/treealgo/treealgo.c | 123 ++++++++++++------ .../coll/algorithms/treealgo/treealgo_types.h | 26 ++++ src/mpi/coll/include/coll_types.h | 10 -- 5 files changed, 116 insertions(+), 92 deletions(-) diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 40d533237b0..a1b84b706b8 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -280,18 +280,9 @@ struct MPIR_Comm { int **step2_nbrs[MAX_RADIX - 1]; int nbrs_defined[MAX_RADIX - 1]; void **recexch_allreduce_nbr_buffer; - int topo_aware_tree_root; - int topo_aware_tree_k; - MPIR_Treealgo_tree_t *topo_aware_tree; - int topo_aware_k_tree_root; - int topo_aware_k_tree_k; - MPIR_Treealgo_tree_t *topo_aware_k_tree; - int topo_wave_tree_root; - int topo_wave_tree_overhead; - int topo_wave_tree_lat_diff_groups; - int topo_wave_tree_lat_diff_switches; - int topo_wave_tree_lat_same_switches; - MPIR_Treealgo_tree_t *topo_wave_tree; + + MPIR_Treealgo_tree_t *cached_tree; + MPIR_Treealgo_param_t cached_tree_param; } coll; void *csel_comm; /* collective selector handle */ diff --git a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c index 61042c75671..759be0cd6e2 100644 --- a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c +++ b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c @@ -26,19 +26,7 @@ int MPII_Recexchalgo_comm_init(MPIR_Comm * comm) } comm->coll.recexch_allreduce_nbr_buffer = NULL; - comm->coll.topo_aware_tree_root = -1; - comm->coll.topo_aware_tree_k = 0; - comm->coll.topo_aware_tree = NULL; - comm->coll.topo_aware_k_tree_root = -1; - comm->coll.topo_aware_k_tree_k = 0; - comm->coll.topo_aware_k_tree = NULL; - comm->coll.topo_wave_tree_root = -1; - comm->coll.topo_wave_tree = NULL; - comm->coll.topo_wave_tree_overhead = 0; - comm->coll.topo_wave_tree_lat_diff_groups = 0; - comm->coll.topo_wave_tree_lat_diff_switches = 0; - comm->coll.topo_wave_tree_lat_same_switches = 0; - + comm->coll.cached_tree = NULL; return mpi_errno; } @@ -66,22 +54,10 @@ int MPII_Recexchalgo_comm_cleanup(MPIR_Comm * comm) MPL_free(comm->coll.recexch_allreduce_nbr_buffer); } - if (comm->coll.topo_aware_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree); - MPL_free(comm->coll.topo_aware_tree); - comm->coll.topo_aware_tree = NULL; - } - - if (comm->coll.topo_aware_k_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree); - MPL_free(comm->coll.topo_aware_k_tree); - comm->coll.topo_aware_k_tree = NULL; - } - - if (comm->coll.topo_wave_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree); - MPL_free(comm->coll.topo_wave_tree); - comm->coll.topo_wave_tree = NULL; + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); + MPL_free(comm->coll.cached_tree); + comm->coll.cached_tree = NULL; } return mpi_errno; diff --git a/src/mpi/coll/algorithms/treealgo/treealgo.c b/src/mpi/coll/algorithms/treealgo/treealgo.c index 25a291c5f1c..da5ac83eef9 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo.c +++ b/src/mpi/coll/algorithms/treealgo/treealgo.c @@ -33,6 +33,56 @@ int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm) return mpi_errno; } +static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE && + param->root == root && param->u.topo_aware.k == k); +} + +static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE; + param->root = root; + param->u.topo_aware.k = k; +} + +static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K && + param->root == root && param->u.topo_aware.k == k); +} + +static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE_K; + param->root = root; + param->u.topo_aware.k = k; +} + +static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param, + int root, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_WAVE && + param->root == root && + param->u.topo_wave.overhead == overhead && + param->u.topo_wave.lat_diff_groups == lat_diff_groups && + param->u.topo_wave.lat_diff_switches == lat_diff_switches && + param->u.topo_wave.lat_same_switches == lat_same_switches); +} + +static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param, + int root, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_WAVE; + param->root = root; + param->u.topo_wave.overhead = overhead; + param->u.topo_wave.lat_diff_groups = lat_diff_groups; + param->u.topo_wave.lat_diff_switches = lat_diff_switches; + param->u.topo_wave.lat_same_switches = lat_same_switches; +} + int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int root, MPIR_Treealgo_tree_t * ct) @@ -84,56 +134,52 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, switch (tree_type) { case MPIR_TREE_TYPE_TOPOLOGY_AWARE: - if (!comm->coll.topo_aware_tree || root != comm->coll.topo_aware_tree_root - || k != comm->coll.topo_aware_tree_k) { - if (comm->coll.topo_aware_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree); + if (!comm->coll.cached_tree || + !match_param_topo_aware(&comm->coll.cached_tree_param, root, k)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_aware_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder, - comm->coll.topo_aware_tree); + comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_aware_tree; - comm->coll.topo_aware_tree_root = root; - comm->coll.topo_aware_tree_k = k; + *ct = *comm->coll.cached_tree; + set_param_topo_aware(&comm->coll.cached_tree_param, root, k); } - *ct = *comm->coll.topo_aware_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_aware_tree->children)[i], - MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } break; case MPIR_TREE_TYPE_TOPOLOGY_AWARE_K: - if (!comm->coll.topo_aware_k_tree || root != comm->coll.topo_aware_k_tree_root - || k != comm->coll.topo_aware_k_tree_k) { - if (comm->coll.topo_aware_k_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree); + if (!comm->coll.cached_tree || + !match_param_topo_aware_k(&comm->coll.cached_tree_param, root, k)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_aware_k_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = MPII_Treeutil_tree_topology_aware_k_init(comm, k, root, enable_reorder, - comm->coll.topo_aware_k_tree); + comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_aware_k_tree; - comm->coll.topo_aware_k_tree_root = root; - comm->coll.topo_aware_k_tree_k = k; + *ct = *comm->coll.cached_tree; + set_param_topo_aware_k(&comm->coll.cached_tree_param, root, k); } - *ct = *comm->coll.topo_aware_k_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_aware_k_tree->children)[i], - MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } break; @@ -164,34 +210,29 @@ int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, MPIR_FUNC_ENTER; - if (!comm->coll.topo_wave_tree || root != comm->coll.topo_wave_tree_root - || overhead != comm->coll.topo_wave_tree_overhead - || lat_diff_groups != comm->coll.topo_wave_tree_lat_diff_groups - || lat_diff_switches != comm->coll.topo_wave_tree_lat_diff_switches - || lat_same_switches != comm->coll.topo_wave_tree_lat_same_switches) { - if (comm->coll.topo_wave_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree); + if (!comm->coll.cached_tree || + !match_param_topo_wave(&comm->coll.cached_tree_param, root, overhead, + lat_diff_groups, lat_diff_switches, lat_same_switches)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_wave_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = MPII_Treeutil_tree_topology_wave_init(comm, k, root, enable_reorder, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, - comm->coll.topo_wave_tree); + comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_wave_tree; - comm->coll.topo_wave_tree_root = root; - comm->coll.topo_wave_tree_overhead = overhead; - comm->coll.topo_wave_tree_lat_diff_groups = lat_diff_groups; - comm->coll.topo_wave_tree_lat_diff_switches = lat_diff_switches; - comm->coll.topo_wave_tree_lat_same_switches = lat_same_switches; + *ct = *comm->coll.cached_tree; + set_param_topo_wave(&comm->coll.cached_tree_param, root, overhead, + lat_diff_groups, lat_diff_switches, lat_same_switches); } - *ct = *comm->coll.topo_wave_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_wave_tree->children)[i], MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } MPIR_FUNC_EXIT; diff --git a/src/mpi/coll/algorithms/treealgo/treealgo_types.h b/src/mpi/coll/algorithms/treealgo/treealgo_types.h index 5db2c5ae931..bb15947046e 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo_types.h +++ b/src/mpi/coll/algorithms/treealgo/treealgo_types.h @@ -8,6 +8,16 @@ #include +/* enumerator for different tree types */ +typedef enum MPIR_Tree_type_t { + MPIR_TREE_TYPE_KARY = 0, + MPIR_TREE_TYPE_KNOMIAL_1, + MPIR_TREE_TYPE_KNOMIAL_2, + MPIR_TREE_TYPE_TOPOLOGY_AWARE, + MPIR_TREE_TYPE_TOPOLOGY_AWARE_K, + MPIR_TREE_TYPE_TOPOLOGY_WAVE, +} MPIR_Tree_type_t; + typedef struct { int rank; int nranks; @@ -16,4 +26,20 @@ typedef struct { UT_array *children; } MPIR_Treealgo_tree_t; +typedef struct { + MPIR_Tree_type_t type; + int root; + union { + struct { + int k; + } topo_aware; + struct { + int overhead; + int lat_diff_groups; + int lat_diff_switches; + int lat_same_switches; + } topo_wave; + } u; +} MPIR_Treealgo_param_t; + #endif /* TREEALGO_TYPES_H_INCLUDED */ diff --git a/src/mpi/coll/include/coll_types.h b/src/mpi/coll/include/coll_types.h index a32ce6c551d..22fbad4716b 100644 --- a/src/mpi/coll/include/coll_types.h +++ b/src/mpi/coll/include/coll_types.h @@ -13,16 +13,6 @@ #define MPIR_COLL_FLAG_REDUCE_L 1 #define MPIR_COLL_FLAG_REDUCE_R 0 -/* enumerator for different tree types */ -typedef enum MPIR_Tree_type_t { - MPIR_TREE_TYPE_KARY = 0, - MPIR_TREE_TYPE_KNOMIAL_1, - MPIR_TREE_TYPE_KNOMIAL_2, - MPIR_TREE_TYPE_TOPOLOGY_AWARE, - MPIR_TREE_TYPE_TOPOLOGY_AWARE_K, - MPIR_TREE_TYPE_TOPOLOGY_WAVE, -} MPIR_Tree_type_t; - /* enumerator for different recexch types */ enum { MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER = 0, From 1d79bebc10902943219254383c85e358c7ce86dd Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 22 Aug 2024 17:47:35 -0500 Subject: [PATCH 22/27] coll: add coll_group to treealgo routines The topology-aware tree utilities need check coll_group for correct world ranks. --- src/mpi/coll/algorithms/treealgo/treealgo.c | 52 ++++++++++-------- src/mpi/coll/algorithms/treealgo/treealgo.h | 5 +- .../coll/algorithms/treealgo/treealgo_types.h | 1 + src/mpi/coll/algorithms/treealgo/treeutil.c | 55 ++++++++++++------- src/mpi/coll/algorithms/treealgo/treeutil.h | 15 ++--- src/mpi/coll/allreduce/allreduce_intra_tree.c | 4 +- .../coll/bcast/bcast_intra_pipelined_tree.c | 5 +- src/mpi/coll/bcast/bcast_intra_tree.c | 5 +- src/mpi/coll/ireduce/ireduce_tsp_tree.c | 4 +- 9 files changed, 86 insertions(+), 60 deletions(-) diff --git a/src/mpi/coll/algorithms/treealgo/treealgo.c b/src/mpi/coll/algorithms/treealgo/treealgo.c index da5ac83eef9..d806e3b8f51 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo.c +++ b/src/mpi/coll/algorithms/treealgo/treealgo.c @@ -33,37 +33,42 @@ int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm) return mpi_errno; } -static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k) +static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) { return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE && + param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED && param->root == root && param->u.topo_aware.k == k); } -static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k) +static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) { param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE; + param->coll_group = coll_group; param->root = root; param->u.topo_aware.k = k; } -static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k) +static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) { return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K && + param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED && param->root == root && param->u.topo_aware.k == k); } -static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k) +static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int coll_group, int root, int k) { param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE_K; + param->coll_group = coll_group; param->root = root; param->u.topo_aware.k = k; } -static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param, +static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param, int coll_group, int root, int overhead, int lat_diff_groups, int lat_diff_switches, int lat_same_switches) { return (param->type == MPIR_TREE_TYPE_TOPOLOGY_WAVE && + param->coll_group == coll_group && coll_group < MPIR_SUBGROUP_NUM_RESERVED && param->root == root && param->u.topo_wave.overhead == overhead && param->u.topo_wave.lat_diff_groups == lat_diff_groups && @@ -71,11 +76,12 @@ static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param, param->u.topo_wave.lat_same_switches == lat_same_switches); } -static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param, +static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param, int coll_group, int root, int overhead, int lat_diff_groups, int lat_diff_switches, int lat_same_switches) { param->type = MPIR_TREE_TYPE_TOPOLOGY_WAVE; + param->coll_group = coll_group; param->root = root; param->u.topo_wave.overhead = overhead; param->u.topo_wave.lat_diff_groups = lat_diff_groups; @@ -125,7 +131,8 @@ int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int ro } -int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, int root, +int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int coll_group, int tree_type, + int k, int root, bool enable_reorder, MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; @@ -135,7 +142,7 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, switch (tree_type) { case MPIR_TREE_TYPE_TOPOLOGY_AWARE: if (!comm->coll.cached_tree || - !match_param_topo_aware(&comm->coll.cached_tree_param, root, k)) { + !match_param_topo_aware(&comm->coll.cached_tree_param, coll_group, root, k)) { if (comm->coll.cached_tree) { MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { @@ -144,11 +151,11 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, MPL_MEM_BUFFER); } mpi_errno = - MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder, - comm->coll.cached_tree); + MPII_Treeutil_tree_topology_aware_init(comm, coll_group, k, root, + enable_reorder, comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); *ct = *comm->coll.cached_tree; - set_param_topo_aware(&comm->coll.cached_tree_param, root, k); + set_param_topo_aware(&comm->coll.cached_tree_param, coll_group, root, k); } *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); @@ -160,7 +167,7 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, case MPIR_TREE_TYPE_TOPOLOGY_AWARE_K: if (!comm->coll.cached_tree || - !match_param_topo_aware_k(&comm->coll.cached_tree_param, root, k)) { + !match_param_topo_aware_k(&comm->coll.cached_tree_param, coll_group, root, k)) { if (comm->coll.cached_tree) { MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { @@ -169,11 +176,12 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, MPL_MEM_BUFFER); } mpi_errno = - MPII_Treeutil_tree_topology_aware_k_init(comm, k, root, enable_reorder, + MPII_Treeutil_tree_topology_aware_k_init(comm, coll_group, k, root, + enable_reorder, comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); *ct = *comm->coll.cached_tree; - set_param_topo_aware_k(&comm->coll.cached_tree_param, root, k); + set_param_topo_aware_k(&comm->coll.cached_tree_param, coll_group, root, k); } *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); @@ -201,7 +209,7 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, } -int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, +int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int coll_group, int k, int root, bool enable_reorder, int overhead, int lat_diff_groups, int lat_diff_switches, int lat_same_switches, MPIR_Treealgo_tree_t * ct) @@ -211,21 +219,21 @@ int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, MPIR_FUNC_ENTER; if (!comm->coll.cached_tree || - !match_param_topo_wave(&comm->coll.cached_tree_param, root, overhead, - lat_diff_groups, lat_diff_switches, lat_same_switches)) { + !match_param_topo_wave(&comm->coll.cached_tree_param, coll_group, root, + overhead, lat_diff_groups, lat_diff_switches, lat_same_switches)) { if (comm->coll.cached_tree) { MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } - mpi_errno = MPII_Treeutil_tree_topology_wave_init(comm, k, root, enable_reorder, overhead, - lat_diff_groups, lat_diff_switches, - lat_same_switches, - comm->coll.cached_tree); + mpi_errno = + MPII_Treeutil_tree_topology_wave_init(comm, coll_group, k, root, enable_reorder, + overhead, lat_diff_groups, lat_diff_switches, + lat_same_switches, comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); *ct = *comm->coll.cached_tree; - set_param_topo_wave(&comm->coll.cached_tree_param, root, overhead, + set_param_topo_wave(&comm->coll.cached_tree_param, coll_group, root, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches); } *ct = *comm->coll.cached_tree; diff --git a/src/mpi/coll/algorithms/treealgo/treealgo.h b/src/mpi/coll/algorithms/treealgo/treealgo.h index 60bac96806d..50e473f2b94 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo.h +++ b/src/mpi/coll/algorithms/treealgo/treealgo.h @@ -13,9 +13,10 @@ int MPII_Treealgo_comm_init(MPIR_Comm * comm); int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm); int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int root, MPIR_Treealgo_tree_t * ct); -int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, int root, +int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int coll_group, int tree_type, + int k, int root, bool enable_reorder, MPIR_Treealgo_tree_t * ct); -int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, +int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int coll_group, int k, int root, bool enable_reorder, int overhead, int lat_diff_groups, int lat_diff_switches, int lat_same_switches, MPIR_Treealgo_tree_t * ct); diff --git a/src/mpi/coll/algorithms/treealgo/treealgo_types.h b/src/mpi/coll/algorithms/treealgo/treealgo_types.h index bb15947046e..646c63f866b 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo_types.h +++ b/src/mpi/coll/algorithms/treealgo/treealgo_types.h @@ -28,6 +28,7 @@ typedef struct { typedef struct { MPIR_Tree_type_t type; + int coll_group; int root; union { struct { diff --git a/src/mpi/coll/algorithms/treealgo/treeutil.c b/src/mpi/coll/algorithms/treealgo/treeutil.c index 7b522332b49..1c55d713e41 100644 --- a/src/mpi/coll/algorithms/treealgo/treeutil.c +++ b/src/mpi/coll/algorithms/treealgo/treeutil.c @@ -472,8 +472,8 @@ static void MPII_Treeutil_hierarchy_reorder(UT_array * hierarchy, int rank) } /* tree init function is for building hierarchy of MPIR_Process::coords_dims */ -static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nranks, int root, - bool enable_reorder, UT_array * hierarchy) +static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int coll_group, int rank, int nranks, + int root, bool enable_reorder, UT_array * hierarchy) { int mpi_errno = MPI_SUCCESS; @@ -504,8 +504,12 @@ static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nran MPIR_Assert(upper_level != NULL); /* Get wrank from the communicator as the coords are stored with wrank */ + int comm_rank = r; + if (coll_group > 0) { + comm_rank = comm->subgroups[coll_group].proc_table[r]; + } uint64_t temp = 0; - MPID_Comm_get_lpid(comm, r, &temp, FALSE); + MPID_Comm_get_lpid(comm, comm_rank, &temp, FALSE); int wrank = (int) temp; if (wrank < 0) goto fn_fail; @@ -600,12 +604,13 @@ static int MPII_Treeutil_hierarchy_populate(MPIR_Comm * comm, int rank, int nran * build the hierarchy of the topology-aware tree. * For the mentioned cases see tags 'goto fn_fallback;'. */ -int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct) +int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; - int rank = comm->rank; - int nranks = comm->local_size; + + int rank, nranks; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); UT_array hierarchy[MAX_HIERARCHY_DEPTH]; int dim = MPIR_Process.coords_dims - 1; @@ -613,7 +618,8 @@ int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bo tree_ut_hierarchy_init(&hierarchy[dim]); if (k <= 0 || - 0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy)) + 0 != MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder, + hierarchy)) goto fn_fallback; ct->rank = rank; @@ -695,16 +701,18 @@ int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bo } /* Implementation of 'Topology aware' algorithm with the branching factor k */ -int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct) +int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; - int rank = comm->rank; - int nranks = comm->local_size; + + int rank, nranks; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); /* fall back to MPII_Treeutil_tree_topology_aware_init if k is less or equal to 2 */ if (k <= 2) { - return MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder, ct); + return MPII_Treeutil_tree_topology_aware_init(comm, coll_group, k, root, enable_reorder, + ct); } int *num_childrens = NULL; @@ -719,7 +727,9 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, for (dim = MPIR_Process.coords_dims - 1; dim >= 0; --dim) tree_ut_hierarchy_init(&hierarchy[dim]); - if (0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy)) + if (0 != + MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder, + hierarchy)) goto fn_fallback; ct->rank = rank; @@ -758,7 +768,7 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, /* Do an allgather to know the current num_children on each rank */ MPIR_Errflag_t errflag = MPIR_ERR_NONE; MPIR_Allgather_impl(&(ct->num_children), 1, MPI_INT, num_childrens, 1, MPI_INT, - comm, MPIR_SUBGROUP_NONE, errflag); + comm, coll_group, errflag); if (mpi_errno) { goto fn_fail; } @@ -1111,13 +1121,12 @@ static int init_root_switch(const UT_array * hierarchy, heap_vector * minHeaps, } /* 'Topology Wave' implementation */ -int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - int overhead, int lat_diff_groups, int lat_diff_switches, - int lat_same_switches, MPIR_Treealgo_tree_t * ct) +int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches, + MPIR_Treealgo_tree_t * ct) { int mpi_errno = MPI_SUCCESS; - int rank = comm->rank; - int nranks = comm->local_size; int root_gr_sorted_idx = 0; int root_sw_sorted_idx = 0; int group_offset = 0; @@ -1126,6 +1135,9 @@ int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, boo UT_array hierarchy[MAX_HIERARCHY_DEPTH]; UT_array *unv_set = NULL; + int rank, nranks; + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); + heap_vector minHeaps; heap_vector_init(&minHeaps); @@ -1135,7 +1147,8 @@ int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, boo tree_ut_hierarchy_init(&hierarchy[dim]); if (overhead <= 0 || lat_diff_groups <= 0 || lat_diff_switches <= 0 || lat_same_switches <= 0 || - 0 != MPII_Treeutil_hierarchy_populate(comm, rank, nranks, root, enable_reorder, hierarchy)) + 0 != MPII_Treeutil_hierarchy_populate(comm, coll_group, rank, nranks, root, enable_reorder, + hierarchy)) goto fn_fallback; UT_icd intpair_icd = { sizeof(pair), NULL, NULL, NULL }; diff --git a/src/mpi/coll/algorithms/treealgo/treeutil.h b/src/mpi/coll/algorithms/treealgo/treeutil.h index c628f162ca6..51864938f4d 100644 --- a/src/mpi/coll/algorithms/treealgo/treeutil.h +++ b/src/mpi/coll/algorithms/treealgo/treeutil.h @@ -123,15 +123,16 @@ int MPII_Treeutil_tree_knomial_2_init(int rank, int nranks, int k, int root, MPIR_Treealgo_tree_t * ct); /* Generate topology_aware tree information */ -int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct); +int MPII_Treeutil_tree_topology_aware_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct); -int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - MPIR_Treealgo_tree_t * ct); +int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, MPIR_Treealgo_tree_t * ct); /* Generate topology_wave tree information */ -int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, bool enable_reorder, - int overhead, int lat_diff_groups, int lat_diff_switches, - int lat_same_switches, MPIR_Treealgo_tree_t * ct); +int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int coll_group, int k, int root, + bool enable_reorder, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches, + MPIR_Treealgo_tree_t * ct); #endif /* TREEUTIL_H_INCLUDED */ diff --git a/src/mpi/coll/allreduce/allreduce_intra_tree.c b/src/mpi/coll/allreduce/allreduce_intra_tree.c index 226e40d6bfc..03bcad74cb3 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_tree.c +++ b/src/mpi/coll/allreduce/allreduce_intra_tree.c @@ -66,7 +66,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, /* initialize the tree */ if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, k, root, + MPIR_Treealgo_tree_create_topo_aware(comm_ptr, coll_group, tree_type, k, root, MPIR_CVAR_ALLREDUCE_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { MPIR_Csel_coll_sig_s coll_sig = { @@ -96,7 +96,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, } mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm_ptr, k, root, + MPIR_Treealgo_tree_create_topo_wave(comm_ptr, coll_group, k, root, MPIR_CVAR_ALLREDUCE_TOPO_REORDER_ENABLE, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, &my_tree); diff --git a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c index 98498d7ca77..f551de11ded 100644 --- a/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_pipelined_tree.c @@ -74,11 +74,12 @@ int MPIR_Bcast_intra_pipelined_tree(void *buffer, if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, branching_factor, root, + MPIR_Treealgo_tree_create_topo_aware(comm_ptr, coll_group, tree_type, + branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm_ptr, branching_factor, root, + MPIR_Treealgo_tree_create_topo_wave(comm_ptr, coll_group, branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, MPIR_CVAR_BCAST_TOPO_OVERHEAD, MPIR_CVAR_BCAST_TOPO_DIFF_GROUPS, diff --git a/src/mpi/coll/bcast/bcast_intra_tree.c b/src/mpi/coll/bcast/bcast_intra_tree.c index eaf48370014..5cc2261b33a 100644 --- a/src/mpi/coll/bcast/bcast_intra_tree.c +++ b/src/mpi/coll/bcast/bcast_intra_tree.c @@ -76,7 +76,8 @@ int MPIR_Bcast_intra_tree(void *buffer, if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, branching_factor, root, + MPIR_Treealgo_tree_create_topo_aware(comm_ptr, coll_group, tree_type, + branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { MPIR_Csel_coll_sig_s coll_sig = { @@ -105,7 +106,7 @@ int MPIR_Bcast_intra_tree(void *buffer, } mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm_ptr, branching_factor, root, + MPIR_Treealgo_tree_create_topo_wave(comm_ptr, coll_group, branching_factor, root, MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, &my_tree); diff --git a/src/mpi/coll/ireduce/ireduce_tsp_tree.c b/src/mpi/coll/ireduce/ireduce_tsp_tree.c index 664dae78341..efa8682bb40 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_tree.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_tree.c @@ -69,7 +69,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai my_tree.children = NULL; if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE || tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) { mpi_errno = - MPIR_Treealgo_tree_create_topo_aware(comm, tree_type, k, tree_root, + MPIR_Treealgo_tree_create_topo_aware(comm, coll_group, tree_type, k, tree_root, MPIR_CVAR_IREDUCE_TOPO_REORDER_ENABLE, &my_tree); } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) { MPIR_Csel_coll_sig_s coll_sig = { @@ -100,7 +100,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai } mpi_errno = - MPIR_Treealgo_tree_create_topo_wave(comm, k, tree_root, + MPIR_Treealgo_tree_create_topo_wave(comm, coll_group, k, tree_root, MPIR_CVAR_IREDUCE_TOPO_REORDER_ENABLE, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, &my_tree); From b8c1f54b5e33fd188a10500b8f03a20ba41f343c Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 23 Aug 2024 07:31:14 -0500 Subject: [PATCH 23/27] coll: add nogroup restriction to certain algorithms Some algorithm, e.g. Allgather recexch, caches comm size-related info in communicator, thus won't work with none trivial coll_group. Add a restriction so it will fallback when coll_group != MPIR_SUBGROUP_NONE. --- maint/gen_coll.py | 2 ++ src/mpi/coll/allgather/allgather_intra_recexch.c | 3 +++ .../allreduce/allreduce_intra_k_reduce_scatter_allgather.c | 2 ++ src/mpi/coll/allreduce/allreduce_intra_recexch.c | 3 +++ src/mpi/coll/coll_algorithms.txt | 5 ++++- 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 83323a951da..afd46f026e6 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -570,6 +570,8 @@ def dump_fallback(algo): elif a == "displs-ordered": # assume it's allgatherv cond_list.append("MPII_Iallgatherv_is_displs_ordered(comm_ptr->local_size, recvcounts, displs)") + elif a == "nogroup": + cond_list.append("coll_group == MPIR_SUBGROUP_NONE") else: raise Exception("Unsupported restrictions - %s" % a) (func_name, commkind) = algo['func-commkind'].split('-') diff --git a/src/mpi/coll/allgather/allgather_intra_recexch.c b/src/mpi/coll/allgather/allgather_intra_recexch.c index 5f8fe672e44..ff6c566de69 100644 --- a/src/mpi/coll/allgather/allgather_intra_recexch.c +++ b/src/mpi/coll/allgather/allgather_intra_recexch.c @@ -36,6 +36,9 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPIR_Request *rreqs[MAX_RADIX * 2], *sreqs[MAX_RADIX * 2]; MPIR_Request **recv_reqs = NULL, **send_reqs = NULL; + /* it caches data in comm */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + is_inplace = (sendbuf == MPI_IN_PLACE); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c index 842e0914005..3428178baa8 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c @@ -36,6 +36,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, MPIR_CHKLMEM_DECL(2); MPIR_Assert(k > 1); + /* This algorithm uses cached data in comm, thus it won't work with coll_group */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/allreduce/allreduce_intra_recexch.c b/src/mpi/coll/allreduce/allreduce_intra_recexch.c index 9d54f4524c0..a7739e2a1b8 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recexch.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recexch.c @@ -34,6 +34,9 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, MPIR_Request **send_reqs = NULL, **recv_reqs = NULL; int send_nreq = 0, recv_nreq = 0, total_phases = 0; + /* uses cached data in comm */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/coll_algorithms.txt b/src/mpi/coll/coll_algorithms.txt index 9556b82554a..ce95383946b 100644 --- a/src/mpi/coll/coll_algorithms.txt +++ b/src/mpi/coll/coll_algorithms.txt @@ -174,10 +174,12 @@ allgather-intra: func_name: recexch extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_DOUBLING, k, single_phase_recv cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup recexch_halving func_name: recexch extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_HALVING, k, single_phase_recv cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup allgather-inter: local_gather_remote_bcast iallgather-intra: @@ -350,10 +352,11 @@ allreduce-intra: recexch extra_params: k, single_phase_recv cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup ring restrictions: commutative k_reduce_scatter_allgather - restrictions: commutative + restrictions: commutative, nogroup extra_params: k, single_phase_recv cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV allreduce-inter: From ef79319066eb4fb5ee709276b9698bde6fc2cea0 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 24 Aug 2024 12:39:08 -0500 Subject: [PATCH 24/27] coll: check coll_group in MPIR_Sched_next_tag All subgroup collectives should use the same tag within the parent collectives. This is because all processes in the communicator has to agree on the tag to use, but group collectives may not involve all processes. It is okay to use the same tag as long as the group collectives are always issued in order. This is the case since all group collectives are spawned under a parent collective, which has to obey the non-overlapping rule. --- maint/gen_coll.py | 4 ++++ src/include/mpir_nbc.h | 2 +- src/mpi/coll/allreduce/allreduce_intra_ring.c | 2 +- src/mpi/coll/allreduce/allreduce_intra_tree.c | 2 +- src/mpi/coll/iallgather/iallgather_tsp_brucks.c | 2 +- src/mpi/coll/iallgather/iallgather_tsp_recexch.c | 2 +- src/mpi/coll/iallgather/iallgather_tsp_ring.c | 2 +- src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c | 2 +- src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c | 2 +- src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c | 2 +- src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c | 2 +- ...e_tsp_recexch_reduce_scatter_recexch_allgatherv.c | 2 +- src/mpi/coll/iallreduce/iallreduce_tsp_ring.c | 2 +- src/mpi/coll/iallreduce/iallreduce_tsp_tree.c | 2 +- src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c | 2 +- src/mpi/coll/ialltoall/ialltoall_tsp_ring.c | 2 +- src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c | 2 +- src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c | 2 +- src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c | 2 +- src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c | 2 +- src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c | 2 +- src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c | 2 +- src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c | 2 +- src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c | 2 +- src/mpi/coll/ibcast/ibcast_tsp_tree.c | 2 +- src/mpi/coll/igather/igather_tsp_tree.c | 2 +- src/mpi/coll/igatherv/igatherv_tsp_linear.c | 2 +- src/mpi/coll/include/coll_impl.h | 2 +- .../ineighbor_allgather_tsp_linear.c | 2 +- .../ineighbor_allgatherv_tsp_linear.c | 2 +- .../ineighbor_alltoall_tsp_linear.c | 2 +- .../ineighbor_alltoallv_tsp_linear.c | 2 +- .../ineighbor_alltoallw_tsp_linear.c | 2 +- src/mpi/coll/ireduce/ireduce_tsp_tree.c | 2 +- .../ireduce_scatter/ireduce_scatter_tsp_recexch.c | 2 +- .../ireduce_scatter_block_tsp_recexch.c | 2 +- src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c | 2 +- src/mpi/coll/iscatter/iscatter_tsp_tree.c | 2 +- src/mpi/coll/iscatterv/iscatterv_tsp_linear.c | 2 +- src/mpi/comm/contextid.c | 6 +++--- src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h | 12 ++++++------ src/mpid/ch4/netmod/ofi/ofi_events.c | 1 + .../posix/release_gather/nb_bcast_release_gather.h | 3 ++- .../posix/release_gather/nb_reduce_release_gather.h | 3 ++- src/mpid/common/sched/mpidu_sched.c | 12 +++++++++--- src/mpid/common/sched/mpidu_sched.h | 2 +- 46 files changed, 66 insertions(+), 53 deletions(-) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index afd46f026e6..cf77d6fea5a 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -165,6 +165,8 @@ def dump_allcomm_sched_auto(name): dump_split(0, "int MPIR_%s_allcomm_sched_auto(%s)" % (Name, func_params)) dump_open('{') G.out.append("int mpi_errno = MPI_SUCCESS;") + if re.match(r'Ineighbor_', Name): + G.out.append("int coll_group = MPIR_SUBGROUP_NONE;") G.out.append("") # -- Csel_search @@ -367,6 +369,8 @@ def dump_cases(commkind): dump_split(0, "int MPIR_%s_sched_impl(%s)" % (Name, func_params)) dump_open('{') G.out.append("int mpi_errno = MPI_SUCCESS;") + if re.match(r'Ineighbor_', Name): + G.out.append("int coll_group = MPIR_SUBGROUP_NONE;") G.out.append("") dump_open("if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {") diff --git a/src/include/mpir_nbc.h b/src/include/mpir_nbc.h index 9320bf5b633..710521d9a5d 100644 --- a/src/include/mpir_nbc.h +++ b/src/include/mpir_nbc.h @@ -45,7 +45,7 @@ /* Open question: should tag allocation be rolled into Sched_start? Keeping it * separate potentially allows more parallelism in the future, but it also * pushes more work onto the clients of this interface. */ -int MPIR_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag); +int MPIR_Sched_next_tag(MPIR_Comm * comm_ptr, int coll_group, int *tag); void MPIR_Sched_set_tag(MPIR_Sched_t s, int tag); /* the device must provide a typedef for MPIR_Sched_t in mpidpre.h */ diff --git a/src/mpi/coll/allreduce/allreduce_intra_ring.c b/src/mpi/coll/allreduce/allreduce_intra_ring.c index fadd14d6b98..88e2c2e695d 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_ring.c +++ b/src/mpi/coll/allreduce/allreduce_intra_ring.c @@ -74,7 +74,7 @@ int MPIR_Allreduce_intra_ring(const void *sendbuf, void *recvbuf, MPI_Aint count send_rank = (nranks + rank - 1 - i) % nranks; /* get a new tag to prevent out of order messages */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIC_Irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, diff --git a/src/mpi/coll/allreduce/allreduce_intra_tree.c b/src/mpi/coll/allreduce/allreduce_intra_tree.c index 03bcad74cb3..73915a72484 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_tree.c +++ b/src/mpi/coll/allreduce/allreduce_intra_tree.c @@ -139,7 +139,7 @@ int MPIR_Allreduce_intra_tree(const void *sendbuf, void *reduce_address = (char *) reduce_buffer + offset * extent; MPIR_ERR_CHKANDJUMP(!reduce_address, mpi_errno, MPI_ERR_OTHER, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < num_children; i++) { diff --git a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c index d30e85ceca8..72ac971069f 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_brucks.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_brucks.c @@ -40,7 +40,7 @@ MPIR_TSP_Iallgather_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_FUNC_ENTER; diff --git a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c index 6663ba9c96c..bb3dfcacdc1 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_recexch.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_recexch.c @@ -255,7 +255,7 @@ int MPIR_TSP_Iallgather_sched_intra_recexch(const void *sendbuf, MPI_Aint sendco /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); is_inplace = (sendbuf == MPI_IN_PLACE); diff --git a/src/mpi/coll/iallgather/iallgather_tsp_ring.c b/src/mpi/coll/iallgather/iallgather_tsp_ring.c index 700f972a083..7b5e34ce881 100644 --- a/src/mpi/coll/iallgather/iallgather_tsp_ring.c +++ b/src/mpi/coll/iallgather/iallgather_tsp_ring.c @@ -83,7 +83,7 @@ int MPIR_TSP_Iallgather_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount int recv_id[3] = { 0 }; /* warning fix: icc: maybe used before set */ for (i = 0; i < size - 1; i++) { /* Get new tag for each cycle so that the send-recv pairs are matched correctly */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); int vtcs[3], nvtcs; diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c index 4eb3f04579d..0a72891cf7d 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_brucks.c @@ -64,7 +64,7 @@ MPIR_TSP_Iallgatherv_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); is_inplace = (sendbuf == MPI_IN_PLACE); diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c index a22694e0c6a..1751ac4c091 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_recexch.c @@ -273,7 +273,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch(const void *sendbuf, MPI_Aint sendc /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* get the neighbors, the function allocates the required memory */ diff --git a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c index bc5cdd96978..d014fa777a2 100644 --- a/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c +++ b/src/mpi/coll/iallgatherv/iallgatherv_tsp_ring.c @@ -84,7 +84,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, MPI_Aint sendcoun send_rank = (rank - i + nranks) % nranks; /* Rank whose data you're sending */ /* New tag for each send-recv pair. */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); int nvtcs, vtcs[3]; diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c index 85646e3d637..38679c4907a 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch.c @@ -50,7 +50,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch(const void *sendbuf, void *recvbuf, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); /* get the neighbors, the function allocates the required memory */ MPII_Recexchalgo_get_neighbors(rank, nranks, &k, &step1_sendto, diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c index 284ca9e5b19..2ecde2af9d9 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_recexch_reduce_scatter_recexch_allgatherv.c @@ -54,7 +54,7 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* get the neighbors, the function allocates the required memory */ diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c index a461635964c..5906a710024 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_ring.c @@ -81,7 +81,7 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI send_rank = (nranks + rank - 1 - i) % nranks; /* get a new tag to prevent out of order messages */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); nvtcs = (i == 0) ? 0 : 1; diff --git a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c index 4c5d3d79b50..a0ea3ddd7dd 100644 --- a/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c +++ b/src/mpi/coll/iallreduce/iallreduce_tsp_tree.c @@ -117,7 +117,7 @@ int MPIR_TSP_Iallreduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < num_children; i++) { diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c index 9ab18d690e9..aabd475f413 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_brucks.c @@ -146,7 +146,7 @@ MPIR_TSP_Ialltoall_sched_intra_brucks(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_CHKLMEM_MALLOC(pack_invtcs, int *, sizeof(int) * k, mpi_errno, "pack_invtcs", diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c index 989769feb15..29e34d080d6 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_ring.c @@ -117,7 +117,7 @@ int MPIR_TSP_Ialltoall_sched_intra_ring(const void *sendbuf, MPI_Aint sendcount, for (i = 0; i < size - 1; i++) { /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); int vtcs[3], nvtcs; diff --git a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c index 1fbf7d2b383..506a5546d5c 100644 --- a/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c +++ b/src/mpi/coll/ialltoall/ialltoall_tsp_scattered.c @@ -58,7 +58,7 @@ int MPIR_TSP_Ialltoall_sched_intra_scattered(const void *sendbuf, MPI_Aint sendc /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c index 2e3b4d31829..1bb6902f2ba 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_blocked.c @@ -28,7 +28,7 @@ int MPIR_TSP_Ialltoallv_sched_intra_blocked(const void *sendbuf, const MPI_Aint /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c index 80feead1da9..039e1e4c693 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_inplace.c @@ -27,7 +27,7 @@ int MPIR_TSP_Ialltoallv_sched_intra_inplace(const void *sendbuf, const MPI_Aint /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c index 57ec5f136d9..8bf0b6cedfa 100644 --- a/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c +++ b/src/mpi/coll/ialltoallv/ialltoallv_tsp_scattered.c @@ -55,7 +55,7 @@ int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const MPI_Ain MPIR_Type_get_true_extent_impl(sendtype, &sendtype_lb, &sendtype_true_extent); sendtype_extent = MPL_MAX(sendtype_extent, sendtype_true_extent); - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* First, post bblock number of sends/recvs */ diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c index 9675c9ffe3b..922a7cd4a00 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_blocked.c @@ -32,7 +32,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_blocked(const void *sendbuf, const MPI_Aint /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* post only bblock isends/irecvs at a time as suggested by Tony Ladd */ diff --git a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c index 7af7c386b4b..8bebe075c03 100644 --- a/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c +++ b/src/mpi/coll/ialltoallw/ialltoallw_tsp_inplace.c @@ -31,7 +31,7 @@ int MPIR_TSP_Ialltoallw_sched_intra_inplace(const void *sendbuf, const MPI_Aint /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* FIXME: Here we allocate tmp_buf using extent and send/recv with datatype directly, diff --git a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c index e59e162ab65..323c41eb3cb 100644 --- a/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c +++ b/src/mpi/coll/ibarrier/ibarrier_intra_tsp_dissem.c @@ -22,7 +22,7 @@ int MPIR_TSP_Ibarrier_sched_intra_k_dissemination(MPIR_Comm * comm, int coll_gro MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); diff --git a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c index 5bece20ed41..30cff266f12 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv.c @@ -32,7 +32,7 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_FUNC_ENTER; diff --git a/src/mpi/coll/ibcast/ibcast_tsp_tree.c b/src/mpi/coll/ibcast/ibcast_tsp_tree.c index 836ffbed01c..4ec7336c063 100644 --- a/src/mpi/coll/ibcast/ibcast_tsp_tree.c +++ b/src/mpi/coll/ibcast/ibcast_tsp_tree.c @@ -61,7 +61,7 @@ int MPIR_TSP_Ibcast_sched_intra_tree(void *buffer, MPI_Aint count, MPI_Datatype /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* Receive message from parent */ diff --git a/src/mpi/coll/igather/igather_tsp_tree.c b/src/mpi/coll/igather/igather_tsp_tree.c index 7799da9bdc2..bbb5dee7059 100644 --- a/src/mpi/coll/igather/igather_tsp_tree.c +++ b/src/mpi/coll/igather/igather_tsp_tree.c @@ -45,7 +45,7 @@ int MPIR_TSP_Igather_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); if (rank == root && is_inplace) { diff --git a/src/mpi/coll/igatherv/igatherv_tsp_linear.c b/src/mpi/coll/igatherv/igatherv_tsp_linear.c index 324cebfb441..f3a941a22da 100644 --- a/src/mpi/coll/igatherv/igatherv_tsp_linear.c +++ b/src/mpi/coll/igatherv/igatherv_tsp_linear.c @@ -37,7 +37,7 @@ int MPIR_TSP_Igatherv_sched_allcomm_linear(const void *sendbuf, MPI_Aint sendcou comm_size = comm_ptr->remote_size; } - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* If rank == root, then I recv lots, otherwise I send */ diff --git a/src/mpi/coll/include/coll_impl.h b/src/mpi/coll/include/coll_impl.h index b5b576b1b33..cef96e59aeb 100644 --- a/src/mpi/coll/include/coll_impl.h +++ b/src/mpi/coll/include/coll_impl.h @@ -75,7 +75,7 @@ int MPII_Coll_finalize(void); mpi_errno = MPIR_Sched_create(&s, sched_kind); \ MPIR_ERR_CHECK(mpi_errno); \ int tag = -1; \ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); \ + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); \ MPIR_ERR_CHECK(mpi_errno); \ MPIR_Sched_set_tag(s, tag); \ *sched_type_p = MPIR_SCHED_NORMAL; \ diff --git a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c index 6f3218c2c62..c0717d70ddd 100644 --- a/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c +++ b/src/mpi/coll/ineighbor_allgather/ineighbor_allgather_tsp_linear.c @@ -38,7 +38,7 @@ int MPIR_TSP_Ineighbor_allgather_sched_allcomm_linear(const void *sendbuf, MPI_A /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { diff --git a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c index c8a3d70867e..863b86e4973 100644 --- a/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c +++ b/src/mpi/coll/ineighbor_allgatherv/ineighbor_allgatherv_tsp_linear.c @@ -39,7 +39,7 @@ int MPIR_TSP_Ineighbor_allgatherv_sched_allcomm_linear(const void *sendbuf, MPI_ /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { diff --git a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c index 74e3caa9ca6..4b427bfbe73 100644 --- a/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoall/ineighbor_alltoall_tsp_linear.c @@ -39,7 +39,7 @@ int MPIR_TSP_Ineighbor_alltoall_sched_allcomm_linear(const void *sendbuf, MPI_Ai /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c index 6ed728dc8a5..87eead1a457 100644 --- a/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoallv/ineighbor_alltoallv_tsp_linear.c @@ -43,7 +43,7 @@ int MPIR_TSP_Ineighbor_alltoallv_sched_allcomm_linear(const void *sendbuf, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { diff --git a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c index bc6501a51ba..711f95a220d 100644 --- a/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c +++ b/src/mpi/coll/ineighbor_alltoallw/ineighbor_alltoallw_tsp_linear.c @@ -39,7 +39,7 @@ int MPIR_TSP_Ineighbor_alltoallw_sched_allcomm_linear(const void *sendbuf, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); for (k = 0; k < outdegree; ++k) { diff --git a/src/mpi/coll/ireduce/ireduce_tsp_tree.c b/src/mpi/coll/ireduce/ireduce_tsp_tree.c index efa8682bb40..4fbd0fe2ae5 100644 --- a/src/mpi/coll/ireduce/ireduce_tsp_tree.c +++ b/src/mpi/coll/ireduce/ireduce_tsp_tree.c @@ -193,7 +193,7 @@ int MPIR_TSP_Ireduce_sched_intra_tree(const void *sendbuf, void *recvbuf, MPI_Ai /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); for (i = 0; i < num_children; i++) { diff --git a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c index f71afdf58cf..1d0e23bf409 100644 --- a/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter/ireduce_scatter_tsp_recexch.c @@ -159,7 +159,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); is_inplace = (sendbuf == MPI_IN_PLACE); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c index 5e4107de0b4..70aa2681896 100644 --- a/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c +++ b/src/mpi/coll/ireduce_scatter_block/ireduce_scatter_block_tsp_recexch.c @@ -32,7 +32,7 @@ int MPIR_TSP_Ireduce_scatter_block_sched_intra_recexch(const void *sendbuf, void /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); is_inplace = (sendbuf == MPI_IN_PLACE); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c index 2307607c8e1..34a24b05a5c 100644 --- a/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c +++ b/src/mpi/coll/iscan/iscan_tsp_recursive_doubling.c @@ -28,7 +28,7 @@ int MPIR_TSP_Iscan_sched_intra_recursive_doubling(const void *sendbuf, void *rec /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/iscatter/iscatter_tsp_tree.c b/src/mpi/coll/iscatter/iscatter_tsp_tree.c index 858f988050d..1c8aa9121e0 100644 --- a/src/mpi/coll/iscatter/iscatter_tsp_tree.c +++ b/src/mpi/coll/iscatter/iscatter_tsp_tree.c @@ -47,7 +47,7 @@ int MPIR_TSP_Iscatter_sched_intra_tree(const void *sendbuf, MPI_Aint sendcount, /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm, &tag); + mpi_errno = MPIR_Sched_next_tag(comm, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); if (rank == root && is_inplace) { diff --git a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c index 1ede379982d..39fece505dc 100644 --- a/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c +++ b/src/mpi/coll/iscatterv/iscatterv_tsp_linear.c @@ -32,7 +32,7 @@ int MPIR_TSP_Iscatterv_sched_allcomm_linear(const void *sendbuf, const MPI_Aint /* For correctness, transport based collectives need to get the * tag from the same pool as schedule based collectives */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); MPIR_ERR_CHECK(mpi_errno); /* If I'm the root, then scatter */ diff --git a/src/mpi/comm/contextid.c b/src/mpi/comm/contextid.c index d1d2833a435..d10b9833d6f 100644 --- a/src/mpi/comm/contextid.c +++ b/src/mpi/comm/contextid.c @@ -762,7 +762,7 @@ static int sched_cb_gcn_allocate_cid(MPIR_Comm * comm, int tag, void *state) * are not necessarily completed in the same order as they are issued, also on the * same communicator. To avoid deadlocks, we cannot add the elements to the * list bevfore the first iallreduce is completed. The "tag" is created for the - * scheduling - by calling MPIR_Sched_next_tag(comm_ptr, &tag) - and the same + * scheduling - by calling MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag) - and the same * for a idup operation on all processes. So we use it here. */ /* FIXME I'm not sure if there can be an overflows for this tag */ st->tag = (uint64_t) tag + MPIR_Process.attrs.tag_ub; @@ -945,7 +945,7 @@ int MPIR_Get_contextid_nonblock(MPIR_Comm * comm_ptr, MPIR_Comm * newcommp, MPIR MPIR_FUNC_ENTER; /* now create a schedule */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_create(&s, MPIR_SCHED_KIND_GENERALIZED); MPIR_ERR_CHECK(mpi_errno); @@ -986,7 +986,7 @@ int MPIR_Get_intercomm_contextid_nonblock(MPIR_Comm * comm_ptr, MPIR_Comm * newc } /* now create a schedule */ - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, MPIR_SUBGROUP_NONE, &tag); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIR_Sched_create(&s, MPIR_SCHED_KIND_GENERALIZED); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h index 957a67d11ef..21761833075 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h @@ -90,7 +90,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf sizeof(struct fi_deferred_work), MPL_MEM_BUFFER); MPIR_ERR_CHKANDSTMT(*works == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -110,7 +110,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf } i = i + j; - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -239,7 +239,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer MPIR_ERR_CHKANDSTMT(*works == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -260,7 +260,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer } i = i + j; - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -419,7 +419,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer MPIR_ERR_CHKANDJUMP1(*works == NULL, mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", "Triggered bcast deferred work alloc"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); @@ -564,7 +564,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_rma(void *buffer, i sizeof(struct fi_deferred_work), MPL_MEM_BUFFER); MPIR_ERR_CHKANDSTMT(*works == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &rtr_tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &rtr_tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); diff --git a/src/mpid/ch4/netmod/ofi/ofi_events.c b/src/mpid/ch4/netmod/ofi/ofi_events.c index d3459d21b54..f51541f296b 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_events.c +++ b/src/mpid/ch4/netmod/ofi/ofi_events.c @@ -186,6 +186,7 @@ static int pipeline_recv_event(struct fi_cq_tagged_entry *wc, MPIR_Request * r, chunk_req->offset = chunk_sz * i; int ret = 0; if (!MPIDI_OFI_global.gpu_recv_queue && host_buf) { + /* FIXME: error handling */ ret = fi_trecv (MPIDI_OFI_global.ctx [MPIDI_OFI_REQUEST(rreq, pipeline_info.ctx_idx)].rx, diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h b/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h index b8e9197ce73..139d90a346a 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/nb_bcast_release_gather.h @@ -355,6 +355,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ibcast_impl(void *loc MPI_Aint type_size, nbytes, true_lb, true_extent; void *ori_local_buf = local_buf; MPI_Datatype ori_datatype = datatype; + int coll_group = MPIR_SUBGROUP_NONE; MPIR_CHKLMEM_DECL(1); /* Register the vertices */ @@ -425,7 +426,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ibcast_impl(void *loc MPIR_TSP_sched_malloc(sizeof(MPIDI_POSIX_per_call_ibcast_info_t), sched); MPIR_ERR_CHKANDJUMP(!data, mpi_errno, MPI_ERR_OTHER, "**nomem"); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); diff --git a/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h b/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h index a3b7144f9f6..e03a7a02e0a 100644 --- a/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h +++ b/src/mpid/ch4/shm/posix/release_gather/nb_reduce_release_gather.h @@ -364,6 +364,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ireduce_impl(void *se MPI_Aint num_chunks, chunk_count_floor, chunk_count_ceil; MPI_Aint true_extent, type_size, lb, extent; int offset = 0, is_contig; + int coll_group = MPIR_SUBGROUP_NONE; /* Register the vertices */ reserve_buf_type_id = MPIR_TSP_sched_new_type(sched, MPIDI_POSIX_NB_RG_rank0_hold_buf_issue, @@ -418,7 +419,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_nb_release_gather_ireduce_impl(void *se data->seq_no = MPIDI_POSIX_COMM(comm_ptr, nb_reduce_seq_no); - mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); + mpi_errno = MPIR_Sched_next_tag(comm_ptr, coll_group, &tag); if (mpi_errno) MPIR_ERR_POP(mpi_errno); diff --git a/src/mpid/common/sched/mpidu_sched.c b/src/mpid/common/sched/mpidu_sched.c index e2fa385330f..0298b928c41 100644 --- a/src/mpid/common/sched/mpidu_sched.c +++ b/src/mpid/common/sched/mpidu_sched.c @@ -148,7 +148,7 @@ int MPIDU_Sched_are_pending(void) return (all_schedules.head != NULL); } -int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag) +int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int coll_group, int *tag) { int mpi_errno = MPI_SUCCESS; /* TODO there should be an internal accessor/utility macro for getting the @@ -162,6 +162,10 @@ int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag) MPIR_FUNC_ENTER; *tag = comm_ptr->next_sched_tag; + if (coll_group != MPIR_SUBGROUP_NONE) { + /* subgroup collectives use the same tag within a parent collective */ + goto fn_exit; + } ++comm_ptr->next_sched_tag; #if defined(HAVE_ERROR_CHECKING) @@ -191,11 +195,13 @@ int MPIDU_Sched_next_tag(MPIR_Comm * comm_ptr, int *tag) if (comm_ptr->next_sched_tag == tag_ub) { comm_ptr->next_sched_tag = MPIR_FIRST_NBC_TAG; } + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; #if defined(HAVE_ERROR_CHECKING) fn_fail: + goto fn_exit; #endif - MPIR_FUNC_EXIT; - return mpi_errno; } void MPIDU_Sched_set_tag(struct MPIDU_Sched *s, int tag) diff --git a/src/mpid/common/sched/mpidu_sched.h b/src/mpid/common/sched/mpidu_sched.h index a61e64c6b69..90d43b75392 100644 --- a/src/mpid/common/sched/mpidu_sched.h +++ b/src/mpid/common/sched/mpidu_sched.h @@ -134,7 +134,7 @@ struct MPIDU_Sched { /* prototypes */ int MPIDU_Sched_progress(int vci, int *made_progress); int MPIDU_Sched_are_pending(void); -int MPIDU_Sched_next_tag(struct MPIR_Comm *comm_ptr, int *tag); +int MPIDU_Sched_next_tag(struct MPIR_Comm *comm_ptr, int coll_group, int *tag); void MPIDU_Sched_set_tag(MPIR_Sched_t s, int tag); int MPIDU_Sched_create(MPIR_Sched_t * sp, enum MPIR_Sched_kind kind); int MPIDU_Sched_clone(MPIR_Sched_t orig, MPIR_Sched_t * cloned); From 138c7609af8f190aa021b810b607be0efb13ac8f Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 24 Aug 2024 16:45:28 -0500 Subject: [PATCH 25/27] coll: refactor barrier_intra_k_dissemination MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Because the compiler can't figure out the arithmetic, it is warning: ‘MPIC_Waitall’ accessing 8 bytes in a region of size 0 [-Wstringop-overflow=] Refactor to suppress warning and for better readability. --- .../barrier/barrier_intra_k_dissemination.c | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c index 5e71475c466..13c54dccc3d 100644 --- a/src/mpi/coll/barrier/barrier_intra_k_dissemination.c +++ b/src/mpi/coll/barrier/barrier_intra_k_dissemination.c @@ -51,7 +51,7 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, int p_of_k; /* minimum power of k that is greater than or equal to number of ranks */ int shift, to, from; int nphases = 0; - MPIR_Request *sreqs[MAX_RADIX], *rreqs[MAX_RADIX * 2]; + MPIR_Request *static_sreqs[MAX_RADIX], *static_rreqs[MAX_RADIX * 2]; MPIR_Request **send_reqs = NULL, **recv_reqs = NULL; MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); @@ -76,8 +76,8 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, send_reqs = (MPIR_Request **) MPL_malloc((k - 1) * sizeof(MPIR_Request *), MPL_MEM_BUFFER); MPIR_ERR_CHKANDJUMP(!send_reqs, mpi_errno, MPI_ERR_OTHER, "**nomem"); } else { - send_reqs = sreqs; - recv_reqs = rreqs; + send_reqs = static_sreqs; + recv_reqs = static_rreqs; } p_of_k = 1; @@ -86,6 +86,8 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, nphases++; } + MPIR_Request **rreqs = recv_reqs; + MPIR_Request **prev_rreqs = recv_reqs + (k - 1); shift = 1; for (i = 0; i < nphases; i++) { for (j = 1; j < k; j++) { @@ -97,14 +99,12 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, MPIR_Assert(to >= 0 && to < nranks); /* recv from (k-1) nbrs */ - mpi_errno = - MPIC_Irecv(NULL, 0, MPI_BYTE, from, MPIR_BARRIER_TAG, comm, coll_group, - &recv_reqs[(j - 1) + ((k - 1) * (i & 1))]); + mpi_errno = MPIC_Irecv(NULL, 0, MPI_BYTE, from, MPIR_BARRIER_TAG, comm, coll_group, + &rreqs[j - 1]); MPIR_ERR_CHECK(mpi_errno); /* wait on recvs from prev phase */ if (i > 0 && j == 1) { - mpi_errno = - MPIC_Waitall(k - 1, &recv_reqs[((k - 1) * ((i - 1) & 1))], MPI_STATUSES_IGNORE); + mpi_errno = MPIC_Waitall(k - 1, prev_rreqs, MPI_STATUSES_IGNORE); MPIR_ERR_CHECK(mpi_errno); } @@ -115,10 +115,13 @@ int MPIR_Barrier_intra_k_dissemination(MPIR_Comm * comm, int coll_group, int k, mpi_errno = MPIC_Waitall(k - 1, send_reqs, MPI_STATUSES_IGNORE); MPIR_ERR_CHECK(mpi_errno); shift *= k; + + MPIR_Request **tmp = rreqs; + rreqs = prev_rreqs; + prev_rreqs = tmp; } - mpi_errno = - MPIC_Waitall(k - 1, recv_reqs + ((k - 1) * ((nphases - 1) & 1)), MPI_STATUSES_IGNORE); + mpi_errno = MPIC_Waitall(k - 1, prev_rreqs, MPI_STATUSES_IGNORE); MPIR_ERR_CHECK(mpi_errno); fn_exit: From 7001a71f972aed52386fcf960e192d8d4a027a52 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 12 Sep 2024 18:03:45 -0500 Subject: [PATCH 26/27] coll/allreduce: remove a leftover empty branch Commit ba1b4dd1e52c30a5a72502ff8d20fc7112b57337 left an empty branch that should be removed. --- src/mpi/coll/allreduce/allreduce_intra_recexch.c | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mpi/coll/allreduce/allreduce_intra_recexch.c b/src/mpi/coll/allreduce/allreduce_intra_recexch.c index a7739e2a1b8..f304d8e349d 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recexch.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recexch.c @@ -204,8 +204,6 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, mpi_errno = MPIC_Isend(recvbuf, count, datatype, nbr, MPIR_ALLREDUCE_TAG, comm, coll_group, &send_reqs[send_nreq++], errflag); MPIR_ERR_CHECK(mpi_errno); - if (rank > nbr) { - } } mpi_errno = MPIC_Waitall(send_nreq, send_reqs, MPI_STATUSES_IGNORE); From 37fc447b983106a848ca80d24727d09b38716436 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 8 Nov 2024 12:02:35 -0600 Subject: [PATCH 27/27] coll: patch allreduce_intra_recursive_multiplying.c Update this code to use coll_group and apply some whitespace changes. --- .../allreduce_intra_recursive_multiplying.c | 124 +++++++++--------- 1 file changed, 61 insertions(+), 63 deletions(-) diff --git a/src/mpi/coll/allreduce/allreduce_intra_recursive_multiplying.c b/src/mpi/coll/allreduce/allreduce_intra_recursive_multiplying.c index 30be5bc893b..1cdd7a99d12 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recursive_multiplying.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recursive_multiplying.c @@ -12,33 +12,31 @@ * * This algorithm is a generalization of the recursive doubling algorithm, * and it has three stages. In the first stage, ranks above the nearest - * power of k less than or equal to comm_size collapse their data to the + * power of k less than or equal to comm_size collapse their data to the * lower ranks. The main stage proceeds with power-of-k ranks. In the main - * stage, ranks exchange data within groups of size k in rounds with - * increasing distance (k, k^2, ...). Lastly, those in the main stage - * disperse the result back to the excluded ranks. Setting k according - * to the network hierarchy (e.g., the number of NICs in a node) can + * stage, ranks exchange data within groups of size k in rounds with + * increasing distance (k, k^2, ...). Lastly, those in the main stage + * disperse the result back to the excluded ranks. Setting k according + * to the network hierarchy (e.g., the number of NICs in a node) can * improve performance. */ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf, - void *recvbuf, - MPI_Aint count, - MPI_Datatype datatype, - MPI_Op op, - MPIR_Comm * comm_ptr, - const int k, - MPIR_Errflag_t errflag) + void *recvbuf, + MPI_Aint count, + MPI_Datatype datatype, + MPI_Op op, + MPIR_Comm * comm_ptr, + int coll_group, const int k, MPIR_Errflag_t errflag) { int mpi_errno = MPI_SUCCESS; /* Ensure the op is commutative */ int comm_size, rank, virt_rank; - comm_size = comm_ptr->local_size; - rank = comm_ptr->rank; + MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size); virt_rank = rank; - + /* get nearest power-of-two less than or equal to comm_size */ int power = (int) (log(comm_size) / log(k)); int pofk = (int) lround(pow(k, power)); @@ -46,21 +44,21 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf, MPIR_CHKLMEM_DECL(2); void *tmp_buf; - /*Allocate for nb requests*/ + /*Allocate for nb requests */ MPIR_Request **reqs; int num_reqs = 0; MPIR_CHKLMEM_MALLOC(reqs, MPIR_Request **, (2 * (k - 1) * sizeof(MPIR_Request *)), mpi_errno, - "reqs", MPL_MEM_BUFFER); + "reqs", MPL_MEM_BUFFER); /* need to allocate temporary buffer to store incoming data */ MPI_Aint true_extent, true_lb, extent; MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent); MPIR_Datatype_get_extent_macro(datatype, extent); MPI_Aint single_node_data_size = extent * count - (extent - true_extent); - + MPIR_CHKLMEM_MALLOC(tmp_buf, void *, (k - 1) * count * single_node_data_size, mpi_errno, "temporary buffer", MPL_MEM_BUFFER); - + /* adjust for potential negative lower bound in datatype */ tmp_buf = (void *) ((char *) tmp_buf - true_lb); @@ -82,34 +80,33 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf, int pre_dst = rank % pofk; /* This is follower so send data */ mpi_errno = MPIC_Send(recvbuf, count, datatype, - pre_dst, MPIR_ALLREDUCE_TAG, comm_ptr, errflag); + pre_dst, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, errflag); MPIR_ERR_CHECK(mpi_errno); /* Set virtual rank so this rank is not used in main stage */ virt_rank = -1; } else { /* Receive data from all those greater than pofk */ for (int pre_src = (rank % pofk) + pofk; pre_src < comm_size; pre_src += pofk) { - mpi_errno = MPIC_Irecv(((char *)tmp_buf) + num_reqs * count * extent, count, - datatype, pre_src, MPIR_ALLREDUCE_TAG, comm_ptr, - &reqs[num_reqs]); + mpi_errno = MPIC_Irecv(((char *) tmp_buf) + num_reqs * count * extent, count, + datatype, pre_src, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, + &reqs[num_reqs]); MPIR_ERR_CHECK(mpi_errno); num_reqs++; } - - /* Wait for asynchronous operations to complete */ + + /* Wait for asynchronous operations to complete */ MPIC_Waitall(num_reqs, reqs, MPI_STATUSES_IGNORE); /* Reduce locally */ - for(int i = 0; i < num_reqs; i++) { - if(i == (num_reqs - 1)) { - mpi_errno = MPIR_Reduce_local(((char *)tmp_buf) + i * count * extent, - recvbuf, count, datatype, op); + for (int i = 0; i < num_reqs; i++) { + if (i == (num_reqs - 1)) { + mpi_errno = MPIR_Reduce_local(((char *) tmp_buf) + i * count * extent, + recvbuf, count, datatype, op); MPIR_ERR_CHECK(mpi_errno); - } - else { - mpi_errno = MPIR_Reduce_local(((char *)tmp_buf) + i * count * extent, - ((char *)tmp_buf) + (i + 1) * count * extent, - count, datatype, op); + } else { + mpi_errno = MPIR_Reduce_local(((char *) tmp_buf) + i * count * extent, + ((char *) tmp_buf) + (i + 1) * count * extent, + count, datatype, op); MPIR_ERR_CHECK(mpi_errno); } } @@ -119,60 +116,61 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf, /*MAIN-STAGE: Ranks exchange data with groups size k over increasing * distances */ if (virt_rank != -1) { - /*Do exchanges*/ + /*Do exchanges */ num_reqs = 0; int exchanges = 0; int distance = 1; int next_distance = k; while (distance < pofk) { /* Asynchronous sends */ - - int starting_rank = rank/next_distance * next_distance; + + int starting_rank = rank / next_distance * next_distance; int rank_offset = starting_rank + rank % distance; - for(int dst = rank_offset; dst < starting_rank + next_distance; dst += distance) { - if(dst != rank) { - mpi_errno = MPIC_Isend(recvbuf, count, datatype, dst, MPIR_ALLREDUCE_TAG, - comm_ptr, &reqs[num_reqs++], errflag); + for (int dst = rank_offset; dst < starting_rank + next_distance; dst += distance) { + if (dst != rank) { + mpi_errno = MPIC_Isend(recvbuf, count, datatype, dst, MPIR_ALLREDUCE_TAG, + comm_ptr, coll_group, &reqs[num_reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIC_Irecv(((char *)tmp_buf) + exchanges * count * extent, - count, datatype, dst, MPIR_ALLREDUCE_TAG, comm_ptr, - &reqs[num_reqs++]); + mpi_errno = MPIC_Irecv(((char *) tmp_buf) + exchanges * count * extent, + count, datatype, dst, MPIR_ALLREDUCE_TAG, + comm_ptr, coll_group, &reqs[num_reqs++]); MPIR_ERR_CHECK(mpi_errno); exchanges++; } } - /* Wait for asynchronous operations to complete */ + /* Wait for asynchronous operations to complete */ MPIC_Waitall(num_reqs, reqs, MPI_STATUSES_IGNORE); num_reqs = 0; exchanges = 0; /* Perform reduction on the received values */ int recvbuf_last = 0; - for(int dst = rank_offset; dst < starting_rank + next_distance - distance; dst += distance) { - void *dst_buf = ((char *)tmp_buf) + exchanges * count * extent; - if(dst == rank - distance) { + for (int dst = rank_offset; dst < starting_rank + next_distance - distance; + dst += distance) { + void *dst_buf = ((char *) tmp_buf) + exchanges * count * extent; + if (dst == rank - distance) { mpi_errno = MPIR_Reduce_local(dst_buf, recvbuf, count, datatype, op); MPIR_ERR_CHECK(mpi_errno); recvbuf_last = 1; exchanges++; - } - else if(dst == rank){ + } else if (dst == rank) { mpi_errno = MPIR_Reduce_local(recvbuf, dst_buf, count, datatype, op); MPIR_ERR_CHECK(mpi_errno); recvbuf_last = 0; - } - else { - mpi_errno = MPIR_Reduce_local(dst_buf, (char *)dst_buf + count * extent, count, datatype, op); + } else { + mpi_errno = + MPIR_Reduce_local(dst_buf, (char *) dst_buf + count * extent, count, + datatype, op); MPIR_ERR_CHECK(mpi_errno); recvbuf_last = 0; exchanges++; } } - - if(!recvbuf_last) { - mpi_errno = MPIR_Localcopy((char *)tmp_buf + exchanges * count * extent, - count, datatype, recvbuf, count, datatype); + + if (!recvbuf_last) { + mpi_errno = MPIR_Localcopy((char *) tmp_buf + exchanges * count * extent, + count, datatype, recvbuf, count, datatype); MPIR_ERR_CHECK(mpi_errno); } @@ -183,23 +181,23 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf, } /* POST-STAGE: Send result to ranks outside main algorithm */ - if(pofk < comm_size) { + if (pofk < comm_size) { num_reqs = 0; - if(rank >= pofk) { + if (rank >= pofk) { int post_src = rank % pofk; /* This process is outside the core algorithm, so receive data */ mpi_errno = MPIC_Recv(recvbuf, count, datatype, post_src, - MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE); + MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE); MPIR_ERR_CHECK(mpi_errno); } else { /* This is process is in the algorithm, so send data */ for (int post_dst = (rank % pofk) + pofk; post_dst < comm_size; post_dst += pofk) { - mpi_errno = MPIC_Isend(recvbuf, count, datatype, post_dst, MPIR_ALLREDUCE_TAG, comm_ptr, - &reqs[num_reqs++], errflag); + mpi_errno = MPIC_Isend(recvbuf, count, datatype, post_dst, MPIR_ALLREDUCE_TAG, + comm_ptr, coll_group, &reqs[num_reqs++], errflag); MPIR_ERR_CHECK(mpi_errno); } - + MPIC_Waitall(num_reqs, reqs, MPI_STATUSES_IGNORE); } }