From 19292d69e90e3f0ed0ff80b480d1856e3e65470e Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 10 Dec 2024 13:14:43 -0600 Subject: [PATCH] group: refactor MPIR_Group * add option to use stride to describe group composition * remove the linked list design --- src/include/mpir_group.h | 102 ++++--- src/mpi/comm/comm_impl.c | 3 - src/mpi/group/group_impl.c | 559 ++++++++++++------------------------- src/mpi/group/grouputil.c | 284 +++++-------------- 4 files changed, 317 insertions(+), 631 deletions(-) diff --git a/src/include/mpir_group.h b/src/include/mpir_group.h index 520ab8f8d02..aefb68fcd70 100644 --- a/src/include/mpir_group.h +++ b/src/include/mpir_group.h @@ -12,24 +12,6 @@ * MPI_Group_intersection) and for the scalable RMA synchronization *---------------------------------------------------------------------------*/ -/* Abstract the integer type for lpid (process id). It is possible to use 32-bit - * in principle, but 64-bit is simpler since we can trivially combine - * (world_idx, world_rank). - */ -typedef uint64_t MPIR_Lpid; - -/* This structure is used to implement the group operations such as - MPI_Group_translate_ranks */ -/* note: next_lpid (with idx_of_first_lpid in MPIR_Group) gives a linked list - * in a sorted lpid ascending order */ -typedef struct MPII_Group_pmap_t { - MPIR_Lpid lpid; /* local process id, from VCONN */ - int next_lpid; /* Index of next lpid (in lpid order) */ -} MPII_Group_pmap_t; - -/* Any changes in the MPIR_Group structure must be made to the - predefined value in MPIR_Group_builtin for MPI_GROUP_EMPTY in - src/mpi/group/grouputil.c */ /*S MPIR_Group - Description of the Group data structure @@ -60,22 +42,75 @@ typedef struct MPII_Group_pmap_t { Group-DS S*/ + +/* Abstract the integer type for lpid (process id). It is possible to use 32-bit + * in principle, but 64-bit is simpler since we can trivially combine + * (world_idx, world_rank). + */ +typedef uint64_t MPIR_Lpid; + +struct MPIR_Pmap { + int size; /* same as group->size, duplicate here so Pmap is logically complete */ + bool use_map; + union { + MPIR_Lpid *map; + struct { + MPIR_Lpid offset; + MPIR_Lpid stride; + MPIR_Lpid blocksize; + } stride; + } u; +}; + +MPL_STATIC_INLINE_PREFIX MPIR_Lpid MPIR_Pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank) +{ + if (rank >= 0 && rank < pmap->size) { + return MPI_UNDEFINED; + } + + if (pmap->use_map) { + return pmap->u.map[rank]; + } else { + MPIR_Lpid i_blk = rank / pmap->u.stride.blocksize; + MPIR_Lpid r_blk = rank % pmap->u.stride.blocksize; + return pmap->u.stride.offset + i_blk * pmap->u.stride.stride + r_blk; + } +} + +MPL_STATIC_INLINE_PREFIX int MPIR_Pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid) +{ + if (pmap->use_map) { + /* linear search */ + for (int rank = 0; rank < pmap->size; rank++) { + if (pmap->u.map[rank] == lpid) { + return rank; + } + } + return MPI_UNDEFINED; + } else { + lpid -= pmap->u.stride.offset; + MPIR_Lpid i_blk = lpid / pmap->u.stride.stride; + MPIR_Lpid r_blk = lpid % pmap->u.stride.stride; + + if (r_blk >= pmap->u.stride.blocksize) { + return MPI_UNDEFINED; + } + + int rank = i_blk * pmap->u.stride.blocksize + r_blk; + if (rank >= 0 && rank < pmap->size) { + return rank; + } else { + return MPI_UNDEFINED; + } + } +} + struct MPIR_Group { MPIR_OBJECT_HEADER; /* adds handle and ref_count fields */ int size; /* Size of a group */ - int rank; /* rank of this process relative to this - * group */ - int idx_of_first_lpid; - MPII_Group_pmap_t *lrank_to_lpid; /* Array mapping a local rank to local - * process number */ - int is_local_dense_monotonic; /* see NOTE-G1 */ - - /* We may want some additional data for the RMA syncrhonization calls */ - /* Other, device-specific information */ -#ifdef MPID_DEV_GROUP_DECL - MPID_DEV_GROUP_DECL -#endif - MPIR_Session * session_ptr; /* Pointer to session to which this group belongs */ + int rank; /* rank of this process relative to this group */ + struct MPIR_Pmap pmap; + MPIR_Session *session_ptr; /* Pointer to session to which this group belongs */ }; /* NOTE-G1: is_local_dense_monotonic will be true iff the group meets the @@ -104,10 +139,8 @@ extern MPIR_Group *const MPIR_Group_empty; #define MPIR_Group_release_ref(_group, _inuse) \ do { MPIR_Object_release_ref(_group, _inuse); } while (0) -void MPII_Group_setup_lpid_list(MPIR_Group *); int MPIR_Group_check_valid_ranks(MPIR_Group *, const int[], int); int MPIR_Group_check_valid_ranges(MPIR_Group *, int[][3], int); -void MPIR_Group_setup_lpid_pairs(MPIR_Group *, MPIR_Group *); int MPIR_Group_create(int, MPIR_Group **); int MPIR_Group_release(MPIR_Group * group_ptr); @@ -122,7 +155,4 @@ int MPIR_Group_check_subset(MPIR_Group * group_ptr, MPIR_Comm * comm_ptr); void MPIR_Group_set_session_ptr(MPIR_Group * group_ptr, MPIR_Session * session_out); int MPIR_Group_init(void); -/* internal functions */ -void MPII_Group_setup_lpid_list(MPIR_Group *); - #endif /* MPIR_GROUP_H_INCLUDED */ diff --git a/src/mpi/comm/comm_impl.c b/src/mpi/comm/comm_impl.c index a91c40fc018..9eafdb1d223 100644 --- a/src/mpi/comm/comm_impl.c +++ b/src/mpi/comm/comm_impl.c @@ -198,9 +198,6 @@ int MPII_Comm_create_calculate_mapping(MPIR_Group * group_ptr, * exactly the same as the ranks in comm world. */ - /* we examine the group's lpids in both the intracomm and non-comm_world cases */ - MPII_Group_setup_lpid_list(group_ptr); - /* Optimize for groups contained within MPI_COMM_WORLD. */ if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) { int wsize; diff --git a/src/mpi/group/group_impl.c b/src/mpi/group/group_impl.c index a347dd0b3fc..347facee4db 100644 --- a/src/mpi/group/group_impl.c +++ b/src/mpi/group/group_impl.c @@ -34,10 +34,22 @@ int MPIR_Group_free_impl(MPIR_Group * group_ptr) goto fn_exit; } +static int compar(const void *_a, const void *_b) +{ + MPIR_Lpid a = *(MPIR_Lpid *) _a; + MPIR_Lpid b = *(MPIR_Lpid *) _b; + if (a == b) { + return 0; + } else if (a < b) { + return -1; + } else { + return 1; + } +} + int MPIR_Group_compare_impl(MPIR_Group * group_ptr1, MPIR_Group * group_ptr2, int *result) { int mpi_errno = MPI_SUCCESS; - int g1_idx, g2_idx, size, i; /* See if their sizes are equal */ if (group_ptr1->size != group_ptr2->size) { @@ -45,39 +57,52 @@ int MPIR_Group_compare_impl(MPIR_Group * group_ptr1, MPIR_Group * group_ptr2, in goto fn_exit; } - /* Run through the lrank to lpid lists of each group in lpid order - * to see if the same processes are involved */ - g1_idx = group_ptr1->idx_of_first_lpid; - g2_idx = group_ptr2->idx_of_first_lpid; - /* If the lpid list hasn't been created, do it now */ - if (g1_idx < 0) { - MPII_Group_setup_lpid_list(group_ptr1); - g1_idx = group_ptr1->idx_of_first_lpid; - } - if (g2_idx < 0) { - MPII_Group_setup_lpid_list(group_ptr2); - g2_idx = group_ptr2->idx_of_first_lpid; - } - while (g1_idx >= 0 && g2_idx >= 0) { - if (group_ptr1->lrank_to_lpid[g1_idx].lpid != group_ptr2->lrank_to_lpid[g2_idx].lpid) { + int n = group_ptr1->size; + struct MPIR_Pmap *pmap1 = &group_ptr1->pmap; + struct MPIR_Pmap *pmap2 = &group_ptr2->pmap; + if (!pmap1->use_map && !pmap2->use_map) { + /* just compare the stride parameters */ + if (pmap1->u.stride.offset != pmap2->u.stride.offset || + pmap1->u.stride.stride != pmap2->u.stride.stride || + pmap1->u.stride.blocksize != pmap2->u.stride.blocksize) { *result = MPI_UNEQUAL; + } else { + *result = MPI_IDENT; + } + } else { + /* check whether it's identical first */ + bool is_ident = true; + for (int rank = 0; rank < n; rank++) { + if (MPIR_Pmap_rank_to_lpid(pmap1, rank) != MPIR_Pmap_rank_to_lpid(pmap2, rank)) { + is_ident = false; + break; + } + } + if (is_ident) { + *result = MPI_IDENT; goto fn_exit; } - g1_idx = group_ptr1->lrank_to_lpid[g1_idx].next_lpid; - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - } - /* See if the processes are in the same order by rank */ - size = group_ptr1->size; - for (i = 0; i < size; i++) { - if (group_ptr1->lrank_to_lpid[i].lpid != group_ptr2->lrank_to_lpid[i].lpid) { - *result = MPI_SIMILAR; - goto fn_exit; + /* sort both pmaps and compare. O(n lg(n)) */ + MPIR_Lpid *map1 = MPL_malloc(n * sizeof(MPIR_Lpid), MPL_MEM_OTHER); + MPIR_Lpid *map2 = MPL_malloc(n * sizeof(MPIR_Lpid), MPL_MEM_OTHER); + for (int rank = 0; rank < n; rank++) { + map1[rank] = MPIR_Pmap_rank_to_lpid(pmap1, rank); + map2[rank] = MPIR_Pmap_rank_to_lpid(pmap2, rank); + } + qsort(map1, n, sizeof(MPIR_Lpid), compar); + + *result = MPI_SIMILAR; + for (int i = 0; i < n; i++) { + if (map1[i] != map2[i]) { + *result = MPI_UNEQUAL; + break; + } } - } - /* If we reach here, the groups are identical */ - *result = MPI_IDENT; + MPL_free(map1); + MPL_free(map2); + } fn_exit: return mpi_errno; @@ -87,71 +112,16 @@ int MPIR_Group_translate_ranks_impl(MPIR_Group * gp1, int n, const int ranks1[], MPIR_Group * gp2, int ranks2[]) { int mpi_errno = MPI_SUCCESS; - int i, g2_idx; - uint64_t l1_pid, l2_pid; - - MPL_DBG_MSG_S(MPIR_DBG_OTHER, VERBOSE, "gp2->is_local_dense_monotonic=%s", - (gp2->is_local_dense_monotonic ? "TRUE" : "FALSE")); - /* Initialize the output ranks */ - for (i = 0; i < n; i++) - ranks2[i] = MPI_UNDEFINED; - - if (gp2->size > 0 && gp2->is_local_dense_monotonic) { - /* g2 probably == group_of(MPI_COMM_WORLD); use fast, constant-time lookup */ - uint64_t lpid_offset = gp2->lrank_to_lpid[0].lpid; - - for (i = 0; i < n; ++i) { - uint64_t g1_lpid; - - if (ranks1[i] == MPI_PROC_NULL) { - ranks2[i] = MPI_PROC_NULL; - continue; - } - /* "adjusted" lpid from g1 */ - g1_lpid = gp1->lrank_to_lpid[ranks1[i]].lpid - lpid_offset; - if (g1_lpid < gp2->size) { - ranks2[i] = g1_lpid; - } - /* else leave UNDEFINED */ - } - } else { - /* general, slow path; lookup time is dependent on the user-provided rank values! */ - g2_idx = gp2->idx_of_first_lpid; - if (g2_idx < 0) { - MPII_Group_setup_lpid_list(gp2); - g2_idx = gp2->idx_of_first_lpid; - } - if (g2_idx >= 0) { - /* g2_idx can be < 0 if the g2 group is empty */ - l2_pid = gp2->lrank_to_lpid[g2_idx].lpid; - for (i = 0; i < n; i++) { - if (ranks1[i] == MPI_PROC_NULL) { - ranks2[i] = MPI_PROC_NULL; - continue; - } - l1_pid = gp1->lrank_to_lpid[ranks1[i]].lpid; - /* Search for this l1_pid in group2. Use the following - * optimization: start from the last position in the lpid list - * if possible. A more sophisticated version could use a - * tree based or even hashed search to speed the translation. */ - if (l1_pid < l2_pid || g2_idx < 0) { - /* Start over from the beginning */ - g2_idx = gp2->idx_of_first_lpid; - l2_pid = gp2->lrank_to_lpid[g2_idx].lpid; - } - while (g2_idx >= 0 && l1_pid > l2_pid) { - g2_idx = gp2->lrank_to_lpid[g2_idx].next_lpid; - if (g2_idx >= 0) - l2_pid = gp2->lrank_to_lpid[g2_idx].lpid; - else - l2_pid = (uint64_t) - 1; - } - if (l1_pid == l2_pid) - ranks2[i] = g2_idx; - } + for (int i = 0; i < n; i++) { + if (ranks1[i] == MPI_PROC_NULL) { + ranks2[i] = MPI_PROC_NULL; + continue; } + MPIR_Lpid lpid = MPIR_Pmap_rank_to_lpid(&gp1->pmap, ranks1[i]); + ranks2[i] = MPIR_Pmap_lpid_to_rank(&gp2->pmap, lpid); } + return mpi_errno; } @@ -159,39 +129,36 @@ int MPIR_Group_excl_impl(MPIR_Group * group_ptr, int n, const int ranks[], MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - int size, i, newi; - int *flags = NULL; - MPIR_FUNC_ENTER; - size = group_ptr->size; - - MPIR_Lpid *map = MPL_malloc((size - n) * sizeof(MPIR_Lpid), MPL_MEM_GROUP); + int nnew = 0; + MPIR_Lpid *map = MPL_malloc((group_ptr->size - n) * sizeof(MPIR_Lpid), MPL_MEM_OTHER); MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); /* Use flag fields to mark the members to *exclude* . */ - flags = MPL_calloc(size, sizeof(int), MPL_MEM_OTHER); - - for (i = 0; i < n; i++) { + int *flags; + flags = MPL_calloc(group_ptr->size, sizeof(int), MPL_MEM_OTHER); + for (int i = 0; i < n; i++) { flags[ranks[i]] = 1; } int myrank = MPI_UNDEFINED; - newi = 0; - for (i = 0; i < size; i++) { + for (int i = 0; i < group_ptr->size; i++) { if (flags[i] == 0) { - map[newi] = MPIR_Group_lookup(group_ptr, i); - if (group_ptr->rank == i) - myrank = newi; - newi++; + if (i == group_ptr->rank) { + myrank = nnew; + } + map[nnew++] = MPIR_Group_lookup(group_ptr, i); } } - mpi_errno = MPIR_Group_create_map(size - n, myrank, group_ptr->session_ptr, map, new_group_ptr); + MPL_free(flags); + MPIR_Assert(nnew == group_ptr->size - n); + + mpi_errno = MPIR_Group_create_map(nnew, myrank, group_ptr->session_ptr, map, new_group_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: - MPL_free(flags); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -209,17 +176,21 @@ int MPIR_Group_incl_impl(MPIR_Group * group_ptr, int n, const int ranks[], goto fn_exit; } - MPIR_Lpid *map = MPL_malloc(n * sizeof(MPIR_Lpid), MPL_MEM_GROUP); + int nnew = 0; + MPIR_Lpid *map = MPL_malloc(n * sizeof(MPIR_Lpid), MPL_MEM_OTHER); MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); int myrank = MPI_UNDEFINED; for (int i = 0; i < n; i++) { - map[i] = MPIR_Group_lookup(group_ptr, i); - if (ranks[i] == group_ptr->rank) - myrank = i; + if (ranks[i] == group_ptr->rank) { + myrank = nnew; + } + map[nnew++] = MPIR_Group_lookup(group_ptr, ranks[i]); } - mpi_errno = MPIR_Group_create_map(n, myrank, group_ptr->session_ptr, map, new_group_ptr); + MPIR_Assert(nnew == n); + + mpi_errno = MPIR_Group_create_map(nnew, myrank, group_ptr->session_ptr, map, new_group_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -233,73 +204,56 @@ int MPIR_Group_range_excl_impl(MPIR_Group * group_ptr, int n, int ranges[][3], MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - int size, i, j, k, nnew, first, last, stride; - int *flags = NULL; - MPIR_FUNC_ENTER; - /* Compute size, assuming that included ranks are valid (and distinct) */ - size = group_ptr->size; - nnew = 0; - for (i = 0; i < n; i++) { - first = ranges[i][0]; - last = ranges[i][1]; - stride = ranges[i][2]; + + int count = 0; + for (int i = 0; i < n; i++) { + int first = ranges[i][0]; + int last = ranges[i][1]; + int stride = ranges[i][2]; /* works for stride of either sign. Error checking above * has already guaranteed stride != 0 */ - nnew += 1 + (last - first) / stride; + count += 1 + (last - first) / stride; } - nnew = size - nnew; - if (nnew == 0) { - *new_group_ptr = MPIR_Group_empty; - goto fn_exit; - } - - MPIR_Lpid *map = MPL_malloc(nnew * sizeof(MPIR_Lpid), MPL_MEM_GROUP); + int nnew = 0; + MPIR_Lpid *map = MPL_malloc((group_ptr->size - count) * sizeof(MPIR_Lpid), MPL_MEM_OTHER); MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); - int myrank = MPI_UNDEFINED; - - /* Group members are taken in rank order from the original group, - * with the specified members removed. Use the flag array for that - * purpose. If this was a critical routine, we could use the - * flag values set in the error checking part, if the error checking - * was enabled *and* we are not MPI_THREAD_MULTIPLE, but since this - * is a low-usage routine, we haven't taken that optimization. */ - flags = MPL_calloc(size, sizeof(int), MPL_MEM_OTHER); - - for (i = 0; i < n; i++) { - first = ranges[i][0]; - last = ranges[i][1]; - stride = ranges[i][2]; + int *flags; + flags = MPL_calloc(group_ptr->size, sizeof(int), MPL_MEM_OTHER); + for (int i = 0; i < n; i++) { + int first = ranges[i][0]; + int last = ranges[i][1]; + int stride = ranges[i][2]; if (stride > 0) { - for (j = first; j <= last; j += stride) { + for (int j = first; j <= last; j += stride) { flags[j] = 1; } } else { - for (j = first; j >= last; j += stride) { + for (int j = first; j >= last; j += stride) { flags[j] = 1; } } } - /* Now, run through the group and pick up the members that were - * not excluded */ - k = 0; - for (i = 0; i < size; i++) { - if (!flags[i]) { - map[k] = MPIR_Group_lookup(group_ptr, i); - if (group_ptr->rank == i) { - myrank = k; + + int myrank = MPI_UNDEFINED; + for (int i = 0; i < group_ptr->size; i++) { + if (flags[i] == 0) { + if (i == group_ptr->rank) { + myrank = nnew; } - k++; + map[nnew++] = MPIR_Group_lookup(group_ptr, i); } } + MPL_free(flags); + MPIR_Assert(nnew == group_ptr->size - count); + mpi_errno = MPIR_Group_create_map(nnew, myrank, group_ptr->session_ptr, map, new_group_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: - MPL_free(flags); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -310,55 +264,42 @@ int MPIR_Group_range_incl_impl(MPIR_Group * group_ptr, int n, int ranges[][3], MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - int first, last, stride, nnew, i, j, k; - MPIR_FUNC_ENTER; - /* Compute size, assuming that included ranks are valid (and distinct) */ - nnew = 0; - for (i = 0; i < n; i++) { - first = ranges[i][0]; - last = ranges[i][1]; - stride = ranges[i][2]; + int count = 0; + for (int i = 0; i < n; i++) { + int first = ranges[i][0]; + int last = ranges[i][1]; + int stride = ranges[i][2]; /* works for stride of either sign. Error checking above * has already guaranteed stride != 0 */ - nnew += 1 + (last - first) / stride; - } - - if (nnew == 0) { - *new_group_ptr = MPIR_Group_empty; - goto fn_exit; + count += 1 + (last - first) / stride; } - MPIR_Lpid *map = MPL_malloc(nnew * sizeof(MPIR_Lpid), MPL_MEM_GROUP); - MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); + int nnew = 0; int myrank = MPI_UNDEFINED; + MPIR_Lpid *map = MPL_malloc(count * sizeof(MPIR_Lpid), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); - /* Group members taken in order specified by the range array */ - /* This could be integrated with the error checking, but since this - * is a low-usage routine, we haven't taken that optimization */ - k = 0; - for (i = 0; i < n; i++) { - first = ranges[i][0]; - last = ranges[i][1]; - stride = ranges[i][2]; - if (stride > 0) { - for (j = first; j <= last; j += stride) { - (*new_group_ptr)->lrank_to_lpid[k].lpid = group_ptr->lrank_to_lpid[j].lpid; - if (j == group_ptr->rank) - (*new_group_ptr)->rank = k; - k++; - } - } else { - for (j = first; j >= last; j += stride) { - (*new_group_ptr)->lrank_to_lpid[k].lpid = group_ptr->lrank_to_lpid[j].lpid; - if (j == group_ptr->rank) - (*new_group_ptr)->rank = k; - k++; + for (int i = 0; i < n; i++) { + int first = ranges[i][0]; + int last = ranges[i][1]; + int stride = ranges[i][2]; + if (stride < 0) { + first = ranges[i][1]; + last = ranges[i][0]; + stride = -stride; + } + for (int j = first; j <= last; j += stride) { + if (j == group_ptr->rank) { + myrank = nnew; } + map[nnew++] = MPIR_Group_lookup(group_ptr, j); } } + MPIR_Assert(nnew == count); + mpi_errno = MPIR_Group_create_map(nnew, myrank, group_ptr->session_ptr, map, new_group_ptr); MPIR_ERR_CHECK(mpi_errno); @@ -373,66 +314,31 @@ int MPIR_Group_difference_impl(MPIR_Group * group_ptr1, MPIR_Group * group_ptr2, MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - int size1, i, k, g1_idx, g2_idx, nnew; - uint64_t l1_pid, l2_pid; - int *flags = NULL; - MPIR_FUNC_ENTER; - /* Return a group consisting of the members of group1 that are *not* - * in group2 */ - size1 = group_ptr1->size; - /* Insure that the lpid lists are setup */ - MPIR_Group_setup_lpid_pairs(group_ptr1, group_ptr2); - - flags = MPL_calloc(size1, sizeof(int), MPL_MEM_OTHER); - - g1_idx = group_ptr1->idx_of_first_lpid; - g2_idx = group_ptr2->idx_of_first_lpid; - - nnew = size1; - while (g1_idx >= 0 && g2_idx >= 0) { - l1_pid = group_ptr1->lrank_to_lpid[g1_idx].lpid; - l2_pid = group_ptr2->lrank_to_lpid[g2_idx].lpid; - if (l1_pid < l2_pid) { - g1_idx = group_ptr1->lrank_to_lpid[g1_idx].next_lpid; - } else if (l1_pid > l2_pid) { - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - } else { - /* Equal */ - flags[g1_idx] = 1; - g1_idx = group_ptr1->lrank_to_lpid[g1_idx].next_lpid; - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - nnew--; - } - } - /* Create the group */ - if (nnew == 0) { - /* See 5.3.2, Group Constructors. For many group routines, - * the standard explicitly says to return MPI_GROUP_EMPTY; - * for others it is implied */ - *new_group_ptr = MPIR_Group_empty; - goto fn_exit; - } else { - MPIR_Lpid *map = MPL_malloc(nnew * sizeof(MPIR_Lpid), MPL_MEM_GROUP); - MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); - int myrank = MPI_UNDEFINED; - - k = 0; - for (i = 0; i < size1; i++) { - if (!flags[i]) { - map[k] = MPIR_Group_lookup(group_ptr1, i); - if (i == group_ptr1->rank) - myrank = k; - k++; + + /* FIXME: check session_ptr */ + + int nnew = 0; + MPIR_Lpid *map = MPL_malloc(group_ptr1->size * sizeof(MPIR_Lpid), MPL_MEM_OTHER); + + int myrank = MPI_UNDEFINED; + /* For each rank in group1, search it in group2. */ + for (int rank = 0; rank < group_ptr1->size; rank++) { + MPIR_Lpid lpid = MPIR_Group_lookup(group_ptr1, rank); + if (MPI_UNDEFINED == MPIR_Pmap_lpid_to_rank(&group_ptr2->pmap, lpid)) { + /* not found */ + if (rank == group_ptr1->rank) { + myrank = nnew; } + map[nnew++] = lpid; } - mpi_errno = MPIR_Group_create_map(nnew, myrank, group_ptr1->session_ptr, map, - new_group_ptr); - MPIR_ERR_CHECK(mpi_errno); } + /* Create the group */ + mpi_errno = MPIR_Group_create_map(nnew, myrank, group_ptr1->session_ptr, map, new_group_ptr); + MPIR_ERR_CHECK(mpi_errno); + fn_exit: - MPL_free(flags); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -443,63 +349,31 @@ int MPIR_Group_intersection_impl(MPIR_Group * group_ptr1, MPIR_Group * group_ptr MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - int size1, i, k, g1_idx, g2_idx, nnew; - uint64_t l1_pid, l2_pid; - int *flags = NULL; - MPIR_FUNC_ENTER; - /* Return a group consisting of the members of group1 that are - * in group2 */ - size1 = group_ptr1->size; - /* Insure that the lpid lists are setup */ - MPIR_Group_setup_lpid_pairs(group_ptr1, group_ptr2); - - flags = MPL_calloc(size1, sizeof(int), MPL_MEM_OTHER); - - g1_idx = group_ptr1->idx_of_first_lpid; - g2_idx = group_ptr2->idx_of_first_lpid; - - nnew = 0; - while (g1_idx >= 0 && g2_idx >= 0) { - l1_pid = group_ptr1->lrank_to_lpid[g1_idx].lpid; - l2_pid = group_ptr2->lrank_to_lpid[g2_idx].lpid; - if (l1_pid < l2_pid) { - g1_idx = group_ptr1->lrank_to_lpid[g1_idx].next_lpid; - } else if (l1_pid > l2_pid) { - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - } else { - /* Equal */ - flags[g1_idx] = 1; - g1_idx = group_ptr1->lrank_to_lpid[g1_idx].next_lpid; - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - nnew++; - } - } - /* Create the group. Handle the trivial case first */ - if (nnew == 0) { - *new_group_ptr = MPIR_Group_empty; - goto fn_exit; - } - MPIR_Lpid *map = MPL_malloc(nnew * sizeof(MPIR_Lpid), MPL_MEM_GROUP); - MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); - int myrank = MPI_UNDEFINED; + /* FIXME: check session_ptr */ - k = 0; - for (i = 0; i < size1; i++) { - if (flags[i]) { - map[k] = MPIR_Group_lookup(group_ptr1, i); - if (i == group_ptr1->rank) - myrank = k; - k++; + /* For each rank in group1, search it in group2. */ + int nnew = 0; + MPIR_Lpid *map = MPL_malloc(group_ptr1->size * sizeof(MPIR_Lpid), MPL_MEM_OTHER); + + int myrank = MPI_UNDEFINED; + for (int rank = 0; rank < group_ptr1->size; rank++) { + MPIR_Lpid lpid = MPIR_Group_lookup(group_ptr1, rank); + if (MPI_UNDEFINED != MPIR_Pmap_lpid_to_rank(&group_ptr2->pmap, lpid)) { + /* found */ + if (rank == group_ptr1->rank) { + myrank = nnew; + } + map[nnew++] = lpid; } } + /* Create the group */ mpi_errno = MPIR_Group_create_map(nnew, myrank, group_ptr1->session_ptr, map, new_group_ptr); MPIR_ERR_CHECK(mpi_errno); fn_exit: - MPL_free(flags); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -510,94 +384,26 @@ int MPIR_Group_union_impl(MPIR_Group * group_ptr1, MPIR_Group * group_ptr2, MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - int g1_idx, g2_idx, nnew, i, k, size1, size2; - uint64_t mylpid; - int *flags = NULL; - MPIR_FUNC_ENTER; - /* Determine the size of the new group. The new group consists of all - * members of group1 plus the members of group2 that are not in group1. - */ - g1_idx = group_ptr1->idx_of_first_lpid; - g2_idx = group_ptr2->idx_of_first_lpid; + /* FIXME: check session_ptr */ - /* If the lpid list hasn't been created, do it now */ - if (g1_idx < 0) { - MPII_Group_setup_lpid_list(group_ptr1); - g1_idx = group_ptr1->idx_of_first_lpid; - } - if (g2_idx < 0) { - MPII_Group_setup_lpid_list(group_ptr2); - g2_idx = group_ptr2->idx_of_first_lpid; - } - nnew = group_ptr1->size; - - /* Clear the flag bits on the second group. The flag is set if - * a member of the second group belongs to the union */ - size2 = group_ptr2->size; - flags = MPL_calloc(size2, sizeof(int), MPL_MEM_OTHER); - - /* Loop through the lists that are ordered by lpid (local process - * id) to detect which processes in group 2 are not in group 1 - */ - while (g1_idx >= 0 && g2_idx >= 0) { - uint64_t l1_pid, l2_pid; - l1_pid = group_ptr1->lrank_to_lpid[g1_idx].lpid; - l2_pid = group_ptr2->lrank_to_lpid[g2_idx].lpid; - if (l1_pid > l2_pid) { - nnew++; - flags[g2_idx] = 1; - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - } else if (l1_pid == l2_pid) { - g1_idx = group_ptr1->lrank_to_lpid[g1_idx].next_lpid; - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - } else { - /* l1 < l2 */ - g1_idx = group_ptr1->lrank_to_lpid[g1_idx].next_lpid; - } - } - /* If we hit the end of group1, add the remaining members of group 2 */ - while (g2_idx >= 0) { - nnew++; - flags[g2_idx] = 1; - g2_idx = group_ptr2->lrank_to_lpid[g2_idx].next_lpid; - } - - if (nnew == 0) { - *new_group_ptr = MPIR_Group_empty; - goto fn_exit; - } - - MPIR_Lpid *map = MPL_malloc(nnew * sizeof(MPIR_Lpid), MPL_MEM_GROUP); - MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem"); + /* For each rank in group2, search it in group1. Put those processes are + * unique in group2 in map2.*/ + int nnew = group_ptr1->size; + MPIR_Lpid *map = MPL_malloc((group_ptr1->size + group_ptr2->size) * sizeof(MPIR_Lpid), + MPL_MEM_OTHER); + MPIR_Lpid *map2 = map + group_ptr1->size; - /* If this process is in group1, then we can set the rank now. - * If we are not in this group, this assignment will set the - * current rank to MPI_UNDEFINED */ int myrank = group_ptr1->rank; - - /* Add group1 */ - size1 = group_ptr1->size; - for (i = 0; i < size1; i++) { - map[i] = MPIR_Group_lookup(group_ptr1, i); - } - - /* Add members of group2 that are not in group 1 */ - - if (group_ptr1->rank == MPI_UNDEFINED && group_ptr2->rank >= 0) { - mylpid = group_ptr2->lrank_to_lpid[group_ptr2->rank].lpid; - } else { - mylpid = (uint64_t) - 2; - } - k = size1; - for (i = 0; i < size2; i++) { - if (flags[i]) { - map[k] = MPIR_Group_lookup(group_ptr2, i); - if (myrank == MPI_UNDEFINED && i == group_ptr2->rank) { - myrank = k; + for (int rank = 0; rank < group_ptr2->size; rank++) { + MPIR_Lpid lpid = MPIR_Group_lookup(group_ptr2, rank); + if (MPI_UNDEFINED == MPIR_Pmap_lpid_to_rank(&group_ptr1->pmap, lpid)) { + /* not found */ + if (rank == group_ptr2->rank) { + myrank = nnew; } - k++; + map[nnew++] = lpid; } } @@ -605,7 +411,6 @@ int MPIR_Group_union_impl(MPIR_Group * group_ptr1, MPIR_Group * group_ptr2, MPIR_ERR_CHECK(mpi_errno); fn_exit: - MPL_free(flags); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: diff --git a/src/mpi/group/grouputil.c b/src/mpi/group/grouputil.c index d98487c7b29..1d8d4ee0607 100644 --- a/src/mpi/group/grouputil.c +++ b/src/mpi/group/grouputil.c @@ -28,10 +28,9 @@ int MPIR_Group_init(void) MPIR_Object_set_ref(&MPIR_Group_builtin[0], 1); MPIR_Group_builtin[0].size = 0; MPIR_Group_builtin[0].rank = MPI_UNDEFINED; - MPIR_Group_builtin[0].idx_of_first_lpid = -1; - MPIR_Group_builtin[0].lrank_to_lpid = NULL; + MPIR_Group_builtin[0].session_ptr = NULL; + memset(&MPIR_Group_builtin[0].pmap, 0, sizeof(struct MPIR_Pmap)); - /* TODO hook for device here? */ return mpi_errno; } @@ -44,7 +43,9 @@ int MPIR_Group_release(MPIR_Group * group_ptr) MPIR_Group_release_ref(group_ptr, &inuse); if (!inuse) { /* Only if refcount is 0 do we actually free. */ - MPL_free(group_ptr->lrank_to_lpid); + if (group_ptr->pmap.use_map) { + MPL_free(group_ptr->pmap.u.map); + } if (group_ptr->session_ptr != NULL) { /* Release session */ MPIR_Session_release(group_ptr->session_ptr); @@ -73,24 +74,14 @@ int MPIR_Group_create(int nproc, MPIR_Group ** new_group_ptr) } /* --END ERROR HANDLING-- */ MPIR_Object_set_ref(*new_group_ptr, 1); - (*new_group_ptr)->lrank_to_lpid = - (MPII_Group_pmap_t *) MPL_calloc(nproc, sizeof(MPII_Group_pmap_t), MPL_MEM_GROUP); - /* --BEGIN ERROR HANDLING-- */ - if (!(*new_group_ptr)->lrank_to_lpid) { - MPIR_Handle_obj_free(&MPIR_Group_mem, *new_group_ptr); - *new_group_ptr = NULL; - MPIR_CHKMEM_SETERR(mpi_errno, nproc * sizeof(MPII_Group_pmap_t), "newgroup->lrank_to_lpid"); - return mpi_errno; - } - /* --END ERROR HANDLING-- */ - (*new_group_ptr)->size = nproc; - /* Make sure that there is no question that the list of ranks sorted - * by pids is marked as uninitialized */ - (*new_group_ptr)->idx_of_first_lpid = -1; - - (*new_group_ptr)->is_local_dense_monotonic = FALSE; + /* initialize fields */ + (*new_group_ptr)->size = nproc; + (*new_group_ptr)->rank = MPI_UNDEFINED; (*new_group_ptr)->session_ptr = NULL; + memset(&(*new_group_ptr)->pmap, 0, sizeof(struct MPIR_Pmap)); + (*new_group_ptr)->pmap.size = nproc; + return mpi_errno; } @@ -98,24 +89,28 @@ int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_L MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - MPIR_Group *newgrp; - - MPIR_Assert(size > 0); - mpi_errno = MPIR_Group_create(size, &newgrp); - MPIR_ERR_CHECK(mpi_errno); + if (size == 0) { + /* See 5.3.2, Group Constructors. For many group routines, + * the standard explicitly says to return MPI_GROUP_EMPTY; + * for others it is implied */ + MPL_free(map); + *new_group_ptr = MPIR_Group_empty; + goto fn_exit; + } else { + mpi_errno = MPIR_Group_create(size, new_group_ptr); + MPIR_ERR_CHECK(mpi_errno); - newgrp->rank = rank; - MPIR_Group_set_session_ptr(newgrp, session_ptr); + (*new_group_ptr)->rank = rank; + MPIR_Group_set_session_ptr((*new_group_ptr), session_ptr); - for (int i = 0; i < size; i++) { - newgrp->lrank_to_lpid[i].lpid = map[i]; + struct MPIR_Pmap *pmap = &(*new_group_ptr)->pmap; + /* TODO check whether it's strided, or resize map */ + pmap->use_map = true; + pmap->u.map = map; } - *new_group_ptr = newgrp; - fn_exit: - MPL_free(map); return mpi_errno; fn_fail: goto fn_exit; @@ -126,28 +121,27 @@ int MPIR_Group_create_stride(int size, int rank, MPIR_Session * session_ptr, MPIR_Group ** new_group_ptr) { int mpi_errno = MPI_SUCCESS; - MPIR_Group *newgrp; - MPIR_Assert(size > 0); - - mpi_errno = MPIR_Group_create(size, &newgrp); - MPIR_ERR_CHECK(mpi_errno); + if (size == 0) { + /* See 5.3.2, Group Constructors. For many group routines, + * the standard explicitly says to return MPI_GROUP_EMPTY; + * for others it is implied */ + *new_group_ptr = MPIR_Group_empty; + goto fn_exit; + } else { + mpi_errno = MPIR_Group_create(size, new_group_ptr); + MPIR_ERR_CHECK(mpi_errno); - newgrp->rank = rank; - MPIR_Group_set_session_ptr(newgrp, session_ptr); + (*new_group_ptr)->rank = rank; + MPIR_Group_set_session_ptr((*new_group_ptr), session_ptr); - MPIR_Lpid lpid = offset; - int i = 0; - while (i < size) { - for (int j = 0; j < blocksize; j++) { - newgrp->lrank_to_lpid[i + j].lpid = lpid + j; - } - i += blocksize; - lpid += stride; + struct MPIR_Pmap *pmap = &(*new_group_ptr)->pmap; + pmap->use_map = false; + pmap->u.stride.offset = offset; + pmap->u.stride.stride = stride; + pmap->u.stride.blocksize = blocksize; } - *new_group_ptr = newgrp; - fn_exit: return mpi_errno; fn_fail: @@ -156,133 +150,7 @@ int MPIR_Group_create_stride(int size, int rank, MPIR_Session * session_ptr, MPIR_Lpid MPIR_Group_lookup(MPIR_Group * group, int rank) { - return group->lrank_to_lpid[rank].lpid; -} - -/* - * return value is the first index in the list - * - * This "sorts" an lpid array by lpid value, using a simple merge sort - * algorithm. - * - * In actuality, it does not reorder the elements of maparray (these must remain - * in group rank order). Instead it builds the traversal order (in increasing - * lpid order) through the maparray given by the "next_lpid" fields. - */ -static int mergesort_lpidarray(MPII_Group_pmap_t maparray[], int n) -{ - int idx1, idx2, first_idx, cur_idx, next_lpid, idx2_offset; - - if (n == 2) { - if (maparray[0].lpid > maparray[1].lpid) { - first_idx = 1; - maparray[0].next_lpid = -1; - maparray[1].next_lpid = 0; - } else { - first_idx = 0; - maparray[0].next_lpid = 1; - maparray[1].next_lpid = -1; - } - return first_idx; - } - if (n == 1) { - maparray[0].next_lpid = -1; - return 0; - } - if (n == 0) - return -1; - - /* Sort each half */ - idx2_offset = n / 2; - idx1 = mergesort_lpidarray(maparray, n / 2); - idx2 = mergesort_lpidarray(maparray + idx2_offset, n - n / 2) + idx2_offset; - /* merge the results */ - /* There are three lists: - * first_idx - points to the HEAD of the sorted, merged list - * cur_idx - points to the LAST element of the sorted, merged list - * idx1 - points to the HEAD of one sorted list - * idx2 - points to the HEAD of the other sorted list - * - * We first identify the head element of the sorted list. We then - * take elements from the remaining lists. When one list is empty, - * we add the other list to the end of sorted list. - * - * The last wrinkle is that the next_lpid fields in maparray[idx2] - * are relative to n/2, not 0 (that is, a next_lpid of 1 is - * really 1 + n/2, relative to the beginning of maparray). - */ - /* Find the head element */ - if (maparray[idx1].lpid > maparray[idx2].lpid) { - first_idx = idx2; - idx2 = maparray[idx2].next_lpid + idx2_offset; - } else { - first_idx = idx1; - idx1 = maparray[idx1].next_lpid; - } - - /* Merge the lists until one is empty */ - cur_idx = first_idx; - while (idx1 >= 0 && idx2 >= 0) { - if (maparray[idx1].lpid > maparray[idx2].lpid) { - next_lpid = maparray[idx2].next_lpid; - if (next_lpid >= 0) - next_lpid += idx2_offset; - maparray[cur_idx].next_lpid = idx2; - cur_idx = idx2; - idx2 = next_lpid; - } else { - next_lpid = maparray[idx1].next_lpid; - maparray[cur_idx].next_lpid = idx1; - cur_idx = idx1; - idx1 = next_lpid; - } - } - /* Add whichever list remains */ - if (idx1 >= 0) { - maparray[cur_idx].next_lpid = idx1; - } else { - maparray[cur_idx].next_lpid = idx2; - /* Convert the rest of these next_lpid values to be - * relative to the beginning of maparray */ - while (idx2 >= 0) { - next_lpid = maparray[idx2].next_lpid; - if (next_lpid >= 0) { - next_lpid += idx2_offset; - maparray[idx2].next_lpid = next_lpid; - } - idx2 = next_lpid; - } - } - - return first_idx; -} - -/* - * Create a list of the lpids, in lpid order. - * - * Called by group_compare, group_translate_ranks, group_union - * - * In the case of a single main thread lock, the lock must - * be held on entry to this routine. This forces some of the routines - * noted above to hold the SINGLE_CS; which would otherwise not be required. - */ -void MPII_Group_setup_lpid_list(MPIR_Group * group_ptr) -{ - if (group_ptr->idx_of_first_lpid == -1) { - group_ptr->idx_of_first_lpid = - mergesort_lpidarray(group_ptr->lrank_to_lpid, group_ptr->size); - } -} - -void MPIR_Group_setup_lpid_pairs(MPIR_Group * group_ptr1, MPIR_Group * group_ptr2) -{ - /* If the lpid list hasn't been created, do it now */ - if (group_ptr1->idx_of_first_lpid < 0) { - MPII_Group_setup_lpid_list(group_ptr1); - } - if (group_ptr2->idx_of_first_lpid < 0) { - MPII_Group_setup_lpid_list(group_ptr2); - } + return MPIR_Pmap_rank_to_lpid(&group->pmap, rank); } #ifdef HAVE_ERROR_CHECKING @@ -420,54 +288,40 @@ int MPIR_Group_check_valid_ranges(MPIR_Group * group_ptr, int ranges[][3], int n int MPIR_Group_check_subset(MPIR_Group * group_ptr, MPIR_Comm * comm_ptr) { int mpi_errno = MPI_SUCCESS; - int g1_idx, g2_idx, l1_pid, l2_pid, i; - MPII_Group_pmap_t *vmap = 0; + int vsize = comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM ? comm_ptr->local_size : comm_ptr->remote_size; - MPIR_CHKLMEM_DECL(1); - - MPIR_Assert(group_ptr != NULL); - - MPIR_CHKLMEM_MALLOC(vmap, MPII_Group_pmap_t *, - vsize * sizeof(MPII_Group_pmap_t), mpi_errno, "", MPL_MEM_GROUP); /* Initialize the vmap */ - for (i = 0; i < vsize; i++) { - MPID_Comm_get_lpid(comm_ptr, i, &vmap[i].lpid, FALSE); - vmap[i].next_lpid = 0; + MPIR_Lpid *vmap = MPL_malloc(vsize * sizeof(MPIR_Lpid), MPL_MEM_GROUP); + for (int i = 0; i < vsize; i++) { + /* FIXME: MPID_Comm_get_lpid to be removed */ + uint64_t dev_lpid; + MPID_Comm_get_lpid(comm_ptr, i, &dev_lpid, FALSE); + MPIR_Assert((dev_lpid >> 32) == 0); + vmap[i] = dev_lpid; } - MPII_Group_setup_lpid_list(group_ptr); - g1_idx = group_ptr->idx_of_first_lpid; - g2_idx = mergesort_lpidarray(vmap, vsize); - MPL_DBG_MSG_FMT(MPIR_DBG_COMM, VERBOSE, (MPL_DBG_FDEST, - "initial indices: %d %d\n", g1_idx, g2_idx)); - while (g1_idx >= 0 && g2_idx >= 0) { - l1_pid = group_ptr->lrank_to_lpid[g1_idx].lpid; - l2_pid = vmap[g2_idx].lpid; - MPL_DBG_MSG_FMT(MPIR_DBG_COMM, VERBOSE, (MPL_DBG_FDEST, - "Lpids are %d, %d\n", l1_pid, l2_pid)); - if (l1_pid < l2_pid) { - /* If we have to advance g1, we didn't find a match, so - * that's an error. */ - break; - } else if (l1_pid > l2_pid) { - g2_idx = vmap[g2_idx].next_lpid; - } else { - /* Equal */ - g1_idx = group_ptr->lrank_to_lpid[g1_idx].next_lpid; - g2_idx = vmap[g2_idx].next_lpid; + for (int rank = 0; rank < group_ptr->size; rank++) { + MPIR_Lpid lpid = MPIR_Pmap_rank_to_lpid(&group_ptr->pmap, rank); + bool found = false; + for (int i = 0; i < vsize; i++) { + if (vmap[i] == lpid) { + found = true; + break; + } + } + if (!found) { + MPIR_ERR_SET1(mpi_errno, MPI_ERR_GROUP, "**groupnotincomm", + "**groupnotincomm %d", rank); + goto fn_fail; } - MPL_DBG_MSG_FMT(MPIR_DBG_COMM, VERBOSE, (MPL_DBG_FDEST, - "g1 = %d, g2 = %d\n", g1_idx, g2_idx)); - } - - if (g1_idx >= 0) { - MPIR_ERR_SET1(mpi_errno, MPI_ERR_GROUP, "**groupnotincomm", "**groupnotincomm %d", g1_idx); } - fn_fail: - MPIR_CHKLMEM_FREEALL(); + fn_exit: + MPL_free(vmap); return mpi_errno; + fn_fail: + goto fn_exit; } #endif /* HAVE_ERROR_CHECKING */