Skip to content

api: Update internal API function interface and versioning convention #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ struct nccl_net_ofi_send_comm {
*/
int (*deregMr)(nccl_net_ofi_send_comm_t *send_comm, nccl_net_ofi_mr_handle_t *mhandle);

int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **req);

int (*close)(nccl_net_ofi_send_comm_t *send_comm);
Expand Down Expand Up @@ -554,7 +554,7 @@ struct nccl_net_ofi_recv_comm {
*/
int (*deregMr)(nccl_net_ofi_recv_comm_t *recv_comm, nccl_net_ofi_mr_handle_t *mhandle);

int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes, int *tags,
int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, size_t *sizes, int *tags,
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req);

int (*flush)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes,
Expand Down
20 changes: 11 additions & 9 deletions include/nccl_ofi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,28 @@ ncclResult_t nccl_net_ofi_init(ncclDebugLogger_t logFunction);
ncclResult_t nccl_net_ofi_devices(int *ndev);
ncclResult_t nccl_net_ofi_get_properties(int dev, struct nccl_ofi_properties *ofi_properties);
ncclResult_t nccl_net_ofi_listen(int dev, void *handle, void **listenComm);
ncclResult_t nccl_net_ofi_listen_v4(int dev, void* handle, void** listenComm);
ncclResult_t nccl_net_ofi_listen_v2(int dev, void* handle, void** listenComm);
ncclResult_t nccl_net_ofi_connect(int dev, void* handle, void** sendComm);
ncclResult_t nccl_net_ofi_connect_v4(int dev, void* handle, void** sendComm);
ncclResult_t nccl_net_ofi_connect_v2(int dev, void* handle, void** sendComm);
ncclResult_t nccl_net_ofi_accept(void *listenComm, void **recvComm);
ncclResult_t nccl_net_ofi_accept_v4(void* listenComm, void** recvComm);
ncclResult_t nccl_net_ofi_regMr_v7(void *comm, void *data, int size, int type,
ncclResult_t nccl_net_ofi_accept_v2(void* listenComm, void** recvComm);
ncclResult_t nccl_net_ofi_regMr_v2(void *comm, void *data, int size, int type,
void **mhandle);
ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, size_t size, int type,
void **mhandle);
ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle);
ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle);
ncclResult_t nccl_net_ofi_isend(void *sendComm, void* data, int size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_isend_v4(void* sendComm, void* data, int size, void* mhandle, void** request);
ncclResult_t nccl_net_ofi_isend(void *sendComm, void* data, size_t size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_isend_v2(void* sendComm, void* data, int size, void* mhandle, void** request);
ncclResult_t nccl_net_ofi_isend_v5(void *sendComm, void* data, int size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_isend_v9(void *sendComm, void* data, size_t size, int tag, void *mhandle, void** request);
ncclResult_t nccl_net_ofi_irecv(void* recvComm, int n, void** buffers, int* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_irecv_v4(void* recvComm, void* data, int size, void* mhandle, void** request);
ncclResult_t nccl_net_ofi_irecv(void* recvComm, int n, void** buffers, size_t* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_irecv_v2(void* recvComm, void* data, int size, void* mhandle, void** request);
ncclResult_t nccl_net_ofi_irecv_v5(void* recvComm, int n, void** buffers, int* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** buffers, size_t* sizes, int *tags, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_test(void *request, int *done, int *size);
ncclResult_t nccl_net_ofi_iflush(void* recvComm, int n, void** buffers, int* sizes, void** mhandles, void** request);
ncclResult_t nccl_net_ofi_flush_v3(void* recvComm, void* data, int size, void* mhandle);
ncclResult_t nccl_net_ofi_iflush_v2(void* recvComm, void* data, int size, void* mhandle);
ncclResult_t nccl_net_ofi_iflush_v4(void* recvComm, void* data, int size, void* mhandle, void** request);
ncclResult_t nccl_net_ofi_closeSend(void *sendComm);
ncclResult_t nccl_net_ofi_closeRecv(void *recvComm);
Expand Down
8 changes: 4 additions & 4 deletions include/tracing_impl/lttng.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ LTTNG_UST_TRACEPOINT_EVENT(
Send,
LTTNG_UST_TP_ARGS(
int, dev,
int, size,
size_t, size,
void *, comm,
uint16_t, msg_seq_num,
void *, request,
void *, nccl_req
),
LTTNG_UST_TP_FIELDS(
lttng_ust_field_integer(int, dev, dev)
lttng_ust_field_integer(int, size, size)
lttng_ust_field_integer(size_t, size, size)
lttng_ust_field_integer_hex(uint64_t, comm, (uint64_t)comm)
lttng_ust_field_integer(uint16_t, msg_seq_num, msg_seq_num)
lttng_ust_field_integer_hex(uint64_t, request, (uint64_t)request)
Expand Down Expand Up @@ -225,14 +225,14 @@ LTTNG_UST_TRACEPOINT_EVENT(
LTTNG_UST_TP_ARGS(
int, dev,
int, comm_id,
int, size,
size_t, size,
void *, request,
void *, nccl_req
),
LTTNG_UST_TP_FIELDS(
lttng_ust_field_integer(int, dev, dev)
lttng_ust_field_integer(int, comm_id, comm_id)
lttng_ust_field_integer(int, size, size)
lttng_ust_field_integer(size_t, size, size)
lttng_ust_field_integer_hex(uint64_t, request, (uint64_t)request)
lttng_ust_field_integer_hex(uint64_t, nccl_req, (uint64_t)nccl_req)
)
Expand Down
100 changes: 59 additions & 41 deletions src/nccl_ofi_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,22 @@ static ncclResult_t nccl_net_ofi_retval_translate_impl(int retval)


/**
* @brief Verifies if a message length is within the maximum allowed size
* @brief Convert message length from int type to size_t type.
*
* @return ncclSuccess, if size is valid
* ncclInternalError, if exceeded
* @return ncclSuccess, if convertion is successful
* ncclInvalidArgument, if input is NULL
*/
static inline ncclResult_t convert_int_msg_sizes_to_size_t(const int* sizes, size_t *sizes_size_t, const size_t n) {
if (OFI_UNLIKELY(sizes == NULL)) {
NCCL_OFI_WARN("Invalid argument: NULL pointer provided for sizes array");
return ncclInvalidArgument;
}

static inline ncclResult_t msg_length_verify_max_size(const size_t *sizes, const size_t len) {
if (OFI_UNLIKELY(sizes == NULL)) {
NCCL_OFI_WARN("Invalid argument: NULL pointer provided for sizes array");
return ncclInvalidArgument;
}

for (size_t i = 0; i < len; i++) {
if (OFI_UNLIKELY(sizes[i] > INT_MAX)) {
NCCL_OFI_WARN("Message size %zu exceeds maximum allowed size %d at index %zu", sizes[i], INT_MAX, i);
return ncclInternalError;
}
}
return ncclSuccess;
}
for (size_t i = 0; i < n; i++) {
sizes_size_t[i] = (size_t)sizes[i];
}
return ncclSuccess;
}


static void nccl_net_ofi_fini(void)
Expand Down Expand Up @@ -248,7 +244,7 @@ ncclResult_t nccl_net_ofi_listen(int dev_id, void *handle, void **lComm)
}


ncclResult_t nccl_net_ofi_listen_v4(int dev, void* handle, void** listenComm)
ncclResult_t nccl_net_ofi_listen_v2(int dev, void* handle, void** listenComm)
{
nccl_net_ofi_conn_handle_t nccl_net_ofi_handle = {};
ncclResult_t ret;
Expand Down Expand Up @@ -341,7 +337,7 @@ ncclResult_t nccl_net_ofi_connect(int dev_id, void *handle, void **sComm)
}


ncclResult_t nccl_net_ofi_connect_v4(int dev, void* handle, void** sendComm)
ncclResult_t nccl_net_ofi_connect_v2(int dev, void* handle, void** sendComm)
{
ncclResult_t ret = ncclSuccess;
nccl_net_ofi_conn_handle_t nccl_net_ofi_handle = {};
Expand All @@ -362,7 +358,7 @@ ncclResult_t nccl_net_ofi_connect_v4(int dev, void* handle, void** sendComm)
return ret;
}

ncclResult_t nccl_net_ofi_regMr_v7(void *comm, void *data, int size, int type,
ncclResult_t nccl_net_ofi_regMr_v2(void *comm, void *data, int size, int type,
void **mhandle)
{
return nccl_net_ofi_regMr(comm, data, (size_t)size, type, mhandle);
Expand Down Expand Up @@ -544,7 +540,7 @@ ncclResult_t nccl_net_ofi_accept(void *lComm, void **rComm)
}


ncclResult_t nccl_net_ofi_accept_v4(void* listenComm, void** recvComm)
ncclResult_t nccl_net_ofi_accept_v2(void* listenComm, void** recvComm)
{
ncclResult_t ret = ncclInvalidArgument;

Expand All @@ -565,7 +561,7 @@ ncclResult_t nccl_net_ofi_accept_v4(void* listenComm, void** recvComm)
}


ncclResult_t nccl_net_ofi_isend(void *sComm, void* data, int size,
ncclResult_t nccl_net_ofi_isend(void *sComm, void* data, size_t size,
int tag, void *mhandle, void** req)
{
nccl_net_ofi_send_comm_t *send_comm =
Expand Down Expand Up @@ -706,26 +702,40 @@ ncclResult_t nccl_net_ofi_iread(void* rComm, void* dest, size_t size, void* mhan
}


ncclResult_t nccl_net_ofi_isend_v4(void* sendComm, void* data, int size,
ncclResult_t nccl_net_ofi_isend_v2(void* sendComm, void* data, int size,
void* mhandle, void** request)
{
return nccl_net_ofi_isend(sendComm, data, size, 0, mhandle, request);
size_t size_size_t;
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

return nccl_net_ofi_isend(sendComm, data, size_size_t, 0, mhandle, request);
}


ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size,
ncclResult_t nccl_net_ofi_isend_v5(void* sendComm, void* data, int size,
int tag, void* mhandle, void** request)
{
ncclResult_t validation_result = msg_length_verify_max_size(&size, 1);
size_t size_size_t;
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

return nccl_net_ofi_isend(sendComm, data, (int)size, tag, mhandle, request);
return nccl_net_ofi_isend(sendComm, data, size_size_t, tag, mhandle, request);
}


ncclResult_t nccl_net_ofi_irecv(void* rComm, int n, void** buffers, int* sizes,
ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size,
int tag, void* mhandle, void** request)
{
return nccl_net_ofi_isend(sendComm, data, size, tag, mhandle, request);
}


ncclResult_t nccl_net_ofi_irecv(void* rComm, int n, void** buffers, size_t* sizes,
int *tags, void** mhandles, void** req)
{
nccl_net_ofi_recv_comm_t *recv_comm =
Expand Down Expand Up @@ -765,21 +775,26 @@ ncclResult_t nccl_net_ofi_irecv(void* rComm, int n, void** buffers, int* sizes,
}


ncclResult_t nccl_net_ofi_irecv_v4(void* recvComm, void* data, int size,
ncclResult_t nccl_net_ofi_irecv_v2(void* recvComm, void* data, int size,
void* mhandle, void** request)
{
int tag = 0;
size_t size_size_t = 0;
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(&size, &size_size_t, 1);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

return nccl_net_ofi_irecv(recvComm, 1, &data, &size, &tag, &mhandle, request);
return nccl_net_ofi_irecv(recvComm, 1, &data, &size_size_t, &tag, &mhandle, request);
}


ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
size_t* sizes, int* tags, void** mhandles, void** request)
ncclResult_t nccl_net_ofi_irecv_v5(void* recvComm, int n, void** data,
int* sizes, int* tags, void** mhandles, void** request)
{
if (OFI_UNLIKELY(recvComm == NULL || data == NULL ||
sizes == NULL || tags == NULL ||
mhandles == NULL || request == NULL)) {
sizes == NULL || tags == NULL ||
mhandles == NULL || request == NULL)) {
NCCL_OFI_WARN("Invalid argument: NULL pointer detected");
return check_return(ncclInvalidArgument);
}
Expand All @@ -789,17 +804,20 @@ ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
return check_return(ncclInvalidArgument);
}

ncclResult_t validation_result = msg_length_verify_max_size(sizes, n);
size_t sizes_size_t[NCCL_OFI_MAX_RECVS] = {0};
ncclResult_t validation_result = convert_int_msg_sizes_to_size_t(sizes, sizes_size_t, n);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
}

int sizesInt[NCCL_OFI_MAX_RECVS] = {0};
for (int i = 0; i < n; i++) {
sizesInt[i] = (int)sizes[i];
}
return nccl_net_ofi_irecv(recvComm, n, data, sizes_size_t, tags, mhandles, request);
}

return nccl_net_ofi_irecv(recvComm, n, data, sizesInt, tags, mhandles, request);

ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
size_t* sizes, int* tags, void** mhandles, void** request)
{
return nccl_net_ofi_irecv(recvComm, n, data, sizes, tags, mhandles, request);
}


Expand Down Expand Up @@ -856,7 +874,7 @@ ncclResult_t nccl_net_ofi_iflush(void* rComm, int n, void** buffers, int* sizes,
}


ncclResult_t nccl_net_ofi_flush_v3(void* recvComm, void* data, int size, void* mhandle)
ncclResult_t nccl_net_ofi_iflush_v2(void* recvComm, void* data, int size, void* mhandle)
{
void *req = NULL;
ncclResult_t ret = ncclSuccess;
Expand Down
18 changes: 9 additions & 9 deletions src/nccl_ofi_interface_neuron.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v5_t ncclNetPlugin_v5 = {
.regMr = nccl_net_ofi_regMr,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.isend = nccl_net_ofi_isend_v5,
.irecv = nccl_net_ofi_irecv_v5,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
Expand All @@ -110,7 +110,7 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v5_t ncclNetPlugin_v5 = {
.iread = nccl_net_ofi_iread,
};

static ncclResult_t getProperties_v4(int dev_id, ncclNetProperties_v4_t *props)
static ncclResult_t getProperties_v3(int dev_id, ncclNetProperties_v4_t *props)
{
nccl_ofi_properties_t ofi_properties;
ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties);
Expand All @@ -136,14 +136,14 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v4_t ncclNetPlugin_v4 = {
.name = "AWS Libfabric",
.init = init_v4,
.devices = nccl_net_ofi_devices,
.getProperties = getProperties_v4,
.listen = nccl_net_ofi_listen_v4,
.connect = nccl_net_ofi_connect_v4,
.accept = nccl_net_ofi_accept_v4,
.getProperties = getProperties_v3,
.listen = nccl_net_ofi_listen_v2,
.connect = nccl_net_ofi_connect_v2,
.accept = nccl_net_ofi_accept_v2,
.regMr = nccl_net_ofi_regMr,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend_v4,
.irecv = nccl_net_ofi_irecv_v4,
.isend = nccl_net_ofi_isend_v2,
.irecv = nccl_net_ofi_irecv_v2,
.iflush = nccl_net_ofi_iflush_v4,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
Expand Down
Loading
Loading