Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: Update internal send/recv function signatures #820

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ struct nccl_net_ofi_send_comm {
*/
int (*deregMr)(nccl_net_ofi_send_comm_t *send_comm, nccl_net_ofi_mr_handle_t *mhandle);

int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **req);

int (*close)(nccl_net_ofi_send_comm_t *send_comm);
Expand Down Expand Up @@ -554,7 +554,7 @@ struct nccl_net_ofi_recv_comm {
*/
int (*deregMr)(nccl_net_ofi_recv_comm_t *recv_comm, nccl_net_ofi_mr_handle_t *mhandle);

int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes, int *tags,
int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, size_t *sizes, int *tags,
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req);

int (*flush)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes,
Expand Down
6 changes: 4 additions & 2 deletions include/nccl_ofi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, size_t size, int type,
void **mhandle);
ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle);
ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle);
ncclResult_t nccl_net_ofi_isend(void *sendComm, void* data, int size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_isend(void *sendComm, void* data, size_t size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_isend_v4(void* sendComm, void* data, int size, void* mhandle, void** request);
ncclResult_t nccl_net_ofi_isend_v8(void *sendComm, void* data, int size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_isend_v9(void *sendComm, void* data, size_t size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_irecv(void* recvComm, int n, void** buffers, int* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_irecv(void* recvComm, int n, void** buffers, size_t* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_irecv_v4(void* recvComm, void* data, int size, void* mhandle, void** request);
ncclResult_t nccl_net_ofi_irecv_v8(void* recvComm, int n, void** buffers, int* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** buffers, size_t* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_test(void *request, int *done, int *size);
ncclResult_t nccl_net_ofi_iflush(void* recvComm, int n, void** buffers, int* sizes, void** mhandles, void** request);
Expand Down
91 changes: 57 additions & 34 deletions src/nccl_ofi_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,27 @@ static ncclResult_t nccl_net_ofi_retval_translate_impl(int retval)


/**
* @brief Verifies if a message length is within the maximum allowed size
* @brief Convert message length from int type to size_t type.
*
* @return ncclSuccess, if size is valid
* ncclInternalError, if exceeded
* @return ncclSuccess, if convertion is successful
* ncclInvalidArgument, if input is NULL
* ncclInternalError, if message length doesn't fit into size_t range
*/
static inline ncclResult_t convert_int_msg_sizes_to_size_t(const int* sizes, size_t *sizes_size_t, const size_t n) {
if (OFI_UNLIKELY(sizes == NULL)) {
NCCL_OFI_WARN("Invalid argument: NULL pointer provided for sizes array");
return ncclInvalidArgument;
}

static inline ncclResult_t msg_length_verify_max_size(const size_t *sizes, const size_t len) {
if (OFI_UNLIKELY(sizes == NULL)) {
NCCL_OFI_WARN("Invalid argument: NULL pointer provided for sizes array");
return ncclInvalidArgument;
}

for (size_t i = 0; i < len; i++) {
if (OFI_UNLIKELY(sizes[i] > INT_MAX)) {
NCCL_OFI_WARN("Message size %zu exceeds maximum allowed size %d at index %zu", sizes[i], INT_MAX, i);
return ncclInternalError;
}
}
return ncclSuccess;
}
for (size_t i = 0; i < n; i++) {
if (OFI_UNLIKELY(sizes[i] < 0)) {
NCCL_OFI_WARN("Message size %d can't be negative at index %zu", sizes[i], i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't bozo check in the critical path. You shouldn't need this check function any more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

today, user is technically allowed to provide a negative value, and we implicitly type case to size_t, for example, when calling alloc_rdma_send_req in send(). Shouldn't that be checked and dis-allowed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Length is always a non-negative value. NCCL isn't going to pass a negative number. It might pass a very large number, which is why I was worried about truncation. But a bunch of unnecessary checks just add latency.

return ncclInternalError;
}
sizes_size_t[i] = (size_t)sizes[i];
}
return ncclSuccess;
}


static void nccl_net_ofi_fini(void)
Expand Down Expand Up @@ -565,7 +566,7 @@ ncclResult_t nccl_net_ofi_accept_v4(void* listenComm, void** recvComm)
}


ncclResult_t nccl_net_ofi_isend(void *sComm, void* data, int size,
ncclResult_t nccl_net_ofi_isend(void *sComm, void* data, size_t size,
int tag, void *mhandle, void** req)
{
nccl_net_ofi_send_comm_t *send_comm =
Expand Down Expand Up @@ -709,23 +710,37 @@ ncclResult_t nccl_net_ofi_iread(void* rComm, void* dest, size_t size, void* mhan
ncclResult_t nccl_net_ofi_isend_v4(void* sendComm, void* data, int size,
void* mhandle, void** request)
{
return nccl_net_ofi_isend(sendComm, data, size, 0, mhandle, request);
size_t size_size_t;
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

return nccl_net_ofi_isend(sendComm, data, size_size_t, 0, mhandle, request);
}


ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size,
ncclResult_t nccl_net_ofi_isend_v8(void* sendComm, void* data, int size,
int tag, void* mhandle, void** request)
{
ncclResult_t validation_result = msg_length_verify_max_size(&size, 1);
size_t size_size_t;
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

return nccl_net_ofi_isend(sendComm, data, (int)size, tag, mhandle, request);
return nccl_net_ofi_isend(sendComm, data, size_size_t, tag, mhandle, request);
}


ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size,
int tag, void* mhandle, void** request)
{
return nccl_net_ofi_isend(sendComm, data, size, tag, mhandle, request);
}
Comment on lines +736 to 740
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The isend_v9 interface looks the same as the isend interface. Do we really need both?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the two we shouldn't need. We have confusion around APIs every time we need to add / change an interface function.

What we have been doing (mostly, we've screwed this up in the past) is that the latest version API we support gets the un-versioned names. When a function interface changes, we add a function suffixed with _v[N-1] and update all the previous code blocks. To me, this is not intuitive, and I don't love that we have to update all the old interfaces in an error-prone process.

I think we should change our operations. Every API function should have a version suffix of the API version in which the function was first used. When a function's prototype or behavior changes, we add a new version of the api with a version suffix of the API version in which the change occurred. Then we only have to copy the v[N-1] block to a vN block, change the few functions in the new version that changed, and not touch the past at all. I think this is more intuitive, but also means we don't change the past, which gives me a bit of comfort.

What I think we should do with this patch is not touch the unversioned interface functions. Just have this be changing the core interface from int to size_t, removing the overflow check for the size_t -> int cast, and adding the handling to old functions to pass a size_t array instead of an int array into the internal recv function.

In another patch, we should rename all the old functions to follow the "first time added" behavior, so we can get rid of some of this version madness.

@rajachan thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. agreed with the point of the naming confusion, suffixing with version firstly used makes more sense to me.
  2. irrelevant to what we suffix the old version interfaces with, if there's an internal function signature change(like this patch), won't we need to touch old versions in an error-prone process any ways? e.g. the _v4 function takes in an int, whether it's called _v4, or _v2, we still need an type conversion to size_t



ncclResult_t nccl_net_ofi_irecv(void* rComm, int n, void** buffers, int* sizes,
ncclResult_t nccl_net_ofi_irecv(void* rComm, int n, void** buffers, size_t* sizes,
int *tags, void** mhandles, void** req)
{
nccl_net_ofi_recv_comm_t *recv_comm =
Expand Down Expand Up @@ -769,17 +784,22 @@ ncclResult_t nccl_net_ofi_irecv_v4(void* recvComm, void* data, int size,
void* mhandle, void** request)
{
int tag = 0;
size_t size_size_t = 0;
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

return nccl_net_ofi_irecv(recvComm, 1, &data, &size, &tag, &mhandle, request);
return nccl_net_ofi_irecv(recvComm, 1, &data, &size_size_t, &tag, &mhandle, request);
}


ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
size_t* sizes, int* tags, void** mhandles, void** request)
ncclResult_t nccl_net_ofi_irecv_v8(void* recvComm, int n, void** data,
int* sizes, int* tags, void** mhandles, void** request)
{
if (OFI_UNLIKELY(recvComm == NULL || data == NULL ||
sizes == NULL || tags == NULL ||
mhandles == NULL || request == NULL)) {
sizes == NULL || tags == NULL ||
mhandles == NULL || request == NULL)) {
NCCL_OFI_WARN("Invalid argument: NULL pointer detected");
return check_return(ncclInvalidArgument);
}
Expand All @@ -789,17 +809,20 @@ ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
return check_return(ncclInvalidArgument);
}

ncclResult_t validation_result = msg_length_verify_max_size(sizes, n);
size_t sizes_size_t[NCCL_OFI_MAX_RECVS] = {0};
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(sizes, sizes_size_t, n);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

int sizesInt[NCCL_OFI_MAX_RECVS] = {0};
for (int i = 0; i < n; i++) {
sizesInt[i] = (int)sizes[i];
}
return nccl_net_ofi_irecv(recvComm, n, data, sizes_size_t, tags, mhandles, request);
}


return nccl_net_ofi_irecv(recvComm, n, data, sizesInt, tags, mhandles, request);
ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
size_t* sizes, int* tags, void** mhandles, void** request)
{
return nccl_net_ofi_irecv(recvComm, n, data, sizes, tags, mhandles, request);
}


Expand Down
4 changes: 2 additions & 2 deletions src/nccl_ofi_interface_neuron.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v5_t ncclNetPlugin_v5 = {
.regMr = nccl_net_ofi_regMr,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.isend = nccl_net_ofi_isend_v8,
.irecv = nccl_net_ofi_irecv_v8,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
Expand Down
16 changes: 8 additions & 8 deletions src/nccl_ofi_interface_nvidia.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v5_t ncclNetPlugin_v5 = {
.accept = nccl_net_ofi_accept,
.regMr = nccl_net_ofi_regMr_v7,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.isend = nccl_net_ofi_isend_v8,
.irecv = nccl_net_ofi_irecv_v8,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
Expand All @@ -324,8 +324,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v6_t ncclNetPlugin_v6 = {
.regMr = nccl_net_ofi_regMr_v7,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.isend = nccl_net_ofi_isend_v8,
.irecv = nccl_net_ofi_irecv_v8,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
Expand All @@ -344,8 +344,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v7_t ncclNetPlugin_v7 = {
.regMr = nccl_net_ofi_regMr_v7,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.isend = nccl_net_ofi_isend_v8,
.irecv = nccl_net_ofi_irecv_v8,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
Expand All @@ -366,8 +366,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v8_t ncclNetPlugin_v8 = {
.regMr = nccl_net_ofi_regMr,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.isend = nccl_net_ofi_isend_v8,
.irecv = nccl_net_ofi_irecv_v8,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
Expand Down
4 changes: 2 additions & 2 deletions src/nccl_ofi_rdma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3464,7 +3464,7 @@ static int process_cq_if_pending(nccl_net_ofi_rdma_ep_t *ep)
}

static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
int *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
size_t *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
nccl_net_ofi_req_t **base_req)
{
int ret = 0;
Expand Down Expand Up @@ -5773,7 +5773,7 @@ static inline int check_post_rx_buff_req(nccl_net_ofi_rdma_req_t *rx_buff_req)
* @brief Send a message. This "interface function" is called, indirectly, from
* the application
*/
static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just glancing at uses of size, it looks like we pass it into tracing and so into LTTNG (for send, NCCL_OFI_TRACE_SEND), so we should make sure this patch works with LTTNG enabled (--with-lttng). I'm guessing we'll get some complaint about trying to fit a size_t into an int, since we haven't updated LTTNG yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch should update LTTNG as well, in that case.

nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **base_req)
{
int ret = 0;
Expand Down
4 changes: 2 additions & 2 deletions src/nccl_ofi_sendrecv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ static inline nccl_net_ofi_sendrecv_req_t *sendrecv_allocate_req(nccl_ofi_freeli
}

static int sendrecv_recv_comm_recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
int *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
size_t *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
nccl_net_ofi_req_t **base_req)
{
int ret = 0;
Expand Down Expand Up @@ -1752,7 +1752,7 @@ static int sendrecv_send_comm_dereg_mr(nccl_net_ofi_send_comm_t *send_comm,
domain->base.mr_cache);
}

static int sendrecv_send_comm_send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
static int sendrecv_send_comm_send(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **base_req)
{
int ret = 0;
Expand Down
Loading