From 53704b2e7a42e3133c17edbae75c4a05dc060084 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 7 Nov 2024 22:58:58 -0600 Subject: [PATCH] ch4: use am_tag_{send,recv} in MPIDIG put In the MPIDIG_PUT_DT_REQ protocol, use am_tag_{send,recv} when available. --- src/mpid/ch4/src/ch4_types.h | 1 + src/mpid/ch4/src/mpidig.h | 1 + src/mpid/ch4/src/mpidig_init.c | 1 + src/mpid/ch4/src/mpidig_rma_callbacks.c | 55 +++++++++++++++++++++---- src/mpid/ch4/src/mpidig_rma_callbacks.h | 1 + 5 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/mpid/ch4/src/ch4_types.h b/src/mpid/ch4/src/ch4_types.h index 910e6ebaf4c..914b89c11a6 100644 --- a/src/mpid/ch4/src/ch4_types.h +++ b/src/mpid/ch4/src/ch4_types.h @@ -117,6 +117,7 @@ typedef struct MPIDIG_put_msg_t { typedef struct MPIDIG_put_dt_ack_msg_t { int src_rank; + int am_tag; MPIR_Request *target_preq_ptr; MPIR_Request *origin_preq_ptr; } MPIDIG_put_dt_ack_msg_t; diff --git a/src/mpid/ch4/src/mpidig.h b/src/mpid/ch4/src/mpidig.h index 3b7d934467b..18e846a9e27 100644 --- a/src/mpid/ch4/src/mpidig.h +++ b/src/mpid/ch4/src/mpidig.h @@ -78,6 +78,7 @@ enum { enum { MPIDIG_TAG_RECV_COMPLETE = 0, MPIDIG_TAG_GET_COMPLETE, + MPIDIG_TAG_PUT_COMPLETE, MPIDIG_TAG_RECV_STATIC_MAX }; diff --git a/src/mpid/ch4/src/mpidig_init.c b/src/mpid/ch4/src/mpidig_init.c index c09a0cd8c3e..a3b74cc3b8e 100644 --- a/src/mpid/ch4/src/mpidig_init.c +++ b/src/mpid/ch4/src/mpidig_init.c @@ -159,6 +159,7 @@ int MPIDIG_am_init(void) MPIDIG_am_rndv_reg_cb(MPIDIG_RNDV_GENERIC, &MPIDIG_do_cts); MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_RECV_COMPLETE, &MPIDIG_tag_recv_complete); MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_GET_COMPLETE, &MPIDIG_tag_get_complete); + MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_PUT_COMPLETE, &MPIDIG_tag_put_complete); MPIDIG_am_comm_abort_init(); diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.c b/src/mpid/ch4/src/mpidig_rma_callbacks.c index fabdb70ef36..3ce80bf0ef0 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.c +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.c @@ -988,6 +988,26 @@ static int put_dt_target_cmpl_cb(MPIR_Request * rreq) int local_vci = MPIDIG_REQUEST(rreq, req->local_vci); int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci); + MPIR_Comm *comm = rreq->u.rma.win->comm_ptr; + + bool is_local; +#ifndef MPIDI_CH4_DIRECT_NETMOD + is_local = MPIDI_REQUEST(rreq, is_local); +#else + is_local = 0; +#endif + if (MPIDIG_can_do_tag(is_local)) { + ack_msg.am_tag = MPIDIG_get_next_am_tag(comm); + CH4_CALL(am_tag_recv(ack_msg.src_rank, comm, MPIDIG_TAG_PUT_COMPLETE, ack_msg.am_tag, + MPIDIG_REQUEST(rreq, buffer), + MPIDIG_REQUEST(rreq, count), + MPIDIG_REQUEST(rreq, datatype), + local_vci, remote_vci, rreq), is_local, mpi_errno); + MPIR_ERR_CHECK(mpi_errno); + } else { + ack_msg.am_tag = -1; + } + CH4_CALL(am_send_hdr_reply (rreq->u.rma.win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank), MPIDIG_PUT_DT_ACK, &ack_msg, sizeof(ack_msg), local_vci, remote_vci), @@ -1605,13 +1625,25 @@ int MPIDIG_put_dt_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_s /* origin datatype to be released in MPIDIG_put_data_origin_cb */ MPIDIG_REQUEST(rreq, datatype) = MPIDIG_REQUEST(origin_req, datatype); - CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(origin_req, u.origin.target_rank), - MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg), - MPIDIG_REQUEST(origin_req, buffer), - MPIDIG_REQUEST(origin_req, count), - MPIDIG_REQUEST(origin_req, datatype), - local_vci, remote_vci, rreq), - (attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno); + int target_rank = MPIDIG_REQUEST(origin_req, u.origin.target_rank); + if (msg_hdr->am_tag >= 0) { + CH4_CALL(am_tag_send(target_rank, win->comm_ptr, MPIDIG_PUT_DAT_REQ, + msg_hdr->am_tag, + MPIDIG_REQUEST(origin_req, buffer), + MPIDIG_REQUEST(origin_req, count), + MPIDIG_REQUEST(origin_req, datatype), + local_vci, remote_vci, rreq), + (attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno); + + } else { + CH4_CALL(am_isend_reply(win->comm_ptr, target_rank, + MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg), + MPIDIG_REQUEST(origin_req, buffer), + MPIDIG_REQUEST(origin_req, count), + MPIDIG_REQUEST(origin_req, datatype), + local_vci, remote_vci, rreq), + (attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno); + } MPIR_ERR_CHECK(mpi_errno); if (attr & MPIDIG_AM_ATTR__IS_ASYNC) { @@ -2177,3 +2209,12 @@ int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status) return mpi_errno; } + +int MPIDIG_tag_put_complete(MPIR_Request * req, MPI_Status * status) +{ + int mpi_errno = MPI_SUCCESS; + + mpi_errno = put_target_cmpl_cb(req); + + return mpi_errno; +} diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.h b/src/mpid/ch4/src/mpidig_rma_callbacks.h index ac8d7c6b887..87dc729f83d 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.h +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.h @@ -113,5 +113,6 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req); int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status); +int MPIDIG_tag_put_complete(MPIR_Request * req, MPI_Status * status); #endif /* MPIDIG_RMA_CALLBACKS_H_INCLUDED */