Skip to content

Commit

Permalink
comm: always set local_group and remote_group
Browse files Browse the repository at this point in the history
To make MPI group a first-class citizen, we will always have group
before creating communicators, so that when device layer activate
communiators, e.g. in MPID_Comm_commit_pre_hook, it can rely on the
group to look up the involved processes. It also removes the necessity
to maintain any other process addressing schems.
  • Loading branch information
hzhou committed Dec 12, 2024
1 parent e0c851e commit 0cf5832
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ struct MPIR_Comm {
int rank; /* Value of MPI_Comm_rank */
MPIR_Attribute *attributes; /* List of attributes */
int local_size; /* Value of MPI_Comm_size for local group */
MPIR_Group *local_group, /* Groups in communicator. */
*remote_group; /* The local and remote groups are the
* same for intra communicators */
MPIR_Group *local_group; /* Groups in communicator. */
MPIR_Group *remote_group; /* The remote group in a inter communicator.
* Must be NULL in a intra communicator. */
MPIR_Comm_kind_t comm_kind; /* MPIR_COMM_KIND__INTRACOMM or MPIR_COMM_KIND__INTERCOMM */
char name[MPI_MAX_OBJECT_NAME]; /* Required for MPI-2 */
MPIR_Errhandler *errhandler; /* Pointer to the error handler structure */
Expand Down
6 changes: 6 additions & 0 deletions src/mpi/comm/builtin_comms.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ int MPIR_init_comm_world(void)
MPIR_Process.comm_world->remote_size = MPIR_Process.size;
MPIR_Process.comm_world->local_size = MPIR_Process.size;

MPIR_Process.comm_world->local_group = MPIR_GROUP_WORLD_PTR;

mpi_errno = MPIR_Comm_commit(MPIR_Process.comm_world);
MPIR_ERR_CHECK(mpi_errno);

Expand Down Expand Up @@ -59,6 +61,8 @@ int MPIR_init_comm_self(void)
MPIR_Process.comm_self->remote_size = 1;
MPIR_Process.comm_self->local_size = 1;

MPIR_Process.comm_self->local_group = MPIR_GROUP_SELF_PTR;

mpi_errno = MPIR_Comm_commit(MPIR_Process.comm_self);
MPIR_ERR_CHECK(mpi_errno);

Expand Down Expand Up @@ -91,6 +95,8 @@ int MPIR_init_icomm_world(void)
MPIR_Process.icomm_world->remote_size = MPIR_Process.size;
MPIR_Process.icomm_world->local_size = MPIR_Process.size;

MPIR_Process.icomm_world->local_group = MPIR_GROUP_WORLD_PTR;

mpi_errno = MPIR_Comm_commit(MPIR_Process.icomm_world);
MPIR_ERR_CHECK(mpi_errno);

Expand Down
52 changes: 41 additions & 11 deletions src/mpi/comm/comm_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ int MPIR_Comm_create_intra(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
(*newcomm_ptr)->local_group = group_ptr;
MPIR_Group_add_ref(group_ptr);

(*newcomm_ptr)->remote_group = group_ptr;
MPIR_Group_add_ref(group_ptr);
(*newcomm_ptr)->remote_group = NULL;
(*newcomm_ptr)->context_id = (*newcomm_ptr)->recvcontext_id;
(*newcomm_ptr)->remote_size = (*newcomm_ptr)->local_size = n;

Expand Down Expand Up @@ -381,16 +380,12 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
{
int mpi_errno = MPI_SUCCESS;
int new_context_id;
int *mapping = NULL;
int *remote_mapping = NULL;
MPIR_Comm *mapping_comm = NULL;
int remote_size = -1;
int rinfo[2];
MPIR_CHKLMEM_DECL(1);

MPIR_FUNC_ENTER;

MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM);
MPIR_Session *session_ptr = comm_ptr->session_ptr;

/* Create a new communicator from the specified group members */

Expand All @@ -409,6 +404,8 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
MPIR_Assert(new_context_id != 0);
MPIR_Assert(new_context_id != comm_ptr->recvcontext_id);

int *mapping; /* a list of local ranks */
MPIR_Comm *mapping_comm;
mpi_errno = MPII_Comm_create_calculate_mapping(group_ptr, comm_ptr, &mapping, &mapping_comm);
MPIR_ERR_CHECK(mpi_errno);

Expand All @@ -434,7 +431,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co

(*newcomm_ptr)->is_low_group = comm_ptr->is_low_group;

MPIR_Comm_set_session_ptr(*newcomm_ptr, comm_ptr->session_ptr);
MPIR_Comm_set_session_ptr(*newcomm_ptr, session_ptr);
}

/* There is an additional step. We must communicate the
Expand All @@ -445,6 +442,11 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
* in the remote group, from which the remote network address
* mapping can be constructed. We need to use the "collective"
* context in the original intercommunicator */

int remote_size = -1;
int *remote_mapping; /* a list of remote ranks */
int rinfo[2];

if (comm_ptr->rank == 0) {
int info[2];
info[0] = new_context_id;
Expand Down Expand Up @@ -494,6 +496,24 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co

MPIR_Assert(remote_size >= 0);

/* create remote_group.
* FIXME: we can directly exchange group maps once we get rid of comm mappers */
MPIR_Group *remote_group;

MPIR_Lpid *remote_map;
remote_map = MPL_malloc(remote_size * sizeof(MPIR_Lpid), MPL_MEM_GROUP);
MPIR_ERR_CHKANDJUMP(!remote_map, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPIR_Group *mapping_group = mapping_comm->remote_group;
MPIR_Assert(mapping_group);
for (int i = 0; i < remote_size; i++) {
remote_map[i] = MPIR_Group_rank_to_lpid(mapping_group, remote_mapping[i]);
}
mpi_errno = MPIR_Group_create_map(remote_size, MPI_UNDEFINED, session_ptr, remote_map,
&remote_group);
(*newcomm_ptr)->remote_group = remote_group;


if (group_ptr->rank != MPI_UNDEFINED) {
(*newcomm_ptr)->remote_size = remote_size;
/* Now, everyone has the remote_mapping, and can apply that to
Expand Down Expand Up @@ -605,8 +625,7 @@ int MPIR_Comm_create_group_impl(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, in
(*newcomm_ptr)->local_group = group_ptr;
MPIR_Group_add_ref(group_ptr);

(*newcomm_ptr)->remote_group = group_ptr;
MPIR_Group_add_ref(group_ptr);
(*newcomm_ptr)->remote_group = NULL;
(*newcomm_ptr)->context_id = (*newcomm_ptr)->recvcontext_id;
(*newcomm_ptr)->remote_size = (*newcomm_ptr)->local_size = n;

Expand Down Expand Up @@ -913,6 +932,9 @@ int MPIR_Comm_remote_group_impl(MPIR_Comm * comm_ptr, MPIR_Group ** group_ptr)
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

/* FIXME: remove the following remote_group creation once this assertion passes */
MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM && comm_ptr->remote_group);

/* Create a group and populate it with the local process ids */
if (!comm_ptr->remote_group) {
int n = comm_ptr->remote_size;
Expand Down Expand Up @@ -965,6 +987,7 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader,
uint64_t *remote_lpids = NULL;
int comm_info[3];
int is_low_group = 0;
MPIR_Session *session_ptr = local_comm_ptr->session_ptr;

MPIR_FUNC_ENTER;

Expand Down Expand Up @@ -1042,7 +1065,14 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader,
(*new_intercomm_ptr)->local_comm = 0;
(*new_intercomm_ptr)->is_low_group = is_low_group;

MPIR_Comm_set_session_ptr(*new_intercomm_ptr, local_comm_ptr->session_ptr);
(*new_intercomm_ptr)->local_group = local_comm_ptr->local_group;
MPIR_Group_add_ref(local_comm_ptr->local_group);

/* construct remote_group */
mpi_errno = MPIR_Group_create_map(remote_size, MPI_UNDEFINED, session_ptr, remote_lpids,
&(*new_intercomm_ptr)->remote_group);

MPIR_Comm_set_session_ptr(*new_intercomm_ptr, session_ptr);

mpi_errno = MPID_Create_intercomm_from_lpids(*new_intercomm_ptr, remote_size, remote_lpids);
if (mpi_errno)
Expand Down
13 changes: 13 additions & 0 deletions src/mpi/comm/comm_split.c
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm **
(*newcomm_ptr)->rank = i;
}

mpi_errno = MPIR_Group_incl_impl(comm_ptr->local_group, new_size, mapper->src_mapping,
&(*newcomm_ptr)->local_group);
MPIR_ERR_CHECK(mpi_errno);

/* For the remote group, the situation is more complicated.
* We need to find the size of our "partner" group in the
* remote comm. The easiest way (in terms of code) is for
Expand All @@ -313,6 +317,11 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm **
for (i = 0; i < new_remote_size; i++)
mapper->src_mapping[i] = remotekeytable[i].color;

mpi_errno = MPIR_Group_incl_impl(comm_ptr->remote_group,
new_remote_size, mapper->src_mapping,
&(*newcomm_ptr)->remote_group);
MPIR_ERR_CHECK(mpi_errno);

(*newcomm_ptr)->context_id = remote_context_id;
(*newcomm_ptr)->remote_size = new_remote_size;
(*newcomm_ptr)->local_comm = 0;
Expand All @@ -331,6 +340,10 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm **
if (keytable[i].color == comm_ptr->rank)
(*newcomm_ptr)->rank = i;
}

mpi_errno = MPIR_Group_incl_impl(comm_ptr->local_group, new_size, mapper->src_mapping,
&(*newcomm_ptr)->local_group);
MPIR_ERR_CHECK(mpi_errno);
}

/* Inherit the error handler (if any) */
Expand Down
34 changes: 34 additions & 0 deletions src/mpi/comm/commutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ int MPII_Setup_intercomm_localcomm(MPIR_Comm * intercomm_ptr)
mpi_errno = MPII_Comm_init(localcomm_ptr);
MPIR_ERR_CHECK(mpi_errno);

MPIR_Assert(intercomm_ptr->local_group);
localcomm_ptr->local_group = intercomm_ptr->local_group;
MPIR_Group_add_ref(intercomm_ptr->local_group);

MPIR_Comm_set_session_ptr(localcomm_ptr, intercomm_ptr->session_ptr);

/* use the parent intercomm's recv ctx as the basis for our ctx */
Expand Down Expand Up @@ -687,6 +691,14 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm)
/* Copy relevant hints to node_comm */
propagate_hints_to_subcomm(comm, comm->node_comm);

/* construct local_group */
MPIR_Group *parent_group = comm->local_group;
MPIR_Assert(parent_group);
mpi_errno = MPIR_Group_incl_impl(parent_group, num_local, local_procs,
&comm->node_comm->local_group);
MPIR_ERR_CHECK(mpi_errno);

/* mapper */
MPIR_Comm_map_irregular(comm->node_comm, comm, local_procs, num_local,
MPIR_COMM_MAP_DIR__L2L, NULL);
mpi_errno = MPIR_Comm_commit_internal(comm->node_comm);
Expand Down Expand Up @@ -714,6 +726,14 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm)
/* Copy relevant hints to node_roots_comm */
propagate_hints_to_subcomm(comm, comm->node_roots_comm);

/* construct local_group */
MPIR_Group *parent_group = comm->local_group;
MPIR_Assert(parent_group);
mpi_errno = MPIR_Group_incl_impl(parent_group, num_external, external_procs,
&comm->node_roots_comm->local_group);
MPIR_ERR_CHECK(mpi_errno);

/* mapper */
MPIR_Comm_map_irregular(comm->node_roots_comm, comm, external_procs, num_external,
MPIR_COMM_MAP_DIR__L2L, NULL);
mpi_errno = MPIR_Comm_commit_internal(comm->node_roots_comm);
Expand Down Expand Up @@ -961,6 +981,13 @@ int MPII_Comm_copy(MPIR_Comm * comm_ptr, int size, MPIR_Info * info, MPIR_Comm *
newcomm_ptr->comm_kind = comm_ptr->comm_kind;
newcomm_ptr->local_comm = 0;

newcomm_ptr->local_group = comm_ptr->local_group;
MPIR_Group_add_ref(comm_ptr->local_group);
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) {
newcomm_ptr->remote_group = comm_ptr->remote_group;
MPIR_Group_add_ref(comm_ptr->remote_group);
}

MPIR_Comm_set_session_ptr(newcomm_ptr, comm_ptr->session_ptr);

/* There are two cases here - size is the same as the old communicator,
Expand Down Expand Up @@ -1059,6 +1086,13 @@ int MPII_Comm_copy_data(MPIR_Comm * comm_ptr, MPIR_Info * info, MPIR_Comm ** out
newcomm_ptr->comm_kind = comm_ptr->comm_kind;
newcomm_ptr->local_comm = 0;

newcomm_ptr->local_group = comm_ptr->local_group;
MPIR_Group_add_ref(comm_ptr->local_group);
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) {
newcomm_ptr->remote_group = comm_ptr->remote_group;
MPIR_Group_add_ref(comm_ptr->remote_group);
}

if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM)
MPIR_Comm_map_dup(newcomm_ptr, comm_ptr, MPIR_COMM_MAP_DIR__L2L);
else
Expand Down
24 changes: 22 additions & 2 deletions src/mpid/ch3/src/ch3u_port.c
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,13 @@ static int MPIDI_CH3I_Initialize_tmp_comm(MPIR_Comm **comm_pptr,

MPIR_Coll_comm_init(tmp_comm);

MPIR_Lpid local_lpid = tmp_comm->dev.local_vcrt->vcr_table[0]->lpid;
MPIR_Lpid remote_lpid = tmp_comm->dev.vcrt->vcr_table[0]->lpid;
mpi_errno = MPIR_Group_create_stride(1, 0, commself_ptr->session_ptr, local_lpid, 1, 1,
&tmp_comm->local_group);
mpi_errno = MPIR_Group_create_stride(1, 0, commself_ptr->session_ptr, remote_lpid, 1, 1,
&tmp_comm->remote_group);

/* Even though this is a tmp comm and we don't call
MPI_Comm_commit, we still need to call the creation hook
because the destruction hook will be called in comm_release */
Expand Down Expand Up @@ -1337,8 +1344,6 @@ static int SetupNewIntercomm( MPIR_Comm *comm_ptr, int remote_comm_size,
intercomm->remote_size = remote_comm_size;
intercomm->local_size = comm_ptr->local_size;
intercomm->rank = comm_ptr->rank;
intercomm->local_group = NULL;
intercomm->remote_group = NULL;
intercomm->comm_kind = MPIR_COMM_KIND__INTERCOMM;
intercomm->local_comm = NULL;

Expand All @@ -1356,6 +1361,21 @@ static int SetupNewIntercomm( MPIR_Comm *comm_ptr, int remote_comm_size,
remote_translation[i].pg_rank, &intercomm->dev.vcrt->vcr_table[i]);
}

intercomm->local_group = comm_ptr->local_group;
MPIR_Group_add_ref(comm_ptr->local_group);

MPIR_Lpid *remote_map;
remote_map = MPL_malloc(remote_comm_size * sizeof(MPIR_Lpid), MPL_MEM_GROUP);
MPIR_ERR_CHKANDJUMP(!remote_map, mpi_errno, MPI_ERR_OTHER, "**nomem");
for (i=0; i < intercomm->remote_size; i++) {
MPIDI_PG_t *pg = remote_pg[remote_translation[i].pg_index];
int rank = remote_translation[i].pg_rank;
remote_map[i] = pg->vct[rank].lpid;
}
mpi_errno = MPIR_Group_create_map(remote_comm_size, MPI_UNDEFINED, comm_ptr->session_ptr,
remote_map, &intercomm->remote_group);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIR_Comm_commit(intercomm);
MPIR_ERR_CHECK(mpi_errno);

Expand Down
4 changes: 4 additions & 0 deletions src/mpid/ch4/src/ch4_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,10 @@ int MPIDI_Comm_create_multi_leaders(MPIR_Comm * comm)
MPIR_Comm_map_irregular(MPIDI_COMM(comm, multi_leads_comm), comm,
external_procs, num_external, MPIR_COMM_MAP_DIR__L2L, NULL);

mpi_errno = MPIR_Group_incl_impl(comm->local_group, num_external, external_procs,
&MPIDI_COMM(comm, multi_leads_comm)->local_group);
MPIR_ERR_CHECK(mpi_errno);

/* Notify device of communicator creation */
mpi_errno = MPID_Comm_commit_pre_hook(MPIDI_COMM(comm, multi_leads_comm));
if (mpi_errno)
Expand Down
15 changes: 13 additions & 2 deletions src/mpid/ch4/src/init_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ int MPIDI_create_init_comm(MPIR_Comm ** comm)
init_comm->remote_size = node_roots_comm_size;
init_comm->local_size = node_roots_comm_size;
init_comm->coll.pof2 = MPL_pof2(node_roots_comm_size);

MPIR_Lpid *map;
map = MPL_malloc(node_roots_comm_size * sizeof(MPIR_Lpid), MPL_MEM_GROUP);
MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem");
for (i = 0; i < node_roots_comm_size; ++i) {
map[i] = MPIR_Process.node_root_map[i];
}
mpi_errno = MPIR_Group_create_map(node_roots_comm_size, node_roots_comm_rank, NULL,
map, &init_comm->local_group);
MPIR_ERR_CHECK(mpi_errno);

MPIDI_COMM(init_comm, map).mode = MPIDI_RANK_MAP_LUT_INTRA;
mpi_errno = MPIDIU_alloc_lut(&lut, node_roots_comm_size);
MPIR_ERR_CHECK(mpi_errno);
Expand All @@ -47,8 +58,8 @@ int MPIDI_create_init_comm(MPIR_Comm ** comm)
mpi_errno = MPIDIG_init_comm(init_comm);
MPIR_ERR_CHECK(mpi_errno);
/* hacky, consider a separate MPIDI_{NM,SHM}_init_comm_hook
* to initialize the init_comm, e.g. to eliminate potential
* runtime features for stability during init */
* to initialize the init_comm, e.g. to eliminate potential
* runtime features for stability during init */
mpi_errno = MPIDI_NM_mpi_comm_commit_pre_hook(init_comm);
MPIR_ERR_CHECK(mpi_errno);

Expand Down

0 comments on commit 0cf5832

Please sign in to comment.