Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coll: add coll_attr and comm subgroups #6590

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion maint/extracterrmsgs
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ sub ProcessFile
!($args[$errClassLoc] =~ /^MPIDI_CH3I_SOCK_ERR_/) &&
!($args[$errClassLoc] =~ /^MPIX_ERR_/) &&
!($args[$errClassLoc] =~ /^errclass/) &&
!($args[$errClassLoc] =~ /^errflag/) &&
!($args[$errClassLoc] =~ /^(errflag|coll_attr)/) &&
!($args[$errClassLoc] =~ /^\*errflag/)) {
$bad_syntax_in_file{$filename} = 1;
print STDERR "Invalid argument $args[$errClassLoc] for the MPI Error class in $routineName in $filename\n";
Expand Down
8 changes: 4 additions & 4 deletions maint/gen_coll.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def get_algo_args(args, algo, kind):
elif algo['func-commkind'].startswith('i'):
algo_args += ", *sched_p"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_args += ", errflag"
algo_args += ", coll_attr"

return algo_args

Expand All @@ -666,7 +666,7 @@ def get_algo_params(params, algo):
elif algo['func-commkind'].startswith('i'):
algo_params += ", MPIR_Sched_t s"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_params += ", MPIR_Errflag_t errflag"
algo_params += ", int coll_attr"

return algo_params

Expand All @@ -683,7 +683,7 @@ def get_func_params(params, name, kind):
func_params = params
if kind == "blocking":
if not name.startswith('neighbor_'):
func_params += ", MPIR_Errflag_t errflag"
func_params += ", int coll_attr"
elif kind == "nonblocking":
func_params += ", MPIR_Request ** request"
elif kind == "persistent":
Expand All @@ -703,7 +703,7 @@ def get_func_args(args, name, kind):
func_args = args
if kind == "blocking":
if not name.startswith('neighbor_'):
func_args += ", errflag"
func_args += ", coll_attr"
elif kind == "nonblocking":
func_args += ", request"
elif kind == "persistent":
Expand Down
2 changes: 1 addition & 1 deletion maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,7 +1691,7 @@ def push_impl_decl(func, impl_name=None):
if func['dir'] == 'coll':
# block collective use an extra errflag
if not RE.match(r'MPI_(I.*|Neighbor.*|.*_init)$', func['name']):
params = params + ", MPIR_Errflag_t errflag"
params = params + ", int coll_attr"
else:
params="void"

Expand Down
47 changes: 38 additions & 9 deletions src/include/mpir_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,40 @@
#include "coll_impl.h"
#include "coll_algos.h"

/* Define bit values for collective attributes. */

/* NOTE: the low 8 bits must be the same as pt2pt attr, ref. mpir_pt2pt.h */
#define MPIR_COLL_ATTR_GET_ERRFLAG(attr) ((attr) & 0x6)
#define MPIR_ERR_NONE 0
#define MPIR_ERR_PROC_FAILED 2
#define MPIR_ERR_OTHER 4

#define MPIR_ATTR_COLL_CONTEXT 1
#define MPIR_ATTR_SYNCFLAG 8

/* Subgroup is the index to comm->subgroups[], and resulting in group collectives.
* NOTE: cross reference MPIR_MAX_SUBGROUPS */
#define MPIR_COLL_ATTR_GET_SUBGROUP(attr) (((attr) & 0xf00) >> 8)
#define MPIR_COLL_ATTR_SUBGROUP(idx) (((idx) & 0xf) << 8)

#define MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_) \
do { \
int grp = MPIR_COLL_ATTR_GET_SUBGROUP(coll_attr); \
if (grp == 0) { \
rank_ = (comm)->rank; \
size_ = (comm)->local_size; \
} else { \
rank_ = (comm)->subgroups[grp].rank; \
size_ = (comm)->subgroups[grp].size; \
} \
} while (0)

/* During init, not all algorithms are safe to use. For example, the csel
* may not have been initialized. We define a set of fallback routines that
* are safe to use during init. They are all intra algorithms.
*/
#define MPIR_Barrier_fallback MPIR_Barrier_intra_dissemination
#define MPIR_Bcast_fallback MPIR_Bcast_intra_binomial
#define MPIR_Allgather_fallback MPIR_Allgather_intra_brucks
#define MPIR_Allgatherv_fallback MPIR_Allgatherv_intra_brucks
#define MPIR_Allreduce_fallback MPIR_Allreduce_intra_recursive_doubling
Expand All @@ -28,31 +57,31 @@ int MPIC_Wait(MPIR_Request * request_ptr);
int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status);

int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, int coll_attr);
int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag,
MPIR_Comm * comm_ptr, MPI_Status * status);
MPIR_Comm * comm_ptr, int coll_attr, MPI_Status * status);
int MPIC_Ssend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, int coll_attr);
int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
int dest, int sendtag, void *recvbuf, MPI_Aint recvcount,
MPI_Datatype recvtype, int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPI_Status * status, int coll_attr);
int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int sendtag,
int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPI_Status * status, int coll_attr);
int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPIR_Request ** request, int coll_attr);
int MPIC_Issend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, MPIR_Request ** request, int coll_attr);
int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source,
int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request);
int tag, MPIR_Comm * comm_ptr, int coll_attr, MPIR_Request ** request);
int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status * statuses);

int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op);

int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_attr);

/* TSP auto */
int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
Expand Down
50 changes: 49 additions & 1 deletion src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,51 @@ enum MPIR_COMM_HINT_PREDEFINED_t {
MPIR_COMM_HINT_PREDEFINED_COUNT
};

/* MPIR_Subgroup is similar to MPIR_Group, but only used to describe subgroups within
* an intra communicator. The proc_table refers to ranks within the communicator.
* It is only used internally for group collectives.
*/
enum MPIR_Subgoup_kind {
MPIR_SUBGROUP_SELF, /* Refers to the comm group itself */
MPIR_SUBGROUP_THREADCOMM, /* Supercomm includes threads in a threadcomm */
MPIR_SUBGROUP_NODE, /* i.e. nodecomm */
MPIR_SUBGROUP_NODE_CROSS, /* node_roots_comm, node_rank_1_comm, ... */
MPIR_SUBGROUP_NUMA1, /* 1-level below node in topology */
MPIR_SUBGROUP_NUMA1_CROSS, /* cross-link group at NUMA1 within NODE */
MPIR_SUBGROUP_NUMA2, /* and so on */
MPIR_SUBGROUP_NUMA2_CROSS,
MPIR_SUBGROUP_TEMP, /* Temporary anonymous group */
};

typedef struct MPIR_Subgroup {
enum MPIR_Subgoup_kind kind;
int size;
int rank;
int *proc_table; /* can be NULL if the group is trivial */
} MPIR_Subgroup;

/* NOTE: cross reference coll_attr bit patterns in mpir_coll.h */
#define MPIR_MAX_SUBGROUPS 16

#define MPIR_COMM_LAST_SUBGROUP(comm) ((comm)->num_subgroups - 1)

#define MPIR_COMM_NEW_SUBGROUP(comm, _kind, _size, _rank) \
do { \
int i = (comm)->num_subgroups++; \
MPIR_Assert((comm)->num_subgroups < MPIR_MAX_SUBGROUPS); \
(comm)->subgroups[i].kind = _kind; \
(comm)->subgroups[i].size = _size; \
(comm)->subgroups[i].rank = _rank; \
(comm)->subgroups[i].proc_table = NULL; \
} while (0)

#define MPIR_COMM_POP_SUBGROUP(comm) \
do { \
int i = --(comm)->num_subgroups; \
MPIR_Assert(i > 0); \
MPL_free((comm)->subgroups[i].proc_table); \
} while (0)

/*S
MPIR_Comm - Description of the Communicator data structure

Expand Down Expand Up @@ -187,7 +232,8 @@ struct MPIR_Comm {
int *internode_table; /* internode_table[i] gives the rank in
* node_roots_comm of rank i in this comm.
* It is of size 'local_size'. */
int node_count; /* number of nodes this comm is spread over */
int num_local; /* number of procs in this comm on local node */
int num_external; /* number of nodes this comm is spread over */

int is_low_group; /* For intercomms only, this boolean is
* set for all members of one of the
Expand All @@ -196,6 +242,8 @@ struct MPIR_Comm {
* intercommunicator collective operations
* that wish to use half-duplex operations
* to implement a full-duplex operation */
MPIR_Subgroup subgroups[MPIR_MAX_SUBGROUPS];
int num_subgroups;

struct MPIR_Comm *comm_next; /* Provides a chain through all active
* communicators */
Expand Down
10 changes: 0 additions & 10 deletions src/include/mpir_misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@
#define MPIR_FINALIZE_CALLBACK_DEFAULT_PRIO 0
#define MPIR_FINALIZE_CALLBACK_MAX_PRIO 10

/* Define a typedef for the errflag value used by many internal
* functions. If an error needs to be returned, these values can be
* used to signal such. More details can be found further down in the
* code with the bitmasking logic */
typedef enum {
MPIR_ERR_NONE = MPI_SUCCESS,
MPIR_ERR_PROC_FAILED = MPIX_ERR_PROC_FAILED,
MPIR_ERR_OTHER = MPI_ERR_OTHER
} MPIR_Errflag_t;

/*E
MPIR_Lang_t - Known language bindings for MPI

Expand Down
14 changes: 4 additions & 10 deletions src/include/mpir_pt2pt.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* 24-31: reserved (must be 0)
*/
/* NOTE: All explicit vci (allocated) must be greater than 0 */
/* NOTE: MPIR_ERR_XXX flags defined in mpir_misc.h and must be
* consistent, i.e. 0x2 and 0x4 respectively. */

#define MPIR_PT2PT_ATTR_SRC_VCI_SHIFT 8
#define MPIR_PT2PT_ATTR_DST_VCI_SHIFT 16
Expand All @@ -33,19 +35,11 @@
#define MPIR_PT2PT_ATTR_SET_CONTEXT_OFFSET(attr, context_offset) (attr) |= (context_offset)

/* bit 1-2: errflag */
#define MPIR_PT2PT_ATTR_GET_ERRFLAG(attr) \
((!((attr) & 0x6)) ? MPIR_ERR_NONE : \
(((attr) & 0x2) ? MPIX_ERR_PROC_FAILED : MPI_ERR_OTHER))
#define MPIR_PT2PT_ATTR_GET_ERRFLAG(attr) ((attr) & 0x6)

#define MPIR_PT2PT_ATTR_SET_ERRFLAG(attr, errflag) \
do { \
if (errflag) { \
if (errflag == MPIR_ERR_PROC_FAILED) { \
(attr) |= 0x2; \
} else { \
(attr) |= 0x4; \
} \
} \
(attr) |= (errflag); \
} while (0)

/* bit 3: syncflag */
Expand Down
4 changes: 2 additions & 2 deletions src/include/mpir_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct MPIR_Request {
struct MPIR_Grequest_fns *greq_fns;
} ureq; /* kind : MPIR_REQUEST_KIND__GREQUEST */
struct {
MPIR_Errflag_t errflag;
int coll_attr;
MPII_Coll_req_t coll;
} nbc; /* kind : MPIR_REQUEST_KIND__COLL */
struct {
Expand Down Expand Up @@ -429,7 +429,7 @@ static inline MPIR_Request *MPIR_Request_create_from_pool(MPIR_Request_kind_t ki

switch (kind) {
case MPIR_REQUEST_KIND__COLL:
req->u.nbc.errflag = MPIR_ERR_NONE;
req->u.nbc.coll_attr = 0;
req->u.nbc.coll.host_sendbuf = NULL;
req->u.nbc.coll.host_recvbuf = NULL;
req->u.nbc.coll.datatype = MPI_DATATYPE_NULL;
Expand Down
11 changes: 5 additions & 6 deletions src/include/mpir_threadcomm.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,22 @@ MPL_STATIC_INLINE_PREFIX
}

#ifdef ENABLE_THREADCOMM
#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \
MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \
if (threadcomm) { \
MPIR_Assert(MPIR_COLL_ATTR_GET_SUBGROUP(coll_attr) == 0); /* for now */ \
int intracomm_size = (comm)->local_size; \
size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \
rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \
} else { \
rank_ = (comm)->rank; \
size_ = (comm)->local_size; \
MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \
} \
} while (0)

#else
#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
#define MPIR_THREADCOMM_RANK_SIZE(comm, coll_attr, rank_, size_) do { \
MPIR_Assert((comm)->threadcomm == NULL); \
rank_ = (comm)->rank; \
size_ = (comm)->local_size; \
MPIR_COLL_INTRA_RANK_SIZE(comm, coll_attr, rank_, size_); \
} while (0)

#endif
Expand Down
6 changes: 3 additions & 3 deletions src/mpi/coll/algorithms/treealgo/treeutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,9 @@ int MPII_Treeutil_tree_topology_aware_k_init(MPIR_Comm * comm, int k, int root,
} else {
/* rank level - build a tree on the ranks */
/* Do an allgather to know the current num_children on each rank */
MPIR_Errflag_t errflag = MPIR_ERR_NONE;
int coll_attr = MPIR_ERR_NONE;
MPIR_Allgather_impl(&(ct->num_children), 1, MPI_INT, num_childrens, 1, MPI_INT,
comm, errflag);
comm, coll_attr);
if (mpi_errno) {
goto fn_fail;
}
Expand Down Expand Up @@ -1129,7 +1129,7 @@ int MPII_Treeutil_tree_topology_wave_init(MPIR_Comm * comm, int k, int root, boo
heap_vector minHeaps;
heap_vector_init(&minHeaps);

/* To build hierarchy of ranks, swiches and groups */
/* To build hierarchy of ranks, switches and groups */
int dim = MPIR_Process.coords_dims - 1;
for (dim = MPIR_Process.coords_dims - 1; dim >= 0; --dim)
tree_ut_hierarchy_init(&hierarchy[dim]);
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/allgather/allgather_allcomm_nb.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

int MPIR_Allgather_allcomm_nb(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
void *recvbuf, MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int coll_attr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Request *req_ptr = NULL;
Expand Down
22 changes: 11 additions & 11 deletions src/mpi/coll/allgather/allgather_inter_local_gather_remote_bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint sendcount,
MPI_Datatype sendtype, void *recvbuf,
MPI_Aint recvcount, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
MPIR_Comm * comm_ptr, int coll_attr)
{
int rank, local_size, remote_size, mpi_errno = MPI_SUCCESS, root;
int mpi_errno_ret = MPI_SUCCESS;
Expand Down Expand Up @@ -48,8 +48,8 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint

if (sendcount != 0) {
mpi_errno = MPIR_Gather(sendbuf, sendcount, sendtype, tmp_buf, sendcount * sendtype_sz,
MPI_BYTE, 0, newcomm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPI_BYTE, 0, newcomm_ptr, coll_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}

/* first broadcast from left to right group, then from right to
Expand All @@ -59,32 +59,32 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPI_BYTE, root, comm_ptr, coll_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}

/* receive bcast from right */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
recvtype, root, comm_ptr, coll_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}
} else {
/* receive bcast from left */
if (recvcount != 0) {
root = 0;
mpi_errno = MPIR_Bcast(recvbuf, recvcount * remote_size,
recvtype, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
recvtype, root, comm_ptr, coll_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}

/* bcast to left */
if (sendcount != 0) {
root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
mpi_errno = MPIR_Bcast(tmp_buf, sendcount * sendtype_sz * local_size,
MPI_BYTE, root, comm_ptr, errflag);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, errflag, mpi_errno_ret);
MPI_BYTE, root, comm_ptr, coll_attr);
MPIR_ERR_COLL_CHECKANDCONT(mpi_errno, coll_attr, mpi_errno_ret);
}
}

Expand Down
Loading