Skip to content

Commit

Permalink
ch4: use am_tag_{send,recv} in MPIDIG put
Browse files Browse the repository at this point in the history
In the MPIDIG_PUT_DT_REQ protocol, use am_tag_{send,recv} when
available.
  • Loading branch information
hzhou committed Nov 8, 2024
1 parent b0a5b5d commit 53704b2
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/mpid/ch4/src/ch4_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ typedef struct MPIDIG_put_msg_t {

typedef struct MPIDIG_put_dt_ack_msg_t {
int src_rank;
int am_tag;
MPIR_Request *target_preq_ptr;
MPIR_Request *origin_preq_ptr;
} MPIDIG_put_dt_ack_msg_t;
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/src/mpidig.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ enum {
enum {
MPIDIG_TAG_RECV_COMPLETE = 0,
MPIDIG_TAG_GET_COMPLETE,
MPIDIG_TAG_PUT_COMPLETE,

MPIDIG_TAG_RECV_STATIC_MAX
};
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/src/mpidig_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ int MPIDIG_am_init(void)
MPIDIG_am_rndv_reg_cb(MPIDIG_RNDV_GENERIC, &MPIDIG_do_cts);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_RECV_COMPLETE, &MPIDIG_tag_recv_complete);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_GET_COMPLETE, &MPIDIG_tag_get_complete);
MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_PUT_COMPLETE, &MPIDIG_tag_put_complete);

MPIDIG_am_comm_abort_init();

Expand Down
55 changes: 48 additions & 7 deletions src/mpid/ch4/src/mpidig_rma_callbacks.c
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,26 @@ static int put_dt_target_cmpl_cb(MPIR_Request * rreq)

int local_vci = MPIDIG_REQUEST(rreq, req->local_vci);
int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci);
MPIR_Comm *comm = rreq->u.rma.win->comm_ptr;

bool is_local;
#ifndef MPIDI_CH4_DIRECT_NETMOD
is_local = MPIDI_REQUEST(rreq, is_local);
#else
is_local = 0;
#endif
if (MPIDIG_can_do_tag(is_local)) {
ack_msg.am_tag = MPIDIG_get_next_am_tag(comm);
CH4_CALL(am_tag_recv(ack_msg.src_rank, comm, MPIDIG_TAG_PUT_COMPLETE, ack_msg.am_tag,
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype),
local_vci, remote_vci, rreq), is_local, mpi_errno);
MPIR_ERR_CHECK(mpi_errno);
} else {
ack_msg.am_tag = -1;
}

CH4_CALL(am_send_hdr_reply
(rreq->u.rma.win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank),
MPIDIG_PUT_DT_ACK, &ack_msg, sizeof(ack_msg), local_vci, remote_vci),
Expand Down Expand Up @@ -1605,13 +1625,25 @@ int MPIDIG_put_dt_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_s
/* origin datatype to be released in MPIDIG_put_data_origin_cb */
MPIDIG_REQUEST(rreq, datatype) = MPIDIG_REQUEST(origin_req, datatype);

CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(origin_req, u.origin.target_rank),
MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg),
MPIDIG_REQUEST(origin_req, buffer),
MPIDIG_REQUEST(origin_req, count),
MPIDIG_REQUEST(origin_req, datatype),
local_vci, remote_vci, rreq),
(attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno);
int target_rank = MPIDIG_REQUEST(origin_req, u.origin.target_rank);
if (msg_hdr->am_tag >= 0) {
CH4_CALL(am_tag_send(target_rank, win->comm_ptr, MPIDIG_PUT_DAT_REQ,
msg_hdr->am_tag,
MPIDIG_REQUEST(origin_req, buffer),
MPIDIG_REQUEST(origin_req, count),
MPIDIG_REQUEST(origin_req, datatype),
local_vci, remote_vci, rreq),
(attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno);

} else {
CH4_CALL(am_isend_reply(win->comm_ptr, target_rank,
MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg),
MPIDIG_REQUEST(origin_req, buffer),
MPIDIG_REQUEST(origin_req, count),
MPIDIG_REQUEST(origin_req, datatype),
local_vci, remote_vci, rreq),
(attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno);
}
MPIR_ERR_CHECK(mpi_errno);

if (attr & MPIDIG_AM_ATTR__IS_ASYNC) {
Expand Down Expand Up @@ -2177,3 +2209,12 @@ int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status)

return mpi_errno;
}

int MPIDIG_tag_put_complete(MPIR_Request * req, MPI_Status * status)
{
int mpi_errno = MPI_SUCCESS;

mpi_errno = put_target_cmpl_cb(req);

return mpi_errno;
}
1 change: 1 addition & 0 deletions src/mpid/ch4/src/mpidig_rma_callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz,
uint32_t attr, MPIR_Request ** req);
int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status);
int MPIDIG_tag_put_complete(MPIR_Request * req, MPI_Status * status);

#endif /* MPIDIG_RMA_CALLBACKS_H_INCLUDED */

0 comments on commit 53704b2

Please sign in to comment.