Skip to content

MPI 4.0: Allow MPI_WIN_SHARED_QUERY on regular windows #13330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion ompi/mca/osc/rdma/osc_rdma_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
#include "ompi/mca/bml/base/base.h"
#include "ompi/mca/mtl/base/base.h"

static int ompi_osc_rdma_shared_query(struct ompi_win_t *win, int rank, size_t *size,
ptrdiff_t *disp_unit, void *baseptr);
static int ompi_osc_rdma_component_register (void);
static int ompi_osc_rdma_component_init (bool enable_progress_threads, bool enable_mpi_threads);
static int ompi_osc_rdma_component_finalize (void);
Expand Down Expand Up @@ -113,6 +115,7 @@ ompi_osc_rdma_component_t mca_osc_rdma_component = {
MCA_BASE_COMPONENT_INIT(ompi, osc, rdma)

ompi_osc_base_module_t ompi_osc_rdma_module_rdma_template = {
.osc_win_shared_query = ompi_osc_rdma_shared_query,
.osc_win_attach = ompi_osc_rdma_attach,
.osc_win_detach = ompi_osc_rdma_detach,
.osc_free = ompi_osc_rdma_free,
Expand Down Expand Up @@ -898,7 +901,7 @@ static void ompi_osc_rdma_ensure_local_add_procs (void)
/* this will cause add_proc to get called if it has not already been called */
(void) mca_bml_base_get_endpoint (proc);
}
}
}

free(procs);
}
Expand Down Expand Up @@ -1632,3 +1635,60 @@ ompi_osc_rdma_set_no_lock_info(opal_infosubscriber_t *obj, const char *key, cons
*/
return module->no_locks ? "true" : "false";
}

int ompi_osc_rdma_shared_query(
struct ompi_win_t *win, int rank, size_t *size,
ptrdiff_t *disp_unit, void *baseptr)
{
int rc = OMPI_ERR_NOT_SUPPORTED;
ompi_osc_rdma_peer_t *peer;
int actual_rank = rank;
ompi_osc_rdma_module_t *module = GET_MODULE(win);

peer = ompi_osc_rdma_module_peer (module, actual_rank);
if (NULL == peer) {
return OMPI_ERR_BAD_PARAM;
}

/* currently only supported for allocated windows */
if (MPI_WIN_FLAVOR_ALLOCATE != module->flavor) {
return OMPI_ERR_NOT_SUPPORTED;
}

if (!ompi_osc_rdma_peer_local_base(peer)) {
return OMPI_ERR_NOT_SUPPORTED;
}

if (MPI_PROC_NULL == rank) {
/* iterate until we find a rank that has a non-zero size */
for (int i = 0 ; i < ompi_comm_size(module->comm) ; ++i) {
peer = ompi_osc_rdma_module_peer (module, i);
ompi_osc_rdma_peer_extended_t *ex_peer = (ompi_osc_rdma_peer_extended_t *) peer;
if (!ompi_osc_rdma_peer_local_base(peer)) {
continue;
} else if (module->same_size && ex_peer->super.base) {
break;
} else if (ex_peer->size > 0) {
break;
}
}
}

if (module->same_size && module->same_disp_unit) {
*size = module->size;
*disp_unit = module->disp_unit;
ompi_osc_rdma_peer_basic_t *ex_peer = (ompi_osc_rdma_peer_basic_t *) peer;
*((void**) baseptr) = (void *) (intptr_t)ex_peer->base;
rc = OMPI_SUCCESS;
} else {
ompi_osc_rdma_peer_extended_t *ex_peer = (ompi_osc_rdma_peer_extended_t *) peer;
if (ex_peer->super.base != 0) {
/* we know the base of the peer */
*((void**) baseptr) = (void *) (intptr_t)ex_peer->super.base;
*size = ex_peer->size;
*disp_unit = ex_peer->disp_unit;
rc = OMPI_SUCCESS;
}
}
return rc;
}
4 changes: 0 additions & 4 deletions ompi/mca/osc/sm/osc_sm_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,6 @@ ompi_osc_sm_shared_query(struct ompi_win_t *win, int rank, size_t *size, ptrdiff
ompi_osc_sm_module_t *module =
(ompi_osc_sm_module_t*) win->w_osc_module;

if (module->flavor != MPI_WIN_FLAVOR_SHARED) {
return MPI_ERR_WIN;
}

if (MPI_PROC_NULL != rank) {
*size = module->sizes[rank];
*((void**) baseptr) = module->bases[rank];
Expand Down
14 changes: 14 additions & 0 deletions ompi/mca/osc/ucx/osc_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ typedef struct ompi_osc_ucx_mem_ranges {
uint64_t tail;
} ompi_osc_ucx_mem_ranges_t;

/**
* Structure to hold information about shared memory regions.
* We store the rank, it's address, and the size of the window region.
* We don't store the disp_unit here, as that is stored elsewhere already.
*/
struct ompi_osc_ucx_shmem_info_s {
int peer; /* rank of the peer this information belongs to */
char *addr; /* address of the shared memory region */
size_t size; /* size of the shared memory region */
};

typedef struct ompi_osc_ucx_shmem_info_s ompi_osc_ucx_shmem_info_t;

typedef struct ompi_osc_ucx_module {
ompi_osc_base_module_t super;
struct ompi_communicator_t *comm;
Expand All @@ -128,6 +141,7 @@ typedef struct ompi_osc_ucx_module {
* disp unit size; if disp_unit == -1, then we
* need to look at disp_units */
ptrdiff_t *disp_units;
ompi_osc_ucx_shmem_info_t *shmem_info; /* shared memory info */

ompi_osc_ucx_state_t state; /* remote accessible flags */
ompi_osc_local_dynamic_win_info_t local_dynamic_win_info[OMPI_OSC_UCX_ATTACH_MAX];
Expand Down
73 changes: 66 additions & 7 deletions ompi/mca/osc/ucx/osc_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -468,30 +468,70 @@ static const char* ompi_osc_ucx_set_no_lock_info(opal_infosubscriber_t *obj, con
return module->no_locks ? "true" : "false";
}

static int ompi_osc_ucx_shared_query_peer(ompi_osc_ucx_module_t *module, int peer, size_t *size,
ptrdiff_t *disp_unit, void *baseptr) {

int rc;
ucp_ep_h *dflt_ep;
ucp_ep_h ep; // ignored
ucp_rkey_h rkey;
OSC_UCX_GET_DEFAULT_EP(dflt_ep, module, peer);
ucs_status_t status;
opal_common_ucx_winfo_t *winfo; // ignored
rc = opal_common_ucx_tlocal_fetch(module->mem, peer, &ep, &rkey, &winfo, dflt_ep);
if (OMPI_SUCCESS != rc) {
return rc;
}
uint64_t raddr;
void *addr_p;
if (UCS_OK != ucp_rkey_ptr(rkey, module->addrs[peer], &addr_p)) {
return OMPI_ERR_NOT_AVAILABLE;
}
*size = module->sizes[peer];
*((void**) baseptr) = (void *)module->shmem_addrs[peer];
*disp_unit = module->disp_units[peer];

return OMPI_SUCCESS;
}

int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size,
ptrdiff_t *disp_unit, void *baseptr)
{
ompi_osc_ucx_module_t *module =
(ompi_osc_ucx_module_t*) win->w_osc_module;

*size = 0;
*((void**) baseptr) = NULL;
*disp_unit = 0;

if (module->flavor != MPI_WIN_FLAVOR_SHARED) {
return MPI_ERR_WIN;
}

if (MPI_PROC_NULL != rank) {
if (MPI_PROC_NULL == rank) {
for (int i = 0 ; i < ompi_comm_size(module->comm) ; ++i) {
if (0 != module->sizes[i]) {
if (OMPI_SUCCESS == ompi_osc_ucx_shared_query_peer(module, i, size, disp_unit, baseptr)) {
return OMPI_SUCCESS;
}
}
}
} else {
if (0 != module->sizes[rank]) {
return ompi_osc_ucx_shared_query_peer(module, rank, size, disp_unit, baseptr);
}
}
return OMPI_ERR_NOT_SUPPORTED;

} else if (MPI_PROC_NULL != rank) { // shared memory window with given rank
*size = module->sizes[rank];
*((void**) baseptr) = (void *)module->shmem_addrs[rank];
if (module->disp_unit == -1) {
*disp_unit = module->disp_units[rank];
} else {
*disp_unit = module->disp_unit;
}
} else {
} else { // shared memory window with MPI_PROC_NULL
int i = 0;

*size = 0;
*((void**) baseptr) = NULL;
*disp_unit = 0;
for (i = 0 ; i < ompi_comm_size(module->comm) ; ++i) {
if (0 != module->sizes[i]) {
*size = module->sizes[i];
Expand Down Expand Up @@ -639,6 +679,13 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
module->acc_single_intrinsic = check_config_value_bool ("acc_single_intrinsic", info);
module->skip_sync_check = false;

/**
* TODO: we need to collect the shared memory information from all processes
* on the same node. This includes the size and base address, which needs
* to be passed to ucp_rkey_ptr().
*/
module->shmem_info = NULL;

/* share everyone's displacement units. Only do an allgather if
strictly necessary, since it requires O(p) state. */
values[0] = disp_unit;
Expand Down Expand Up @@ -807,6 +854,18 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt

module->size = module->sizes[ompi_comm_rank(module->comm)];
*base = (void *)module->shmem_addrs[ompi_comm_rank(module->comm)];
} else {
/* non-shared memory: exchange sizes and addresses so they can be queried for shared memory */
for (i = 0; i < comm_size; i++) {
ompi_proc_t *peer = ompi_comm_peer_lookup(module->comm, i);
peer->
if (ompi_comm_peer_lookup(module->comm, i) == NULL) {
OSC_UCX_ERROR("Failed to lookup peer %d in communicator %s", i, ompi_comm_print_cid(module->comm));
ret = OMPI_ERR_COMM_FAILURE;
goto error;
}
}

}

void **mem_base = base;
Expand Down
21 changes: 16 additions & 5 deletions ompi/mpi/c/win_shared_query.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

PROTOTYPE ERROR_CLASS win_shared_query(WIN win, INT rank, AINT_OUT size, INT_AINT_OUT disp_unit, BUFFER_OUT baseptr)
{
int rc;
size_t tsize;
ptrdiff_t du;
int rc = OMPI_SUCCESS;

if (MPI_PARAM_CHECK) {
OMPI_ERR_INIT_FINALIZE(FUNC_NAME);
Expand All @@ -40,12 +40,23 @@ PROTOTYPE ERROR_CLASS win_shared_query(WIN win, INT rank, AINT_OUT size, INT_AIN
}
}

rc = OMPI_ERR_NOT_SUPPORTED;

if (NULL != win->w_osc_module->osc_win_shared_query) {
rc = win->w_osc_module->osc_win_shared_query(win, rank, &tsize, &du, baseptr);
*size = tsize;
*disp_unit = du;
} else {
rc = MPI_ERR_RMA_FLAVOR;
if (OMPI_SUCCESS == rc) {
*size = tsize;
*disp_unit = du;
}
}

if (OMPI_ERR_NOT_SUPPORTED == rc) {
/* gracefully bail out */
*size = 0;
*disp_unit = 0;
*(void**) baseptr = NULL;
rc = MPI_SUCCESS; // don't raise an error if the function is not supported
}

OMPI_ERRHANDLER_RETURN(rc, win, rc, FUNC_NAME);
}
Loading