diff --git a/ompi/mca/osc/rdma/osc_rdma_component.c b/ompi/mca/osc/rdma/osc_rdma_component.c index a5d06cb7916..92a32f45285 100644 --- a/ompi/mca/osc/rdma/osc_rdma_component.c +++ b/ompi/mca/osc/rdma/osc_rdma_component.c @@ -69,6 +69,8 @@ #include "ompi/mca/bml/base/base.h" #include "ompi/mca/mtl/base/base.h" +static int ompi_osc_rdma_shared_query(struct ompi_win_t *win, int rank, size_t *size, + ptrdiff_t *disp_unit, void *baseptr); static int ompi_osc_rdma_component_register (void); static int ompi_osc_rdma_component_init (bool enable_progress_threads, bool enable_mpi_threads); static int ompi_osc_rdma_component_finalize (void); @@ -113,6 +115,7 @@ ompi_osc_rdma_component_t mca_osc_rdma_component = { MCA_BASE_COMPONENT_INIT(ompi, osc, rdma) ompi_osc_base_module_t ompi_osc_rdma_module_rdma_template = { + .osc_win_shared_query = ompi_osc_rdma_shared_query, .osc_win_attach = ompi_osc_rdma_attach, .osc_win_detach = ompi_osc_rdma_detach, .osc_free = ompi_osc_rdma_free, @@ -898,7 +901,7 @@ static void ompi_osc_rdma_ensure_local_add_procs (void) /* this will cause add_proc to get called if it has not already been called */ (void) mca_bml_base_get_endpoint (proc); } - } + } free(procs); } @@ -1632,3 +1635,60 @@ ompi_osc_rdma_set_no_lock_info(opal_infosubscriber_t *obj, const char *key, cons */ return module->no_locks ? "true" : "false"; } + +int ompi_osc_rdma_shared_query( + struct ompi_win_t *win, int rank, size_t *size, + ptrdiff_t *disp_unit, void *baseptr) +{ + int rc = OMPI_ERR_NOT_SUPPORTED; + ompi_osc_rdma_peer_t *peer; + int actual_rank = rank; + ompi_osc_rdma_module_t *module = GET_MODULE(win); + + peer = ompi_osc_rdma_module_peer (module, actual_rank); + if (NULL == peer) { + return OMPI_ERR_BAD_PARAM; + } + + /* currently only supported for allocated windows */ + if (MPI_WIN_FLAVOR_ALLOCATE != module->flavor) { + return OMPI_ERR_NOT_SUPPORTED; + } + + if (!ompi_osc_rdma_peer_local_base(peer)) { + return OMPI_ERR_NOT_SUPPORTED; + } + + if (MPI_PROC_NULL == rank) { + /* iterate until we find a rank that has a non-zero size */ + for (int i = 0 ; i < ompi_comm_size(module->comm) ; ++i) { + peer = ompi_osc_rdma_module_peer (module, i); + ompi_osc_rdma_peer_extended_t *ex_peer = (ompi_osc_rdma_peer_extended_t *) peer; + if (!ompi_osc_rdma_peer_local_base(peer)) { + continue; + } else if (module->same_size && ex_peer->super.base) { + break; + } else if (ex_peer->size > 0) { + break; + } + } + } + + if (module->same_size && module->same_disp_unit) { + *size = module->size; + *disp_unit = module->disp_unit; + ompi_osc_rdma_peer_basic_t *ex_peer = (ompi_osc_rdma_peer_basic_t *) peer; + *((void**) baseptr) = (void *) (intptr_t)ex_peer->base; + rc = OMPI_SUCCESS; + } else { + ompi_osc_rdma_peer_extended_t *ex_peer = (ompi_osc_rdma_peer_extended_t *) peer; + if (ex_peer->super.base != 0) { + /* we know the base of the peer */ + *((void**) baseptr) = (void *) (intptr_t)ex_peer->super.base; + *size = ex_peer->size; + *disp_unit = ex_peer->disp_unit; + rc = OMPI_SUCCESS; + } + } + return rc; +} diff --git a/ompi/mca/osc/sm/osc_sm_component.c b/ompi/mca/osc/sm/osc_sm_component.c index 87ed6a1431b..5d4510e36f1 100644 --- a/ompi/mca/osc/sm/osc_sm_component.c +++ b/ompi/mca/osc/sm/osc_sm_component.c @@ -481,10 +481,6 @@ ompi_osc_sm_shared_query(struct ompi_win_t *win, int rank, size_t *size, ptrdiff ompi_osc_sm_module_t *module = (ompi_osc_sm_module_t*) win->w_osc_module; - if (module->flavor != MPI_WIN_FLAVOR_SHARED) { - return MPI_ERR_WIN; - } - if (MPI_PROC_NULL != rank) { *size = module->sizes[rank]; *((void**) baseptr) = module->bases[rank]; diff --git a/ompi/mca/osc/ucx/osc_ucx.h b/ompi/mca/osc/ucx/osc_ucx.h index 1c349f30592..2f2294797e7 100644 --- a/ompi/mca/osc/ucx/osc_ucx.h +++ b/ompi/mca/osc/ucx/osc_ucx.h @@ -116,6 +116,19 @@ typedef struct ompi_osc_ucx_mem_ranges { uint64_t tail; } ompi_osc_ucx_mem_ranges_t; +/** + * Structure to hold information about shared memory regions. + * We store the rank, it's address, and the size of the window region. + * We don't store the disp_unit here, as that is stored elsewhere already. + */ +struct ompi_osc_ucx_shmem_info_s { + int peer; /* rank of the peer this information belongs to */ + char *addr; /* address of the shared memory region */ + size_t size; /* size of the shared memory region */ +}; + +typedef struct ompi_osc_ucx_shmem_info_s ompi_osc_ucx_shmem_info_t; + typedef struct ompi_osc_ucx_module { ompi_osc_base_module_t super; struct ompi_communicator_t *comm; @@ -128,6 +141,7 @@ typedef struct ompi_osc_ucx_module { * disp unit size; if disp_unit == -1, then we * need to look at disp_units */ ptrdiff_t *disp_units; + ompi_osc_ucx_shmem_info_t *shmem_info; /* shared memory info */ ompi_osc_ucx_state_t state; /* remote accessible flags */ ompi_osc_local_dynamic_win_info_t local_dynamic_win_info[OMPI_OSC_UCX_ATTACH_MAX]; diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 27201eae8ff..22bba2a48f3 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -468,17 +468,60 @@ static const char* ompi_osc_ucx_set_no_lock_info(opal_infosubscriber_t *obj, con return module->no_locks ? "true" : "false"; } +static int ompi_osc_ucx_shared_query_peer(ompi_osc_ucx_module_t *module, int peer, size_t *size, + ptrdiff_t *disp_unit, void *baseptr) { + + int rc; + ucp_ep_h *dflt_ep; + ucp_ep_h ep; // ignored + ucp_rkey_h rkey; + OSC_UCX_GET_DEFAULT_EP(dflt_ep, module, peer); + ucs_status_t status; + opal_common_ucx_winfo_t *winfo; // ignored + rc = opal_common_ucx_tlocal_fetch(module->mem, peer, &ep, &rkey, &winfo, dflt_ep); + if (OMPI_SUCCESS != rc) { + return rc; + } + uint64_t raddr; + void *addr_p; + if (UCS_OK != ucp_rkey_ptr(rkey, module->addrs[peer], &addr_p)) { + return OMPI_ERR_NOT_AVAILABLE; + } + *size = module->sizes[peer]; + *((void**) baseptr) = (void *)module->shmem_addrs[peer]; + *disp_unit = module->disp_units[peer]; + + return OMPI_SUCCESS; +} + int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size, ptrdiff_t *disp_unit, void *baseptr) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; + *size = 0; + *((void**) baseptr) = NULL; + *disp_unit = 0; + if (module->flavor != MPI_WIN_FLAVOR_SHARED) { - return MPI_ERR_WIN; - } - if (MPI_PROC_NULL != rank) { + if (MPI_PROC_NULL == rank) { + for (int i = 0 ; i < ompi_comm_size(module->comm) ; ++i) { + if (0 != module->sizes[i]) { + if (OMPI_SUCCESS == ompi_osc_ucx_shared_query_peer(module, i, size, disp_unit, baseptr)) { + return OMPI_SUCCESS; + } + } + } + } else { + if (0 != module->sizes[rank]) { + return ompi_osc_ucx_shared_query_peer(module, rank, size, disp_unit, baseptr); + } + } + return OMPI_ERR_NOT_SUPPORTED; + + } else if (MPI_PROC_NULL != rank) { // shared memory window with given rank *size = module->sizes[rank]; *((void**) baseptr) = (void *)module->shmem_addrs[rank]; if (module->disp_unit == -1) { @@ -486,12 +529,9 @@ int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size, } else { *disp_unit = module->disp_unit; } - } else { + } else { // shared memory window with MPI_PROC_NULL int i = 0; - *size = 0; - *((void**) baseptr) = NULL; - *disp_unit = 0; for (i = 0 ; i < ompi_comm_size(module->comm) ; ++i) { if (0 != module->sizes[i]) { *size = module->sizes[i]; @@ -639,6 +679,13 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt module->acc_single_intrinsic = check_config_value_bool ("acc_single_intrinsic", info); module->skip_sync_check = false; + /** + * TODO: we need to collect the shared memory information from all processes + * on the same node. This includes the size and base address, which needs + * to be passed to ucp_rkey_ptr(). + */ + module->shmem_info = NULL; + /* share everyone's displacement units. Only do an allgather if strictly necessary, since it requires O(p) state. */ values[0] = disp_unit; @@ -807,6 +854,18 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt module->size = module->sizes[ompi_comm_rank(module->comm)]; *base = (void *)module->shmem_addrs[ompi_comm_rank(module->comm)]; + } else { + /* non-shared memory: exchange sizes and addresses so they can be queried for shared memory */ + for (i = 0; i < comm_size; i++) { + ompi_proc_t *peer = ompi_comm_peer_lookup(module->comm, i); + peer-> + if (ompi_comm_peer_lookup(module->comm, i) == NULL) { + OSC_UCX_ERROR("Failed to lookup peer %d in communicator %s", i, ompi_comm_print_cid(module->comm)); + ret = OMPI_ERR_COMM_FAILURE; + goto error; + } + } + } void **mem_base = base; diff --git a/ompi/mpi/c/win_shared_query.c.in b/ompi/mpi/c/win_shared_query.c.in index ad88189428f..0616a9366aa 100644 --- a/ompi/mpi/c/win_shared_query.c.in +++ b/ompi/mpi/c/win_shared_query.c.in @@ -26,9 +26,9 @@ PROTOTYPE ERROR_CLASS win_shared_query(WIN win, INT rank, AINT_OUT size, INT_AINT_OUT disp_unit, BUFFER_OUT baseptr) { - int rc; size_t tsize; ptrdiff_t du; + int rc = OMPI_SUCCESS; if (MPI_PARAM_CHECK) { OMPI_ERR_INIT_FINALIZE(FUNC_NAME); @@ -40,12 +40,23 @@ PROTOTYPE ERROR_CLASS win_shared_query(WIN win, INT rank, AINT_OUT size, INT_AIN } } + rc = OMPI_ERR_NOT_SUPPORTED; + if (NULL != win->w_osc_module->osc_win_shared_query) { rc = win->w_osc_module->osc_win_shared_query(win, rank, &tsize, &du, baseptr); - *size = tsize; - *disp_unit = du; - } else { - rc = MPI_ERR_RMA_FLAVOR; + if (OMPI_SUCCESS == rc) { + *size = tsize; + *disp_unit = du; + } + } + + if (OMPI_ERR_NOT_SUPPORTED == rc) { + /* gracefully bail out */ + *size = 0; + *disp_unit = 0; + *(void**) baseptr = NULL; + rc = MPI_SUCCESS; // don't raise an error if the function is not supported } + OMPI_ERRHANDLER_RETURN(rc, win, rc, FUNC_NAME); }