Skip to content

Commit

Permalink
ch3: use group to build vcrt instead of mapper
Browse files Browse the repository at this point in the history
Replace the usage of mapper with comm->local_group and
comm->remote_group in MPIDI_CH3I_Comm_commit_pre_hook.
  • Loading branch information
hzhou committed Dec 20, 2024
1 parent 749c30d commit 9cc9027
Showing 1 changed file with 75 additions and 147 deletions.
222 changes: 75 additions & 147 deletions src/mpid/ch3/src/ch3u_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,77 +111,71 @@ int MPIDI_CH3I_Comm_init(void)
goto fn_exit;
}


static void dup_vcrt(struct MPIDI_VCRT *src_vcrt, struct MPIDI_VCRT **dest_vcrt,
MPIR_Comm_map_t *mapper, int src_comm_size, int vcrt_size,
int vcrt_offset)
static int create_vcrt_from_group(MPIR_Group *group, struct MPIDI_VCRT **vcrt_out)
{
int flag, i;

/* try to find the simple case where the new comm is a simple
* duplicate of the previous comm. in that case, we simply add a
* reference to the previous VCRT instead of recreating it. */
if (mapper->type == MPIR_COMM_MAP_TYPE__DUP && src_comm_size == vcrt_size) {
*dest_vcrt = src_vcrt;
MPIDI_VCRT_Add_ref(src_vcrt);
return;
}
else if (mapper->type == MPIR_COMM_MAP_TYPE__IRREGULAR &&
mapper->src_mapping_size == vcrt_size) {
/* if the mapping array is exactly the same as the original
* comm's VC list, there is no need to create a new VCRT.
* instead simply point to the original comm's VCRT and bump
* up it's reference count */
flag = 1;
for (i = 0; i < mapper->src_mapping_size; i++)
if (mapper->src_mapping[i] != i)
flag = 0;
int mpi_errno = MPI_SUCCESS;

if (flag) {
*dest_vcrt = src_vcrt;
MPIDI_VCRT_Add_ref(src_vcrt);
return;
}
if (group->ch3_vcrt) {
MPIDI_VCRT_Add_ref(group->ch3_vcrt);
*vcrt_out = group->ch3_vcrt;
goto fn_exit;
}

/* we are in the more complex case where we need to allocate a new
* VCRT */
struct MPIDI_VCRT *vcrt;
mpi_errno = MPIDI_VCRT_Create(group->size, &vcrt);
MPIR_ERR_CHECK(mpi_errno);

if (!vcrt_offset)
MPIDI_VCRT_Create(vcrt_size, dest_vcrt);
*vcrt_out = vcrt;

if (mapper->type == MPIR_COMM_MAP_TYPE__DUP) {
for (i = 0; i < src_comm_size; i++)
MPIDI_VCR_Dup(src_vcrt->vcr_table[i],
&((*dest_vcrt)->vcr_table[i + vcrt_offset]));
}
else {
for (i = 0; i < mapper->src_mapping_size; i++)
MPIDI_VCR_Dup(src_vcrt->vcr_table[mapper->src_mapping[i]],
&((*dest_vcrt)->vcr_table[i + vcrt_offset]));
for (int i = 0; i < group->size; i++) {
MPIR_Lpid lpid = MPIR_Group_rank_to_lpid(group, i);
/* Currently ch3 does not synchronize pg with MPIR_worlds. All lpid are contiguous
* with world_idx = 0. We can tell whether it is a spawned process by checking whether
* it is >= world size.
*/
if (lpid < MPIR_Process.size) {
MPIDI_VCR_Dup(&MPIDI_Process.my_pg->vct[lpid], &vcrt->vcr_table[i]);
} else {
/* search PGs to find the vc. Not particularly efficient, but likely not critical */
MPIDI_PG_iterator iter;
MPIDI_PG_Get_iterator(&iter);
bool found_it = false;
while (MPIDI_PG_Has_next(&iter)) {
MPIDI_PG_t *pg;
MPIDI_PG_Get_next(&iter, &pg);
for (int j = 0; j < pg->size; j++) {
if (pg->vct[j].lpid == lpid) {
MPIDI_VCR_Dup(&pg->vct[j], &vcrt->vcr_table[i]);
found_it = true;
break;
}
}
if (found_it) {
break;
}
pg = pg->next;
}
MPIR_Assert(found_it);
}
}
}

static inline int map_size(MPIR_Comm_map_t map)
{
if (map.type == MPIR_COMM_MAP_TYPE__IRREGULAR)
return map.src_mapping_size;
else if (map.dir == MPIR_COMM_MAP_DIR__L2L || map.dir == MPIR_COMM_MAP_DIR__L2R)
return map.src_comm->local_size;
else
return map.src_comm->remote_size;
fn_exit:
return mpi_errno;
fn_fail:
goto fn_exit;

}

int MPIDI_CH3I_Comm_commit_pre_hook(MPIR_Comm *comm)
{
int mpi_errno = MPI_SUCCESS;
hook_elt *elt;
MPIR_Comm_map_t *mapper;
MPIR_Comm *src_comm;
int vcrt_size, vcrt_offset;

MPIR_FUNC_ENTER;

/* initialize the is_disconnected variable to FALSE. this will be
* set to TRUE if the communicator is freed by an
* MPI_COMM_DISCONNECT call. */
comm->dev.is_disconnected = 0;

if (comm == MPIR_Process.comm_world) {
comm->rank = MPIR_Process.rank;
comm->remote_size = MPIR_Process.size;
Expand All @@ -198,6 +192,7 @@ int MPIDI_CH3I_Comm_commit_pre_hook(MPIR_Comm *comm)
for (int p = 0; p < MPIR_Process.size; p++) {
MPIDI_VCR_Dup(&MPIDI_Process.my_pg->vct[p], &comm->dev.vcrt->vcr_table[p]);
}
goto done_vcrt;
} else if (comm == MPIR_Process.comm_self) {
comm->rank = 0;
comm->remote_size = 1;
Expand All @@ -211,111 +206,43 @@ int MPIDI_CH3I_Comm_commit_pre_hook(MPIR_Comm *comm)
}

MPIDI_VCR_Dup(&MPIDI_Process.my_pg->vct[MPIR_Process.rank], &comm->dev.vcrt->vcr_table[0]);
goto done_vcrt;
} else if (comm == MPIR_Process.icomm_world) {
comm->rank = MPIR_Process.rank;
comm->remote_size = MPIR_Process.size;
comm->local_size = MPIR_Process.size;

MPIDI_VCRT_Add_ref(MPIR_Process.comm_world->dev.vcrt );
comm->dev.vcrt = MPIR_Process.comm_world->dev.vcrt;
goto done_vcrt;
}

/* initialize the is_disconnected variable to FALSE. this will be
* set to TRUE if the communicator is freed by an
* MPI_COMM_DISCONNECT call. */
comm->dev.is_disconnected = 0;

/* do some sanity checks */
LL_FOREACH(comm->mapper_head, mapper) {
if (mapper->src_comm->comm_kind == MPIR_COMM_KIND__INTRACOMM)
MPIR_Assertp(mapper->dir == MPIR_COMM_MAP_DIR__L2L ||
mapper->dir == MPIR_COMM_MAP_DIR__L2R);
if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM)
MPIR_Assertp(mapper->dir == MPIR_COMM_MAP_DIR__L2L ||
mapper->dir == MPIR_COMM_MAP_DIR__R2L);
}

/* First, handle all the mappers that contribute to the local part
* of the comm */
vcrt_size = 0;
LL_FOREACH(comm->mapper_head, mapper) {
if (mapper->dir == MPIR_COMM_MAP_DIR__L2R ||
mapper->dir == MPIR_COMM_MAP_DIR__R2R)
continue;

vcrt_size += map_size(*mapper);
if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
mpi_errno = create_vcrt_from_group(comm->local_group, &comm->dev.vcrt);
MPIR_ERR_CHECK(mpi_errno);
} else {
mpi_errno = create_vcrt_from_group(comm->local_group, &comm->dev.local_vcrt);
MPIR_ERR_CHECK(mpi_errno);
mpi_errno = create_vcrt_from_group(comm->remote_group, &comm->dev.vcrt);
MPIR_ERR_CHECK(mpi_errno);
}
vcrt_offset = 0;
LL_FOREACH(comm->mapper_head, mapper) {
src_comm = mapper->src_comm;

if (mapper->dir == MPIR_COMM_MAP_DIR__L2R ||
mapper->dir == MPIR_COMM_MAP_DIR__R2R)
continue;

if (mapper->dir == MPIR_COMM_MAP_DIR__L2L) {
if (src_comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
dup_vcrt(src_comm->dev.vcrt, &comm->dev.vcrt, mapper, mapper->src_comm->local_size,
vcrt_size, vcrt_offset);
}
else if (src_comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && comm->comm_kind == MPIR_COMM_KIND__INTERCOMM)
dup_vcrt(src_comm->dev.vcrt, &comm->dev.local_vcrt, mapper, mapper->src_comm->local_size,
vcrt_size, vcrt_offset);
else if (src_comm->comm_kind == MPIR_COMM_KIND__INTERCOMM && comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
dup_vcrt(src_comm->dev.local_vcrt, &comm->dev.vcrt, mapper, mapper->src_comm->local_size,
vcrt_size, vcrt_offset);
}
else
dup_vcrt(src_comm->dev.local_vcrt, &comm->dev.local_vcrt, mapper,
mapper->src_comm->local_size, vcrt_size, vcrt_offset);
done_vcrt:
/* add vcrt to the comm groups if they are not there */
if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
if (comm->local_group->ch3_vcrt == NULL) {
MPIDI_VCRT_Add_ref(comm->dev.vcrt);
comm->local_group->ch3_vcrt = comm->dev.vcrt;
}
else { /* mapper->dir == MPIR_COMM_MAP_DIR__R2L */
MPIR_Assert(src_comm->comm_kind == MPIR_COMM_KIND__INTERCOMM);
if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
dup_vcrt(src_comm->dev.vcrt, &comm->dev.vcrt, mapper, mapper->src_comm->remote_size,
vcrt_size, vcrt_offset);
}
else
dup_vcrt(src_comm->dev.vcrt, &comm->dev.local_vcrt, mapper, mapper->src_comm->remote_size,
vcrt_size, vcrt_offset);
}
vcrt_offset += map_size(*mapper);
}

/* Next, handle all the mappers that contribute to the remote part
* of the comm (only valid for intercomms) */
vcrt_size = 0;
LL_FOREACH(comm->mapper_head, mapper) {
if (mapper->dir == MPIR_COMM_MAP_DIR__L2L ||
mapper->dir == MPIR_COMM_MAP_DIR__R2L)
continue;

vcrt_size += map_size(*mapper);
}
vcrt_offset = 0;
LL_FOREACH(comm->mapper_head, mapper) {
src_comm = mapper->src_comm;

if (mapper->dir == MPIR_COMM_MAP_DIR__L2L ||
mapper->dir == MPIR_COMM_MAP_DIR__R2L)
continue;

MPIR_Assert(comm->comm_kind == MPIR_COMM_KIND__INTERCOMM);

if (mapper->dir == MPIR_COMM_MAP_DIR__L2R) {
if (src_comm->comm_kind == MPIR_COMM_KIND__INTRACOMM)
dup_vcrt(src_comm->dev.vcrt, &comm->dev.vcrt, mapper, mapper->src_comm->local_size,
vcrt_size, vcrt_offset);
else
dup_vcrt(src_comm->dev.local_vcrt, &comm->dev.vcrt, mapper,
mapper->src_comm->local_size, vcrt_size, vcrt_offset);
} else {
if (comm->local_group->ch3_vcrt == NULL) {
MPIDI_VCRT_Add_ref(comm->dev.local_vcrt);
comm->local_group->ch3_vcrt = comm->dev.local_vcrt;
}
else { /* mapper->dir == MPIR_COMM_MAP_DIR__R2R */
MPIR_Assert(src_comm->comm_kind == MPIR_COMM_KIND__INTERCOMM);
dup_vcrt(src_comm->dev.vcrt, &comm->dev.vcrt, mapper, mapper->src_comm->remote_size,
vcrt_size, vcrt_offset);
if (comm->remote_group->ch3_vcrt == NULL) {
MPIDI_VCRT_Add_ref(comm->dev.vcrt);
comm->remote_group->ch3_vcrt = comm->dev.vcrt;
}
vcrt_offset += map_size(*mapper);
}

if (comm->comm_kind == MPIR_COMM_KIND__INTERCOMM) {
Expand All @@ -326,6 +253,7 @@ int MPIDI_CH3I_Comm_commit_pre_hook(MPIR_Comm *comm)
}
}

hook_elt *elt;
LL_FOREACH(create_hooks_head, elt) {
mpi_errno = elt->hook_fn(comm, elt->param);
if (mpi_errno) MPIR_ERR_POP(mpi_errno);;
Expand Down

0 comments on commit 9cc9027

Please sign in to comment.