Skip to content

Commit

Permalink
ch4: refactor MPID_Comm_connect/accept
Browse files Browse the repository at this point in the history
Simply establish remote_lpid and call MPIR_Intercomm_create.
  • Loading branch information
hzhou committed Dec 18, 2024
1 parent db5cfae commit e07760f
Showing 1 changed file with 22 additions and 77 deletions.
99 changes: 22 additions & 77 deletions src/mpid/ch4/src/ch4_spawn.c
Original file line number Diff line number Diff line change
Expand Up @@ -273,81 +273,48 @@ int MPID_Close_port(const char *port_name)

/* MPID_Comm_accept, MPID_Comm_connect */

static int peer_intercomm_create(char *remote_addrname, int len, int tag, int timeout,
bool is_sender, MPIR_Comm ** newcomm);
static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int root,
MPIR_Comm * comm_ptr, int timeout, bool is_sender,
MPIR_Comm ** newcomm);

struct dynproc_conn_hdr {
int context_id;
int addrname_len;
char addrname[MPIDI_DYNPROC_NAME_MAX];
};

static int peer_intercomm_create(char *remote_addrname, int len, int tag,
int timeout, bool is_sender, MPIR_Comm ** newcomm)
static int establish_peer_conn(char *remote_addrname, int remote_addrname_len, int tag,
int timeout, bool is_sender, MPIR_Lpid * remote_lpid_out)
{
int mpi_errno = MPI_SUCCESS;
int context_id, recvcontext_id;
MPIR_Lpid remote_lpid;

mpi_errno = MPIR_Get_contextid_sparse(MPIR_Process.comm_self, &recvcontext_id, FALSE);
MPIR_ERR_CHECK(mpi_errno);
struct dynproc_conn_hdr {
int addrname_len;
char addrname[MPIDI_DYNPROC_NAME_MAX];
} hdr;

struct dynproc_conn_hdr hdr;
if (is_sender) {
/* insert remote address */
int addrname_len = len;
MPIR_Lpid *remote_lpids = &remote_lpid;
mpi_errno = MPIDIU_upids_to_lpids(1, &addrname_len, remote_addrname, remote_lpids);
mpi_errno = MPIDIU_insert_dynamic_upid(&remote_lpid, remote_addrname, remote_addrname_len);
MPIR_ERR_CHECK(mpi_errno);

/* fill hdr with context_id and addrname */
hdr.context_id = recvcontext_id;

char *addrname;
int *addrname_size;
mpi_errno = MPIDI_NM_get_local_upids(MPIR_Process.comm_self, &addrname_size, &addrname);
/* get my addrname and send it to remote */
char *my_addrname;
int *my_addrname_len;
mpi_errno = MPIDI_NM_get_local_upids(MPIR_Process.comm_self,
&my_addrname_len, &my_addrname);
MPIR_ERR_CHECK(mpi_errno);
MPIR_Assert(addrname_size[0] <= MPIDI_DYNPROC_NAME_MAX);
memcpy(hdr.addrname, addrname, addrname_size[0]);
hdr.addrname_len = addrname_size[0];

/* send remote context_id + addrname */
MPIR_Assert(my_addrname_len[0] <= MPIDI_DYNPROC_NAME_MAX);
memcpy(hdr.addrname, my_addrname, my_addrname_len[0]);
hdr.addrname_len = my_addrname_len[0];

/* send it to remote */
int hdr_sz = sizeof(hdr) - MPIDI_DYNPROC_NAME_MAX + hdr.addrname_len;
mpi_errno = MPIDI_NM_dynamic_send(remote_lpid, tag, &hdr, hdr_sz, timeout);
MPL_free(addrname);
MPL_free(addrname_size);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIDI_NM_dynamic_recv(tag, &hdr, sizeof(hdr), timeout);
MPIR_ERR_CHECK(mpi_errno);
context_id = hdr.context_id;
} else {
/* recv remote address */
mpi_errno = MPIDI_NM_dynamic_recv(tag, &hdr, sizeof(hdr), timeout);
MPIR_ERR_CHECK(mpi_errno);
context_id = hdr.context_id;

/* insert remote address */
int addrname_len = hdr.addrname_len;
MPIR_Lpid *remote_lpids = &remote_lpid;
mpi_errno = MPIDIU_upids_to_lpids(1, &addrname_len, hdr.addrname, remote_lpids);
MPIR_ERR_CHECK(mpi_errno);

/* send remote context_id */
hdr.context_id = recvcontext_id;
mpi_errno = MPIDI_NM_dynamic_send(remote_lpid, tag, &hdr, sizeof(hdr.context_id), timeout);
mpi_errno = MPIDIU_insert_dynamic_upid(&remote_lpid, hdr.addrname, hdr.addrname_len);
MPIR_ERR_CHECK(mpi_errno);
}

/* create peer intercomm */
mpi_errno = MPIR_peer_intercomm_create(context_id, recvcontext_id,
remote_lpid, is_sender, newcomm);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
return mpi_errno;
fn_fail:
Expand All @@ -362,15 +329,14 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int
MPIR_Comm ** newcomm)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Lpid remote_lpid;

MPIR_Comm *peer_intercomm = NULL;
int tag;
int bcast_ints[2]; /* used to bcast tag and errno */
if (comm_ptr->rank == root) {
/* NOTE: do not goto fn_fail on error, or it will leave children hanging */
mpi_errno = get_tag_from_port(port_name, &tag);
if (mpi_errno)
goto bcast_tag_and_errno;
MPIR_ERR_CHECK(mpi_errno);

char remote_addrname[MPIDI_DYNPROC_NAME_MAX];
char *addrname;
Expand All @@ -379,40 +345,19 @@ static int dynamic_intercomm_create(const char *port_name, MPIR_Info * info, int
addrname = remote_addrname;
mpi_errno = get_conn_name_from_port(port_name, remote_addrname, MPIDI_DYNPROC_NAME_MAX,
&len);
if (mpi_errno)
goto bcast_tag_and_errno;
MPIR_ERR_CHECK(mpi_errno);
} else {
/* Use NULL for better error behavior */
addrname = NULL;
len = 0;
}
mpi_errno = peer_intercomm_create(addrname, len, tag, timeout, is_sender, &peer_intercomm);

bcast_tag_and_errno:
bcast_ints[0] = tag;
bcast_ints[1] = mpi_errno;
mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE);
mpi_errno = establish_peer_conn(addrname, len, tag, timeout, is_sender, &remote_lpid);
MPIR_ERR_CHECK(mpi_errno);
mpi_errno = bcast_ints[1];
MPIR_ERR_CHECK(mpi_errno);
} else {
mpi_errno = MPIR_Bcast_allcomm_auto(bcast_ints, 2, MPI_INT, root, comm_ptr, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
if (bcast_ints[1]) {
/* errno from root cannot be directly returned */
MPIR_ERR_SET(mpi_errno, MPI_ERR_PORT, "**comm_connect_fail");
goto fn_fail;
}
tag = bcast_ints[0];
}

mpi_errno = MPIR_Intercomm_create_impl(comm_ptr, root, peer_intercomm, 0, tag, newcomm);
mpi_errno = MPIR_Intercomm_create(comm_ptr, root,, 0, tag, newcomm);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
if (peer_intercomm) {
MPIR_Comm_free_impl(peer_intercomm);
}
return mpi_errno;
fn_fail:
goto fn_exit;
Expand Down

0 comments on commit e07760f

Please sign in to comment.