-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
api: Update internal send/recv function signatures #820
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,26 +104,27 @@ static ncclResult_t nccl_net_ofi_retval_translate_impl(int retval) | |
|
||
|
||
/** | ||
* @brief Verifies if a message length is within the maximum allowed size | ||
* @brief Convert message length from int type to size_t type. | ||
* | ||
* @return ncclSuccess, if size is valid | ||
* ncclInternalError, if exceeded | ||
* @return ncclSuccess, if convertion is successful | ||
* ncclInvalidArgument, if input is NULL | ||
* ncclInternalError, if message length doesn't fit into size_t range | ||
*/ | ||
static inline ncclResult_t convert_int_msg_sizes_to_size_t(const int* sizes, size_t *sizes_size_t, const size_t n) { | ||
if (OFI_UNLIKELY(sizes == NULL)) { | ||
NCCL_OFI_WARN("Invalid argument: NULL pointer provided for sizes array"); | ||
return ncclInvalidArgument; | ||
} | ||
|
||
static inline ncclResult_t msg_length_verify_max_size(const size_t *sizes, const size_t len) { | ||
if (OFI_UNLIKELY(sizes == NULL)) { | ||
NCCL_OFI_WARN("Invalid argument: NULL pointer provided for sizes array"); | ||
return ncclInvalidArgument; | ||
} | ||
|
||
for (size_t i = 0; i < len; i++) { | ||
if (OFI_UNLIKELY(sizes[i] > INT_MAX)) { | ||
NCCL_OFI_WARN("Message size %zu exceeds maximum allowed size %d at index %zu", sizes[i], INT_MAX, i); | ||
return ncclInternalError; | ||
} | ||
} | ||
return ncclSuccess; | ||
} | ||
for (size_t i = 0; i < n; i++) { | ||
if (OFI_UNLIKELY(sizes[i] < 0)) { | ||
NCCL_OFI_WARN("Message size %d can't be negative at index %zu", sizes[i], i); | ||
return ncclInternalError; | ||
} | ||
sizes_size_t[i] = (size_t)sizes[i]; | ||
} | ||
return ncclSuccess; | ||
} | ||
|
||
|
||
static void nccl_net_ofi_fini(void) | ||
|
@@ -565,7 +566,7 @@ ncclResult_t nccl_net_ofi_accept_v4(void* listenComm, void** recvComm) | |
} | ||
|
||
|
||
ncclResult_t nccl_net_ofi_isend(void *sComm, void* data, int size, | ||
ncclResult_t nccl_net_ofi_isend(void *sComm, void* data, size_t size, | ||
int tag, void *mhandle, void** req) | ||
{ | ||
nccl_net_ofi_send_comm_t *send_comm = | ||
|
@@ -709,23 +710,37 @@ ncclResult_t nccl_net_ofi_iread(void* rComm, void* dest, size_t size, void* mhan | |
ncclResult_t nccl_net_ofi_isend_v4(void* sendComm, void* data, int size, | ||
void* mhandle, void** request) | ||
{ | ||
return nccl_net_ofi_isend(sendComm, data, size, 0, mhandle, request); | ||
size_t size_size_t; | ||
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1); | ||
if (validation_result != ncclSuccess) { | ||
return check_return(validation_result); | ||
} | ||
|
||
return nccl_net_ofi_isend(sendComm, data, size_size_t, 0, mhandle, request); | ||
} | ||
|
||
|
||
ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size, | ||
ncclResult_t nccl_net_ofi_isend_v8(void* sendComm, void* data, int size, | ||
int tag, void* mhandle, void** request) | ||
{ | ||
ncclResult_t validation_result = msg_length_verify_max_size(&size, 1); | ||
size_t size_size_t; | ||
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1); | ||
if (validation_result != ncclSuccess) { | ||
return check_return(validation_result); | ||
} | ||
|
||
return nccl_net_ofi_isend(sendComm, data, (int)size, tag, mhandle, request); | ||
return nccl_net_ofi_isend(sendComm, data, size_size_t, tag, mhandle, request); | ||
} | ||
|
||
|
||
ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size, | ||
int tag, void* mhandle, void** request) | ||
{ | ||
return nccl_net_ofi_isend(sendComm, data, size, tag, mhandle, request); | ||
} | ||
Comment on lines
+736
to
740
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The isend_v9 interface looks the same as the isend interface. Do we really need both? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the two we shouldn't need. We have confusion around APIs every time we need to add / change an interface function. What we have been doing (mostly, we've screwed this up in the past) is that the latest version API we support gets the un-versioned names. When a function interface changes, we add a function suffixed with I think we should change our operations. Every API function should have a version suffix of the API version in which the function was first used. When a function's prototype or behavior changes, we add a new version of the api with a version suffix of the API version in which the change occurred. Then we only have to copy the v[N-1] block to a vN block, change the few functions in the new version that changed, and not touch the past at all. I think this is more intuitive, but also means we don't change the past, which gives me a bit of comfort. What I think we should do with this patch is not touch the unversioned interface functions. Just have this be changing the core interface from int to size_t, removing the overflow check for the size_t -> int cast, and adding the handling to old functions to pass a size_t array instead of an int array into the internal recv function. In another patch, we should rename all the old functions to follow the "first time added" behavior, so we can get rid of some of this version madness. @rajachan thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
ncclResult_t nccl_net_ofi_irecv(void* rComm, int n, void** buffers, int* sizes, | ||
ncclResult_t nccl_net_ofi_irecv(void* rComm, int n, void** buffers, size_t* sizes, | ||
int *tags, void** mhandles, void** req) | ||
{ | ||
nccl_net_ofi_recv_comm_t *recv_comm = | ||
|
@@ -769,17 +784,22 @@ ncclResult_t nccl_net_ofi_irecv_v4(void* recvComm, void* data, int size, | |
void* mhandle, void** request) | ||
{ | ||
int tag = 0; | ||
size_t size_size_t = 0; | ||
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1); | ||
if (validation_result != ncclSuccess) { | ||
return check_return(validation_result); | ||
} | ||
|
||
return nccl_net_ofi_irecv(recvComm, 1, &data, &size, &tag, &mhandle, request); | ||
return nccl_net_ofi_irecv(recvComm, 1, &data, &size_size_t, &tag, &mhandle, request); | ||
} | ||
|
||
|
||
ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data, | ||
size_t* sizes, int* tags, void** mhandles, void** request) | ||
ncclResult_t nccl_net_ofi_irecv_v8(void* recvComm, int n, void** data, | ||
int* sizes, int* tags, void** mhandles, void** request) | ||
{ | ||
if (OFI_UNLIKELY(recvComm == NULL || data == NULL || | ||
sizes == NULL || tags == NULL || | ||
mhandles == NULL || request == NULL)) { | ||
sizes == NULL || tags == NULL || | ||
mhandles == NULL || request == NULL)) { | ||
NCCL_OFI_WARN("Invalid argument: NULL pointer detected"); | ||
return check_return(ncclInvalidArgument); | ||
} | ||
|
@@ -789,17 +809,20 @@ ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data, | |
return check_return(ncclInvalidArgument); | ||
} | ||
|
||
ncclResult_t validation_result = msg_length_verify_max_size(sizes, n); | ||
size_t sizes_size_t[NCCL_OFI_MAX_RECVS] = {0}; | ||
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(sizes, sizes_size_t, n); | ||
if (validation_result != ncclSuccess) { | ||
return check_return(validation_result); | ||
} | ||
|
||
int sizesInt[NCCL_OFI_MAX_RECVS] = {0}; | ||
for (int i = 0; i < n; i++) { | ||
sizesInt[i] = (int)sizes[i]; | ||
} | ||
return nccl_net_ofi_irecv(recvComm, n, data, sizes_size_t, tags, mhandles, request); | ||
} | ||
|
||
|
||
return nccl_net_ofi_irecv(recvComm, n, data, sizesInt, tags, mhandles, request); | ||
ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data, | ||
size_t* sizes, int* tags, void** mhandles, void** request) | ||
{ | ||
return nccl_net_ofi_irecv(recvComm, n, data, sizes, tags, mhandles, request); | ||
} | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3464,7 +3464,7 @@ static int process_cq_if_pending(nccl_net_ofi_rdma_ep_t *ep) | |
} | ||
|
||
static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, | ||
int *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles, | ||
size_t *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles, | ||
nccl_net_ofi_req_t **base_req) | ||
{ | ||
int ret = 0; | ||
|
@@ -5773,7 +5773,7 @@ static inline int check_post_rx_buff_req(nccl_net_ofi_rdma_req_t *rx_buff_req) | |
* @brief Send a message. This "interface function" is called, indirectly, from | ||
* the application | ||
*/ | ||
static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag, | ||
static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just glancing at uses of size, it looks like we pass it into tracing and so into LTTNG (for send, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This patch should update LTTNG as well, in that case. |
||
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **base_req) | ||
{ | ||
int ret = 0; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't bozo check in the critical path. You shouldn't need this check function any more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
today, user is technically allowed to provide a negative value, and we implicitly type case to size_t, for example, when calling
alloc_rdma_send_req
insend()
. Shouldn't that be checked and dis-allowed?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Length is always a non-negative value. NCCL isn't going to pass a negative number. It might pass a very large number, which is why I was worried about truncation. But a bunch of unnecessary checks just add latency.