Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coll: add coll_group to collective interfaces #7103

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9340ed5
comm: store num_local and num_external in MPIR_Comm
hzhou Aug 13, 2024
417c2c5
comm: remove node_count
hzhou Aug 13, 2024
a24834b
comm/csel: remove reference to subcomms in csel prune_tree
hzhou Aug 19, 2024
132c188
coll: remove coll.pof2 field
hzhou Aug 21, 2024
c941fbe
comm: add MPIR_Subgroup
hzhou Aug 11, 2024
d9bf21b
coll: add macros to get rank/size with coll_group
hzhou Aug 22, 2024
105c7e0
coll: add coll_group argument to coll interfaces
hzhou Aug 16, 2024
3501b57
continue: add coll_group to collective interfaces
hzhou Aug 16, 2024
1c40a12
coll: add coll_group argument to MPIC/sched/TSP routines
hzhou Aug 16, 2024
b410435
continue: add coll_group in MPIC/sched/TSP routines
hzhou Aug 16, 2024
ca5917d
ch4: fallback to mpir if coll_group is non-zero
hzhou Aug 19, 2024
a683ec6
coll: add coll_group to csel signature
hzhou Aug 18, 2024
287aeb4
coll: threadcomm coll to use MPIR_SUBGROUP_THREADCOMM
hzhou Aug 18, 2024
f7f6ae1
coll: check coll_group in MPIR_Comm_is_parent_comm
hzhou Aug 18, 2024
75a0e67
coll: make non-compositional algorithm coll_group aware
hzhou Aug 18, 2024
bdd4532
coll: modify bcast_intra_smp to use subgroups
hzhou Aug 18, 2024
066e586
coll: avoid extra intra bcast in bcast_intra_smp
hzhou Aug 18, 2024
e374920
coll: modify smp algorithms to use MPIR_Subgroup
hzhou Aug 19, 2024
f98f2e7
mpir: replace subcomm usage with subgroups
hzhou Aug 20, 2024
1a16de9
coll/csel: omit prunning on communicator size
hzhou Aug 22, 2024
718d868
coll: refactor caching tree in the comm struct
hzhou Aug 22, 2024
1d79beb
coll: add coll_group to treealgo routines
hzhou Aug 22, 2024
b8c1f54
coll: add nogroup restriction to certain algorithms
hzhou Aug 23, 2024
ef79319
coll: check coll_group in MPIR_Sched_next_tag
hzhou Aug 24, 2024
138c760
coll: refactor barrier_intra_k_dissemination
hzhou Aug 24, 2024
7001a71
coll/allreduce: remove a leftover empty branch
hzhou Sep 12, 2024
37fc447
coll: patch allreduce_intra_recursive_multiplying.c
hzhou Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 61 additions & 63 deletions src/mpi/coll/allreduce/allreduce_intra_recursive_multiplying.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,53 @@
*
* This algorithm is a generalization of the recursive doubling algorithm,
* and it has three stages. In the first stage, ranks above the nearest
* power of k less than or equal to comm_size collapse their data to the
* power of k less than or equal to comm_size collapse their data to the
* lower ranks. The main stage proceeds with power-of-k ranks. In the main
* stage, ranks exchange data within groups of size k in rounds with
* increasing distance (k, k^2, ...). Lastly, those in the main stage
* disperse the result back to the excluded ranks. Setting k according
* to the network hierarchy (e.g., the number of NICs in a node) can
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how these escaped the whitespace checker in the original PR.

* stage, ranks exchange data within groups of size k in rounds with
* increasing distance (k, k^2, ...). Lastly, those in the main stage
* disperse the result back to the excluded ranks. Setting k according
* to the network hierarchy (e.g., the number of NICs in a node) can
* improve performance.
*/


int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf,
void *recvbuf,
MPI_Aint count,
MPI_Datatype datatype,
MPI_Op op,
MPIR_Comm * comm_ptr,
const int k,
MPIR_Errflag_t errflag)
void *recvbuf,
MPI_Aint count,
MPI_Datatype datatype,
MPI_Op op,
MPIR_Comm * comm_ptr,
int coll_group, const int k, MPIR_Errflag_t errflag)
{
int mpi_errno = MPI_SUCCESS;
/* Ensure the op is commutative */

int comm_size, rank, virt_rank;
comm_size = comm_ptr->local_size;
rank = comm_ptr->rank;
MPIR_COLL_RANK_SIZE(comm_ptr, coll_group, rank, comm_size);
virt_rank = rank;

/* get nearest power-of-two less than or equal to comm_size */
int power = (int) (log(comm_size) / log(k));
int pofk = (int) lround(pow(k, power));

MPIR_CHKLMEM_DECL(2);
void *tmp_buf;

/*Allocate for nb requests*/
/*Allocate for nb requests */
MPIR_Request **reqs;
int num_reqs = 0;
MPIR_CHKLMEM_MALLOC(reqs, MPIR_Request **, (2 * (k - 1) * sizeof(MPIR_Request *)), mpi_errno,
"reqs", MPL_MEM_BUFFER);
"reqs", MPL_MEM_BUFFER);

/* need to allocate temporary buffer to store incoming data */
MPI_Aint true_extent, true_lb, extent;
MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
MPIR_Datatype_get_extent_macro(datatype, extent);
MPI_Aint single_node_data_size = extent * count - (extent - true_extent);

MPIR_CHKLMEM_MALLOC(tmp_buf, void *, (k - 1) * count * single_node_data_size, mpi_errno,
"temporary buffer", MPL_MEM_BUFFER);

/* adjust for potential negative lower bound in datatype */
tmp_buf = (void *) ((char *) tmp_buf - true_lb);

Expand All @@ -82,34 +80,33 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf,
int pre_dst = rank % pofk;
/* This is follower so send data */
mpi_errno = MPIC_Send(recvbuf, count, datatype,
pre_dst, MPIR_ALLREDUCE_TAG, comm_ptr, errflag);
pre_dst, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, errflag);
MPIR_ERR_CHECK(mpi_errno);
/* Set virtual rank so this rank is not used in main stage */
virt_rank = -1;
} else {
/* Receive data from all those greater than pofk */
for (int pre_src = (rank % pofk) + pofk; pre_src < comm_size; pre_src += pofk) {
mpi_errno = MPIC_Irecv(((char *)tmp_buf) + num_reqs * count * extent, count,
datatype, pre_src, MPIR_ALLREDUCE_TAG, comm_ptr,
&reqs[num_reqs]);
mpi_errno = MPIC_Irecv(((char *) tmp_buf) + num_reqs * count * extent, count,
datatype, pre_src, MPIR_ALLREDUCE_TAG, comm_ptr, coll_group,
&reqs[num_reqs]);
MPIR_ERR_CHECK(mpi_errno);
num_reqs++;
}
/* Wait for asynchronous operations to complete */

/* Wait for asynchronous operations to complete */
MPIC_Waitall(num_reqs, reqs, MPI_STATUSES_IGNORE);

/* Reduce locally */
for(int i = 0; i < num_reqs; i++) {
if(i == (num_reqs - 1)) {
mpi_errno = MPIR_Reduce_local(((char *)tmp_buf) + i * count * extent,
recvbuf, count, datatype, op);
for (int i = 0; i < num_reqs; i++) {
if (i == (num_reqs - 1)) {
mpi_errno = MPIR_Reduce_local(((char *) tmp_buf) + i * count * extent,
recvbuf, count, datatype, op);
MPIR_ERR_CHECK(mpi_errno);
}
else {
mpi_errno = MPIR_Reduce_local(((char *)tmp_buf) + i * count * extent,
((char *)tmp_buf) + (i + 1) * count * extent,
count, datatype, op);
} else {
mpi_errno = MPIR_Reduce_local(((char *) tmp_buf) + i * count * extent,
((char *) tmp_buf) + (i + 1) * count * extent,
count, datatype, op);
MPIR_ERR_CHECK(mpi_errno);
}
}
Expand All @@ -119,60 +116,61 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf,
/*MAIN-STAGE: Ranks exchange data with groups size k over increasing
* distances */
if (virt_rank != -1) {
/*Do exchanges*/
/*Do exchanges */
num_reqs = 0;
int exchanges = 0;
int distance = 1;
int next_distance = k;
while (distance < pofk) {
/* Asynchronous sends */
int starting_rank = rank/next_distance * next_distance;

int starting_rank = rank / next_distance * next_distance;
int rank_offset = starting_rank + rank % distance;
for(int dst = rank_offset; dst < starting_rank + next_distance; dst += distance) {
if(dst != rank) {
mpi_errno = MPIC_Isend(recvbuf, count, datatype, dst, MPIR_ALLREDUCE_TAG,
comm_ptr, &reqs[num_reqs++], errflag);
for (int dst = rank_offset; dst < starting_rank + next_distance; dst += distance) {
if (dst != rank) {
mpi_errno = MPIC_Isend(recvbuf, count, datatype, dst, MPIR_ALLREDUCE_TAG,
comm_ptr, coll_group, &reqs[num_reqs++], errflag);
MPIR_ERR_CHECK(mpi_errno);
mpi_errno = MPIC_Irecv(((char *)tmp_buf) + exchanges * count * extent,
count, datatype, dst, MPIR_ALLREDUCE_TAG, comm_ptr,
&reqs[num_reqs++]);
mpi_errno = MPIC_Irecv(((char *) tmp_buf) + exchanges * count * extent,
count, datatype, dst, MPIR_ALLREDUCE_TAG,
comm_ptr, coll_group, &reqs[num_reqs++]);
MPIR_ERR_CHECK(mpi_errno);
exchanges++;
}
}

/* Wait for asynchronous operations to complete */
/* Wait for asynchronous operations to complete */
MPIC_Waitall(num_reqs, reqs, MPI_STATUSES_IGNORE);
num_reqs = 0;
exchanges = 0;

/* Perform reduction on the received values */
int recvbuf_last = 0;
for(int dst = rank_offset; dst < starting_rank + next_distance - distance; dst += distance) {
void *dst_buf = ((char *)tmp_buf) + exchanges * count * extent;
if(dst == rank - distance) {
for (int dst = rank_offset; dst < starting_rank + next_distance - distance;
dst += distance) {
void *dst_buf = ((char *) tmp_buf) + exchanges * count * extent;
if (dst == rank - distance) {
mpi_errno = MPIR_Reduce_local(dst_buf, recvbuf, count, datatype, op);
MPIR_ERR_CHECK(mpi_errno);
recvbuf_last = 1;
exchanges++;
}
else if(dst == rank){
} else if (dst == rank) {
mpi_errno = MPIR_Reduce_local(recvbuf, dst_buf, count, datatype, op);
MPIR_ERR_CHECK(mpi_errno);
recvbuf_last = 0;
}
else {
mpi_errno = MPIR_Reduce_local(dst_buf, (char *)dst_buf + count * extent, count, datatype, op);
} else {
mpi_errno =
MPIR_Reduce_local(dst_buf, (char *) dst_buf + count * extent, count,
datatype, op);
MPIR_ERR_CHECK(mpi_errno);
recvbuf_last = 0;
exchanges++;
}
}
if(!recvbuf_last) {
mpi_errno = MPIR_Localcopy((char *)tmp_buf + exchanges * count * extent,
count, datatype, recvbuf, count, datatype);

if (!recvbuf_last) {
mpi_errno = MPIR_Localcopy((char *) tmp_buf + exchanges * count * extent,
count, datatype, recvbuf, count, datatype);
MPIR_ERR_CHECK(mpi_errno);
}

Expand All @@ -183,23 +181,23 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf,
}

/* POST-STAGE: Send result to ranks outside main algorithm */
if(pofk < comm_size) {
if (pofk < comm_size) {
num_reqs = 0;
if(rank >= pofk) {
if (rank >= pofk) {
int post_src = rank % pofk;
/* This process is outside the core algorithm, so receive data */
mpi_errno = MPIC_Recv(recvbuf, count,
datatype, post_src,
MPIR_ALLREDUCE_TAG, comm_ptr, MPI_STATUS_IGNORE);
MPIR_ALLREDUCE_TAG, comm_ptr, coll_group, MPI_STATUS_IGNORE);
MPIR_ERR_CHECK(mpi_errno);
} else {
/* This is process is in the algorithm, so send data */
for (int post_dst = (rank % pofk) + pofk; post_dst < comm_size; post_dst += pofk) {
mpi_errno = MPIC_Isend(recvbuf, count, datatype, post_dst, MPIR_ALLREDUCE_TAG, comm_ptr,
&reqs[num_reqs++], errflag);
mpi_errno = MPIC_Isend(recvbuf, count, datatype, post_dst, MPIR_ALLREDUCE_TAG,
comm_ptr, coll_group, &reqs[num_reqs++], errflag);
MPIR_ERR_CHECK(mpi_errno);
}

MPIC_Waitall(num_reqs, reqs, MPI_STATUSES_IGNORE);
}
}
Expand Down