diff --git a/maint/local_python/binding_c.py b/maint/local_python/binding_c.py index f219ee4e194..04728fff5eb 100644 --- a/maint/local_python/binding_c.py +++ b/maint/local_python/binding_c.py @@ -570,7 +570,7 @@ def process_func_parameters(func): if p['length']: length = p['length'] if length == '*': - if RE.match(r'MPI_(Test|Wait|Request_get_status_)all', func_name, re.IGNORECASE): + if RE.match(r'MPI_(Test|Wait|Request_get_status_|Continue)all', func_name, re.IGNORECASE): length = "count" elif RE.match(r'MPI_(Test|Wait|Request_get_status_)some', func_name, re.IGNORECASE): length = "incount" @@ -595,7 +595,7 @@ def process_func_parameters(func): if kind == "REQUEST": if RE.match(r'mpi_startall', func_name, re.IGNORECASE): do_handle_ptr = 3 - elif RE.match(r'mpix?_(wait|test|request_get_status)', func_name, re.IGNORECASE): + elif RE.match(r'mpix?_(wait|test|request_get_status|continue)', func_name, re.IGNORECASE): do_handle_ptr = 3 elif kind == "RANK": validation_list.append({'kind': "RANK-ARRAY", 'name': name}) @@ -652,6 +652,8 @@ def process_func_parameters(func): p['can_be_null'] = "MPI_INFO_NULL" elif kind == "REQUEST" and RE.match(r'mpix?_(wait|test|request_get_status|parrived)', func_name, re.IGNORECASE): p['can_be_null'] = "MPI_REQUEST_NULL" + elif kind == "REQUEST" and RE.match(r'mpix_(continue|continueall)', func_name, re.IGNORECASE) and name == "cont_request": + p['can_be_null'] = "MPI_REQUEST_NULL" elif kind == "STREAM" and RE.match(r'mpix?_(stream_(comm_create|progress)|async_(start|spawn))', func_name, re.IGNORECASE): p['can_be_null'] = "MPIX_STREAM_NULL" elif kind == "COMMUNICATOR" and RE.match(r'mpi_comm_get_name', func_name, re.IGNORECASE): @@ -740,7 +742,6 @@ def process_func_parameters(func): validation_list.append({'kind': "ARGNULL", 'name': name}) else: print("Missing error checking: func=%s, name=%s, kind=%s" % (func_name, name, kind), file=sys.stderr) - if do_handle_ptr == 1: if p['param_direction'] == 'inout': # assume only one such parameter @@ -762,7 +763,7 @@ def process_func_parameters(func): if kind == "REQUEST": ptrs_name = "request_ptrs" p['_ptrs_name'] = ptrs_name - if RE.match(r'mpi_startall', func['name'], re.IGNORECASE): + if RE.match(r'mpix?_(start|continue)all', func['name'], re.IGNORECASE): impl_arg_list.append(ptrs_name) impl_param_list.append("MPIR_Request **%s" % ptrs_name) else: diff --git a/src/binding/c/continue_api.txt b/src/binding/c/continue_api.txt new file mode 100644 index 00000000000..1c3dbfae878 --- /dev/null +++ b/src/binding/c/continue_api.txt @@ -0,0 +1,46 @@ +# vim: set ft=c: + +MPIX_Continue_cb_function: + .return: NOTHING + error_code: ERROR_CODE + user_data: BUFFER + +MPIX_Continue_init: + .desc: Creates a new continuation request + flags: ARRAY_LENGTH [flags] + max_poll: ARRAY_LENGTH_NNI [maximum number of continuations to execute when + testing, or 0 for no limit] + info: INFO, [info argument] + cont_req: REQUEST, direction=out, [continuation request created] + +MPIX_Continue: + .desc: Attach a continuation to the operation represented by the request + op_request and register it with the continuation request cont_request + op_request: REQUEST, direction=inout, [the request associated with the active operation] + cb: FUNCTION, func_type=MPIX_Continue_cb_function, [callback to be invoked once the operation is complete] + cb_data: BUFFER, [pointer to a user-controlled buffer] + flags: ARRAY_LENGTH, [flags controlling aspects of the continuation] + status: STATUS, direction=inout, [status object] + cont_request: REQUEST, [continuation request] + +MPIX_Continueall: + .desc: Attach a continuation callback to a set of operation requests + count: ARRAY_LENGTH_NNI, [lists length] + array_of_op_requests: REQUEST, direction=inout, length=count, [array of requests] + cb: FUNCTION, func_type=MPIX_Continue_cb_function, [the continuation callback function] + cb_data: BUFFER, [the argument passed to the callback] + flags: ARRAY_LENGTH, [flags controlling aspects of the continuation] + array_of_statuses: STATUS, direction=inout, length=*, pointer=False, [array of status objects] + cont_request: REQUEST, [the continuation request] +{ + mpi_errno = MPIR_Continueall_impl(count, request_ptrs, cb, cb_data, flags, array_of_statuses, + cont_request_ptr); + if (mpi_errno) { + goto fn_fail; + } + if (!(flags & MPIX_CONT_REQBUF_VOLATILE)) { + for (int i = 0; i < count; ++i) { + array_of_op_requests[i] = MPI_REQUEST_NULL; + } + } +} \ No newline at end of file diff --git a/src/include/Makefile.mk b/src/include/Makefile.mk index 30cb3c1d911..844cd95fde4 100644 --- a/src/include/Makefile.mk +++ b/src/include/Makefile.mk @@ -27,7 +27,7 @@ noinst_HEADERS += \ src/include/mpir_refcount_global.h \ src/include/mpir_refcount_vci.h \ src/include/mpir_refcount_single.h \ - src/include/mpir_refcount.h \ + src/include/mpir_atomic_flag.h \ src/include/mpir_assert.h \ src/include/mpir_misc_post.h \ src/include/mpir_type_defs.h \ diff --git a/src/include/mpi.h.in b/src/include/mpi.h.in index 1180c3f17a1..e046df9cc6e 100644 --- a/src/include/mpi.h.in +++ b/src/include/mpi.h.in @@ -763,6 +763,14 @@ enum MPIR_Combiner_enum { #define MPIX_GPU_SUPPORT_ZE (1) #define MPIX_GPU_SUPPORT_HIP (2) +/* Continue flags */ +#define MPIX_CONT_REQBUF_VOLATILE 1<<0 +#define MPIX_CONT_PERSISTENT 1<<1 +#define MPIX_CONT_POLL_ONLY 1<<2 +#define MPIX_CONT_DEFER_COMPLETE 1<<3 +#define MPIX_CONT_INVOKE_FAILED 1<<4 +#define MPIX_CONT_IMMEDIATE 1<<5 + /* feature advertisement */ #define MPIIMPL_ADVERTISES_FEATURES 1 #define MPIIMPL_HAVE_MPI_INFO 1 @@ -843,6 +851,9 @@ typedef int (MPI_Datarep_extent_function)(MPI_Datatype datatype, MPI_Aint *, typedef int (MPI_Datarep_conversion_function_c)(void *, MPI_Datatype, MPI_Count, void *, MPI_Offset, void *); +/* Typedefs for continuation callback */ +typedef int (MPIX_Continue_cb_function)(int error_code, void *user_data); + /* Make the C names for the dup function mixed case. This is required for systems that use all uppercase names for Fortran externals. */ diff --git a/src/include/mpiimpl.h b/src/include/mpiimpl.h index d665d3e2eff..98d1890b8c1 100644 --- a/src/include/mpiimpl.h +++ b/src/include/mpiimpl.h @@ -163,6 +163,7 @@ typedef struct MPIR_Stream MPIR_Stream; #include "mpir_assert.h" #include "mpir_pointers.h" #include "mpir_refcount.h" +#include "mpir_atomic_flag.h" #include "mpir_mem.h" #include "mpir_info.h" #include "mpir_errcodes.h" diff --git a/src/include/mpir_atomic_flag.h b/src/include/mpir_atomic_flag.h new file mode 100644 index 00000000000..dad47075912 --- /dev/null +++ b/src/include/mpir_atomic_flag.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#ifndef MPIR_ATOMIC_FLAG_H_INCLUDED +#define MPIR_ATOMIC_FLAG_H_INCLUDED + +#include "mpi.h" +#include "mpichconf.h" + +#if MPICH_THREAD_LEVEL == MPI_THREAD_MULTIPLE && \ + MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI + +typedef MPL_atomic_int_t MPIR_atomic_flag_t; + +static inline void MPIR_atomic_flag_set(MPIR_atomic_flag_t * flag_ptr, int val) +{ + MPL_atomic_relaxed_store_int(flag_ptr, val); +} + +static inline int MPIR_atomic_flag_get(MPIR_atomic_flag_t * flag_ptr) +{ + return MPL_atomic_relaxed_load_int(flag_ptr); +} + +static inline int MPIR_atomic_flag_swap(MPIR_atomic_flag_t * flag_ptr, int val) +{ + return MPL_atomic_swap_int(flag_ptr, val); +} + +static inline int MPIR_atomic_flag_cas(MPIR_atomic_flag_t * flag_ptr, int old_val, int new_val) +{ + return MPL_atomic_cas_int(flag_ptr, old_val, new_val); +} + +#else + +typedef int MPIR_atomic_flag_t; + +static inline void MPIR_atomic_flag_set(MPIR_atomic_flag_t * flag_ptr, int val) +{ + *flag_ptr = val; +} + +static inline int MPIR_atomic_flag_get(MPIR_atomic_flag_t * flag_ptr) +{ + return *flag_ptr; +} + +static inline int MPIR_atomic_flag_swap(MPIR_atomic_flag_t * flag_ptr, int val) +{ + int ret = *flag_ptr; + *flag_ptr = val; + return ret; +} + +static inline int MPIR_atomic_flag_cas(MPIR_atomic_flag_t * flag_ptr, int old_val, int new_val) +{ + int ret = *flag_ptr; + if (*flag_ptr == old_val) { + *flag_ptr = new_val; + } + return ret; +} + +#endif + +#endif /* MPIR_ATOMIC_FLAG_H_INCLUDED */ diff --git a/src/include/mpir_err.h b/src/include/mpir_err.h index e6902e2c20f..ab9a63496e4 100644 --- a/src/include/mpir_err.h +++ b/src/include/mpir_err.h @@ -371,9 +371,7 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in } #define MPIR_ERRTEST_STARTREQ(reqp,err) \ - if ((reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_SEND && (reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_RECV \ - && (reqp)->kind != MPIR_REQUEST_KIND__PREQUEST_COLL \ - && (reqp)->kind != MPIR_REQUEST_KIND__PART_SEND && (reqp)->kind != MPIR_REQUEST_KIND__PART_RECV) { \ + if (!MPIR_Request_is_persistent(reqp)) { \ err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \ MPI_ERR_REQUEST, "**requestinvalidstart", 0); \ goto fn_fail; \ @@ -394,6 +392,10 @@ void MPIR_Handle_fatal_error(struct MPIR_Comm *comm_ptr, const char fcname[], in err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \ MPI_ERR_REQUEST, "**requestpartactive", 0); \ goto fn_fail; \ + } else if (((reqp)->kind == MPIR_REQUEST_KIND__CONTINUE) && MPIR_Cont_request_is_active(reqp)) { \ + err = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __func__, __LINE__, \ + MPI_ERR_REQUEST, "**requestpartactive", 0); \ + goto fn_fail; \ } #define MPIR_ERRTEST_PREADYREQ(reqp,err) \ diff --git a/src/include/mpir_request.h b/src/include/mpir_request.h index 3294ca3239f..3af36b560e1 100644 --- a/src/include/mpir_request.h +++ b/src/include/mpir_request.h @@ -36,6 +36,7 @@ typedef enum MPIR_Request_kind_t { MPIR_REQUEST_KIND__COLL, MPIR_REQUEST_KIND__MPROBE, /* see NOTE-R1 */ MPIR_REQUEST_KIND__RMA, + MPIR_REQUEST_KIND__CONTINUE, MPIR_REQUEST_KIND__LAST #ifdef MPID_REQUEST_KIND_DECL , MPID_REQUEST_KIND_DECL @@ -60,6 +61,15 @@ typedef void (MPIR_Grequest_f77_cancel_function) (void *, MPI_Fint *, MPI_Fint * typedef void (MPIR_Grequest_f77_free_function) (void *, MPI_Fint *); typedef void (MPIR_Grequest_f77_query_function) (void *, MPI_Fint *, MPI_Fint *); +/* Typedefs for request callback */ +typedef void (MPIR_Request_callback_function) (MPIR_Request *, bool, void *); +struct MPIR_Request_cb_t { + MPIR_Request_callback_function *fn; + void *arg; + bool is_persistent; + struct MPIR_Request_cb_t *next; +}; + /* vtable-ish structure holding generalized request function pointers and other * state. Saves ~48 bytes in pt2pt requests on many platforms. */ struct MPIR_Grequest_fns { @@ -139,6 +149,13 @@ enum MPIR_sched_type { MPIR_SCHED_GENTRAN }; +/* Declaration for continue */ +struct MPIR_Continue_context; +struct MPIR_Continue; +int MPIR_Continue_start(MPIR_Request * request); +void MPIR_Continue_progress(MPIR_Request *request); +int MPIR_Continue_progress_tls(); + /*S MPIR_Request - Description of the Request data structure @@ -173,6 +190,13 @@ struct MPIR_Request { MPIR_Comm *comm; /* Status is needed for wait/test/recv */ MPI_Status status; + /* Callback */ + MPID_Thread_mutex_t cbs_lock; + bool cbs_invoked; + struct { + struct MPIR_Request_cb_t *head; + struct MPIR_Request_cb_t *tail; + } cbs; union { struct { @@ -207,6 +231,21 @@ struct MPIR_Request { struct { MPIR_Win *win; } rma; /* kind : MPIR_REQUEST_KIND__RMA */ + struct { + MPL_atomic_int_t active_flag; /* flag indicating whether in a start-complete active period. + * Value is 0 or 1. */ + struct { + struct MPIR_Continue_context *head, *tail; + MPID_Thread_mutex_t lock; + } cont_context_on_hold_list; + bool is_pool_only; + int max_poll; + struct { + struct MPIR_Continue *head, *tail; + MPID_Thread_mutex_t lock; + } ready_poll_only_cont_list; + MPID_Progress_state_cnt *state; + } cont; /* Reserve space for local usages. For example, threadcomm, the actual struct * is defined locally and is used via casting */ char dummy[MPIR_REQUEST_UNION_SIZE]; @@ -359,7 +398,8 @@ static inline int MPIR_Request_is_persistent(MPIR_Request * req_ptr) req_ptr->kind == MPIR_REQUEST_KIND__PREQUEST_RECV || req_ptr->kind == MPIR_REQUEST_KIND__PREQUEST_COLL || req_ptr->kind == MPIR_REQUEST_KIND__PART_SEND || - req_ptr->kind == MPIR_REQUEST_KIND__PART_RECV); + req_ptr->kind == MPIR_REQUEST_KIND__PART_RECV || + req_ptr->kind == MPIR_REQUEST_KIND__CONTINUE); } static inline int MPIR_Part_request_is_active(MPIR_Request * req_ptr) @@ -377,6 +417,21 @@ static inline void MPIR_Part_request_activate(MPIR_Request * req_ptr) MPL_atomic_store_int(&req_ptr->u.part.active_flag, 1); } +static inline int MPIR_Cont_request_is_active(MPIR_Request * req_ptr) +{ + return MPL_atomic_load_int(&req_ptr->u.cont.active_flag); +} + +static inline void MPIR_Cont_request_inactivate(MPIR_Request * req_ptr) +{ + MPL_atomic_store_int(&req_ptr->u.cont.active_flag, 0); +} + +static inline void MPIR_Cont_request_activate(MPIR_Request * req_ptr) +{ + MPL_atomic_store_int(&req_ptr->u.cont.active_flag, 1); +} + /* Return whether a request is active. * A persistent request and the handle to it are "inactive" * if the request is not associated with any ongoing communication. @@ -395,6 +450,8 @@ static inline int MPIR_Request_is_active(MPIR_Request * req_ptr) case MPIR_REQUEST_KIND__PART_SEND: case MPIR_REQUEST_KIND__PART_RECV: return MPIR_Part_request_is_active(req_ptr); + case MPIR_REQUEST_KIND__CONTINUE: + return MPIR_Cont_request_is_active(req_ptr); default: return 1; /* regular request is always active */ } @@ -408,6 +465,7 @@ static inline int MPIR_Request_is_active(MPIR_Request * req_ptr) | MPIR_REQUESTS_PROPERTY__NO_GREQUESTS \ | MPIR_REQUESTS_PROPERTY__SEND_RECV_ONLY) +MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_init(MPIR_Request * req); /* NOTE: Pool-specific request creation is unsafe unless under global thread granularity. */ static inline MPIR_Request *MPIR_Request_create_from_pool(MPIR_Request_kind_t kind, int pool, @@ -456,6 +514,7 @@ static inline MPIR_Request *MPIR_Request_create_from_pool(MPIR_Request_kind_t ki MPIR_STATUS_SET_CANCEL_BIT(req->status, FALSE); req->comm = NULL; + MPIR_Request_cb_init(req); switch (kind) { case MPIR_REQUEST_KIND__COLL: @@ -530,6 +589,10 @@ MPL_STATIC_INLINE_PREFIX MPIR_Request *MPIR_Request_create_null_recv(void) return get_builtin_req(HANDLE_INDEX(MPIR_REQUEST_NULL_RECV), MPIR_REQUEST_KIND__RECV); } +MPL_STATIC_INLINE_PREFIX void MPIR_Invoke_callback(MPIR_Request * req, bool in_cs); +MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_free(MPIR_Request * req); +void MPIR_Continue_destroy_impl(MPIR_Request *cont_req); + int MPIR_Grequest_free(MPIR_Request * request_ptr); static inline void MPIR_Request_free_with_safety(MPIR_Request * req, int need_safety, int *errno_out) @@ -600,6 +663,9 @@ static inline void MPIR_Request_free_with_safety(MPIR_Request * req, *errno_out = mpi_errno; } break; + case MPIR_REQUEST_KIND__CONTINUE: + MPIR_Continue_destroy_impl(req); + break; default: break; } @@ -609,6 +675,7 @@ static inline void MPIR_Request_free_with_safety(MPIR_Request * req, * when we destroy a request */ /* FIXME: We need a way to call these routines ONLY when the * related ref count has become zero. */ + MPIR_Request_cb_free(req); if (req->comm != NULL) { if (MPIR_Request_is_persistent(req)) { MPIR_Comm_delete_inactive_request(req->comm, req); @@ -652,6 +719,76 @@ MPL_STATIC_INLINE_PREFIX int MPIR_Request_free_return(MPIR_Request * req) return mpi_errno; } +MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_init(MPIR_Request * req) +{ + req->cbs_invoked = false; + req->cbs.head = NULL; + req->cbs.tail = NULL; + int err; + MPID_Thread_mutex_create(&req->cbs_lock, &err); + MPIR_Assert(!err); +} + +MPL_STATIC_INLINE_PREFIX void MPIR_Request_cb_free(MPIR_Request * req) +{ + int err; + MPID_Thread_mutex_destroy(&req->cbs_lock, &err); + MPIR_Assert(!err); + // free all the persistent callbacks + while (req->cbs.head) { + struct MPIR_Request_cb_t *cb = req->cbs.head; + MPIR_Assert(cb->is_persistent); + LL_DELETE(req->cbs.head, req->cbs.tail, cb); + MPL_free(cb); + } +} + +MPL_STATIC_INLINE_PREFIX bool MPIR_Register_callback(MPIR_Request * req, + MPIR_Request_callback_function *cb_fn, + void *cb_arg, + bool is_persistent) +{ + if (MPIR_Request_is_complete(req)) + return false; + MPID_THREAD_CS_ENTER(VCI, req->cbs_lock); + bool succeed = !req->cbs_invoked; + if (succeed) { + struct MPIR_Request_cb_t *cb = MPL_malloc(sizeof(struct MPIR_Request_cb_t), MPL_MEM_OTHER); + cb->fn = cb_fn; + cb->arg = cb_arg; + cb->is_persistent = is_persistent; + LL_APPEND(req->cbs.head, req->cbs.tail, cb); + } + MPID_THREAD_CS_EXIT(VCI, req->cbs_lock); + return succeed; +} + +MPL_STATIC_INLINE_PREFIX void MPIR_Invoke_callback(MPIR_Request * req, bool in_cs) +{ + /* At this point, for each req, only one thread should execute this code. */ + MPIR_Assert(!req->cbs_invoked); + MPID_THREAD_CS_ENTER(VCI, req->cbs_lock); + req->cbs_invoked = true; + MPID_THREAD_CS_EXIT(VCI, req->cbs_lock); + /* So we do not need to protect this check */ + if (!req->cbs.head) { + return; + } + while (req->cbs.head) { + struct MPIR_Request_cb_t *cb = req->cbs.head; + cb->fn(req, in_cs, cb->arg); + if (!cb->is_persistent) { + LL_DELETE(req->cbs.head, req->cbs.tail, cb); + MPL_free(cb); + } + } +} + +MPL_STATIC_INLINE_PREFIX void MPIR_Request_start(MPIR_Request * req) +{ + req->cbs_invoked = false; +} + /* Requests that are not created inside device (general requests, nonblocking collective * requests such as sched, tsp, hcoll) should call MPIR_Request_complete. * MPID_Request_complete are called inside device critical section, therefore, potentially @@ -660,8 +797,14 @@ MPL_STATIC_INLINE_PREFIX int MPIR_Request_free_return(MPIR_Request * req) MPL_STATIC_INLINE_PREFIX void MPIR_Request_complete(MPIR_Request * req) { MPIR_cc_set(&req->cc, 0); + MPIR_Invoke_callback(req, false); MPIR_Request_free(req); } +MPL_STATIC_INLINE_PREFIX void MPIR_Request_complete_nofree(MPIR_Request * req) +{ + MPIR_cc_set(&req->cc, 0); + MPIR_Invoke_callback(req, false); +} /* The "fastpath" version of MPIR_Request_completion_processing. It only handles * MPIR_REQUEST_KIND__SEND and MPIR_REQUEST_KIND__RECV kinds, and it does not attempt to @@ -801,6 +944,9 @@ int MPIR_Waitany(int count, MPIR_Request * request_ptrs[], int *indx, MPI_Status int MPIR_Waitsome(int incount, MPIR_Request * request_ptrs[], int *outcount, int array_of_indices[], MPI_Status array_of_statuses[]); int MPIR_Parrived(MPIR_Request * request_ptr, int partition, int *flag); +int MPIR_Continueall_impl(int count, MPIR_Request *request_ptrs[], + MPIX_Continue_cb_function *cb, void *cb_data, int flags, + MPI_Status *array_of_statuses, MPIR_Request *cont_request_ptr); void MPIR_Request_debug(void); diff --git a/src/mpi/Makefile.mk b/src/mpi/Makefile.mk index 8d7c4236ee6..4a54d8040c2 100644 --- a/src/mpi/Makefile.mk +++ b/src/mpi/Makefile.mk @@ -6,6 +6,7 @@ include $(top_srcdir)/src/mpi/attr/Makefile.mk include $(top_srcdir)/src/mpi/coll/Makefile.mk include $(top_srcdir)/src/mpi/comm/Makefile.mk +include $(top_srcdir)/src/mpi/continue/Makefile.mk include $(top_srcdir)/src/mpi/datatype/Makefile.mk include $(top_srcdir)/src/mpi/debugger/Makefile.mk include $(top_srcdir)/src/mpi/errhan/Makefile.mk diff --git a/src/mpi/continue/Makefile.mk b/src/mpi/continue/Makefile.mk new file mode 100644 index 00000000000..aacdee09418 --- /dev/null +++ b/src/mpi/continue/Makefile.mk @@ -0,0 +1,7 @@ +## +## Copyright (C) by Argonne National Laboratory +## See COPYRIGHT in top-level directory +## + +mpi_core_sources += \ + src/mpi/continue/continue_impl.c diff --git a/src/mpi/continue/continue_impl.c b/src/mpi/continue/continue_impl.c new file mode 100644 index 00000000000..4e1d3cd774a --- /dev/null +++ b/src/mpi/continue/continue_impl.c @@ -0,0 +1,358 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#include "mpiimpl.h" + +/* Continue object: a wrapper to a continue callback */ +struct MPIR_Continue { + MPIR_Request *cont_req; + MPIX_Continue_cb_function *cb; + void *cb_data; + MPIR_cc_t pending_request_count; + struct MPIR_Continue *next; + bool is_immediate; +}; +typedef struct MPIR_Continue MPIR_Continue; + +/* Continue context object: carrying data for each op request */ +struct MPIR_Continue_context { + struct MPIR_Continue* continue_ptr; + MPI_Status *status_ptr; + /* Used by the on-hold list */ + struct MPIR_Continue_context *next; + MPIR_Request *op_request; +}; +typedef struct MPIR_Continue_context MPIR_Continue_context; + +struct { + struct MPIR_Continue *head, *tail; + MPID_Thread_mutex_t lock; +} g_deferred_cont_list = {NULL, NULL}; + +__thread struct { + struct MPIR_Continue *head, *tail; +} tls_deferred_cont_list = {NULL, NULL}; + +void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete, bool in_request_callback); +void MPIR_Continue_callback(MPIR_Request *op_request, bool in_cs, void *cb_context); +void attach_continue_context(MPIR_Continue_context *context_ptr, bool defer_complete); + +void MPIR_Continue_global_init() +{ + g_deferred_cont_list.head = NULL; + g_deferred_cont_list.tail = NULL; + int err; + MPID_Thread_mutex_create(&g_deferred_cont_list.lock, &err); + MPIR_Assert(err == 0); +} + +void MPIR_Continue_global_finalize() +{ + MPIR_Assert(g_deferred_cont_list.head == NULL); + MPIR_Assert(g_deferred_cont_list.tail == NULL); + int err; + MPID_Thread_mutex_destroy(&g_deferred_cont_list.lock, &err); + MPIR_Assert(err == 0); +} + +int MPIR_Continue_init_impl(int flags, int max_poll, + MPIR_Info *info_ptr, + MPIR_Request **cont_req_ptr) +{ + MPIR_Request *cont_req = MPIR_Request_create(MPIR_REQUEST_KIND__CONTINUE); + MPIR_Cont_request_inactivate(cont_req); + /* We use cc to track how many continue object has been attached to this continuation request. */ + MPIR_cc_set(&cont_req->cc, 0); + { + int err; + MPID_Thread_mutex_create(&cont_req->u.cont.cont_context_on_hold_list.lock, &err); + MPIR_Assert(err == 0); + MPID_Thread_mutex_create(&cont_req->u.cont.ready_poll_only_cont_list.lock, &err); + MPIR_Assert(err == 0); + } + /* Initialize the on-hold context list */ + cont_req->u.cont.cont_context_on_hold_list.head = NULL; + cont_req->u.cont.cont_context_on_hold_list.tail = NULL; + /* Initialize the poll-only continue list */ + cont_req->u.cont.ready_poll_only_cont_list.head = NULL; + cont_req->u.cont.ready_poll_only_cont_list.tail = NULL; + cont_req->u.cont.is_pool_only = flags & MPIX_CONT_POLL_ONLY; + cont_req->u.cont.max_poll = max_poll; + cont_req->u.cont.state = (MPID_Progress_state_cnt *) MPL_malloc(sizeof(MPID_Progress_state_cnt), MPL_MEM_OTHER); + for (int i = 0; i < MPIDI_CH4_MAX_VCIS; ++i) { + MPL_atomic_release_store_uint64(&cont_req->u.cont.state->vci_refcount[i].val, 0); + } + *cont_req_ptr = cont_req; + return MPI_SUCCESS; +} + +void MPIR_Continue_destroy_impl(MPIR_Request *cont_req) +{ + MPL_free(cont_req->u.cont.state); + MPIR_Assert(cont_req->kind == MPIR_REQUEST_KIND__CONTINUE); + { + int err; + MPID_Thread_mutex_destroy(&cont_req->u.cont.cont_context_on_hold_list.lock, &err); + MPIR_Assert(err == 0); + MPID_Thread_mutex_destroy(&cont_req->u.cont.ready_poll_only_cont_list.lock, &err); + MPIR_Assert(err == 0); + } + MPIR_Assert(cont_req->u.cont.cont_context_on_hold_list.head == NULL); + MPIR_Assert(cont_req->u.cont.cont_context_on_hold_list.tail == NULL); + MPIR_Assert(cont_req->u.cont.ready_poll_only_cont_list.head == NULL); + MPIR_Assert(cont_req->u.cont.ready_poll_only_cont_list.tail == NULL); +} + +int MPIR_Continue_start(MPIR_Request * cont_request_ptr) +{ + struct MPIR_Continue_context *tmp_head = NULL, *tmp_tail = NULL; + MPID_THREAD_CS_ENTER(VCI, cont_request_ptr->u.cont.cont_context_on_hold_list.lock); + MPIR_Cont_request_activate(cont_request_ptr); + if (cont_request_ptr->u.cont.cont_context_on_hold_list.head) { + tmp_head = cont_request_ptr->u.cont.cont_context_on_hold_list.head; + tmp_tail = cont_request_ptr->u.cont.cont_context_on_hold_list.tail; + cont_request_ptr->u.cont.cont_context_on_hold_list.head = NULL; + cont_request_ptr->u.cont.cont_context_on_hold_list.tail = NULL; + } + MPID_THREAD_CS_EXIT(VCI, cont_request_ptr->u.cont.cont_context_on_hold_list.lock); + /* Attach those on-hold continue context */ + while (tmp_head) { + MPIR_Continue_context *context_ptr = tmp_head; + LL_DELETE(tmp_head, tmp_tail, context_ptr); + attach_continue_context(context_ptr, false); + } + return MPI_SUCCESS; +} + +void attach_continue_context(MPIR_Continue_context *context_ptr, bool defer_complete) { + /* record the corresponding VCI for the continuation request to progress */ + if (context_ptr->continue_ptr->cont_req) { + int vci = MPIDI_Request_get_vci(context_ptr->op_request); + MPL_atomic_fetch_add_int(&context_ptr->continue_ptr->cont_req->u.cont.state->vci_refcount[vci].val, 1); + } + /* Attach the continue context to the op request */ + if (!MPIR_Register_callback(context_ptr->op_request, MPIR_Continue_callback, context_ptr, false)) { + /* the request has already been completed. */ + complete_op_request(context_ptr->op_request, false, context_ptr, defer_complete, false); + } +} + +int MPIR_Continue_impl(MPIR_Request *op_request_ptr, + MPIX_Continue_cb_function *cb, void *cb_data, + int flags, MPI_Status *status, + MPIR_Request *cont_request_ptr) +{ + return MPIR_Continueall_impl(1, &op_request_ptr, cb, cb_data, flags, status, cont_request_ptr); +} + +int MPIR_Continueall_impl(int count, MPIR_Request *request_ptrs[], + MPIX_Continue_cb_function *cb, void *cb_data, int flags, + MPI_Status *array_of_statuses, MPIR_Request *cont_request_ptr) +{ + if (cont_request_ptr) { + /* Add one continue to the continuation request */ + int was_incompleted; + MPIR_cc_incr(cont_request_ptr->cc_ptr, &was_incompleted); + if (!was_incompleted) { + MPIR_Request_add_ref(cont_request_ptr); + /* A hack for now since continuation request can jump + * between complete and incomplete multiple times */ + cont_request_ptr->cbs_invoked = false; + } + } + /* Set various condition variables */ + bool defer_complete = flags & MPIX_CONT_DEFER_COMPLETE; + /* Create the continue object for every continue callback */ + MPIR_Continue *continue_ptr = (MPIR_Continue *) MPL_malloc(sizeof(MPIR_Continue), MPL_MEM_OTHER); + continue_ptr->cont_req = cont_request_ptr; + continue_ptr->cb = cb; + continue_ptr->cb_data = cb_data; + MPIR_cc_set(&continue_ptr->pending_request_count, count); + if (flags & MPIX_CONT_IMMEDIATE) + continue_ptr->is_immediate = true; + else + continue_ptr->is_immediate = false; + + for (int i = 0; i < count; i++) { + /* Create the continue context object for every op request */ + MPIR_Continue_context *context_ptr = (MPIR_Continue_context *) MPL_malloc(sizeof(MPIR_Continue_context), MPL_MEM_OTHER); + context_ptr->continue_ptr = continue_ptr; + MPIR_Assert(MPI_STATUS_IGNORE == MPI_STATUSES_IGNORE); + if (array_of_statuses != MPI_STATUS_IGNORE) { + context_ptr->status_ptr = &array_of_statuses[i]; + } else { + context_ptr->status_ptr = MPI_STATUS_IGNORE; + } + context_ptr->op_request = request_ptrs[i]; + /* if the continue request is not activated yet, do not attach */ + bool is_on_hold = false; + if (cont_request_ptr && !MPIR_Cont_request_is_active(cont_request_ptr)) { + MPID_THREAD_CS_ENTER(VCI, cont_request_ptr->u.cont.cont_context_on_hold_list.lock); + if (!MPIR_Cont_request_is_active(cont_request_ptr)) { + /* The continuation request is inactive. Do not attach yet. */ + LL_APPEND(cont_request_ptr->u.cont.cont_context_on_hold_list.head, + cont_request_ptr->u.cont.cont_context_on_hold_list.tail, + context_ptr); + is_on_hold = true; + } + MPID_THREAD_CS_EXIT(VCI, cont_request_ptr->u.cont.cont_context_on_hold_list.lock); + } + /* attach the continue context to op request */ + if (!is_on_hold) { + attach_continue_context(context_ptr, defer_complete); + } + } + return MPI_SUCCESS; +} + +void execute_continue(MPIR_Continue *continue_ptr, bool in_cs, int which_cs) +{ + MPIR_Request *cont_req_ptr = continue_ptr->cont_req; + /* Invoke the continue callback */ + continue_ptr->cb(MPI_SUCCESS, continue_ptr->cb_data); + MPL_free(continue_ptr); + /* Signal the continuation request */ + /* TODO: Find a suitable request complete function for continuation requests */ + if (cont_req_ptr) { + int incomplete; + MPIR_cc_decr(cont_req_ptr->cc_ptr, &incomplete); + if (!incomplete) { + /* TODO: reason about the safety of invoking the callback for continuation request here*/ +// MPIR_Invoke_callback(cont_req_ptr, false); + MPIR_Request_free_with_safety(cont_req_ptr, !(in_cs && MPIR_REQUEST_POOL(cont_req_ptr) == which_cs), NULL); + } + } +} + +void complete_op_request(MPIR_Request *op_request, bool in_cs, void *cb_context, bool defer_complete, bool in_request_callback) +{ + MPIR_Continue_context *context_ptr = (MPIR_Continue_context *) cb_context; + MPIR_Continue *continue_ptr = context_ptr->continue_ptr; + /* Decrease the continuation request VCI counter */ + MPIR_Request *cont_req_ptr = continue_ptr->cont_req; + if (cont_req_ptr) { + int vci = MPIDI_Request_get_vci(op_request); + MPL_atomic_fetch_sub_int(&cont_req_ptr->u.cont.state->vci_refcount[vci].val, 1); + } + /* Complete this operation request */ + /* FIXME: MPIR_Request_completion_processing can call MPIR_Request_free, + * which might lead to deadlock */ + int rc = MPIR_Request_completion_processing( + op_request, context_ptr->status_ptr); + if (context_ptr->status_ptr != MPI_STATUS_IGNORE) + context_ptr->status_ptr->MPI_ERROR = rc; + if (!MPIR_Request_is_persistent(op_request)) { + MPIR_Request_free_with_safety(op_request, !in_cs, NULL); + } + MPL_free(context_ptr); + /* Signal the continue callback */ + int incomplete; + MPIR_cc_decr(&continue_ptr->pending_request_count, &incomplete); + if (!incomplete) { + /* All the op requests associated with this continue callback have completed */ + MPIR_Request *cont_req_ptr = continue_ptr->cont_req; + if (cont_req_ptr && cont_req_ptr->u.cont.is_pool_only) { + // Pool-only continuation request + // Push to the continuation request local ready list + MPID_THREAD_CS_ENTER(VCI, cont_req_ptr->u.cont.ready_poll_only_cont_list.lock); + LL_APPEND(cont_req_ptr->u.cont.ready_poll_only_cont_list.head, + cont_req_ptr->u.cont.ready_poll_only_cont_list.tail, + continue_ptr); + MPID_THREAD_CS_EXIT(VCI, cont_req_ptr->u.cont.ready_poll_only_cont_list.lock); + } else if (defer_complete) { + // Deferred completion. + MPID_THREAD_CS_ENTER(VCI, g_deferred_cont_list.lock); + LL_APPEND(g_deferred_cont_list.head, + g_deferred_cont_list.tail, + continue_ptr); + MPID_THREAD_CS_EXIT(VCI, g_deferred_cont_list.lock); + } else if (in_cs && !continue_ptr->is_immediate) { + // General-purpose continuation request. We are in a VCI CS + // Push to the tls ready list + LL_APPEND(tls_deferred_cont_list.head, + tls_deferred_cont_list.tail, + continue_ptr); + } else { + execute_continue(continue_ptr, in_cs, MPIR_REQUEST_POOL(op_request)); + } + } + +} + +void MPIR_Continue_callback(MPIR_Request *op_request, bool in_cs, void *cb_context) +{ + complete_op_request(op_request, in_cs, cb_context, false, true); +} + +int MPIR_Continue_progress_tls() +{ + int count = 0; + while (tls_deferred_cont_list.head) { + /* We have to poll all the things to ensure progress */ + MPIR_Continue *continue_ptr = tls_deferred_cont_list.head; + LL_DELETE(tls_deferred_cont_list.head, tls_deferred_cont_list.tail, continue_ptr); + execute_continue(continue_ptr, false, 0 /* Does not matter */); + ++count; + } + return count; +} + +int MPIR_Continue_progress_request(MPIR_Request *cont_request_ptr) +{ + MPIR_Assert(cont_request_ptr && cont_request_ptr->kind == MPIR_REQUEST_KIND__CONTINUE); + int count = 0; + if (cont_request_ptr->u.cont.ready_poll_only_cont_list.head) { + struct MPIR_Continue *local_head = NULL, *local_tail = NULL; + MPID_THREAD_CS_ENTER(VCI, cont_request_ptr->u.cont.ready_poll_only_cont_list.lock); + /* TODO: use a more efficient way to pop this list */ + while (cont_request_ptr->u.cont.ready_poll_only_cont_list.head) { + MPIR_Continue *continue_ptr = cont_request_ptr->u.cont.ready_poll_only_cont_list.head; + LL_DELETE(cont_request_ptr->u.cont.ready_poll_only_cont_list.head, + cont_request_ptr->u.cont.ready_poll_only_cont_list.tail, + continue_ptr); + LL_APPEND(local_head, local_tail, continue_ptr); + if (cont_request_ptr->u.cont.max_poll && ++count >= cont_request_ptr->u.cont.max_poll) + break; + } + MPID_THREAD_CS_EXIT(VCI, cont_request_ptr->u.cont.ready_poll_only_cont_list.lock); + while (local_head) { + MPIR_Continue *continue_ptr = local_head; + LL_DELETE(local_head, local_tail, continue_ptr); + execute_continue(continue_ptr, false, 0 /* Does not matter */); + } + } + return count; +} + +void MPIR_Continue_progress(MPIR_Request *request) +{ + int count = 0; + int max_poll = 0; // By default we poll unlimited time + if (request && request->kind == MPIR_REQUEST_KIND__CONTINUE) { + // This is a continuation request + count += MPIR_Continue_progress_request(request); + max_poll = request->u.cont.max_poll; + } + // make progress on the global list + if (g_deferred_cont_list.head) { + struct MPIR_Continue *local_head = NULL, *local_tail = NULL; + MPID_THREAD_CS_ENTER(VCI, g_deferred_cont_list.lock); + /* TODO: use a more efficient way to pop this list */ + while ((!max_poll || count < max_poll) && g_deferred_cont_list.head) { + MPIR_Continue *continue_ptr = g_deferred_cont_list.head; + LL_DELETE(g_deferred_cont_list.head, + g_deferred_cont_list.tail, + continue_ptr); + LL_APPEND(local_head, local_tail, continue_ptr); + ++count; + } + MPID_THREAD_CS_EXIT(VCI, g_deferred_cont_list.lock); + while (local_head) { + MPIR_Continue *continue_ptr = local_head; + LL_DELETE(local_head, local_tail, continue_ptr); + execute_continue(continue_ptr, false, 0 /* Does not matter */); + } + } +} \ No newline at end of file diff --git a/src/mpi/init/mpir_init.c b/src/mpi/init/mpir_init.c index 2f1c115aa13..e4e02b46a5a 100644 --- a/src/mpi/init/mpir_init.c +++ b/src/mpi/init/mpir_init.c @@ -120,6 +120,7 @@ static int init_counter; */ /* ------------ Init ------------------- */ +void MPIR_Continue_global_init(); int MPIR_Init_impl(int *argc, char ***argv) { @@ -207,6 +208,7 @@ int MPII_Init_thread(int *argc, char ***argv, int user_required, int *provided, MPII_nettopo_init(); MPII_init_windows(); MPII_init_binding_cxx(); + MPIR_Continue_global_init(); mpi_errno = MPII_init_local_proc_attrs(&required); MPIR_ERR_CHECK(mpi_errno); @@ -377,6 +379,8 @@ int MPIR_Init_thread_impl(int *argc, char ***argv, int user_required, int *provi /* ------------ Finalize ------------------- */ +void MPIR_Continue_global_finalize(); + int MPII_Finalize(MPIR_Session * session_ptr) { int mpi_errno = MPI_SUCCESS; @@ -480,6 +484,7 @@ int MPII_Finalize(MPIR_Session * session_ptr) MPIR_Process.memory_alloc_kinds = NULL; /* All memory should be freed at this point */ + MPIR_Continue_global_finalize(); MPII_finalize_memory_tracing(); MPII_thread_mutex_destroy(); diff --git a/src/mpi/request/mpir_request.c b/src/mpi/request/mpir_request.c index 8ea25f98ce1..e2192f1060f 100644 --- a/src/mpi/request/mpir_request.c +++ b/src/mpi/request/mpir_request.c @@ -27,6 +27,9 @@ static void init_builtin_request(MPIR_Request * req, int handle, MPIR_Request_ki req->status.MPI_TAG = MPI_ANY_TAG; } req->comm = NULL; + MPIR_Request_cb_init(req); + // built-in request are in the completed state + req->cbs_invoked = true; } void MPII_init_request(void) @@ -202,6 +205,15 @@ int MPIR_Request_completion_processing(MPIR_Request * request_ptr, MPI_Status * break; } + case MPIR_REQUEST_KIND__CONTINUE: + { + MPIR_Cont_request_inactivate(request_ptr); + + MPIR_Request_extract_status(request_ptr, status); + mpi_errno = request_ptr->status.MPI_ERROR; + break; + } + case MPIR_REQUEST_KIND__GREQUEST: { mpi_errno = MPIR_Grequest_query(request_ptr); diff --git a/src/mpi/request/request_impl.c b/src/mpi/request/request_impl.c index f9bccb93747..17edcf0398e 100644 --- a/src/mpi/request/request_impl.c +++ b/src/mpi/request/request_impl.c @@ -194,6 +194,8 @@ int MPIR_Request_free_impl(MPIR_Request * request_ptr) MPIR_Request_free(request_ptr->u.persist_coll.real_request); } break; + case MPIR_REQUEST_KIND__CONTINUE: + break; default: break; } @@ -212,6 +214,7 @@ int MPIR_Test_state(MPIR_Request * request_ptr, int *flag, MPI_Status * status, if (!MPIR_Request_is_complete(request_ptr)) { mpi_errno = MPID_Progress_test(state); MPIR_ERR_CHECK(mpi_errno); + MPIR_Continue_progress(request_ptr); } fn_exit: @@ -274,6 +277,7 @@ int MPIR_Testall_state(int count, MPIR_Request * request_ptrs[], int *flag, if (request_ptrs[i] == NULL || MPIR_Request_is_complete(request_ptrs[i])) { n_completed++; } else { + MPIR_Continue_progress(request_ptrs[i]); break; } } @@ -286,6 +290,8 @@ int MPIR_Testall_state(int count, MPIR_Request * request_ptrs[], int *flag, } if (MPIR_Request_is_complete(request_ptrs[i])) { n_completed++; + } else { + MPIR_Continue_progress(request_ptrs[i]); } } else { n_completed++; @@ -309,6 +315,21 @@ int MPIR_Testall_state(int count, MPIR_Request * request_ptrs[], int *flag, *flag = FALSE; fn_exit: + if (n_completed == count) { + *flag = TRUE; + goto fn_exit; + } + + if (need_progress > 0) { + mpi_errno = MPID_Progress_test(state); + MPIR_ERR_CHECK(mpi_errno); + + need_progress--; + goto fn_check_requests; + } + + *flag = FALSE; + return mpi_errno; fn_fail: goto fn_exit; @@ -464,6 +485,7 @@ int MPIR_Testany_state(int count, MPIR_Request * request_ptrs[], if (!request_ptrs[i]) { continue; } + if (MPIR_Request_has_poll_fn(request_ptrs[i])) { mpi_errno = MPIR_Grequest_poll(request_ptrs[i], status); MPIR_ERR_CHECK(mpi_errno); @@ -473,6 +495,7 @@ int MPIR_Testany_state(int count, MPIR_Request * request_ptrs[], *indx = i; goto fn_exit; } + MPIR_Continue_progress(request_ptrs[i]); } if (need_progress > 0) { @@ -590,6 +613,7 @@ int MPIR_Testsome_state(int incount, MPIR_Request * request_ptrs[], if (!request_ptrs[i]) { continue; } + if (MPIR_Request_has_poll_fn(request_ptrs[i])) { mpi_errno = MPIR_Grequest_poll(request_ptrs[i], &array_of_statuses[i]); MPIR_ERR_CHECK(mpi_errno); @@ -597,7 +621,9 @@ int MPIR_Testsome_state(int incount, MPIR_Request * request_ptrs[], if (MPIR_Request_is_complete(request_ptrs[i])) { array_of_indices[*outcount] = i; *outcount += 1; + continue; } + MPIR_Continue_progress(request_ptrs[i]); } if (*outcount) { @@ -706,6 +732,8 @@ int MPIR_Wait_state(MPIR_Request * request_ptr, MPI_Status * status, MPID_Progre while (!MPIR_Request_is_complete(request_ptr)) { mpi_errno = MPID_Progress_wait(state); MPIR_ERR_CHECK(mpi_errno); + + MPIR_Continue_progress(request_ptr); DEBUG_PROGRESS_CHECK; if (unlikely(MPIR_Request_is_anysrc_mismatched(request_ptr))) { @@ -773,6 +801,7 @@ int MPIR_Waitall_state(int count, MPIR_Request * request_ptrs[], MPI_Status arra while (!MPIR_Request_is_complete(request_ptrs[i])) { mpi_errno = MPID_Progress_wait(state); MPIR_ERR_CHECK(mpi_errno); + MPIR_Continue_progress(request_ptrs[i]); DEBUG_PROGRESS_CHECK; } } @@ -789,6 +818,7 @@ int MPIR_Waitall_state(int count, MPIR_Request * request_ptrs[], MPI_Status arra mpi_errno = MPID_Progress_wait(state); MPIR_ERR_CHECK(mpi_errno); + MPIR_Continue_progress(request_ptrs[i]); DEBUG_PROGRESS_CHECK; } } @@ -984,6 +1014,7 @@ int MPIR_Waitany_state(int count, MPIR_Request * request_ptrs[], int *indx, MPI_ *indx = i; goto fn_exit; } + MPIR_Continue_progress(request_ptrs[i]); } mpi_errno = MPID_Progress_test(state); @@ -1113,7 +1144,9 @@ int MPIR_Waitsome_state(int incount, MPIR_Request * request_ptrs[], if (MPIR_Request_is_complete(request_ptrs[i])) { array_of_indices[n_active] = i; n_active += 1; + continue; } + MPIR_Continue_progress(request_ptrs[i]); } } diff --git a/src/mpid/ch3/src/ch3u_request.c b/src/mpid/ch3/src/ch3u_request.c index 0390a8c4545..8db580ab456 100644 --- a/src/mpid/ch3/src/ch3u_request.c +++ b/src/mpid/ch3/src/ch3u_request.c @@ -545,6 +545,7 @@ int MPID_Request_complete(MPIR_Request *req) MPIDI_CH3U_Request_decrement_cc(req, &incomplete); if (!incomplete) { + MPIR_Invoke_callback(req, false); MPIR_Request_free(req); } diff --git a/src/mpid/ch3/src/mpid_startall.c b/src/mpid/ch3/src/mpid_startall.c index cba93847a43..2746d5ba116 100644 --- a/src/mpid/ch3/src/mpid_startall.c +++ b/src/mpid/ch3/src/mpid_startall.c @@ -54,6 +54,7 @@ int MPID_Startall(int count, MPIR_Request * requests[]) for (i = 0; i < count; i++) { MPIR_Request * const preq = requests[i]; + MPIR_Request_start(preq); if (preq->kind == MPIR_REQUEST_KIND__PREQUEST_COLL) { mpi_errno = MPIR_Persist_coll_start(preq); @@ -61,6 +62,12 @@ int MPID_Startall(int count, MPIR_Request * requests[]) continue; } + if (preq->kind == MPIR_REQUEST_KIND__CONTINUE) { + mpi_errno = MPIR_Continue_start(preq); + MPIR_ERR_CHECK(mpi_errno); + continue; + } + /* only pt2pt requests should reach here */ MPIR_Assert(preq->kind == MPIR_REQUEST_KIND__PREQUEST_SEND || preq->kind == MPIR_REQUEST_KIND__PREQUEST_RECV); @@ -69,6 +76,7 @@ int MPID_Startall(int count, MPIR_Request * requests[]) if (preq->dev.match.parts.rank == MPI_PROC_NULL) continue; + /* FIXME: The odd 7th arg (match.context_id - comm->context_id) is probably to get the context offset. Do we really need the context offset? Is there any case where the offset isn't zero? */ diff --git a/src/mpid/ch4/include/mpidpre.h b/src/mpid/ch4/include/mpidpre.h index 0b9ec6e1df9..dca858f7cde 100644 --- a/src/mpid/ch4/include/mpidpre.h +++ b/src/mpid/ch4/include/mpidpre.h @@ -57,6 +57,15 @@ typedef struct { uint8_t vci[MPIDI_CH4_MAX_VCIS]; /* list of vcis that need progress */ } MPID_Progress_state; +typedef struct { + MPL_atomic_int64_t val; + char padding[56]; +} MPL_padded_atomic_int64_t; + +typedef struct { + MPL_padded_atomic_int64_t vci_refcount[MPIDI_CH4_MAX_VCIS]; /* list of vcis that need progress */ +} MPID_Progress_state_cnt; + typedef enum { MPIDI_PTYPE_RECV, MPIDI_PTYPE_SEND, diff --git a/src/mpid/ch4/src/ch4_progress.h b/src/mpid/ch4/src/ch4_progress.h index be35a9bafb0..5d0dbbf2da0 100644 --- a/src/mpid/ch4/src/ch4_progress.h +++ b/src/mpid/ch4/src/ch4_progress.h @@ -143,6 +143,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_progress_test(MPID_Progress_state * state) } } #endif + MPIR_Continue_progress_tls(); fn_exit: MPIR_FUNC_EXIT; @@ -166,10 +167,10 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_progress_state_init(MPID_Progress_state * st state->vci_count = 1; } else { /* global progress by default */ - for (int i = 0; i < MPIDI_global.n_vcis; i++) { - state->vci[i] = i; - } - state->vci_count = MPIDI_global.n_vcis; + for (int i = 0; i < MPIDI_global.n_vcis; i++) { + state->vci[i] = i; + } + state->vci_count = MPIDI_global.n_vcis; } } diff --git a/src/mpid/ch4/src/ch4_request.h b/src/mpid/ch4/src/ch4_request.h index c75f4187e35..acd1eea3af7 100644 --- a/src/mpid/ch4/src/ch4_request.h +++ b/src/mpid/ch4/src/ch4_request.h @@ -96,6 +96,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Request_complete(MPIR_Request * req) MPIDI_SHM_am_request_finalize(req); #endif } + MPIR_Invoke_callback(req, true); MPIDI_CH4_REQUEST_FREE(req); } @@ -111,6 +112,7 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_Request_complete_fast(MPIR_Request * req) if (req->dev.completion_notification) { MPIR_cc_dec(req->dev.completion_notification); } + MPIR_Invoke_callback(req, true); MPIDI_CH4_REQUEST_FREE(req); } } diff --git a/src/mpid/ch4/src/ch4_self.c b/src/mpid/ch4/src/ch4_self.c index 5b836d0a84e..6762994a5fc 100644 --- a/src/mpid/ch4/src/ch4_self.c +++ b/src/mpid/ch4/src/ch4_self.c @@ -167,7 +167,7 @@ int MPIDI_Self_irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int rank, /* comm will be released by MPIR_Request_free(sreq) */ MPIR_Datatype_release_if_not_builtin(sreq->dev.ch4.self.datatype); MPIR_Request_complete(sreq); - MPIR_cc_set(&rreq->cc, 0); + MPIR_Request_complete_nofree(rreq); MPII_UNEXPQ_FORGET(sreq); } else { ENQUEUE_SELF_RECV(rreq, buf, count, datatype, tag, comm->context_id); @@ -257,7 +257,7 @@ int MPIDI_Self_imrecv(char *buf, MPI_Aint count, MPI_Datatype datatype, /* comm will be released by MPIR_Request_free(sreq) */ MPIR_Datatype_release_if_not_builtin(sreq->dev.ch4.self.datatype); MPIR_Request_complete(sreq); - MPIR_cc_set(&rreq->cc, 0); + MPIR_Request_complete_nofree(rreq); *request = rreq; MPID_THREAD_CS_EXIT(VCI, MPIDIU_THREAD_SELF_MUTEX); diff --git a/src/mpid/ch4/src/ch4_startall.h b/src/mpid/ch4/src/ch4_startall.h index e3407ceef98..b9693741ebf 100644 --- a/src/mpid/ch4/src/ch4_startall.h +++ b/src/mpid/ch4/src/ch4_startall.h @@ -89,6 +89,7 @@ MPL_STATIC_INLINE_PREFIX int MPID_Startall(int count, MPIR_Request * requests[]) for (i = 0; i < count; i++) { MPIR_Request *const preq = requests[i]; + MPIR_Request_start(preq); switch (preq->kind) { case MPIR_REQUEST_KIND__PREQUEST_SEND: case MPIR_REQUEST_KIND__PREQUEST_RECV: @@ -104,6 +105,10 @@ MPL_STATIC_INLINE_PREFIX int MPID_Startall(int count, MPIR_Request * requests[]) mpi_errno = MPIDI_part_start(preq); break; + case MPIR_REQUEST_KIND__CONTINUE: + mpi_errno = MPIR_Continue_start(preq); + break; + default: mpi_errno = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_FATAL, __FUNCTION__, __LINE__, MPI_ERR_INTERN, "**ch4|badstartreq", diff --git a/src/mpid/ch4/src/ch4_wait.h b/src/mpid/ch4/src/ch4_wait.h index d6456a26c2b..36cc450d286 100644 --- a/src/mpid/ch4/src/ch4_wait.h +++ b/src/mpid/ch4/src/ch4_wait.h @@ -8,15 +8,44 @@ #include "ch4_impl.h" +MPL_STATIC_INLINE_PREFIX void MPIDI_add_vci_to_state(int vci, + MPID_Progress_state * state) +{ + MPIR_Assert(vci < MPIDI_CH4_MAX_VCIS); + for (int i = 0; i < state->vci_count; ++i) { + if (state->vci[i] == vci) { + return; + } + } + MPIR_Assert(state->vci_count < MPIDI_CH4_MAX_VCIS); + state->vci[state->vci_count++] = vci; +} + +MPL_STATIC_INLINE_PREFIX void MPIDI_add_progress_vci_cont(MPIR_Request * req, + MPID_Progress_state * state) +{ + MPIR_Assert(req->kind == MPIR_REQUEST_KIND__CONTINUE); + for (int i = 0; i < MPIDI_CH4_MAX_VCIS; ++i) { + if (MPL_atomic_relaxed_load_int64(&req->u.cont.state->vci_refcount[i].val) > 0) { + MPIDI_add_vci_to_state(i, state); + } + } +} + MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci(MPIR_Request * req, MPID_Progress_state * state) { state->flag = MPIDI_PROGRESS_ALL; /* TODO: check request is_local/anysource */ - int vci = MPIDI_Request_get_vci(req); + state->vci_count = 0; + if (req->kind == MPIR_REQUEST_KIND__CONTINUE) { + MPIDI_add_progress_vci_cont(req, state); + } else { + int vci = MPIDI_Request_get_vci(req); - state->vci_count = 1; - state->vci[0] = vci; + state->vci_count = 1; + state->vci[0] = vci; + } } MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** reqs, @@ -24,6 +53,7 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** re { state->flag = MPIDI_PROGRESS_ALL; /* TODO: check request is_local/anysource */ + state->vci_count = 0; int idx = 0; for (int i = 0; i < n; i++) { if (!MPIR_Request_is_active(reqs[i])) { @@ -34,16 +64,11 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_set_progress_vci_n(int n, MPIR_Request ** re continue; } - int vci = MPIDI_Request_get_vci(reqs[i]); - int found = 0; - for (int j = 0; j < idx; j++) { - if (state->vci[j] == vci) { - found = 1; - break; - } - } - if (!found) { - state->vci[idx++] = vci; + if (reqs[i]->kind == MPIR_REQUEST_KIND__CONTINUE) { + MPIDI_add_progress_vci_cont(reqs[i], state); + } else { + int vci = MPIDI_Request_get_vci(reqs[i]); + MPIDI_add_vci_to_state(vci, state); } } state->vci_count = idx;