From e07760fada4a11ccd3a85c958070f375cacff7d8 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 17 Dec 2024 21:54:15 -0600 Subject: [PATCH] ch4: refactor MPID_Comm_connect/accept Simply establish remote_lpid and call MPIR_Intercomm_create. --- src/mpid/ch4/src/ch4_spawn.c | 99 ++++++++---------------------------- 1 file changed, 22 insertions(+), 77 deletions(-) diff --git a/src/mpid/ch4/src/ch4_spawn.c b/src/mpid/ch4/src/ch4_spawn.c index 10241261336..7f51ec5897a 100644 --- a/src/mpid/ch4/src/ch4_spawn.c +++ b/src/mpid/ch4/src/ch4_spawn.c @@ -273,81 +273,48 @@ int MPID_Close_port(const char *port_name) /* MPID_Comm_accept, MPID_Comm_connect */ -static int peer_intercomm_create(char *remote_addrname, int len, int tag, int timeout, - bool is_sender, MPIR_Comm ** newcomm); -static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int root, - MPIR_Comm * comm_ptr, int timeout, bool is_sender, - MPIR_Comm ** newcomm); - -struct dynproc_conn_hdr { - int context_id; - int addrname_len; - char addrname[MPIDI_DYNPROC_NAME_MAX]; -}; - -static int peer_intercomm_create(char *remote_addrname, int len, int tag, - int timeout, bool is_sender, MPIR_Comm ** newcomm) +static int establish_peer_conn(char *remote_addrname, int remote_addrname_len, int tag, + int timeout, bool is_sender, MPIR_Lpid * remote_lpid_out) { int mpi_errno = MPI_SUCCESS; - int context_id, recvcontext_id; MPIR_Lpid remote_lpid; - mpi_errno = MPIR_Get_contextid_sparse(MPIR_Process.comm_self, &recvcontext_id, FALSE); - MPIR_ERR_CHECK(mpi_errno); + struct dynproc_conn_hdr { + int addrname_len; + char addrname[MPIDI_DYNPROC_NAME_MAX]; + } hdr; - struct dynproc_conn_hdr hdr; if (is_sender) { /* insert remote address */ - int addrname_len = len; - MPIR_Lpid *remote_lpids = &remote_lpid; - mpi_errno = MPIDIU_upids_to_lpids(1, &addrname_len, remote_addrname, remote_lpids); + mpi_errno = MPIDIU_insert_dynamic_upid(&remote_lpid, remote_addrname, remote_addrname_len); MPIR_ERR_CHECK(mpi_errno); - /* fill hdr with context_id and addrname */ - hdr.context_id = recvcontext_id; - - char *addrname; - int *addrname_size; - mpi_errno = MPIDI_NM_get_local_upids(MPIR_Process.comm_self, &addrname_size, &addrname); + /* get my addrname and send it to remote */ + char *my_addrname; + int *my_addrname_len; + mpi_errno = MPIDI_NM_get_local_upids(MPIR_Process.comm_self, + &my_addrname_len, &my_addrname); MPIR_ERR_CHECK(mpi_errno); - MPIR_Assert(addrname_size[0] <= MPIDI_DYNPROC_NAME_MAX); - memcpy(hdr.addrname, addrname, addrname_size[0]); - hdr.addrname_len = addrname_size[0]; - - /* send remote context_id + addrname */ + MPIR_Assert(my_addrname_len[0] <= MPIDI_DYNPROC_NAME_MAX); + memcpy(hdr.addrname, my_addrname, my_addrname_len[0]); + hdr.addrname_len = my_addrname_len[0]; + /* send it to remote */ int hdr_sz = sizeof(hdr) - MPIDI_DYNPROC_NAME_MAX + hdr.addrname_len; mpi_errno = MPIDI_NM_dynamic_send(remote_lpid, tag, &hdr, hdr_sz, timeout); MPL_free(addrname); MPL_free(addrname_size); MPIR_ERR_CHECK(mpi_errno); - - mpi_errno = MPIDI_NM_dynamic_recv(tag, &hdr, sizeof(hdr), timeout); - MPIR_ERR_CHECK(mpi_errno); - context_id = hdr.context_id; } else { /* recv remote address */ mpi_errno = MPIDI_NM_dynamic_recv(tag, &hdr, sizeof(hdr), timeout); MPIR_ERR_CHECK(mpi_errno); - context_id = hdr.context_id; /* insert remote address */ - int addrname_len = hdr.addrname_len; - MPIR_Lpid *remote_lpids = &remote_lpid; - mpi_errno = MPIDIU_upids_to_lpids(1, &addrname_len, hdr.addrname, remote_lpids); - MPIR_ERR_CHECK(mpi_errno); - - /* send remote context_id */ - hdr.context_id = recvcontext_id; - mpi_errno = MPIDI_NM_dynamic_send(remote_lpid, tag, &hdr, sizeof(hdr.context_id), timeout); + mpi_errno = MPIDIU_insert_dynamic_upid(&remote_lpid, hdr.addrname, hdr.addrname_len); MPIR_ERR_CHECK(mpi_errno); } - /* create peer intercomm */ - mpi_errno = MPIR_peer_intercomm_create(context_id, recvcontext_id, - remote_lpid, is_sender, newcomm); - MPIR_ERR_CHECK(mpi_errno); - fn_exit: return mpi_errno; fn_fail: @@ -362,15 +329,14 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int MPIR_Comm ** newcomm) { int mpi_errno = MPI_SUCCESS; + MPIR_Lpid remote_lpid; - MPIR_Comm *peer_intercomm = NULL; int tag; int bcast_ints[2]; /* used to bcast tag and errno */ if (comm_ptr->rank == root) { /* NOTE: do not goto fn_fail on error, or it will leave children hanging */ mpi_errno = get_tag_from_port(port_name, &tag); - if (mpi_errno) - goto bcast_tag_and_errno; + MPIR_ERR_CHECK(mpi_errno); char remote_addrname[MPIDI_DYNPROC_NAME_MAX]; char *addrname; @@ -379,40 +345,19 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int addrname = remote_addrname; mpi_errno = get_conn_name_from_port(port_name, remote_addrname, MPIDI_DYNPROC_NAME_MAX, &len); - if (mpi_errno) - goto bcast_tag_and_errno; + MPIR_ERR_CHECK(mpi_errno); } else { - /* Use NULL for better error behavior */ addrname = NULL; len = 0; } - mpi_errno = peer_intercomm_create(addrname, len, tag, timeout, is_sender, &peer_intercomm); - - bcast_tag_and_errno: - bcast_ints[0] = tag; - bcast_ints[1] = mpi_errno; - mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); + mpi_errno = establish_peer_conn(addrname, len, tag, timeout, is_sender, &remote_lpid); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = bcast_ints[1]; - MPIR_ERR_CHECK(mpi_errno); - } else { - mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE); - MPIR_ERR_CHECK(mpi_errno); - if (bcast_ints[1]) { - /* errno from root cannot be directly returned */ - MPIR_ERR_SET(mpi_errno, MPI_ERR_PORT, "**comm_connect_fail"); - goto fn_fail; - } - tag = bcast_ints[0]; } - mpi_errno = MPIR_Intercomm_create_impl(comm_ptr, root, peer_intercomm, 0, tag, newcomm); + mpi_errno = MPIR_Intercomm_create(comm_ptr, root,, 0, tag, newcomm); MPIR_ERR_CHECK(mpi_errno); fn_exit: - if (peer_intercomm) { - MPIR_Comm_free_impl(peer_intercomm); - } return mpi_errno; fn_fail: goto fn_exit;