Skip to content

Commit

Permalink
coll: add collattr to function parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhou committed Jan 10, 2023
1 parent bb9389e commit 5246c1a
Show file tree
Hide file tree
Showing 194 changed files with 1,316 additions and 1,015 deletions.
13 changes: 8 additions & 5 deletions src/include/mpir_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,19 @@ int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status statuses[]);
int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op);

int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int collattr);

/* TSP auto */
int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
MPI_Datatype datatype, MPI_Op op,
MPIR_Comm * comm, MPIR_TSP_sched_t sched);
MPIR_Comm * comm, int collattr,
MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype,
int root, MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sched);
int root, MPIR_Comm * comm_ptr, int collattr,
MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int collattr, MPIR_TSP_sched_t sched);
int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
MPI_Datatype datatype, MPI_Op op, int root,
MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched);
MPIR_Comm * comm_ptr, int collattr,
MPIR_TSP_sched_t sched);
#endif /* MPIR_COLL_H_INCLUDED */
4 changes: 2 additions & 2 deletions src/mpi/coll/algorithms/recexchalgo/recexchalgo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ int MPIR_TSP_Iallgatherv_sched_intra_recexch_step2(int step1_sendto, int step2_n
size_t recv_extent, const MPI_Aint * recvcounts,
const MPI_Aint * displs, MPI_Datatype recvtype,
int is_dist_halving, MPIR_Comm * comm,
MPIR_TSP_sched_t sched);
int collattr, MPIR_TSP_sched_t sched);
int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void *tmp_recvbuf,
const MPI_Aint * recvcounts,
MPI_Aint * displs, MPI_Datatype datatype,
Expand All @@ -36,6 +36,6 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch_step2(void *tmp_results, void *
int step2_nphases, int **step2_nbrs,
int rank, int nranks, int sink_id,
int is_out_vtcs, int *reduce_id_,
MPIR_TSP_sched_t sched);
int collattr, MPIR_TSP_sched_t sched);

#endif /* RECEXCHALGO_H_INCLUDED */
7 changes: 3 additions & 4 deletions src/mpi/coll/allgather/allgather_allcomm_nb.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int collattr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Request *req_ptr = NULL;

/* just call the nonblocking version and wait on it */
mpi_errno =
MPIR_Iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm_ptr,
&req_ptr);
mpi_errno = MPIR_Iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype,
comm_ptr, collattr, &req_ptr);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIC_Wait(req_ptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount,
MPI_Datatype sendtype, void *recvbuf,
MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int collattr)
{
int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
MPI_Aint sendtype_sz;
void *tmp_buf = NULL;
MPIR_Comm *newcomm_ptr = NULL;
Expand Down Expand Up @@ -48,7 +49,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint

if (sendcount != 0) {
mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz,
MPI_BYTE, 0, newcomm_ptr, errflag);
MPI_BYTE, 0, newcomm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

Expand All @@ -59,31 +60,31 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPI_BYTE, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

/* receive bcast from right */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}
} else {
/* receive bcast from left */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

/* bcast to left */
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPI_BYTE, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/mpi/coll/allgather/allgather_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
MPI_Datatype sendtype,
void *recvbuf,
MPI_Aint recvcount,
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int collattr)
{
int comm_size, rank;
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
MPI_Aint recvtype_extent, recvtype_sz;
int pof2, src, rem;
void *tmp_buf = NULL;
Expand Down Expand Up @@ -68,7 +69,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
MPIR_ALLGATHER_TAG,
((char *) tmp_buf + curr_cnt * recvtype_sz),
curr_cnt * recvtype_sz, MPI_BYTE,
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag);
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE,
collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
curr_cnt *= 2;
pof2 *= 2;
Expand All @@ -85,7 +87,8 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
dst, MPIR_ALLGATHER_TAG,
((char *) tmp_buf + curr_cnt * recvtype_sz),
rem * recvcount * recvtype_sz, MPI_BYTE,
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag);
src, MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE,
collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

Expand Down
7 changes: 4 additions & 3 deletions src/mpi/coll/allgather/allgather_intra_k_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ int
MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,
MPI_Datatype sendtype, void *recvbuf,
MPI_Aint recvcount, MPI_Datatype recvtype, MPIR_Comm * comm, int k,
MPIR_Errflag_t errflag)
int collattr)
{
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
int i, j;
int nphases = 0;
int src, dst, p_of_k = 0; /* Largest power of k that is smaller than 'size' */
Expand Down Expand Up @@ -142,7 +143,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,
/* Receive at the exact location. */
mpi_errno = MPIC_Irecv((char *) tmp_recvbuf + j * recvcount * delta * recvtype_extent,
count, recvtype, src, MPIR_ALLGATHER_TAG, comm,
&reqs[num_reqs++]);
collattr, &reqs[num_reqs++]);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);

MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST,
Expand All @@ -154,7 +155,7 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,
/* Send from the start of recv till `count` amount of data. */
mpi_errno =
MPIC_Isend(tmp_recvbuf, count, recvtype, dst, MPIR_ALLGATHER_TAG, comm,
&reqs[num_reqs++], errflag);
&reqs[num_reqs++], collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);

MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST,
Expand Down
11 changes: 7 additions & 4 deletions src/mpi/coll/allgather/allgather_intra_recursive_doubling.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf,
void *recvbuf,
MPI_Aint recvcount,
MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int collattr)
{
int comm_size, rank;
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
MPI_Aint recvtype_extent;
int j, i;
MPI_Aint curr_cnt, last_recv_cnt = 0;
Expand Down Expand Up @@ -82,7 +83,7 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf,
((char *) recvbuf + recv_offset),
(comm_size - dst_tree_root) * recvcount,
recvtype, dst,
MPIR_ALLGATHER_TAG, comm_ptr, &status, errflag);
MPIR_ALLGATHER_TAG, comm_ptr, &status, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
if (mpi_errno) {
last_recv_cnt = 0;
Expand Down Expand Up @@ -140,7 +141,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf,
&& (dst >= tree_root + nprocs_completed)) {
mpi_errno = MPIC_Send(((char *) recvbuf + offset),
last_recv_cnt,
recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, errflag);
recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr,
collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}
/* recv only if this proc. doesn't have data and sender
Expand All @@ -150,7 +152,8 @@ int MPIR_Allgather_intra_recursive_doubling(const void *sendbuf,
(rank >= tree_root + nprocs_completed)) {
mpi_errno = MPIC_Recv(((char *) recvbuf + offset),
(comm_size - (my_tree_root + mask)) * recvcount,
recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, &status);
recvtype, dst, MPIR_ALLGATHER_TAG, comm_ptr, collattr,
&status);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
/* nprocs_completed is also equal to the
* no. of processes whose data we don't have */
Expand Down
6 changes: 4 additions & 2 deletions src/mpi/coll/allgather/allgather_intra_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ int MPIR_Allgather_intra_ring(const void *sendbuf,
MPI_Datatype sendtype,
void *recvbuf,
MPI_Aint recvcount,
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int collattr)
{
int comm_size, rank;
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
MPI_Aint recvtype_extent;
int j, i;
int left, right, jnext;
Expand Down Expand Up @@ -64,7 +65,8 @@ int MPIR_Allgather_intra_ring(const void *sendbuf,
((char *) recvbuf +
jnext * recvcount * recvtype_extent),
recvcount, recvtype, left,
MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag);
MPIR_ALLGATHER_TAG, comm_ptr, MPI_STATUS_IGNORE,
collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
j = jnext;
jnext = (comm_size + jnext - 1) % comm_size;
Expand Down
4 changes: 2 additions & 2 deletions src/mpi/coll/allgatherv/allgatherv_allcomm_nb.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

int MPIR_Allgatherv_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
void *recvbuf, const MPI_Aint * recvcounts, const MPI_Aint * displs,
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int collattr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Request *req_ptr = NULL;

/* just call the nonblocking version and wait on it */
mpi_errno =
MPIR_Iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype,
comm_ptr, &req_ptr);
comm_ptr, collattr, &req_ptr);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIC_Wait(req_ptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain
MPI_Datatype sendtype, void *recvbuf,
const MPI_Aint * recvcounts, const MPI_Aint
* displs, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int collattr)
{
int remote_size, mpi_errno, root, rank;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
MPIR_Comm *newcomm_ptr = NULL;
MPI_Datatype newtype = MPI_DATATYPE_NULL;

Expand All @@ -35,23 +36,23 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain
/* gatherv from right group */
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf,
recvcounts, displs, recvtype, root, comm_ptr, errflag);
recvcounts, displs, recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
/* gatherv to right group */
root = 0;
mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf,
recvcounts, displs, recvtype, root, comm_ptr, errflag);
recvcounts, displs, recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
} else {
/* gatherv to left group */
root = 0;
mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf,
recvcounts, displs, recvtype, root, comm_ptr, errflag);
recvcounts, displs, recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
/* gatherv from left group */
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Gatherv(sendbuf, sendcount, sendtype, recvbuf,
recvcounts, displs, recvtype, root, comm_ptr, errflag);
recvcounts, displs, recvtype, root, comm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

Expand All @@ -72,7 +73,7 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain
mpi_errno = MPIR_Type_commit_impl(&newtype);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIR_Bcast_allcomm_auto(recvbuf, 1, newtype, 0, newcomm_ptr, errflag);
mpi_errno = MPIR_Bcast_allcomm_auto(recvbuf, 1, newtype, 0, newcomm_ptr, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);

MPIR_Type_free_impl(&newtype);
Expand Down
9 changes: 5 additions & 4 deletions src/mpi/coll/allgatherv/allgatherv_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf,
void *recvbuf,
const MPI_Aint * recvcounts,
const MPI_Aint * displs,
MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPI_Datatype recvtype, MPIR_Comm * comm_ptr, int collattr)
{
int comm_size, rank, j, i;
int mpi_errno = MPI_SUCCESS;
int mpi_errno_ret = MPI_SUCCESS;
int errflag = 0;
MPI_Status status;
MPI_Aint recvtype_extent, recvtype_sz;
int pof2, src, rem, send_cnt;
Expand Down Expand Up @@ -79,7 +79,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf,
MPIR_ALLGATHERV_TAG,
((char *) tmp_buf + curr_cnt * recvtype_sz),
(total_count - curr_cnt) * recvtype_sz, MPI_BYTE,
src, MPIR_ALLGATHERV_TAG, comm_ptr, &status, errflag);
src, MPIR_ALLGATHERV_TAG, comm_ptr, &status, collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
if (mpi_errno) {
recv_cnt = 0;
Expand All @@ -106,7 +106,8 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf,
dst, MPIR_ALLGATHERV_TAG,
((char *) tmp_buf + curr_cnt * recvtype_sz),
(total_count - curr_cnt) * recvtype_sz, MPI_BYTE,
src, MPIR_ALLGATHERV_TAG, comm_ptr, MPI_STATUS_IGNORE, errflag);
src, MPIR_ALLGATHERV_TAG, comm_ptr, MPI_STATUS_IGNORE,
collattr | errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
}

Expand Down
Loading

0 comments on commit 5246c1a

Please sign in to comment.