Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(tree): move declarations to top of function (aws#563)
Browse files Browse the repository at this point in the history
c++ suppoorts initializers anywhere in the function, but one must not
jump over an initializer with any goto usage. Given the lack of RAII in
C, this becomes a significant painpoint.

In large to-be-eventually-refactored functions contain gotos or use
switch statements, split declaration and initialization, and move all
declarations to the top of the function. This makes switch statements
and gotos safe in both languages.

Signed-off-by: Nicholas Sielicki <[email protected]>
aws-nslick committed Sep 26, 2024

Verified

This commit was signed with the committer’s verified signature.
aws-nslick Nicholas Sielicki
1 parent b525063 commit 2dae7eb
Showing 6 changed files with 109 additions and 71 deletions.
25 changes: 13 additions & 12 deletions src/nccl_ofi_api.c
Original file line number Diff line number Diff line change
@@ -365,15 +365,16 @@ ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size,
const nccl_ofi_mr_ckey_t cache_key = nccl_ofi_mr_ckey_mk_vec(data, size);
#endif

nccl_net_ofi_send_comm_t *send_comm = NULL;
nccl_net_ofi_recv_comm_t *recv_comm = NULL;

switch (base_comm->type) {
case NCCL_NET_OFI_SEND_COMM:;
nccl_net_ofi_send_comm_t *send_comm =
(nccl_net_ofi_send_comm_t *)base_comm;
case NCCL_NET_OFI_SEND_COMM:
send_comm = (nccl_net_ofi_send_comm_t *)base_comm;
ret = send_comm->regMr(send_comm, &cache_key, type, mhandle);
break;
case NCCL_NET_OFI_RECV_COMM:;
nccl_net_ofi_recv_comm_t *recv_comm =
(nccl_net_ofi_recv_comm_t *)base_comm;
case NCCL_NET_OFI_RECV_COMM:
recv_comm = (nccl_net_ofi_recv_comm_t *)base_comm;
ret = recv_comm->regMr(recv_comm, &cache_key, type, mhandle);
break;
default:
@@ -397,16 +398,16 @@ ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle)
}

int ret = 0;
nccl_net_ofi_send_comm_t *send_comm = NULL;
nccl_net_ofi_recv_comm_t *recv_comm = NULL;

switch (base_comm->type) {
case NCCL_NET_OFI_SEND_COMM:;
nccl_net_ofi_send_comm_t *send_comm =
(nccl_net_ofi_send_comm_t *)base_comm;
case NCCL_NET_OFI_SEND_COMM:
send_comm = (nccl_net_ofi_send_comm_t *)base_comm;
ret = send_comm->deregMr(send_comm, (nccl_net_ofi_mr_handle_t *)mhandle);
break;
case NCCL_NET_OFI_RECV_COMM:;
nccl_net_ofi_recv_comm_t *recv_comm =
(nccl_net_ofi_recv_comm_t *)base_comm;
case NCCL_NET_OFI_RECV_COMM:
recv_comm = (nccl_net_ofi_recv_comm_t *)base_comm;
ret = recv_comm->deregMr(recv_comm, (nccl_net_ofi_mr_handle_t *)mhandle);
break;
default:
4 changes: 2 additions & 2 deletions src/nccl_ofi_ep_addr_list.c
Original file line number Diff line number Diff line change
@@ -134,6 +134,7 @@ int nccl_ofi_ep_addr_list_insert(nccl_ofi_ep_addr_list_t *ep_list,
size_t addr_size)
{
int ret = 0;
ep_pair_list_elem_t *new_pair = NULL;

if (addr_size > ep_list->max_addr_size) {
NCCL_OFI_WARN("Address size %zu > max size (%zu)", addr_size,
@@ -155,8 +156,7 @@ int nccl_ofi_ep_addr_list_insert(nccl_ofi_ep_addr_list_t *ep_list,
memcpy(new_addr->addr, addr_in, addr_size);
zero_pad_address(new_addr->addr, addr_size, ep_list->max_addr_size);

ep_pair_list_elem_t *new_pair = (ep_pair_list_elem_t *)
malloc(sizeof(*new_pair));
new_pair = (ep_pair_list_elem_t *)malloc(sizeof(*new_pair));
if (!new_pair) {
NCCL_OFI_WARN("Failed to allocate new ep list element");
free(new_addr);
5 changes: 3 additions & 2 deletions src/nccl_ofi_net.c
Original file line number Diff line number Diff line change
@@ -136,6 +136,8 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
int ret = 0;
const char *provider_filter = NULL;
nccl_net_ofi_plugin_t *plugin;
nccl_net_ofi_ep_t *base_ep = NULL;
nccl_net_ofi_device_t *device = NULL;

NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Initializing " PACKAGE_STRING);

@@ -275,8 +277,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
* resources. This initialization happens once per process, and thus it
* does not matter which device is used to create the endpoint.
*/
nccl_net_ofi_device_t *device = plugin->get_device(plugin, 0);
nccl_net_ofi_ep_t *base_ep = NULL;
device = plugin->get_device(*plugin_p, 0);

ret = device->get_ep(device, &base_ep);
if (ret != 0) {
103 changes: 67 additions & 36 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
@@ -3165,7 +3165,13 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
int ret = 0;
nccl_net_ofi_rdma_req_t *req = NULL;
nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm;
rdma_req_recv_data_t *recv_data = NULL;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_device_t *device = NULL;
int dev_id = 0;
nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles;
uint16_t msg_seq_num = 0;
bool eager = false;

assert(r_comm != NULL);

@@ -3182,12 +3188,12 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
goto error;
}

int dev_id = r_comm->base.base.dev_id;
dev_id = r_comm->base.base.dev_id;

nccl_net_ofi_rdma_ep_t * ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);

nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)ep->base.device;
device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
assert(device != NULL);

ret = process_cq_if_pending(ep);
@@ -3196,13 +3202,14 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
*base_req = NULL;
ret = 0;
goto error;
} else if (ret != 0) {
}
if (ret != 0) {
goto error;
}

uint16_t msg_seq_num = r_comm->next_msg_seq_num;
msg_seq_num = r_comm->next_msg_seq_num;

bool eager = false;
eager = false;
void *elem;
nccl_ofi_msgbuff_elemtype_t type;
nccl_ofi_msgbuff_status_t msg_stat;
@@ -3241,7 +3248,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
goto error;
}

rdma_req_recv_data_t *recv_data = get_recv_data(req);
recv_data = get_recv_data(req);

if (eager) {
nccl_net_ofi_rdma_req_t *bounce_req = (nccl_net_ofi_rdma_req_t *)elem;
@@ -3428,6 +3435,7 @@ static int alloc_and_reg_flush_buff(nccl_net_ofi_rdma_recv_comm_t *r_comm, int d

static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm)
{
nccl_net_ofi_rdma_device_t *device = NULL;
int ret = 0;

/* Retrieve and validate endpoint */
@@ -3438,7 +3446,7 @@ static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm)
return ret;
}

nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)base_ep->device;
device = (nccl_net_ofi_rdma_device_t *)base_ep->device;

if (r_comm->send_close_req != NULL) {
ret = r_comm->send_close_req->free(r_comm->send_close_req, false);
@@ -3811,6 +3819,12 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
nccl_net_ofi_req_t **base_req)
{
int ret = 0;
int flush_n = 0;
bool network_busy = false;
int dev_id = 0;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_device_t *device = NULL;
nccl_net_ofi_scheduler_t *scheduler = NULL;
nccl_net_ofi_rdma_recv_comm_t *r_comm =
(nccl_net_ofi_rdma_recv_comm_t *)recv_comm;

@@ -3827,19 +3841,19 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
goto error;
}

int dev_id = recv_comm->base.dev_id;
dev_id = recv_comm->base.dev_id;

nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);

nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
assert(device != NULL);

nccl_net_ofi_scheduler_t *scheduler = device->scheduler;
scheduler = device->scheduler;
assert(scheduler != NULL);

/* Process any pending requests */
bool network_busy = false;
network_busy = false;
rc = process_cq_if_pending(ep);
if (rc == -EAGAIN) {
/* Network is still busy. */
@@ -3869,7 +3883,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
* Find the non-zero request for which we will issue flush.
* A single operation can flush all request at once.
*/
int flush_n = -1;
flush_n = -1;
for (int recv_n = 0; recv_n < n; recv_n++) {
if (sizes[recv_n] != 0) {
flush_n = recv_n;
@@ -4033,6 +4047,7 @@ static int rma_read(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size
nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm;
nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle;
nccl_net_ofi_rdma_req_t *req = NULL;
nccl_net_ofi_rdma_ep_t *ep = NULL;

assert(r_comm != NULL);
/* Support only NCCL_OFI_MAX_REQUESTS inflight requests. */
@@ -4043,7 +4058,7 @@ static int rma_read(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size
goto error;
}

nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);

ret = process_cq_if_pending(ep);
@@ -4115,7 +4130,9 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen
{
int ret = 0;

int comm_id = 0;
nccl_net_ofi_rdma_recv_comm_t *r_comm = NULL;
nccl_net_ofi_rdma_ep_t *ep = NULL;
int dev_id = device->base.dev_id;
int num_rails = l_comm_ep->num_rails;

@@ -4154,7 +4171,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen
r_comm->n_ctrl_delivered = 0;

/* Allocate recv communicator ID */
int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
if (OFI_UNLIKELY(comm_id < 0)) {
r_comm->local_comm_id = ~0;
goto error;
@@ -4213,7 +4230,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen
r_comm->base.base.ep = &l_comm_ep->base;
}

nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;

/* Add ourselves to ep's lookup array */
set_comm(device, r_comm->local_comm_id, &r_comm->base.base);
@@ -4706,6 +4723,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
{
int ret = 0;
nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL;
int comm_id = 0;
nccl_net_ofi_rdma_ep_t *ep =
(nccl_net_ofi_rdma_ep_t *)base_ep;

@@ -4744,7 +4762,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
l_comm->leader_local_ep = ep->control_rail.ofi_ep;

/* Allocate listen communicator ID */
int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
if (OFI_UNLIKELY(comm_id < 0)) {
l_comm->comm_id = ~0;
ret = comm_id;
@@ -5340,10 +5358,13 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
int ret = 0;
nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm;
nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_req_t *req = NULL;
uint16_t msg_seq_num = s_comm->next_msg_seq_num;
bool polled_cq = false;
bool have_ctrl = false;
bool eager = false;
int dev_id = 0;

assert(s_comm != NULL);

@@ -5361,9 +5382,9 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
goto error;
}

int dev_id = s_comm->base.base.dev_id;
dev_id = s_comm->base.base.dev_id;

nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
assert(ep != NULL);

ret = process_cq_if_pending(ep);
@@ -5372,14 +5393,19 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
*base_req = NULL;
ret = 0;
goto error;
} else if (ret != 0) {
}
if (ret != 0) {
goto error;
}

/*
* TODO: Use NCCL provided tags when using grouped receives aka
* props->maxRecvs > 1.
*/

have_ctrl = false;
msg_seq_num = s_comm->next_msg_seq_num;

void *elem;
nccl_ofi_msgbuff_elemtype_t type;
nccl_ofi_msgbuff_status_t msg_stat;
@@ -5434,7 +5460,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
}

/* Determine if this should be sent eagerly. */
bool eager = false;
eager = false;
if ((!have_ctrl && size <= eager_max_size) ||
(size == 0)) {
eager = true;
@@ -5742,6 +5768,7 @@ static int rma_write_impl(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t
int ret = 0;
nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm;
nccl_net_ofi_rdma_req_t *req = NULL;
nccl_net_ofi_rdma_ep_t *ep = NULL;

assert(s_comm != NULL);

@@ -5753,7 +5780,7 @@ static int rma_write_impl(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t
goto error;
}

nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
assert(ep != NULL);

ret = process_cq_if_pending(ep);
@@ -5861,6 +5888,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
nccl_net_ofi_rdma_send_comm_t **s_comm)
{
int ret = 0;
int comm_id = 0;
fi_addr_t remote_addr;
nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL;
int num_rails = ep->num_rails;
@@ -5917,7 +5945,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
ret_s_comm->remote_comm_id = handle->comm_id;

/* Allocate send communicator ID */
int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
if (OFI_UNLIKELY(comm_id < 0)) {
ret_s_comm->local_comm_id = ~0;
ret = comm_id;
@@ -6112,6 +6140,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep,
nccl_net_ofi_send_comm_t **send_comm)
{
int ret = 0;
nccl_net_ofi_rdma_req_state_t conn_resp_req_state;
nccl_net_ofi_rdma_req_state_t conn_msg_state;
*send_comm = NULL;
nccl_net_ofi_rdma_ep_t *ep =
@@ -6239,7 +6268,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep,
}

nccl_net_ofi_mutex_lock(&s_comm->conn_resp_req->req_lock);
nccl_net_ofi_rdma_req_state_t conn_resp_req_state = s_comm->conn_resp_req->state;
conn_resp_req_state = s_comm->conn_resp_req->state;
nccl_net_ofi_mutex_unlock(&s_comm->conn_resp_req->req_lock);

/* Wait until conn resp message is received */
@@ -6415,19 +6444,19 @@ static int init_rail_ofi_resources(nccl_net_ofi_rdma_device_t *device,
static int release_ep(nccl_net_ofi_ep_t *base_ep)
{
int ret = 0;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_device_t *device = NULL;

/* Validate device */
nccl_net_ofi_rdma_ep_t *ep =
(nccl_net_ofi_rdma_ep_t*)base_ep;
ep = (nccl_net_ofi_rdma_ep_t *)base_ep;
if (OFI_UNLIKELY(ep == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid endpoint provided");
goto exit;
}

/* Validate device */
nccl_net_ofi_rdma_device_t *device =
(nccl_net_ofi_rdma_device_t*)ep->base.device;
device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
@@ -6607,10 +6636,10 @@ static inline int get_ep(nccl_net_ofi_device_t *base_dev, nccl_net_ofi_ep_t **ba
int ret = 0;
long thread_id;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_device_t *device = NULL;

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device =
(nccl_net_ofi_rdma_device_t*)base_dev;
device = (nccl_net_ofi_rdma_device_t *)base_dev;
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
@@ -6931,8 +6960,9 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin,
int dev_id, struct fi_info *info_list,
nccl_ofi_topo_t *topo, size_t rr_threshold)
{
int ret;

int ret = 0;
bool provide_own_mr_key = false;
int length = 0;
nccl_net_ofi_rdma_device_t *device =
(nccl_net_ofi_rdma_device_t *)calloc(1, sizeof(nccl_net_ofi_rdma_device_t));
if (device == NULL) {
@@ -6963,7 +6993,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin,
}

/* Ensure that number of rails are the same across devices */
int length = ofi_info_list_length(info_list);
length = ofi_info_list_length(info_list);
if (topo->max_group_size != length) {
NCCL_OFI_WARN("Wrong number of NICs for device %i. Expected %i but got %i",
dev_id, topo->max_group_size, length);
@@ -7040,7 +7070,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin,
}

/* Initialize mr key pool */
bool provide_own_mr_key = true;
provide_own_mr_key = true;
ret = nccl_ofi_mr_keys_need_own_key(info_list, &provide_own_mr_key);
if (ret != 0) {
NCCL_OFI_WARN("MR key config parsing failed: %s",
@@ -7261,6 +7291,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter,
nccl_net_ofi_rdma_plugin_t *plugin = NULL;
nccl_ofi_topo_t *topo = NULL;
struct fi_info *hints;
uint32_t api_version = 0;

*found_multiple_rails = false;

@@ -7272,7 +7303,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter,
}

get_hints(hints);
uint32_t api_version = nccl_ofi_dmabuf_viable() ? FI_VERSION(1, 20) : FI_VERSION(1, 18);
api_version = nccl_ofi_dmabuf_viable() ? FI_VERSION(1, 20) : FI_VERSION(1, 18);
ret = nccl_ofi_ofiutils_get_providers(provider_filter, api_version, hints,
&provider_list, &num_providers);
if (ret == 0) {
34 changes: 18 additions & 16 deletions src/nccl_ofi_sendrecv.c
Original file line number Diff line number Diff line change
@@ -360,6 +360,8 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size)
{
int ret = 0;
nccl_net_ofi_sendrecv_req_t *req = (nccl_net_ofi_sendrecv_req_t *)base_req;
nccl_net_ofi_sendrecv_device_t *device = NULL;
nccl_net_ofi_sendrecv_ep_t *ep = NULL;

/* Retrieve and validate comm */
nccl_net_ofi_comm_t *base_comm = req->comm;
@@ -370,17 +372,15 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size)
}

/* Retrieve and validate endpoint */
nccl_net_ofi_sendrecv_ep_t *ep =
(nccl_net_ofi_sendrecv_ep_t *)base_comm->ep;
ep = (nccl_net_ofi_sendrecv_ep_t *)base_comm->ep;
if (OFI_UNLIKELY(ep == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid endpoint provided");
goto exit;
}

/* Retrieve and validate device */
nccl_net_ofi_sendrecv_device_t *device =
(nccl_net_ofi_sendrecv_device_t*)ep->base.device;
device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device;
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
@@ -760,6 +760,7 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm,
/* Retrieve and validate endpoint */
nccl_net_ofi_sendrecv_ep_t *ep =
(nccl_net_ofi_sendrecv_ep_t *)base_comm->ep;
nccl_ofi_idpool_t *key_pool = NULL;
if (OFI_UNLIKELY(ep == NULL)) {
NCCL_OFI_WARN("Invalid endpoint provided");
return -EINVAL;
@@ -795,7 +796,7 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm,
}
/* Cache miss */

nccl_ofi_idpool_t *key_pool = &device->key_pool;
key_pool = &device->key_pool;
struct fid_domain *domain;
domain = get_domain_from_endpoint(ep);
ret = reg_mr_base(domain, ep->ofi_ep, key_pool,
@@ -889,23 +890,23 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
int ret = 0;
ssize_t rc = 0;
nccl_net_ofi_sendrecv_req_t *req = NULL;
nccl_net_ofi_sendrecv_ep_t *ep = NULL;
nccl_net_ofi_sendrecv_device_t *device = NULL;
nccl_net_ofi_sendrecv_recv_comm_t *r_comm =
(nccl_net_ofi_sendrecv_recv_comm_t *)recv_comm;
int dev_id = r_comm->base.base.dev_id;
struct fid_mr **mr_handles = (struct fid_mr **)mhandles;

/* Retrieve and validate endpoint */
nccl_net_ofi_sendrecv_ep_t * ep =
(nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep;
ep = (nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep;
if (OFI_UNLIKELY(ep == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid endpoint provided");
goto error;
}

/* Retrieve and validate device */
nccl_net_ofi_sendrecv_device_t *device =
(nccl_net_ofi_sendrecv_device_t*)ep->base.device;
device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device;
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
@@ -1051,6 +1052,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
void *data = NULL;
void *flush_mr_desc = NULL;
int dev_id = recv_comm->base.dev_id;
int flush_n = -1;
struct fid_mr **mr_handles = (struct fid_mr **)mhandles;

if (ofi_nccl_gdr_flush_disable() || support_gdr == GDR_UNSUPPORTED)
@@ -1073,7 +1075,6 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
* Find the non-zero request for which we will issue flush.
* A single operation can flush all request at once.
*/
int flush_n = -1;
for (int recv_n = 0; recv_n < n; recv_n++) {
if (sizes[recv_n] != 0) {
flush_n = recv_n;
@@ -1552,6 +1553,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
fi_addr_t local_ep_addr;
nccl_net_ofi_sendrecv_listen_comm_t *l_comm = NULL;
uint64_t tag;
int dev_id = 0;
int num_addrs;
nccl_net_ofi_sendrecv_ep_t *ep =
(nccl_net_ofi_sendrecv_ep_t *)base_ep;
@@ -1565,7 +1567,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
goto exit;
}

int dev_id = device->base.dev_id;
dev_id = device->base.dev_id;

/* Zero-out the handle */
memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t));
@@ -1666,6 +1668,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
ssize_t rc = 0;
nccl_net_ofi_sendrecv_req_t *req = NULL;
void *desc = NULL;
nccl_net_ofi_sendrecv_device_t *device = NULL;
int dev_id = s_comm->base.base.dev_id;
struct fid_mr *mr_handle = (struct fid_mr *)mhandle;

@@ -1679,8 +1682,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
}

/* Retrieve and validate device */
nccl_net_ofi_sendrecv_device_t *device =
(nccl_net_ofi_sendrecv_device_t*)ep->base.device;
device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device;
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
@@ -2114,6 +2116,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep,
static int release_ep(nccl_net_ofi_ep_t *base_ep)
{
int ret = 0;
nccl_net_ofi_sendrecv_device_t *device = NULL;

/* Validate device */
nccl_net_ofi_sendrecv_ep_t *ep =
@@ -2125,8 +2128,7 @@ static int release_ep(nccl_net_ofi_ep_t *base_ep)
}

/* Validate device */
nccl_net_ofi_sendrecv_device_t *device =
(nccl_net_ofi_sendrecv_device_t*)ep->base.device;
device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device;
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
@@ -2385,6 +2387,7 @@ nccl_net_ofi_sendrecv_device_create(nccl_net_ofi_plugin_t *plugin,
int dev_id, struct fi_info *info)
{
int ret;
bool provide_own_mr_key = true;

nccl_net_ofi_sendrecv_device_t *device =
(nccl_net_ofi_sendrecv_device_t *)calloc(1, sizeof(nccl_net_ofi_sendrecv_device_t));
@@ -2433,7 +2436,6 @@ nccl_net_ofi_sendrecv_device_create(nccl_net_ofi_plugin_t *plugin,
}

/* Indicates if the provider selects MR keys */
bool provide_own_mr_key = true;
ret = nccl_ofi_mr_keys_need_own_key(info, &provide_own_mr_key);
if (ret != 0) {
NCCL_OFI_WARN("MR key config parsing failed: %s",
9 changes: 6 additions & 3 deletions src/platform-aws.c
Original file line number Diff line number Diff line change
@@ -604,6 +604,10 @@ int platform_init(const char **provider_filter)

int platform_config_endpoint(struct fi_info *info, struct fid_ep* endpoint) {
int ret = 0;
#if HAVE_CUDA
const char *optname_name = "none";
int optname = -1;
#endif

if (endpoint == NULL) {
NCCL_OFI_WARN("Unable to configure invalid endpoint");
@@ -645,8 +649,6 @@ int platform_config_endpoint(struct fi_info *info, struct fid_ep* endpoint) {
static bool nccl_proto_configured = false;
static bool need_ordering = false;
static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
int optname = -1;
const char *optname_name = "none";

/* During initialization, try to set
* FI_OPT_EFA_{SENDRECV,WRTIE}_IN_ORDER_ALIGNED_128_BYTES to
@@ -785,6 +787,7 @@ void platform_sort_rails(struct fi_info **info_list, int num_rails)
{
struct fi_info *info_list_in = *info_list;
struct fi_info **sorted_info_array = (struct fi_info **)alloca(num_rails*sizeof(struct fi_info *));
struct fi_info *info_ptr = NULL;

if (num_rails <= 0) {
return;
@@ -824,7 +827,7 @@ void platform_sort_rails(struct fi_info **info_list, int num_rails)

/* Update info_list references to match sorted order */
*info_list = sorted_info_array[0];
struct fi_info *info_ptr = *info_list;
info_ptr = *info_list;
for (int i = 0; i < num_rails; ++i) {
assert(info_ptr);
assert(sorted_info_array[i]);

0 comments on commit 2dae7eb

Please sign in to comment.